Merge pull request #311 from glide-the/fix_pos_embedding

Add cogvideox-2b-img2vid CogVideoXModelLoader support
This commit is contained in:
Jukka Seppänen 2024-12-08 11:52:12 +02:00 committed by GitHub
commit 795f8b0565
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -727,6 +727,8 @@ class CogVideoXModelLoader:
model_type = "5b_I2V_1_5"
elif sd["patch_embed.proj.weight"].shape == (1920, 33, 2, 2):
model_type = "fun_2b"
elif sd["patch_embed.proj.weight"].shape == (1920, 32, 2, 2):
model_type = "cogvideox-2b-img2vid"
elif sd["patch_embed.proj.weight"].shape == (1920, 16, 2, 2):
model_type = "2b"
elif sd["patch_embed.proj.weight"].shape == (3072, 32, 2, 2):
@ -748,7 +750,7 @@ class CogVideoXModelLoader:
with open(transformer_config_path) as f:
transformer_config = json.load(f)
if model_type in ["I2V", "I2V_5b", "fun_5b_pose", "5b_I2V_1_5"]:
if model_type in ["I2V", "I2V_5b", "fun_5b_pose", "5b_I2V_1_5", "cogvideox-2b-img2vid"]:
transformer_config["in_channels"] = 32
if "1_5" in model_type:
transformer_config["ofs_embed_dim"] = 512
@ -774,6 +776,10 @@ class CogVideoXModelLoader:
#dtype_to_use = base_dtype if any(keyword in name for keyword in params_to_keep) else dtype
set_module_tensor_to_device(transformer, name, device=transformer_load_device, dtype=base_dtype, value=sd[name])
del sd
# TODO fix for transformer model patch_embed.pos_embedding dtype
# or at add line ComfyUI-CogVideoXWrapper/embeddings.py:129 code
# pos_embedding = pos_embedding.to(embeds.device, dtype=embeds.dtype)
transformer = transformer.to(base_dtype).to(transformer_load_device)
#scheduler
with open(scheduler_config_path) as f:
@ -797,7 +803,8 @@ class CogVideoXModelLoader:
dtype=base_dtype,
is_fun_inpaint="fun" in model.lower() and not ("pose" in model.lower() or "control" in model.lower())
)
if "cogvideox-2b-img2vid" == model_type:
pipe.input_with_padding = False
if enable_sequential_cpu_offload:
pipe.enable_sequential_cpu_offload()