mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2026-05-02 09:53:34 +08:00
better compiler selection
This commit is contained in:
parent
cf01dc2b0b
commit
0c9478a9fb
14
nodes.py
14
nodes.py
@ -32,8 +32,7 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
{"default": "bf16", "tooltip": "official recommendation is that 2b model should be fp16, 5b model should be bf16"}
|
{"default": "bf16", "tooltip": "official recommendation is that 2b model should be fp16, 5b model should be bf16"}
|
||||||
),
|
),
|
||||||
"fp8_transformer": ("BOOLEAN", {"default": False, "tooltip": "cast the transformer to torch.float8_e4m3fn"}),
|
"fp8_transformer": ("BOOLEAN", {"default": False, "tooltip": "cast the transformer to torch.float8_e4m3fn"}),
|
||||||
"torch_compile": ("BOOLEAN", {"default": False, "tooltip": "use torch.compile to speed up inference, Linux only"}),
|
"compile": (["disabled","onediff","torch"], {"tooltip": "compile the model for faster inference, these are advanced options only available on Linux, see readme for more info"}),
|
||||||
"onediff": ("BOOLEAN", {"default": False, "tooltip": "use onediff/nexfort to speed up inference, requires onediff installed (Linux only)"}),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -42,7 +41,7 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
FUNCTION = "loadmodel"
|
FUNCTION = "loadmodel"
|
||||||
CATEGORY = "CogVideoWrapper"
|
CATEGORY = "CogVideoWrapper"
|
||||||
|
|
||||||
def loadmodel(self, model, precision, fp8_transformer, torch_compile, onediff):
|
def loadmodel(self, model, precision, fp8_transformer, compile="disabled"):
|
||||||
device = mm.get_torch_device()
|
device = mm.get_torch_device()
|
||||||
offload_device = mm.unet_offload_device()
|
offload_device = mm.unet_offload_device()
|
||||||
mm.soft_empty_cache()
|
mm.soft_empty_cache()
|
||||||
@ -74,13 +73,12 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
|
|
||||||
pipe = CogVideoXPipeline(vae, transformer, scheduler)
|
pipe = CogVideoXPipeline(vae, transformer, scheduler)
|
||||||
|
|
||||||
if torch_compile:
|
if compile == "torch":
|
||||||
torch._dynamo.config.suppress_errors = True
|
torch._dynamo.config.suppress_errors = True
|
||||||
pipe.transformer.to(memory_format=torch.channels_last)
|
pipe.transformer.to(memory_format=torch.channels_last)
|
||||||
pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
|
pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
|
||||||
|
elif compile == "onediff":
|
||||||
if onediff:
|
from onediffx import compile_pipe
|
||||||
from onediffx import compile_pipe, quantize_pipe
|
|
||||||
options = None
|
options = None
|
||||||
pipe = compile_pipe(
|
pipe = compile_pipe(
|
||||||
pipe,
|
pipe,
|
||||||
@ -95,7 +93,7 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
"pipe": pipe,
|
"pipe": pipe,
|
||||||
"dtype": dtype,
|
"dtype": dtype,
|
||||||
"base_path": base_path,
|
"base_path": base_path,
|
||||||
"onediff": onediff
|
"onediff": True if compile == "onediff" else False
|
||||||
}
|
}
|
||||||
|
|
||||||
return (pipeline,)
|
return (pipeline,)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user