From fb5aa296aea6cbcadd04d65592fbd4a7df79849a Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Fri, 22 Nov 2024 17:00:49 +0200 Subject: [PATCH] Add TorchCompileLTXModel --- __init__.py | 1 + nodes/nodes.py | 42 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/__init__.py b/__init__.py index ba6ffd8..64d8cb4 100644 --- a/__init__.py +++ b/__init__.py @@ -158,6 +158,7 @@ NODE_CONFIG = { "TorchCompileVAE": {"class": TorchCompileVAE, "name": "TorchCompileVAE"}, "TorchCompileControlNet": {"class": TorchCompileControlNet, "name": "TorchCompileControlNet"}, "PatchModelPatcherOrder": {"class": PatchModelPatcherOrder, "name": "Patch Model Patcher Order"}, + "TorchCompileLTXModel": {"class": TorchCompileLTXModel, "name": "TorchCompileLTXModel"}, #instance diffusion "CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking}, diff --git a/nodes/nodes.py b/nodes/nodes.py index 08698e7..44dc12e 100644 --- a/nodes/nodes.py +++ b/nodes/nodes.py @@ -2180,6 +2180,7 @@ class CheckpointLoaderKJ: if sage_attention: from sageattention import sageattn + @torch.compiler.disable() def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): if skip_reshape: b, _, _, dim_head = q.shape @@ -2483,6 +2484,47 @@ class TorchCompileControlNet: return (controlnet, ) +class TorchCompileLTXModel: + def __init__(self): + self._compiled = False + + @classmethod + def INPUT_TYPES(s): + return {"required": { + "model": ("MODEL",), + "backend": (["inductor", "cudagraphs"],), + "fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}), + "mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}), + "dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "KJNodes/experimental" + EXPERIMENTAL = True + + def patch(self, model, backend, mode, fullgraph, dynamic): + m = model.clone() + diffusion_model = m.get_model_object("diffusion_model") + + if not self._compiled: + try: + for i, block in enumerate(diffusion_model.transformer_blocks): + #print("Compiling double_block", i) + m.add_object_patch(f"diffusion_model.transformer_blocks.{i}", torch.compile(block, mode=mode, dynamic=dynamic, fullgraph=fullgraph, backend=backend)) + self._compiled = True + compile_settings = { + "backend": backend, + "mode": mode, + "fullgraph": fullgraph, + "dynamic": dynamic, + } + setattr(m.model, "compile_settings", compile_settings) + except: + raise RuntimeError("Failed to compile model") + + return (m, ) + class StyleModelApplyAdvanced: @classmethod def INPUT_TYPES(s):