torch.compile option

experimental
This commit is contained in:
kijai 2024-08-27 23:34:46 +03:00
parent 1e356fa905
commit 8be401f0bb

View File

@ -29,6 +29,7 @@ class DownloadAndLoadCogVideoModel:
{"default": "bf16", "tooltip": "official recommendation is that 2b model should be fp16, 5b model should be bf16"}
),
"fp8_transformer": ("BOOLEAN", {"default": False, "tooltip": "cast the transformer to torch.float8_e4m3fn"}),
"torch_compile": ("BOOLEAN", {"default": False, "tooltip": "use torch.compile to speed up inference, Linux only"}),
}
}
@ -37,7 +38,7 @@ class DownloadAndLoadCogVideoModel:
FUNCTION = "loadmodel"
CATEGORY = "CogVideoWrapper"
def loadmodel(self, model, precision, fp8_transformer):
def loadmodel(self, model, precision, fp8_transformer, torch_compile):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
mm.soft_empty_cache()
@ -69,6 +70,10 @@ class DownloadAndLoadCogVideoModel:
pipe = CogVideoXPipeline(vae, transformer, scheduler)
if torch_compile:
pipe.transformer.to(memory_format=torch.channels_last)
pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
pipeline = {
"pipe": pipe,
"dtype": dtype,