diff --git a/nodes.py b/nodes.py index 74360fd..b91cecd 100644 --- a/nodes.py +++ b/nodes.py @@ -29,6 +29,7 @@ class DownloadAndLoadCogVideoModel: {"default": "bf16", "tooltip": "official recommendation is that 2b model should be fp16, 5b model should be bf16"} ), "fp8_transformer": ("BOOLEAN", {"default": False, "tooltip": "cast the transformer to torch.float8_e4m3fn"}), + "torch_compile": ("BOOLEAN", {"default": False, "tooltip": "use torch.compile to speed up inference, Linux only"}), } } @@ -37,7 +38,7 @@ class DownloadAndLoadCogVideoModel: FUNCTION = "loadmodel" CATEGORY = "CogVideoWrapper" - def loadmodel(self, model, precision, fp8_transformer): + def loadmodel(self, model, precision, fp8_transformer, torch_compile): device = mm.get_torch_device() offload_device = mm.unet_offload_device() mm.soft_empty_cache() @@ -69,6 +70,10 @@ class DownloadAndLoadCogVideoModel: pipe = CogVideoXPipeline(vae, transformer, scheduler) + if torch_compile: + pipe.transformer.to(memory_format=torch.channels_last) + pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True) + pipeline = { "pipe": pipe, "dtype": dtype,