From 19ec49ae2b7963a6706b938be40f9f29b2c89712 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Fri, 25 Oct 2024 20:52:16 +0300 Subject: [PATCH] TorchCompileModelFluxAdvanced --- __init__.py | 1 + nodes/nodes.py | 60 +++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 60 insertions(+), 1 deletion(-) diff --git a/__init__.py b/__init__.py index 9b8606b..aecf764 100644 --- a/__init__.py +++ b/__init__.py @@ -152,6 +152,7 @@ NODE_CONFIG = { "FluxBlockLoraSelect": {"class": FluxBlockLoraSelect, "name": "Flux Block Lora Select"}, "CustomControlNetWeightsFluxFromList": {"class": CustomControlNetWeightsFluxFromList, "name": "Custom ControlNet Weights Flux From List"}, "CheckpointLoaderKJ": {"class": CheckpointLoaderKJ, "name": "CheckpointLoaderKJ"}, + "TorchCompileModelFluxAdvanced": {"class": TorchCompileModelFluxAdvanced, "name": "TorchCompileModelFluxAdvanced"}, #instance diffusion "CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking}, diff --git a/nodes/nodes.py b/nodes/nodes.py index 65d449e..55529d5 100644 --- a/nodes/nodes.py +++ b/nodes/nodes.py @@ -2238,4 +2238,62 @@ class CheckpointLoaderKJ: return model, clip, vae - \ No newline at end of file +import re +class TorchCompileModelFluxAdvanced: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "model": ("MODEL",), + "backend": (["inductor", "cudagraphs"],), + "fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}), + "mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}), + "double_blocks": ("STRING", {"default": "0-18", "multiline": True}), + "single_blocks": ("STRING", {"default": "0-37", "multiline": True}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "_for_testing" + EXPERIMENTAL = True + + def parse_blocks(self, blocks_str): + blocks = [] + for part in blocks_str.split(','): + part = part.strip() + if '-' in part: + start, end = map(int, part.split('-')) + blocks.extend(range(start, end + 1)) + else: + blocks.append(int(part)) + return blocks + + def compile_diffusion_model(self, diffusion_model, backend, mode, fullgraph, single_block_list, double_block_list): + #print("Diffusion model object before compilation:", diffusion_model) + for i, block in enumerate(diffusion_model.double_blocks): + if i in double_block_list: + print("Compiling double block", i) + diffusion_model.double_blocks[i] = torch.compile(block, mode=mode, fullgraph=fullgraph, backend=backend) + for i, block in enumerate(diffusion_model.single_blocks): + if i in single_block_list: + print("Compiling single block", i) + diffusion_model.single_blocks[i] = torch.compile(block, mode=mode, fullgraph=fullgraph, backend=backend) + + # diffusion_model.final_layer = torch.compile(diffusion_model.final_layer, mode=mode, fullgraph=fullgraph, backend=backend) + # diffusion_model.guidance_in = torch.compile(diffusion_model.guidance_in, mode=mode, fullgraph=fullgraph, backend=backend) + # diffusion_model.img_in = torch.compile(diffusion_model.img_in, mode=mode, fullgraph=fullgraph, backend=backend) + # diffusion_model.time_in = torch.compile(diffusion_model.time_in, mode=mode, fullgraph=fullgraph, backend=backend) + # diffusion_model.txt_in = torch.compile(diffusion_model.txt_in, mode=mode, fullgraph=fullgraph, backend=backend) + # diffusion_model.vector_in = torch.compile(diffusion_model.vector_in, mode=mode, fullgraph=fullgraph, backend=backend) + + #compiled_model = torch.compile(model=diffusion_model, backend=backend) + #print("Compiled diffusion model object:", compiled_model) + return diffusion_model + + def patch(self, model, backend, mode, fullgraph, single_blocks, double_blocks): + single_block_list = self.parse_blocks(single_blocks) + double_block_list = self.parse_blocks(double_blocks) + m = model.clone() + diffusion_model = m.get_model_object("diffusion_model") + #self.compile_diffusion_model(diffusion_model, backend, mode, fullgraph, single_block_list, double_block_list) + m.add_object_patch("diffusion_model", self.compile_diffusion_model(diffusion_model, backend, mode, fullgraph, single_block_list, double_block_list)) + return (m, )