mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-09 04:44:22 +08:00
Add cogvideox-2b-img2vid CogVideoXModelLoader support
fix for transformer model patch_embed.pos_embedding dtype or at add line ComfyUI-CogVideoXWrapper/embeddings.py:129 code pos_embedding = pos_embedding.to(embeds.device, dtype=embeds.dtype)
This commit is contained in:
parent
729a6485ea
commit
d9d30f24bb
@ -727,6 +727,8 @@ class CogVideoXModelLoader:
|
||||
model_type = "5b_I2V_1_5"
|
||||
elif sd["patch_embed.proj.weight"].shape == (1920, 33, 2, 2):
|
||||
model_type = "fun_2b"
|
||||
elif sd["patch_embed.proj.weight"].shape == (1920, 32, 2, 2):
|
||||
model_type = "cogvideox-2b-img2vid"
|
||||
elif sd["patch_embed.proj.weight"].shape == (1920, 16, 2, 2):
|
||||
model_type = "2b"
|
||||
elif sd["patch_embed.proj.weight"].shape == (3072, 32, 2, 2):
|
||||
@ -748,7 +750,7 @@ class CogVideoXModelLoader:
|
||||
with open(transformer_config_path) as f:
|
||||
transformer_config = json.load(f)
|
||||
|
||||
if model_type in ["I2V", "I2V_5b", "fun_5b_pose", "5b_I2V_1_5"]:
|
||||
if model_type in ["I2V", "I2V_5b", "fun_5b_pose", "5b_I2V_1_5", "cogvideox-2b-img2vid"]:
|
||||
transformer_config["in_channels"] = 32
|
||||
if "1_5" in model_type:
|
||||
transformer_config["ofs_embed_dim"] = 512
|
||||
@ -774,6 +776,10 @@ class CogVideoXModelLoader:
|
||||
#dtype_to_use = base_dtype if any(keyword in name for keyword in params_to_keep) else dtype
|
||||
set_module_tensor_to_device(transformer, name, device=transformer_load_device, dtype=base_dtype, value=sd[name])
|
||||
del sd
|
||||
# TODO fix for transformer model patch_embed.pos_embedding dtype
|
||||
# or at add line ComfyUI-CogVideoXWrapper/embeddings.py:129 code
|
||||
# pos_embedding = pos_embedding.to(embeds.device, dtype=embeds.dtype)
|
||||
transformer = transformer.to(base_dtype).to(transformer_load_device)
|
||||
|
||||
#scheduler
|
||||
with open(scheduler_config_path) as f:
|
||||
@ -797,7 +803,8 @@ class CogVideoXModelLoader:
|
||||
dtype=base_dtype,
|
||||
is_fun_inpaint="fun" in model.lower() and not ("pose" in model.lower() or "control" in model.lower())
|
||||
)
|
||||
|
||||
if "cogvideox-2b-img2vid" == model_type:
|
||||
pipe.input_with_padding = False
|
||||
if enable_sequential_cpu_offload:
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user