better compiler selection

This commit is contained in:
Kijai 2024-08-28 17:15:47 +03:00
parent cf01dc2b0b
commit 0c9478a9fb

View File

@ -32,8 +32,7 @@ class DownloadAndLoadCogVideoModel:
{"default": "bf16", "tooltip": "official recommendation is that 2b model should be fp16, 5b model should be bf16"} {"default": "bf16", "tooltip": "official recommendation is that 2b model should be fp16, 5b model should be bf16"}
), ),
"fp8_transformer": ("BOOLEAN", {"default": False, "tooltip": "cast the transformer to torch.float8_e4m3fn"}), "fp8_transformer": ("BOOLEAN", {"default": False, "tooltip": "cast the transformer to torch.float8_e4m3fn"}),
"torch_compile": ("BOOLEAN", {"default": False, "tooltip": "use torch.compile to speed up inference, Linux only"}), "compile": (["disabled","onediff","torch"], {"tooltip": "compile the model for faster inference, these are advanced options only available on Linux, see readme for more info"}),
"onediff": ("BOOLEAN", {"default": False, "tooltip": "use onediff/nexfort to speed up inference, requires onediff installed (Linux only)"}),
} }
} }
@ -42,7 +41,7 @@ class DownloadAndLoadCogVideoModel:
FUNCTION = "loadmodel" FUNCTION = "loadmodel"
CATEGORY = "CogVideoWrapper" CATEGORY = "CogVideoWrapper"
def loadmodel(self, model, precision, fp8_transformer, torch_compile, onediff): def loadmodel(self, model, precision, fp8_transformer, compile="disabled"):
device = mm.get_torch_device() device = mm.get_torch_device()
offload_device = mm.unet_offload_device() offload_device = mm.unet_offload_device()
mm.soft_empty_cache() mm.soft_empty_cache()
@ -74,13 +73,12 @@ class DownloadAndLoadCogVideoModel:
pipe = CogVideoXPipeline(vae, transformer, scheduler) pipe = CogVideoXPipeline(vae, transformer, scheduler)
if torch_compile: if compile == "torch":
torch._dynamo.config.suppress_errors = True torch._dynamo.config.suppress_errors = True
pipe.transformer.to(memory_format=torch.channels_last) pipe.transformer.to(memory_format=torch.channels_last)
pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True) pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
elif compile == "onediff":
if onediff: from onediffx import compile_pipe
from onediffx import compile_pipe, quantize_pipe
options = None options = None
pipe = compile_pipe( pipe = compile_pipe(
pipe, pipe,
@ -95,7 +93,7 @@ class DownloadAndLoadCogVideoModel:
"pipe": pipe, "pipe": pipe,
"dtype": dtype, "dtype": dtype,
"base_path": base_path, "base_path": base_path,
"onediff": onediff "onediff": True if compile == "onediff" else False
} }
return (pipeline,) return (pipeline,)