From 1a4259f05206d7360be7a90145b5839d5b64d893 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Wed, 26 Feb 2025 11:10:17 +0200 Subject: [PATCH] TorchCompileModelWanVideo --- __init__.py | 1 + nodes/model_optimization_nodes.py | 57 +++++++++++++++++++++++++++---- 2 files changed, 52 insertions(+), 6 deletions(-) diff --git a/__init__.py b/__init__.py index a9188a5..f78343a 100644 --- a/__init__.py +++ b/__init__.py @@ -176,6 +176,7 @@ NODE_CONFIG = { "PatchModelPatcherOrder": {"class": PatchModelPatcherOrder, "name": "Patch Model Patcher Order"}, "TorchCompileLTXModel": {"class": TorchCompileLTXModel, "name": "TorchCompileLTXModel"}, "TorchCompileCosmosModel": {"class": TorchCompileCosmosModel, "name": "TorchCompileCosmosModel"}, + "TorchCompileModelWanVideo": {"class": TorchCompileModelWanVideo, "name": "TorchCompileModelWanVideo"}, "PathchSageAttentionKJ": {"class": PathchSageAttentionKJ, "name": "Patch Sage Attention KJ"}, "LeapfusionHunyuanI2VPatcher": {"class": LeapfusionHunyuanI2V, "name": "Leapfusion Hunyuan I2V Patcher"}, "VAELoaderKJ": {"class": VAELoaderKJ, "name": "VAELoader KJ"}, diff --git a/nodes/model_optimization_nodes.py b/nodes/model_optimization_nodes.py index fabb9a3..3065ea4 100644 --- a/nodes/model_optimization_nodes.py +++ b/nodes/model_optimization_nodes.py @@ -303,7 +303,7 @@ class TorchCompileModelFluxAdvanced: RETURN_TYPES = ("MODEL",) FUNCTION = "patch" - CATEGORY = "KJNodes/experimental" + CATEGORY = "KJNodes/torchcompile" EXPERIMENTAL = True def parse_blocks(self, blocks_str): @@ -378,7 +378,7 @@ class TorchCompileModelHyVideo: RETURN_TYPES = ("MODEL",) FUNCTION = "patch" - CATEGORY = "KJNodes/experimental" + CATEGORY = "KJNodes/torchcompile" EXPERIMENTAL = True def patch(self, model, backend, fullgraph, mode, dynamic, dynamo_cache_size_limit, compile_single_blocks, compile_double_blocks, compile_txt_in, compile_vector_in, compile_final_layer): @@ -415,6 +415,51 @@ class TorchCompileModelHyVideo: except: raise RuntimeError("Failed to compile model") return (m, ) + +class TorchCompileModelWanVideo: + def __init__(self): + self._compiled = False + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL",), + "backend": (["inductor","cudagraphs"], {"default": "inductor"}), + "fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}), + "mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}), + "dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}), + "dynamo_cache_size_limit": ("INT", {"default": 64, "min": 0, "max": 1024, "step": 1, "tooltip": "torch._dynamo.config.cache_size_limit"}), + "compile_transformer_blocks": ("BOOLEAN", {"default": True, "tooltip": "Compile all transformer blocks"}), + }, + } + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "KJNodes/torchcompile" + EXPERIMENTAL = True + + def patch(self, model, backend, fullgraph, mode, dynamic, dynamo_cache_size_limit, compile_transformer_blocks): + m = model.clone() + diffusion_model = m.get_model_object("diffusion_model") + torch._dynamo.config.cache_size_limit = dynamo_cache_size_limit + if not self._compiled: + try: + if compile_transformer_blocks: + for i, block in enumerate(diffusion_model.blocks): + compiled_block = torch.compile(block, fullgraph=fullgraph, dynamic=dynamic, backend=backend, mode=mode) + m.add_object_patch(f"diffusion_model.blocks.{i}", compiled_block) + self._compiled = True + compile_settings = { + "backend": backend, + "mode": mode, + "fullgraph": fullgraph, + "dynamic": dynamic, + } + setattr(m.model, "compile_settings", compile_settings) + except: + raise RuntimeError("Failed to compile model") + return (m, ) class TorchCompileVAE: def __init__(self): @@ -434,7 +479,7 @@ class TorchCompileVAE: RETURN_TYPES = ("VAE",) FUNCTION = "compile" - CATEGORY = "KJNodes/experimental" + CATEGORY = "KJNodes/torchcompile" EXPERIMENTAL = True def compile(self, vae, backend, mode, fullgraph, compile_encoder, compile_decoder): @@ -495,7 +540,7 @@ class TorchCompileControlNet: RETURN_TYPES = ("CONTROL_NET",) FUNCTION = "compile" - CATEGORY = "KJNodes/experimental" + CATEGORY = "KJNodes/torchcompile" EXPERIMENTAL = True def compile(self, controlnet, backend, mode, fullgraph): @@ -528,7 +573,7 @@ class TorchCompileLTXModel: RETURN_TYPES = ("MODEL",) FUNCTION = "patch" - CATEGORY = "KJNodes/experimental" + CATEGORY = "KJNodes/torchcompile" EXPERIMENTAL = True def patch(self, model, backend, mode, fullgraph, dynamic): @@ -571,7 +616,7 @@ class TorchCompileCosmosModel: RETURN_TYPES = ("MODEL",) FUNCTION = "patch" - CATEGORY = "KJNodes/experimental" + CATEGORY = "KJNodes/torchcompile" EXPERIMENTAL = True def patch(self, model, backend, mode, fullgraph, dynamic, dynamo_cache_size_limit):