diff --git a/model_loading.py b/model_loading.py index 1123a4f..d8ad72f 100644 --- a/model_loading.py +++ b/model_loading.py @@ -147,6 +147,7 @@ class DownloadAndLoadCogVideoModel: "alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose", "alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose", "alibaba-pai/CogVideoX-Fun-V1.1-5b-Control", + "alibaba-pai/CogVideoX-Fun-V1.5-5b-InP", "feizhengcong/CogvideoX-Interpolation", "NimVideo/cogvideox-2b-img2vid" ], @@ -215,7 +216,7 @@ class DownloadAndLoadCogVideoModel: download_path = folder_paths.get_folder_paths("CogVideo")[0] if "Fun" in model: - if not "1.1" in model: + if not "1.1" and not "1.5" in model: repo_id = "kijai/CogVideoX-Fun-pruned" if "2b" in model: base_path = os.path.join(folder_paths.models_dir, "CogVideoX_Fun", "CogVideoX-Fun-2b-InP") # location of the official model @@ -225,7 +226,7 @@ class DownloadAndLoadCogVideoModel: base_path = os.path.join(folder_paths.models_dir, "CogVideoX_Fun", "CogVideoX-Fun-5b-InP") # location of the official model if not os.path.exists(base_path): base_path = os.path.join(download_path, "CogVideoX-Fun-5b-InP") - elif "1.1" in model: + else: repo_id = model base_path = os.path.join(folder_paths.models_dir, "CogVideoX_Fun", (model.split("/")[-1])) # location of the official model if not os.path.exists(base_path): @@ -278,7 +279,7 @@ class DownloadAndLoadCogVideoModel: transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder=subfolder, attention_mode=attention_mode) transformer = transformer.to(dtype).to(transformer_load_device) - if "1.5" in model: + if "1.5" in model and not "fun" in model: transformer.config.sample_height = 300 transformer.config.sample_width = 300 diff --git a/nodes.py b/nodes.py index 3819b54..08997b0 100644 --- a/nodes.py +++ b/nodes.py @@ -360,8 +360,8 @@ class CogVideoImageEncodeFunInP: masked_image_latents = masked_image_latents.permute(0, 2, 1, 3, 4) # B, T, C, H, W mask = torch.zeros_like(masked_image_latents[:, :, :1, :, :]) - if end_image is not None: - mask[:, -1, :, :, :] = 0 + #if end_image is not None: + # mask[:, -1, :, :, :] = 0 mask[:, 0, :, :, :] = vae_scaling_factor final_latents = masked_image_latents * vae_scaling_factor @@ -623,7 +623,7 @@ class CogVideoSampler: image_conds = image_cond_latents["samples"] image_cond_start_percent = image_cond_latents.get("start_percent", 0.0) image_cond_end_percent = image_cond_latents.get("end_percent", 1.0) - if "1.5" in model_name or "1_5" in model_name: + if ("1.5" in model_name or "1_5" in model_name) and not "fun" in model_name.lower(): image_conds = image_conds / 0.7 # needed for 1.5 models else: if not "fun" in model_name.lower(): diff --git a/pipeline_cogvideox.py b/pipeline_cogvideox.py index 8b3a71d..0db9671 100644 --- a/pipeline_cogvideox.py +++ b/pipeline_cogvideox.py @@ -471,6 +471,8 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): # 5.5. if image_cond_latents is not None: + image_cond_frame_count = image_cond_latents.size(1) + patch_size_t = self.transformer.config.patch_size_t if image_cond_latents.shape[1] == 2: logger.info("More than one image conditioning frame received, interpolating") padding_shape = ( @@ -482,8 +484,8 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): ) latent_padding = torch.zeros(padding_shape, device=device, dtype=self.vae_dtype) image_cond_latents = torch.cat([image_cond_latents[:, 0, :, :, :].unsqueeze(1), latent_padding, image_cond_latents[:, -1, :, :, :].unsqueeze(1)], dim=1) - if self.transformer.config.patch_size_t is not None: - first_frame = image_cond_latents[:, : image_cond_latents.size(1) % self.transformer.config.patch_size_t, ...] + if patch_size_t: + first_frame = image_cond_latents[:, : image_cond_frame_count % patch_size_t, ...] image_cond_latents = torch.cat([first_frame, image_cond_latents], dim=1) logger.info(f"image cond latents shape: {image_cond_latents.shape}") @@ -500,13 +502,19 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): latent_padding = torch.zeros(padding_shape, device=device, dtype=self.vae_dtype) image_cond_latents = torch.cat([image_cond_latents, latent_padding], dim=1) # Select the first frame along the second dimension - if self.transformer.config.patch_size_t is not None: - first_frame = image_cond_latents[:, : image_cond_latents.size(1) % self.transformer.config.patch_size_t, ...] + if patch_size_t: + first_frame = image_cond_latents[:, : image_cond_frame_count % patch_size_t, ...] image_cond_latents = torch.cat([first_frame, image_cond_latents], dim=1) else: image_cond_latents = image_cond_latents.repeat(1, latents.shape[1], 1, 1, 1) else: logger.info(f"Received {image_cond_latents.shape[1]} image conditioning frames") + if fun_mask is not None and patch_size_t: + logger.info(f"1.5 model received {fun_mask.shape[1]} masks") + first_frame = image_cond_latents[:, : image_cond_frame_count % patch_size_t, ...] + image_cond_latents = torch.cat([first_frame, image_cond_latents], dim=1) + fun_mask_first_frame = fun_mask[:, : image_cond_frame_count % patch_size_t, ...] + fun_mask = torch.cat([fun_mask_first_frame, fun_mask], dim=1) image_cond_latents = image_cond_latents.to(self.vae_dtype) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline