Allow compiling VAE as well

This commit is contained in:
Jukka Seppänen 2024-11-23 17:08:57 +02:00
parent 9baf100366
commit 8d6e53b556

View File

@ -839,6 +839,7 @@ class CogVideoXVAELoader:
"precision": (["fp16", "fp32", "bf16"], "precision": (["fp16", "fp32", "bf16"],
{"default": "bf16"} {"default": "bf16"}
), ),
"compile_args":("COMPILEARGS", ),
} }
} }
@ -848,7 +849,7 @@ class CogVideoXVAELoader:
CATEGORY = "CogVideoWrapper" CATEGORY = "CogVideoWrapper"
DESCRIPTION = "Loads CogVideoX VAE model from 'ComfyUI/models/vae'" DESCRIPTION = "Loads CogVideoX VAE model from 'ComfyUI/models/vae'"
def loadmodel(self, model_name, precision): def loadmodel(self, model_name, precision, compile_args=None):
device = mm.get_torch_device() device = mm.get_torch_device()
offload_device = mm.unet_offload_device() offload_device = mm.unet_offload_device()
@ -860,6 +861,10 @@ class CogVideoXVAELoader:
vae = AutoencoderKLCogVideoX.from_config(vae_config).to(dtype).to(offload_device) vae = AutoencoderKLCogVideoX.from_config(vae_config).to(dtype).to(offload_device)
vae.load_state_dict(vae_sd) vae.load_state_dict(vae_sd)
#compile
if compile_args is not None:
torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"]
vae = torch.compile(vae, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"])
return (vae,) return (vae,)