TorchCompileModelFluxAdvanced

This commit is contained in:
kijai 2024-10-25 20:52:16 +03:00
parent 1ec5810868
commit 19ec49ae2b
2 changed files with 60 additions and 1 deletions

View File

@ -152,6 +152,7 @@ NODE_CONFIG = {
"FluxBlockLoraSelect": {"class": FluxBlockLoraSelect, "name": "Flux Block Lora Select"},
"CustomControlNetWeightsFluxFromList": {"class": CustomControlNetWeightsFluxFromList, "name": "Custom ControlNet Weights Flux From List"},
"CheckpointLoaderKJ": {"class": CheckpointLoaderKJ, "name": "CheckpointLoaderKJ"},
"TorchCompileModelFluxAdvanced": {"class": TorchCompileModelFluxAdvanced, "name": "TorchCompileModelFluxAdvanced"},
#instance diffusion
"CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking},

View File

@ -2238,4 +2238,62 @@ class CheckpointLoaderKJ:
return model, clip, vae
import re
class TorchCompileModelFluxAdvanced:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"model": ("MODEL",),
"backend": (["inductor", "cudagraphs"],),
"fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}),
"mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}),
"double_blocks": ("STRING", {"default": "0-18", "multiline": True}),
"single_blocks": ("STRING", {"default": "0-37", "multiline": True}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
EXPERIMENTAL = True
def parse_blocks(self, blocks_str):
blocks = []
for part in blocks_str.split(','):
part = part.strip()
if '-' in part:
start, end = map(int, part.split('-'))
blocks.extend(range(start, end + 1))
else:
blocks.append(int(part))
return blocks
def compile_diffusion_model(self, diffusion_model, backend, mode, fullgraph, single_block_list, double_block_list):
#print("Diffusion model object before compilation:", diffusion_model)
for i, block in enumerate(diffusion_model.double_blocks):
if i in double_block_list:
print("Compiling double block", i)
diffusion_model.double_blocks[i] = torch.compile(block, mode=mode, fullgraph=fullgraph, backend=backend)
for i, block in enumerate(diffusion_model.single_blocks):
if i in single_block_list:
print("Compiling single block", i)
diffusion_model.single_blocks[i] = torch.compile(block, mode=mode, fullgraph=fullgraph, backend=backend)
# diffusion_model.final_layer = torch.compile(diffusion_model.final_layer, mode=mode, fullgraph=fullgraph, backend=backend)
# diffusion_model.guidance_in = torch.compile(diffusion_model.guidance_in, mode=mode, fullgraph=fullgraph, backend=backend)
# diffusion_model.img_in = torch.compile(diffusion_model.img_in, mode=mode, fullgraph=fullgraph, backend=backend)
# diffusion_model.time_in = torch.compile(diffusion_model.time_in, mode=mode, fullgraph=fullgraph, backend=backend)
# diffusion_model.txt_in = torch.compile(diffusion_model.txt_in, mode=mode, fullgraph=fullgraph, backend=backend)
# diffusion_model.vector_in = torch.compile(diffusion_model.vector_in, mode=mode, fullgraph=fullgraph, backend=backend)
#compiled_model = torch.compile(model=diffusion_model, backend=backend)
#print("Compiled diffusion model object:", compiled_model)
return diffusion_model
def patch(self, model, backend, mode, fullgraph, single_blocks, double_blocks):
single_block_list = self.parse_blocks(single_blocks)
double_block_list = self.parse_blocks(double_blocks)
m = model.clone()
diffusion_model = m.get_model_object("diffusion_model")
#self.compile_diffusion_model(diffusion_model, backend, mode, fullgraph, single_block_list, double_block_list)
m.add_object_patch("diffusion_model", self.compile_diffusion_model(diffusion_model, backend, mode, fullgraph, single_block_list, double_block_list))
return (m, )