Add compile_args node

This commit is contained in:
kijai 2024-11-07 23:17:37 +02:00
parent 9cce8015d3
commit 07defb52b6
2 changed files with 50 additions and 11 deletions

View File

@ -59,6 +59,7 @@ class DownloadAndLoadCogVideoModel:
"pab_config": ("PAB_CONFIG", {"default": None}),
"block_edit": ("TRANSFORMERBLOCKS", {"default": None}),
"lora": ("COGLORA", {"default": None}),
"compile_args":("COMPILEARGS", ),
}
}
@ -68,7 +69,7 @@ class DownloadAndLoadCogVideoModel:
CATEGORY = "CogVideoWrapper"
DESCRIPTION = "Downloads and loads the selected CogVideo model from Huggingface to 'ComfyUI/models/CogVideo'"
def loadmodel(self, model, precision, fp8_transformer="disabled", compile="disabled", enable_sequential_cpu_offload=False, pab_config=None, block_edit=None, lora=None):
def loadmodel(self, model, precision, fp8_transformer="disabled", compile="disabled", enable_sequential_cpu_offload=False, pab_config=None, block_edit=None, lora=None, compile_args=None):
check_diffusers_version()
@ -186,17 +187,23 @@ class DownloadAndLoadCogVideoModel:
if enable_sequential_cpu_offload:
pipe.enable_sequential_cpu_offload()
# compilation
if compile == "torch":
torch._dynamo.config.suppress_errors = True
torch._dynamo.config.cache_size_limit = 64
pipe.transformer.to(memory_format=torch.channels_last)
#pipe.transformer = torch.compile(pipe.transformer, mode="default", fullgraph=False, backend="inductor")
for i, block in enumerate(pipe.transformer.transformer_blocks):
if "CogVideoXBlock" in str(block):
pipe.transformer.transformer_blocks[i] = torch.compile(block, fullgraph=False, dynamic=False, backend="inductor")
if compile_args is not None:
torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"]
for i, block in enumerate(pipe.transformer.transformer_blocks):
if "CogVideoXBlock" in str(block):
pipe.transformer.transformer_blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"])
else:
for i, block in enumerate(pipe.transformer.transformer_blocks):
if "CogVideoXBlock" in str(block):
pipe.transformer.transformer_blocks[i] = torch.compile(block, fullgraph=False, dynamic=False, backend="inductor")
elif compile == "onediff":
from onediffx import compile_pipe
os.environ['NEXFORT_FX_FORCE_TRITON_SDPA'] = '1'

View File

@ -182,6 +182,36 @@ class CogVideoLoraSelect:
print(cog_loras_list)
return (cog_loras_list,)
class CogVideoXTorchCompileSettings:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"backend": (["inductor","cudagraphs"], {"default": "inductor"}),
"fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}),
"mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}),
"dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}),
"dynamo_cache_size_limit": ("INT", {"default": 64, "min": 0, "max": 1024, "step": 1, "tooltip": "torch._dynamo.config.cache_size_limit"}),
},
}
RETURN_TYPES = ("COMPILEARGS",)
RETURN_NAMES = ("torch_compile_args",)
FUNCTION = "loadmodel"
CATEGORY = "MochiWrapper"
DESCRIPTION = "torch.compile settings, when connected to the model loader, torch.compile of the selected layers is attempted. Requires Triton and torch 2.5.0 is recommended"
def loadmodel(self, backend, fullgraph, mode, dynamic, dynamo_cache_size_limit):
compile_args = {
"backend": backend,
"fullgraph": fullgraph,
"mode": mode,
"dynamic": dynamic,
"dynamo_cache_size_limit": dynamo_cache_size_limit,
}
return (compile_args, )
#region TextEncode
class CogVideoEncodePrompt:
@classmethod
@ -1376,7 +1406,8 @@ NODE_CLASS_MAPPINGS = {
"ToraEncodeOpticalFlow": ToraEncodeOpticalFlow,
"CogVideoXFasterCache": CogVideoXFasterCache,
"CogVideoXFunResizeToClosestBucket": CogVideoXFunResizeToClosestBucket,
"CogVideoLatentPreview": CogVideoLatentPreview
"CogVideoLatentPreview": CogVideoLatentPreview,
"CogVideoXTorchCompileSettings": CogVideoXTorchCompileSettings
}
NODE_DISPLAY_NAME_MAPPINGS = {
"CogVideoSampler": "CogVideo Sampler",
@ -1398,5 +1429,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"ToraEncodeOpticalFlow": "Tora Encode OpticalFlow",
"CogVideoXFasterCache": "CogVideoX FasterCache",
"CogVideoXFunResizeToClosestBucket": "CogVideoXFun ResizeToClosestBucket",
"CogVideoLatentPreview": "CogVideo LatentPreview"
"CogVideoLatentPreview": "CogVideo LatentPreview",
"CogVideoXTorchCompileSettings": "CogVideo TorchCompileSettings",
}