From 16f60e53e513aec0113bc8dc0e78b035f431d88e Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Wed, 21 May 2025 17:17:41 +0300 Subject: [PATCH] Add TorchCompileModelFluxAdvancedV2 Utilizing new ComfyUI core functions that also support LoRAs without additional patches --- __init__.py | 1 + nodes/model_optimization_nodes.py | 58 +++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+) diff --git a/__init__.py b/__init__.py index 8785bac..60128a8 100644 --- a/__init__.py +++ b/__init__.py @@ -174,6 +174,7 @@ NODE_CONFIG = { "CheckpointLoaderKJ": {"class": CheckpointLoaderKJ, "name": "CheckpointLoaderKJ"}, "DiffusionModelLoaderKJ": {"class": DiffusionModelLoaderKJ, "name": "Diffusion Model Loader KJ"}, "TorchCompileModelFluxAdvanced": {"class": TorchCompileModelFluxAdvanced, "name": "TorchCompileModelFluxAdvanced"}, + "TorchCompileModelFluxAdvancedV2": {"class": TorchCompileModelFluxAdvancedV2, "name": "TorchCompileModelFluxAdvancedV2"}, "TorchCompileModelHyVideo": {"class": TorchCompileModelHyVideo, "name": "TorchCompileModelHyVideo"}, "TorchCompileVAE": {"class": TorchCompileVAE, "name": "TorchCompileVAE"}, "TorchCompileControlNet": {"class": TorchCompileControlNet, "name": "TorchCompileControlNet"}, diff --git a/nodes/model_optimization_nodes.py b/nodes/model_optimization_nodes.py index a49a707..0506335 100644 --- a/nodes/model_optimization_nodes.py +++ b/nodes/model_optimization_nodes.py @@ -22,6 +22,7 @@ class BaseLoaderKJ: original_linear = None cublas_patched = False + @torch.compiler.disable() def _patch_modules(self, patch_cublaslinear, sage_attention): from comfy.ops import disable_weight_init, CastWeightBiasOp, cast_bias_weight @@ -146,8 +147,10 @@ class PathchSageAttentionKJ(BaseLoaderKJ): def patch(self, model, sage_attention): model_clone = model.clone() + @torch.compiler.disable() def patch_attention_enable(model): self._patch_modules(False, sage_attention) + @torch.compiler.disable() def patch_attention_disable(model): self._patch_modules(False, "disabled") @@ -574,6 +577,61 @@ class TorchCompileModelFluxAdvanced: # diffusion_model.txt_in = torch.compile(diffusion_model.txt_in, mode=mode, fullgraph=fullgraph, backend=backend) # diffusion_model.vector_in = torch.compile(diffusion_model.vector_in, mode=mode, fullgraph=fullgraph, backend=backend) +class TorchCompileModelFluxAdvancedV2: + def __init__(self): + self._compiled = False + + @classmethod + def INPUT_TYPES(s): + return {"required": { + "model": ("MODEL",), + "backend": (["inductor", "cudagraphs"],), + "fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}), + "mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}), + "double_blocks": ("BOOLEAN", {"default": True, "tooltip": "Compile double blocks"}), + "single_blocks": ("BOOLEAN", {"default": True, "tooltip": "Compile single blocks"}), + "dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}), + }, + "optional": { + "dynamo_cache_size_limit": ("INT", {"default": 64, "min": 0, "max": 1024, "step": 1, "tooltip": "torch._dynamo.config.cache_size_limit"}), + } + } + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "KJNodes/torchcompile" + EXPERIMENTAL = True + + def patch(self, model, backend, mode, fullgraph, single_blocks, double_blocks, dynamic, dynamo_cache_size_limit): + from comfy_api.torch_helpers import set_torch_compile_wrapper + m = model.clone() + diffusion_model = m.get_model_object("diffusion_model") + torch._dynamo.config.cache_size_limit = dynamo_cache_size_limit + + compile_key_list = [] + + try: + if double_blocks: + for block in diffusion_model.double_blocks: + compile_key_list.append(block) + if single_blocks: + for block in diffusion_model.single_blocks: + compile_key_list.append(block) + + set_torch_compile_wrapper(model=m, keys=compile_key_list, backend=backend, mode=mode, dynamic=dynamic, fullgraph=fullgraph) + except: + raise RuntimeError("Failed to compile model") + + return (m, ) + # rest of the layers that are not patched + # diffusion_model.final_layer = torch.compile(diffusion_model.final_layer, mode=mode, fullgraph=fullgraph, backend=backend) + # diffusion_model.guidance_in = torch.compile(diffusion_model.guidance_in, mode=mode, fullgraph=fullgraph, backend=backend) + # diffusion_model.img_in = torch.compile(diffusion_model.img_in, mode=mode, fullgraph=fullgraph, backend=backend) + # diffusion_model.time_in = torch.compile(diffusion_model.time_in, mode=mode, fullgraph=fullgraph, backend=backend) + # diffusion_model.txt_in = torch.compile(diffusion_model.txt_in, mode=mode, fullgraph=fullgraph, backend=backend) + # diffusion_model.vector_in = torch.compile(diffusion_model.vector_in, mode=mode, fullgraph=fullgraph, backend=backend) + + class TorchCompileModelHyVideo: def __init__(self): self._compiled = False