From d3287d61b7d94603481ce8109794c57786b3e7ca Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Wed, 6 Nov 2024 23:38:07 +0200 Subject: [PATCH] Allow settin dynamo_cache_size_limit --- mochi_preview/dit/joint_model/asymm_models_joint.py | 2 +- mochi_preview/dit/joint_model/layers.py | 2 +- mochi_preview/t2v_synth_mochi.py | 3 +++ nodes.py | 4 +++- 4 files changed, 8 insertions(+), 3 deletions(-) diff --git a/mochi_preview/dit/joint_model/asymm_models_joint.py b/mochi_preview/dit/joint_model/asymm_models_joint.py index a05be43..b7259b0 100644 --- a/mochi_preview/dit/joint_model/asymm_models_joint.py +++ b/mochi_preview/dit/joint_model/asymm_models_joint.py @@ -165,6 +165,7 @@ class AsymmetricAttention(nn.Module): raise ImportError("Flash RMSNorm not available.") elif rms_norm_func == "apex": from apex.normalization import FusedRMSNorm as ApexRMSNorm + @torch.compiler.disable() class RMSNorm(ApexRMSNorm): pass else: @@ -237,7 +238,6 @@ class AsymmetricAttention(nn.Module): skip_reshape=True ) return out - def run_attention( self, q, diff --git a/mochi_preview/dit/joint_model/layers.py b/mochi_preview/dit/joint_model/layers.py index 9d66921..f3c26a1 100644 --- a/mochi_preview/dit/joint_model/layers.py +++ b/mochi_preview/dit/joint_model/layers.py @@ -140,7 +140,7 @@ class PatchEmbed(nn.Module): x = self.norm(x) return x - +@torch.compiler.disable() class RMSNorm(torch.nn.Module): def __init__(self, hidden_size, eps=1e-5, device=None): super().__init__() diff --git a/mochi_preview/t2v_synth_mochi.py b/mochi_preview/t2v_synth_mochi.py index 69ae024..c6ab363 100644 --- a/mochi_preview/t2v_synth_mochi.py +++ b/mochi_preview/t2v_synth_mochi.py @@ -34,6 +34,7 @@ except: import torch import torch.utils.data +import torch._dynamo from tqdm import tqdm from comfy.utils import ProgressBar, load_torch_file @@ -161,6 +162,8 @@ class T2VSynthMochiModel: #torch.compile if compile_args is not None: + torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"] + log.info(f"Set dynamo cache size limit to {torch._dynamo.config.cache_size_limit}") if compile_args["compile_dit"]: for i, block in enumerate(model.blocks): model.blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"]) diff --git a/nodes.py b/nodes.py index 50ac380..89c72c1 100644 --- a/nodes.py +++ b/nodes.py @@ -247,6 +247,7 @@ class MochiTorchCompileSettings: "compile_dit": ("BOOLEAN", {"default": True, "tooltip": "Compiles all transformer blocks"}), "compile_final_layer": ("BOOLEAN", {"default": True, "tooltip": "Enable compiling final layer."}), "dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}), + "dynamo_cache_size_limit": ("INT", {"default": 64, "min": 0, "max": 1024, "step": 1, "tooltip": "torch._dynamo.config.cache_size_limit"}), }, } RETURN_TYPES = ("MOCHICOMPILEARGS",) @@ -255,7 +256,7 @@ class MochiTorchCompileSettings: CATEGORY = "MochiWrapper" DESCRIPTION = "torch.compile settings, when connected to the model loader, torch.compile of the selected layers is attempted. Requires Triton and torch 2.5.0 is recommended" - def loadmodel(self, backend, fullgraph, mode, compile_dit, compile_final_layer, dynamic): + def loadmodel(self, backend, fullgraph, mode, compile_dit, compile_final_layer, dynamic, dynamo_cache_size_limit): compile_args = { "backend": backend, @@ -264,6 +265,7 @@ class MochiTorchCompileSettings: "compile_dit": compile_dit, "compile_final_layer": compile_final_layer, "dynamic": dynamic, + "dynamo_cache_size_limit": dynamo_cache_size_limit, } return (compile_args, )