diff --git a/custom_cogvideox_transformer_3d.py b/custom_cogvideox_transformer_3d.py index de80c2d..2fa191a 100644 --- a/custom_cogvideox_transformer_3d.py +++ b/custom_cogvideox_transformer_3d.py @@ -573,10 +573,10 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): p_t = self.config.patch_size_t # We know that the hidden states height and width will always be divisible by patch_size. # But, the number of frames may not be divisible by patch_size_t. So, we pad with the beginning frames. - if p_t is not None: - remaining_frames = 0 if num_frames % 2 == 0 else 1 - first_frame = hidden_states[:, :1].repeat(1, 1 + remaining_frames, 1, 1, 1) - hidden_states = torch.cat([first_frame, hidden_states[:, 1:]], dim=1) + # if p_t is not None: + # remaining_frames = 0 if num_frames % 2 == 0 else 1 + # first_frame = hidden_states[:, :1].repeat(1, 1 + remaining_frames, 1, 1, 1) + # hidden_states = torch.cat([first_frame, hidden_states[:, 1:]], dim=1) hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) @@ -711,7 +711,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p ) output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2) - output = output[:, remaining_frames:] + #output = output[:, remaining_frames:] if self.fastercache_counter >= self.fastercache_start_step + 1: (bb, tt, cc, hh, ww) = output.shape diff --git a/nodes.py b/nodes.py index 248306a..42af0ad 100644 --- a/nodes.py +++ b/nodes.py @@ -1004,7 +1004,7 @@ class CogVideoDecode: vae._clear_fake_context_parallel_cache() except: pass - frames = vae.decode(latents).sample + frames = vae.decode(latents[:, :, pipeline["pipe"].additional_frames:]).sample vae.disable_tiling() if not pipeline["cpu_offloading"]: vae.to(offload_device) diff --git a/pipeline_cogvideox.py b/pipeline_cogvideox.py index ae1aa6d..007987e 100644 --- a/pipeline_cogvideox.py +++ b/pipeline_cogvideox.py @@ -434,6 +434,8 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin): width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial num_videos_per_prompt = 1 + self.num_frames = num_frames + # 1. Check inputs. Raise error if not correct self.check_inputs( height, @@ -463,6 +465,14 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin): # 5. Prepare latents. latent_channels = self.vae.config.latent_channels + latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + # For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t + patch_size_t = self.transformer.config.patch_size_t + self.additional_frames = 0 + if patch_size_t is not None and latent_frames % patch_size_t != 0: + self.additional_frames = patch_size_t - latent_frames % patch_size_t + num_frames += self.additional_frames * self.vae_scale_factor_temporal + #if latents is None and num_frames == t_tile_length: # num_frames += 1 @@ -503,8 +513,12 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin): width // self.vae_scale_factor_spatial, ) latent_padding = torch.zeros(padding_shape, device=device, dtype=self.vae.dtype) - image_cond_latents = torch.cat([image_cond_latents[:, 0, :, :, :].unsqueeze(1), latent_padding, image_cond_latents[:, -1, :, :, :].unsqueeze(1)], dim=1) + # Select the first frame along the second dimension + if self.transformer.config.patch_size_t is not None: + first_frame = image_cond_latents[:, : image_latents.size(1) % self.transformer.config.patch_size_t, ...] + image_cond_latents = torch.cat([first_frame, image_latents], dim=1) + logger.info(f"image cond latents shape: {image_cond_latents.shape}") else: logger.info("Only one image conditioning frame received, img2vid") @@ -597,8 +611,8 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin): if tora is not None: trajectory_length = tora["video_flow_features"].shape[1] logger.info(f"Tora trajectory length: {trajectory_length}") - if trajectory_length != latents.shape[1]: - raise ValueError(f"Tora trajectory length {trajectory_length} does not match inpaint_latents count {latents.shape[2]}") + #if trajectory_length != latents.shape[1]: + # raise ValueError(f"Tora trajectory length {trajectory_length} does not match inpaint_latents count {latents.shape[2]}") for module in self.transformer.fuser_list: for param in module.parameters(): param.data = param.data.to(device)