dynamic compile
This commit is contained in:
parent
04d15b64ae
commit
f6020f71e0
@ -182,9 +182,9 @@ class T2VSynthMochiModel:
|
|||||||
if compile_args is not None:
|
if compile_args is not None:
|
||||||
if compile_args["compile_dit"]:
|
if compile_args["compile_dit"]:
|
||||||
for i, block in enumerate(model.blocks):
|
for i, block in enumerate(model.blocks):
|
||||||
model.blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=False, backend=compile_args["backend"])
|
model.blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"])
|
||||||
if compile_args["compile_final_layer"]:
|
if compile_args["compile_final_layer"]:
|
||||||
model.final_layer = torch.compile(model.final_layer, fullgraph=compile_args["fullgraph"], dynamic=False, backend=compile_args["backend"])
|
model.final_layer = torch.compile(model.final_layer, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"])
|
||||||
|
|
||||||
self.dit = model
|
self.dit = model
|
||||||
|
|
||||||
|
|||||||
4
nodes.py
4
nodes.py
@ -244,6 +244,7 @@ class MochiTorchCompileSettings:
|
|||||||
"mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}),
|
"mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}),
|
||||||
"compile_dit": ("BOOLEAN", {"default": True, "tooltip": "Compiles all transformer blocks"}),
|
"compile_dit": ("BOOLEAN", {"default": True, "tooltip": "Compiles all transformer blocks"}),
|
||||||
"compile_final_layer": ("BOOLEAN", {"default": True, "tooltip": "Enable compiling final layer."}),
|
"compile_final_layer": ("BOOLEAN", {"default": True, "tooltip": "Enable compiling final layer."}),
|
||||||
|
"dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
RETURN_TYPES = ("MOCHICOMPILEARGS",)
|
RETURN_TYPES = ("MOCHICOMPILEARGS",)
|
||||||
@ -252,7 +253,7 @@ class MochiTorchCompileSettings:
|
|||||||
CATEGORY = "MochiWrapper"
|
CATEGORY = "MochiWrapper"
|
||||||
DESCRIPTION = "torch.compile settings, when connected to the model loader, torch.compile of the selected layers is attempted. Requires Triton and torch 2.5.0 is recommended"
|
DESCRIPTION = "torch.compile settings, when connected to the model loader, torch.compile of the selected layers is attempted. Requires Triton and torch 2.5.0 is recommended"
|
||||||
|
|
||||||
def loadmodel(self, backend, fullgraph, mode, compile_dit, compile_final_layer):
|
def loadmodel(self, backend, fullgraph, mode, compile_dit, compile_final_layer, dynamic):
|
||||||
|
|
||||||
compile_args = {
|
compile_args = {
|
||||||
"backend": backend,
|
"backend": backend,
|
||||||
@ -260,6 +261,7 @@ class MochiTorchCompileSettings:
|
|||||||
"mode": mode,
|
"mode": mode,
|
||||||
"compile_dit": compile_dit,
|
"compile_dit": compile_dit,
|
||||||
"compile_final_layer": compile_final_layer,
|
"compile_final_layer": compile_final_layer,
|
||||||
|
"dynamic": dynamic,
|
||||||
}
|
}
|
||||||
|
|
||||||
return (compile_args, )
|
return (compile_args, )
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user