From 4d0e5cf2400595ad3f2d93c26126b8164fce1f39 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sun, 12 Jan 2025 18:46:35 +0200 Subject: [PATCH] torch.compile for cosmos --- __init__.py | 1 + nodes/model_optimization_nodes.py | 47 +++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+) diff --git a/__init__.py b/__init__.py index f024196..b451ec2 100644 --- a/__init__.py +++ b/__init__.py @@ -162,6 +162,7 @@ NODE_CONFIG = { "TorchCompileControlNet": {"class": TorchCompileControlNet, "name": "TorchCompileControlNet"}, "PatchModelPatcherOrder": {"class": PatchModelPatcherOrder, "name": "Patch Model Patcher Order"}, "TorchCompileLTXModel": {"class": TorchCompileLTXModel, "name": "TorchCompileLTXModel"}, + "TorchCompileCosmosModel": {"class": TorchCompileCosmosModel, "name": "TorchCompileCosmosModel"}, "PathchSageAttentionKJ": {"class": PathchSageAttentionKJ, "name": "Pathch Sage Attention KJ"}, #instance diffusion diff --git a/nodes/model_optimization_nodes.py b/nodes/model_optimization_nodes.py index 0f39db1..5213c27 100644 --- a/nodes/model_optimization_nodes.py +++ b/nodes/model_optimization_nodes.py @@ -461,4 +461,51 @@ class TorchCompileLTXModel: except: raise RuntimeError("Failed to compile model") + return (m, ) + +class TorchCompileCosmosModel: + def __init__(self): + self._compiled = False + + @classmethod + def INPUT_TYPES(s): + return {"required": { + "model": ("MODEL",), + "backend": (["inductor", "cudagraphs"],), + "fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}), + "mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}), + "dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "KJNodes/experimental" + EXPERIMENTAL = True + + def patch(self, model, backend, mode, fullgraph, dynamic): + + m = model.clone() + diffusion_model = m.get_model_object("diffusion_model") + + if not self._compiled: + try: + for name, block in diffusion_model.blocks.items(): + print(f"Compiling block {name}") + compiled_block = torch.compile(block, mode=mode, dynamic=dynamic, fullgraph=fullgraph, backend=backend) + m.add_object_patch(f"diffusion_model.blocks.{name}", compiled_block) + #diffusion_model.blocks[name] = compiled_block + + self._compiled = True + compile_settings = { + "backend": backend, + "mode": mode, + "fullgraph": fullgraph, + "dynamic": dynamic, + } + setattr(m.model, "compile_settings", compile_settings) + print(model.model.diffusion_model.blocks) + + except: + raise RuntimeError("Failed to compile model") + return (m, ) \ No newline at end of file