dynamic compile

This commit is contained in:
kijai 2024-11-05 08:57:48 +02:00
parent 04d15b64ae
commit f6020f71e0
2 changed files with 5 additions and 3 deletions

View File

@ -182,9 +182,9 @@ class T2VSynthMochiModel:
if compile_args is not None: if compile_args is not None:
if compile_args["compile_dit"]: if compile_args["compile_dit"]:
for i, block in enumerate(model.blocks): for i, block in enumerate(model.blocks):
model.blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=False, backend=compile_args["backend"]) model.blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"])
if compile_args["compile_final_layer"]: if compile_args["compile_final_layer"]:
model.final_layer = torch.compile(model.final_layer, fullgraph=compile_args["fullgraph"], dynamic=False, backend=compile_args["backend"]) model.final_layer = torch.compile(model.final_layer, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"])
self.dit = model self.dit = model

View File

@ -244,6 +244,7 @@ class MochiTorchCompileSettings:
"mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}), "mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}),
"compile_dit": ("BOOLEAN", {"default": True, "tooltip": "Compiles all transformer blocks"}), "compile_dit": ("BOOLEAN", {"default": True, "tooltip": "Compiles all transformer blocks"}),
"compile_final_layer": ("BOOLEAN", {"default": True, "tooltip": "Enable compiling final layer."}), "compile_final_layer": ("BOOLEAN", {"default": True, "tooltip": "Enable compiling final layer."}),
"dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}),
}, },
} }
RETURN_TYPES = ("MOCHICOMPILEARGS",) RETURN_TYPES = ("MOCHICOMPILEARGS",)
@ -252,7 +253,7 @@ class MochiTorchCompileSettings:
CATEGORY = "MochiWrapper" CATEGORY = "MochiWrapper"
DESCRIPTION = "torch.compile settings, when connected to the model loader, torch.compile of the selected layers is attempted. Requires Triton and torch 2.5.0 is recommended" DESCRIPTION = "torch.compile settings, when connected to the model loader, torch.compile of the selected layers is attempted. Requires Triton and torch 2.5.0 is recommended"
def loadmodel(self, backend, fullgraph, mode, compile_dit, compile_final_layer): def loadmodel(self, backend, fullgraph, mode, compile_dit, compile_final_layer, dynamic):
compile_args = { compile_args = {
"backend": backend, "backend": backend,
@ -260,6 +261,7 @@ class MochiTorchCompileSettings:
"mode": mode, "mode": mode,
"compile_dit": compile_dit, "compile_dit": compile_dit,
"compile_final_layer": compile_final_layer, "compile_final_layer": compile_final_layer,
"dynamic": dynamic,
} }
return (compile_args, ) return (compile_args, )