From 0c9478a9fb8660cfe746ebf4b07a2ada68c02aed Mon Sep 17 00:00:00 2001 From: Kijai <40791699+kijai@users.noreply.github.com> Date: Wed, 28 Aug 2024 17:15:47 +0300 Subject: [PATCH] better compiler selection --- nodes.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/nodes.py b/nodes.py index 33dbd42..8c21e49 100644 --- a/nodes.py +++ b/nodes.py @@ -32,8 +32,7 @@ class DownloadAndLoadCogVideoModel: {"default": "bf16", "tooltip": "official recommendation is that 2b model should be fp16, 5b model should be bf16"} ), "fp8_transformer": ("BOOLEAN", {"default": False, "tooltip": "cast the transformer to torch.float8_e4m3fn"}), - "torch_compile": ("BOOLEAN", {"default": False, "tooltip": "use torch.compile to speed up inference, Linux only"}), - "onediff": ("BOOLEAN", {"default": False, "tooltip": "use onediff/nexfort to speed up inference, requires onediff installed (Linux only)"}), + "compile": (["disabled","onediff","torch"], {"tooltip": "compile the model for faster inference, these are advanced options only available on Linux, see readme for more info"}), } } @@ -42,7 +41,7 @@ class DownloadAndLoadCogVideoModel: FUNCTION = "loadmodel" CATEGORY = "CogVideoWrapper" - def loadmodel(self, model, precision, fp8_transformer, torch_compile, onediff): + def loadmodel(self, model, precision, fp8_transformer, compile="disabled"): device = mm.get_torch_device() offload_device = mm.unet_offload_device() mm.soft_empty_cache() @@ -74,13 +73,12 @@ class DownloadAndLoadCogVideoModel: pipe = CogVideoXPipeline(vae, transformer, scheduler) - if torch_compile: + if compile == "torch": torch._dynamo.config.suppress_errors = True pipe.transformer.to(memory_format=torch.channels_last) pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True) - - if onediff: - from onediffx import compile_pipe, quantize_pipe + elif compile == "onediff": + from onediffx import compile_pipe options = None pipe = compile_pipe( pipe, @@ -95,7 +93,7 @@ class DownloadAndLoadCogVideoModel: "pipe": pipe, "dtype": dtype, "base_path": base_path, - "onediff": onediff + "onediff": True if compile == "onediff" else False } return (pipeline,) @@ -270,7 +268,7 @@ class CogVideoSampler: pipe.scheduler = CogVideoXDDIMScheduler.from_pretrained(base_path, subfolder="scheduler") elif scheduler == "DPM": pipe.scheduler = CogVideoXDPMScheduler.from_pretrained(base_path, subfolder="scheduler") - + autocastcondition = not pipeline["onediff"] autocast_context = torch.autocast(mm.get_autocast_device(device)) if autocastcondition else nullcontext() with autocast_context: