diff --git a/__init__.py b/__init__.py index 2974106..80fb6c7 100644 --- a/__init__.py +++ b/__init__.py @@ -211,6 +211,7 @@ NODE_CONFIG = { "GGUFLoaderKJ": {"class": GGUFLoaderKJ, "name": "GGUF Loader KJ"}, "LatentInpaintTTM": {"class": LatentInpaintTTM, "name": "Latent Inpaint TTM"}, "NABLA_AttentionKJ": {"class": NABLA_AttentionKJ, "name": "NABLA Attention KJ"}, + "TorchCompileModelAdvanced": {"class": TorchCompileModelAdvanced, "name": "TorchCompileModelAdvanced"}, #instance diffusion "CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking}, diff --git a/nodes/model_optimization_nodes.py b/nodes/model_optimization_nodes.py index 7a0c905..2141e74 100644 --- a/nodes/model_optimization_nodes.py +++ b/nodes/model_optimization_nodes.py @@ -856,7 +856,60 @@ class TorchCompileModelWanVideoV2: raise RuntimeError("Failed to compile model") return (m, ) - + + +class TorchCompileModelAdvanced: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL",), + "backend": (["inductor","cudagraphs"], {"default": "inductor"}), + "fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}), + "mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}), + "dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}), + "compile_transformer_blocks_only": ("BOOLEAN", {"default": True, "tooltip": "Compile only transformer blocks, faster compile and less error prone"}), + "dynamo_cache_size_limit": ("INT", {"default": 64, "min": 0, "max": 1024, "step": 1, "tooltip": "torch._dynamo.config.cache_size_limit"}), + "debug_compile_keys": ("BOOLEAN", {"default": False, "tooltip": "Print the compile keys used for torch.compile"}), + }, + } + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + CATEGORY = "KJNodes/torchcompile" + DESCRIPTION = "Advanced torch.compile patching for diffusion models." + EXPERIMENTAL = True + + def patch(self, model, backend, fullgraph, mode, dynamic, dynamo_cache_size_limit, compile_transformer_blocks_only, debug_compile_keys): + from comfy_api.torch_helpers import set_torch_compile_wrapper + m = model.clone() + diffusion_model = m.get_model_object("diffusion_model") + torch._dynamo.config.cache_size_limit = dynamo_cache_size_limit + + try: + if compile_transformer_blocks_only: + layer_types = ["double_blocks", "single_blocks", "layers", "transformer_blocks", "blocks"] + compile_key_list = [] + for layer_name in layer_types: + if hasattr(diffusion_model, layer_name): + blocks = getattr(diffusion_model, layer_name) + for i in range(len(blocks)): + compile_key_list.append(f"diffusion_model.{layer_name}.{i}") + if not compile_key_list: + logging.warning("No known transformer blocks found to compile, compiling entire diffusion model instead") + elif debug_compile_keys: + logging.info("TorchCompileModelAdvanced: Compile key list:") + for key in compile_key_list: + logging.info(f" - {key}") + if not compile_key_list: + compile_key_list =["diffusion_model"] + + set_torch_compile_wrapper(model=m, keys=compile_key_list, backend=backend, mode=mode, dynamic=dynamic, fullgraph=fullgraph) + except: + raise RuntimeError("Failed to compile model") + + return (m, ) + + class TorchCompileModelQwenImage: @classmethod def INPUT_TYPES(s):