fun pose GGUF

This commit is contained in:
kijai 2024-10-01 00:07:32 +03:00
parent f3a1ff933e
commit 03f237f925

View File

@ -384,7 +384,8 @@ class DownloadAndLoadCogVideoGGUFModel:
"CogVideoX_5b_GGUF_Q4_0.safetensors", "CogVideoX_5b_GGUF_Q4_0.safetensors",
"CogVideoX_5b_I2V_GGUF_Q4_0.safetensors", "CogVideoX_5b_I2V_GGUF_Q4_0.safetensors",
"CogVideoX_5b_fun_GGUF_Q4_0.safetensors", "CogVideoX_5b_fun_GGUF_Q4_0.safetensors",
"CogVideoX_5b_fun_1_1_GGUF_Q4_0.safetensors" "CogVideoX_5b_fun_1_1_GGUF_Q4_0.safetensors",
"CogVideoX_5b_fun_1_1_Pose_GGUF_Q4_0.safetensors",
], ],
), ),
"vae_precision": (["fp16", "fp32", "bf16"], {"default": "bf16", "tooltip": "VAE dtype"}), "vae_precision": (["fp16", "fp32", "bf16"], {"default": "bf16", "tooltip": "VAE dtype"}),
@ -451,7 +452,10 @@ class DownloadAndLoadCogVideoGGUFModel:
with mz_gguf_loader.quantize_lazy_load(): with mz_gguf_loader.quantize_lazy_load():
if "fun" in model: if "fun" in model:
transformer_config["in_channels"] = 33 if "Pose" in model:
transformer_config["in_channels"] = 32
else:
transformer_config["in_channels"] = 33
if pab_config is not None: if pab_config is not None:
transformer = CogVideoXTransformer3DModelFunPAB.from_config(transformer_config) transformer = CogVideoXTransformer3DModelFunPAB.from_config(transformer_config)
else: else:
@ -519,7 +523,10 @@ class DownloadAndLoadCogVideoGGUFModel:
if "fun" in model: if "fun" in model:
vae = AutoencoderKLCogVideoXFun.from_config(vae_config).to(vae_dtype).to(offload_device) vae = AutoencoderKLCogVideoXFun.from_config(vae_config).to(vae_dtype).to(offload_device)
vae.load_state_dict(vae_sd) vae.load_state_dict(vae_sd)
pipe = CogVideoX_Fun_Pipeline_Inpaint(vae, transformer, scheduler, pab_config=pab_config) if "Pose" in model:
pipe = CogVideoX_Fun_Pipeline_Control(vae, transformer, scheduler, pab_config=pab_config)
else:
pipe = CogVideoX_Fun_Pipeline_Inpaint(vae, transformer, scheduler, pab_config=pab_config)
else: else:
vae = AutoencoderKLCogVideoX.from_config(vae_config).to(vae_dtype).to(offload_device) vae = AutoencoderKLCogVideoX.from_config(vae_config).to(vae_dtype).to(offload_device)
vae.load_state_dict(vae_sd) vae.load_state_dict(vae_sd)