mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2026-04-03 05:26:59 +08:00
TorchCompileModelFluxAdvanced
This commit is contained in:
parent
1ec5810868
commit
19ec49ae2b
@ -152,6 +152,7 @@ NODE_CONFIG = {
|
||||
"FluxBlockLoraSelect": {"class": FluxBlockLoraSelect, "name": "Flux Block Lora Select"},
|
||||
"CustomControlNetWeightsFluxFromList": {"class": CustomControlNetWeightsFluxFromList, "name": "Custom ControlNet Weights Flux From List"},
|
||||
"CheckpointLoaderKJ": {"class": CheckpointLoaderKJ, "name": "CheckpointLoaderKJ"},
|
||||
"TorchCompileModelFluxAdvanced": {"class": TorchCompileModelFluxAdvanced, "name": "TorchCompileModelFluxAdvanced"},
|
||||
|
||||
#instance diffusion
|
||||
"CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking},
|
||||
|
||||
@ -2238,4 +2238,62 @@ class CheckpointLoaderKJ:
|
||||
|
||||
|
||||
return model, clip, vae
|
||||
|
||||
import re
|
||||
class TorchCompileModelFluxAdvanced:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"model": ("MODEL",),
|
||||
"backend": (["inductor", "cudagraphs"],),
|
||||
"fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}),
|
||||
"mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}),
|
||||
"double_blocks": ("STRING", {"default": "0-18", "multiline": True}),
|
||||
"single_blocks": ("STRING", {"default": "0-37", "multiline": True}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def parse_blocks(self, blocks_str):
|
||||
blocks = []
|
||||
for part in blocks_str.split(','):
|
||||
part = part.strip()
|
||||
if '-' in part:
|
||||
start, end = map(int, part.split('-'))
|
||||
blocks.extend(range(start, end + 1))
|
||||
else:
|
||||
blocks.append(int(part))
|
||||
return blocks
|
||||
|
||||
def compile_diffusion_model(self, diffusion_model, backend, mode, fullgraph, single_block_list, double_block_list):
|
||||
#print("Diffusion model object before compilation:", diffusion_model)
|
||||
for i, block in enumerate(diffusion_model.double_blocks):
|
||||
if i in double_block_list:
|
||||
print("Compiling double block", i)
|
||||
diffusion_model.double_blocks[i] = torch.compile(block, mode=mode, fullgraph=fullgraph, backend=backend)
|
||||
for i, block in enumerate(diffusion_model.single_blocks):
|
||||
if i in single_block_list:
|
||||
print("Compiling single block", i)
|
||||
diffusion_model.single_blocks[i] = torch.compile(block, mode=mode, fullgraph=fullgraph, backend=backend)
|
||||
|
||||
# diffusion_model.final_layer = torch.compile(diffusion_model.final_layer, mode=mode, fullgraph=fullgraph, backend=backend)
|
||||
# diffusion_model.guidance_in = torch.compile(diffusion_model.guidance_in, mode=mode, fullgraph=fullgraph, backend=backend)
|
||||
# diffusion_model.img_in = torch.compile(diffusion_model.img_in, mode=mode, fullgraph=fullgraph, backend=backend)
|
||||
# diffusion_model.time_in = torch.compile(diffusion_model.time_in, mode=mode, fullgraph=fullgraph, backend=backend)
|
||||
# diffusion_model.txt_in = torch.compile(diffusion_model.txt_in, mode=mode, fullgraph=fullgraph, backend=backend)
|
||||
# diffusion_model.vector_in = torch.compile(diffusion_model.vector_in, mode=mode, fullgraph=fullgraph, backend=backend)
|
||||
|
||||
#compiled_model = torch.compile(model=diffusion_model, backend=backend)
|
||||
#print("Compiled diffusion model object:", compiled_model)
|
||||
return diffusion_model
|
||||
|
||||
def patch(self, model, backend, mode, fullgraph, single_blocks, double_blocks):
|
||||
single_block_list = self.parse_blocks(single_blocks)
|
||||
double_block_list = self.parse_blocks(double_blocks)
|
||||
m = model.clone()
|
||||
diffusion_model = m.get_model_object("diffusion_model")
|
||||
#self.compile_diffusion_model(diffusion_model, backend, mode, fullgraph, single_block_list, double_block_list)
|
||||
m.add_object_patch("diffusion_model", self.compile_diffusion_model(diffusion_model, backend, mode, fullgraph, single_block_list, double_block_list))
|
||||
return (m, )
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user