fun pose GGUF

This commit is contained in:
kijai 2024-10-01 00:07:32 +03:00
parent f3a1ff933e
commit 03f237f925

View File

@ -384,7 +384,8 @@ class DownloadAndLoadCogVideoGGUFModel:
"CogVideoX_5b_GGUF_Q4_0.safetensors",
"CogVideoX_5b_I2V_GGUF_Q4_0.safetensors",
"CogVideoX_5b_fun_GGUF_Q4_0.safetensors",
"CogVideoX_5b_fun_1_1_GGUF_Q4_0.safetensors"
"CogVideoX_5b_fun_1_1_GGUF_Q4_0.safetensors",
"CogVideoX_5b_fun_1_1_Pose_GGUF_Q4_0.safetensors",
],
),
"vae_precision": (["fp16", "fp32", "bf16"], {"default": "bf16", "tooltip": "VAE dtype"}),
@ -451,7 +452,10 @@ class DownloadAndLoadCogVideoGGUFModel:
with mz_gguf_loader.quantize_lazy_load():
if "fun" in model:
transformer_config["in_channels"] = 33
if "Pose" in model:
transformer_config["in_channels"] = 32
else:
transformer_config["in_channels"] = 33
if pab_config is not None:
transformer = CogVideoXTransformer3DModelFunPAB.from_config(transformer_config)
else:
@ -519,7 +523,10 @@ class DownloadAndLoadCogVideoGGUFModel:
if "fun" in model:
vae = AutoencoderKLCogVideoXFun.from_config(vae_config).to(vae_dtype).to(offload_device)
vae.load_state_dict(vae_sd)
pipe = CogVideoX_Fun_Pipeline_Inpaint(vae, transformer, scheduler, pab_config=pab_config)
if "Pose" in model:
pipe = CogVideoX_Fun_Pipeline_Control(vae, transformer, scheduler, pab_config=pab_config)
else:
pipe = CogVideoX_Fun_Pipeline_Inpaint(vae, transformer, scheduler, pab_config=pab_config)
else:
vae = AutoencoderKLCogVideoX.from_config(vae_config).to(vae_dtype).to(offload_device)
vae.load_state_dict(vae_sd)