torch compile maybe

This commit is contained in:
kijai 2024-10-25 22:50:13 +03:00
parent 249e8d54d1
commit 25f16462aa

View File

@ -408,7 +408,7 @@ class DownloadAndLoadCogVideoModel:
if compile == "torch":
torch._dynamo.config.suppress_errors = True
pipe.transformer.to(memory_format=torch.channels_last)
pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
pipe.transformer = torch.compile(pipe.transformer, mode="default", fullgraph=False, backend="inductor")
elif compile == "onediff":
from onediffx import compile_pipe
os.environ['NEXFORT_FX_FORCE_TRITON_SDPA'] = '1'
@ -458,6 +458,8 @@ class DownloadAndLoadCogVideoGGUFModel:
"optional": {
"pab_config": ("PAB_CONFIG", {"default": None}),
"block_edit": ("TRANSFORMERBLOCKS", {"default": None}),
"compile": (["disabled","onediff","torch"], {"tooltip": "compile the model for faster inference, these are advanced options only available on Linux, see readme for more info"}),
}
}
@ -466,7 +468,7 @@ class DownloadAndLoadCogVideoGGUFModel:
FUNCTION = "loadmodel"
CATEGORY = "CogVideoWrapper"
def loadmodel(self, model, vae_precision, fp8_fastmode, load_device, enable_sequential_cpu_offload, pab_config=None, block_edit=None):
def loadmodel(self, model, vae_precision, fp8_fastmode, load_device, enable_sequential_cpu_offload, pab_config=None, block_edit=None, compile="disabled"):
check_diffusers_version()
@ -556,7 +558,11 @@ class DownloadAndLoadCogVideoGGUFModel:
from .fp8_optimization import convert_fp8_linear
convert_fp8_linear(transformer, vae_dtype)
# compilation
if compile == "torch":
torch._dynamo.config.suppress_errors = True
pipe.transformer.to(memory_format=torch.channels_last)
pipe.transformer = torch.compile(pipe.transformer, mode="default", fullgraph=False, backend="inductor")
with open(scheduler_path) as f:
scheduler_config = json.load(f)