From 504044f1810e1d1415c0924fffbe0ba7ce43032e Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Wed, 6 Aug 2025 00:09:15 +0300 Subject: [PATCH] Add TorchCompileModelQwenImage Mostly to limit compile to transformer blocks only for less recompiles --- __init__.py | 1 + nodes/model_optimization_nodes.py | 43 ++++++++++++++++++++++++++++--- 2 files changed, 41 insertions(+), 3 deletions(-) diff --git a/__init__.py b/__init__.py index 3a950fc..590ed7c 100644 --- a/__init__.py +++ b/__init__.py @@ -183,6 +183,7 @@ NODE_CONFIG = { "PatchModelPatcherOrder": {"class": PatchModelPatcherOrder, "name": "Patch Model Patcher Order"}, "TorchCompileLTXModel": {"class": TorchCompileLTXModel, "name": "TorchCompileLTXModel"}, "TorchCompileCosmosModel": {"class": TorchCompileCosmosModel, "name": "TorchCompileCosmosModel"}, + "TorchCompileModelQwenImage": {"class": TorchCompileModelQwenImage, "name": "TorchCompileModelQwenImage"}, "TorchCompileModelWanVideo": {"class": TorchCompileModelWanVideo, "name": "TorchCompileModelWanVideo"}, "TorchCompileModelWanVideoV2": {"class": TorchCompileModelWanVideoV2, "name": "TorchCompileModelWanVideoV2"}, "PathchSageAttentionKJ": {"class": PathchSageAttentionKJ, "name": "Patch Sage Attention KJ"}, diff --git a/nodes/model_optimization_nodes.py b/nodes/model_optimization_nodes.py index 1f27778..596f04e 100644 --- a/nodes/model_optimization_nodes.py +++ b/nodes/model_optimization_nodes.py @@ -811,9 +811,6 @@ class TorchCompileModelWanVideo: return (m, ) class TorchCompileModelWanVideoV2: - def __init__(self): - self._compiled = False - @classmethod def INPUT_TYPES(s): return { @@ -851,6 +848,46 @@ class TorchCompileModelWanVideoV2: raise RuntimeError("Failed to compile model") return (m, ) + +class TorchCompileModelQwenImage: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL",), + "backend": (["inductor","cudagraphs"], {"default": "inductor"}), + "fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}), + "mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}), + "dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}), + "compile_transformer_blocks_only": ("BOOLEAN", {"default": True, "tooltip": "Compile only transformer blocks, faster compile and less error prone"}), + "dynamo_cache_size_limit": ("INT", {"default": 64, "min": 0, "max": 1024, "step": 1, "tooltip": "torch._dynamo.config.cache_size_limit"}), + }, + } + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "KJNodes/torchcompile" + EXPERIMENTAL = True + + def patch(self, model, backend, fullgraph, mode, dynamic, dynamo_cache_size_limit, compile_transformer_blocks_only): + from comfy_api.torch_helpers import set_torch_compile_wrapper + m = model.clone() + diffusion_model = m.get_model_object("diffusion_model") + print(diffusion_model) + torch._dynamo.config.cache_size_limit = dynamo_cache_size_limit + try: + if compile_transformer_blocks_only: + compile_key_list = [] + for i, block in enumerate(diffusion_model.transformer_blocks): + compile_key_list.append(f"diffusion_model.transformer_blocks.{i}") + else: + compile_key_list =["diffusion_model"] + + set_torch_compile_wrapper(model=m, keys=compile_key_list, backend=backend, mode=mode, dynamic=dynamic, fullgraph=fullgraph) + except: + raise RuntimeError("Failed to compile model") + + return (m, ) class TorchCompileVAE: def __init__(self):