diff --git a/pipeline_cogvideox.py b/pipeline_cogvideox.py index 0db9671..7e43413 100644 --- a/pipeline_cogvideox.py +++ b/pipeline_cogvideox.py @@ -473,7 +473,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): if image_cond_latents is not None: image_cond_frame_count = image_cond_latents.size(1) patch_size_t = self.transformer.config.patch_size_t - if image_cond_latents.shape[1] == 2: + if image_cond_frame_count == 2: logger.info("More than one image conditioning frame received, interpolating") padding_shape = ( batch_size, @@ -485,11 +485,11 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): latent_padding = torch.zeros(padding_shape, device=device, dtype=self.vae_dtype) image_cond_latents = torch.cat([image_cond_latents[:, 0, :, :, :].unsqueeze(1), latent_padding, image_cond_latents[:, -1, :, :, :].unsqueeze(1)], dim=1) if patch_size_t: - first_frame = image_cond_latents[:, : image_cond_frame_count % patch_size_t, ...] + first_frame = image_cond_latents[:, : image_cond_latents.size(1) % patch_size_t, ...] image_cond_latents = torch.cat([first_frame, image_cond_latents], dim=1) logger.info(f"image cond latents shape: {image_cond_latents.shape}") - elif image_cond_latents.shape[1] == 1: + elif image_cond_frame_count == 1: logger.info("Only one image conditioning frame received, img2vid") if self.input_with_padding: padding_shape = ( @@ -503,7 +503,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): image_cond_latents = torch.cat([image_cond_latents, latent_padding], dim=1) # Select the first frame along the second dimension if patch_size_t: - first_frame = image_cond_latents[:, : image_cond_frame_count % patch_size_t, ...] + first_frame = image_cond_latents[:, : image_cond_latents.size(1) % patch_size_t, ...] image_cond_latents = torch.cat([first_frame, image_cond_latents], dim=1) else: image_cond_latents = image_cond_latents.repeat(1, latents.shape[1], 1, 1, 1)