From 25f16462aa6a4b563a24e6eeb4ab37a026de86ac Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Fri, 25 Oct 2024 22:50:13 +0300 Subject: [PATCH] torch compile maybe --- nodes.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/nodes.py b/nodes.py index f86c047..d279033 100644 --- a/nodes.py +++ b/nodes.py @@ -408,7 +408,7 @@ class DownloadAndLoadCogVideoModel: if compile == "torch": torch._dynamo.config.suppress_errors = True pipe.transformer.to(memory_format=torch.channels_last) - pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True) + pipe.transformer = torch.compile(pipe.transformer, mode="default", fullgraph=False, backend="inductor") elif compile == "onediff": from onediffx import compile_pipe os.environ['NEXFORT_FX_FORCE_TRITON_SDPA'] = '1' @@ -458,6 +458,8 @@ class DownloadAndLoadCogVideoGGUFModel: "optional": { "pab_config": ("PAB_CONFIG", {"default": None}), "block_edit": ("TRANSFORMERBLOCKS", {"default": None}), + "compile": (["disabled","onediff","torch"], {"tooltip": "compile the model for faster inference, these are advanced options only available on Linux, see readme for more info"}), + } } @@ -466,7 +468,7 @@ class DownloadAndLoadCogVideoGGUFModel: FUNCTION = "loadmodel" CATEGORY = "CogVideoWrapper" - def loadmodel(self, model, vae_precision, fp8_fastmode, load_device, enable_sequential_cpu_offload, pab_config=None, block_edit=None): + def loadmodel(self, model, vae_precision, fp8_fastmode, load_device, enable_sequential_cpu_offload, pab_config=None, block_edit=None, compile="disabled"): check_diffusers_version() @@ -556,7 +558,11 @@ class DownloadAndLoadCogVideoGGUFModel: from .fp8_optimization import convert_fp8_linear convert_fp8_linear(transformer, vae_dtype) - + # compilation + if compile == "torch": + torch._dynamo.config.suppress_errors = True + pipe.transformer.to(memory_format=torch.channels_last) + pipe.transformer = torch.compile(pipe.transformer, mode="default", fullgraph=False, backend="inductor") with open(scheduler_path) as f: scheduler_config = json.load(f)