Allow compiling VAE as well

This commit is contained in:
Jukka Seppänen 2024-11-23 17:08:57 +02:00
parent 9baf100366
commit 8d6e53b556

View File

@ -839,6 +839,7 @@ class CogVideoXVAELoader:
"precision": (["fp16", "fp32", "bf16"],
{"default": "bf16"}
),
"compile_args":("COMPILEARGS", ),
}
}
@ -848,7 +849,7 @@ class CogVideoXVAELoader:
CATEGORY = "CogVideoWrapper"
DESCRIPTION = "Loads CogVideoX VAE model from 'ComfyUI/models/vae'"
def loadmodel(self, model_name, precision):
def loadmodel(self, model_name, precision, compile_args=None):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
@ -860,6 +861,10 @@ class CogVideoXVAELoader:
vae = AutoencoderKLCogVideoX.from_config(vae_config).to(dtype).to(offload_device)
vae.load_state_dict(vae_sd)
#compile
if compile_args is not None:
torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"]
vae = torch.compile(vae, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"])
return (vae,)