From ca63f5dadea550a353ca99c36f8f85c33b876985 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Mon, 11 Nov 2024 01:19:11 +0200 Subject: [PATCH] update --- custom_cogvideox_transformer_3d.py | 3 +-- model_loading.py | 3 ++- nodes.py | 4 ++++ 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/custom_cogvideox_transformer_3d.py b/custom_cogvideox_transformer_3d.py index ed955a4..de80c2d 100644 --- a/custom_cogvideox_transformer_3d.py +++ b/custom_cogvideox_transformer_3d.py @@ -571,11 +571,10 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): # 2. Patch embedding p = self.config.patch_size p_t = self.config.patch_size_t - # We know that the hidden states height and width will always be divisible by patch_size. # But, the number of frames may not be divisible by patch_size_t. So, we pad with the beginning frames. if p_t is not None: - remaining_frames = p_t - num_frames % p_t + remaining_frames = 0 if num_frames % 2 == 0 else 1 first_frame = hidden_states[:, :1].repeat(1, 1 + remaining_frames, 1, 1, 1) hidden_states = torch.cat([first_frame, hidden_states[:, 1:]], dim=1) diff --git a/model_loading.py b/model_loading.py index acf8c9e..89c1516 100644 --- a/model_loading.py +++ b/model_loading.py @@ -263,7 +263,8 @@ class DownloadAndLoadCogVideoModel: pipe.set_adapters(adapter_list, adapter_weights=adapter_weights) if fuse: lora_scale = 1 - if "dimensionx" in lora[-1]["path"].lower(): + dimension_loras = ["orbit_left_lora", "dimensionx"] # for now dimensionx loras need scaling + if any(item in lora[-1]["path"].lower() for item in dimension_loras): lora_scale = lora_scale / lora_rank pipe.fuse_lora(lora_scale=lora_scale, components=["transformer"]) diff --git a/nodes.py b/nodes.py index 2ac987a..248306a 100644 --- a/nodes.py +++ b/nodes.py @@ -828,6 +828,10 @@ class CogVideoSampler: num_frames == 49 or context_options is not None ), "1.0 I2V model can only do 49 frames" + if image_cond_latents is not None: + assert "I2V" in pipeline.get("model_name", ""), "Image condition latents only supported for I2V models" + else: + assert "I2V" not in pipeline.get("model_name", ""), "Image condition latents required for I2V models" device = mm.get_torch_device() offload_device = mm.unet_offload_device()