diff --git a/mochi_preview/t2v_synth_mochi.py b/mochi_preview/t2v_synth_mochi.py index ecc152b..9b0d405 100644 --- a/mochi_preview/t2v_synth_mochi.py +++ b/mochi_preview/t2v_synth_mochi.py @@ -182,9 +182,9 @@ class T2VSynthMochiModel: if compile_args is not None: if compile_args["compile_dit"]: for i, block in enumerate(model.blocks): - model.blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=False, backend=compile_args["backend"]) + model.blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"]) if compile_args["compile_final_layer"]: - model.final_layer = torch.compile(model.final_layer, fullgraph=compile_args["fullgraph"], dynamic=False, backend=compile_args["backend"]) + model.final_layer = torch.compile(model.final_layer, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"]) self.dit = model diff --git a/nodes.py b/nodes.py index 83dee79..c36c720 100644 --- a/nodes.py +++ b/nodes.py @@ -244,6 +244,7 @@ class MochiTorchCompileSettings: "mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}), "compile_dit": ("BOOLEAN", {"default": True, "tooltip": "Compiles all transformer blocks"}), "compile_final_layer": ("BOOLEAN", {"default": True, "tooltip": "Enable compiling final layer."}), + "dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}), }, } RETURN_TYPES = ("MOCHICOMPILEARGS",) @@ -252,7 +253,7 @@ class MochiTorchCompileSettings: CATEGORY = "MochiWrapper" DESCRIPTION = "torch.compile settings, when connected to the model loader, torch.compile of the selected layers is attempted. Requires Triton and torch 2.5.0 is recommended" - def loadmodel(self, backend, fullgraph, mode, compile_dit, compile_final_layer): + def loadmodel(self, backend, fullgraph, mode, compile_dit, compile_final_layer, dynamic): compile_args = { "backend": backend, @@ -260,6 +261,7 @@ class MochiTorchCompileSettings: "mode": mode, "compile_dit": compile_dit, "compile_final_layer": compile_final_layer, + "dynamic": dynamic, } return (compile_args, )