diff --git a/__init__.py b/__init__.py index 60128a8..51593c4 100644 --- a/__init__.py +++ b/__init__.py @@ -182,6 +182,7 @@ NODE_CONFIG = { "TorchCompileLTXModel": {"class": TorchCompileLTXModel, "name": "TorchCompileLTXModel"}, "TorchCompileCosmosModel": {"class": TorchCompileCosmosModel, "name": "TorchCompileCosmosModel"}, "TorchCompileModelWanVideo": {"class": TorchCompileModelWanVideo, "name": "TorchCompileModelWanVideo"}, + "TorchCompileModelWanVideoV2": {"class": TorchCompileModelWanVideoV2, "name": "TorchCompileModelWanVideoV2"}, "PathchSageAttentionKJ": {"class": PathchSageAttentionKJ, "name": "Patch Sage Attention KJ"}, "LeapfusionHunyuanI2VPatcher": {"class": LeapfusionHunyuanI2V, "name": "Leapfusion Hunyuan I2V Patcher"}, "VAELoaderKJ": {"class": VAELoaderKJ, "name": "VAELoader KJ"}, diff --git a/nodes/model_optimization_nodes.py b/nodes/model_optimization_nodes.py index 0506335..c2e9a9a 100644 --- a/nodes/model_optimization_nodes.py +++ b/nodes/model_optimization_nodes.py @@ -612,11 +612,11 @@ class TorchCompileModelFluxAdvancedV2: try: if double_blocks: - for block in diffusion_model.double_blocks: - compile_key_list.append(block) + for i, block in enumerate(diffusion_model.double_blocks): + compile_key_list.append(f"diffusion_model.double_blocks.{i}") if single_blocks: - for block in diffusion_model.single_blocks: - compile_key_list.append(block) + for i, block in enumerate(diffusion_model.single_blocks): + compile_key_list.append(f"diffusion_model.single_blocks.{i}") set_torch_compile_wrapper(model=m, keys=compile_key_list, backend=backend, mode=mode, dynamic=dynamic, fullgraph=fullgraph) except: @@ -743,6 +743,48 @@ class TorchCompileModelWanVideo: except: raise RuntimeError("Failed to compile model") return (m, ) + +class TorchCompileModelWanVideoV2: + def __init__(self): + self._compiled = False + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL",), + "backend": (["inductor","cudagraphs"], {"default": "inductor"}), + "fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}), + "mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}), + "dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}), + "compile_transformer_blocks_only": ("BOOLEAN", {"default": True, "tooltip": "Compile only transformer blocks, faster compile and less error prone"}), + "dynamo_cache_size_limit": ("INT", {"default": 64, "min": 0, "max": 1024, "step": 1, "tooltip": "torch._dynamo.config.cache_size_limit"}), + }, + } + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "KJNodes/torchcompile" + EXPERIMENTAL = True + + def patch(self, model, backend, fullgraph, mode, dynamic, dynamo_cache_size_limit, compile_transformer_blocks_only): + from comfy_api.torch_helpers import set_torch_compile_wrapper + m = model.clone() + diffusion_model = m.get_model_object("diffusion_model") + torch._dynamo.config.cache_size_limit = dynamo_cache_size_limit + try: + if compile_transformer_blocks_only: + compile_key_list = [] + for i, block in enumerate(diffusion_model.blocks): + compile_key_list.append(f"diffusion_model.blocks.{i}") + else: + compile_key_list =["diffusion_model"] + + set_torch_compile_wrapper(model=m, keys=compile_key_list, backend=backend, mode=mode, dynamic=dynamic, fullgraph=fullgraph) + except: + raise RuntimeError("Failed to compile model") + + return (m, ) class TorchCompileVAE: def __init__(self):