diff --git a/pipeline_cogvideox.py b/pipeline_cogvideox.py index 2bae2f8..06d0efc 100644 --- a/pipeline_cogvideox.py +++ b/pipeline_cogvideox.py @@ -419,7 +419,7 @@ class CogVideoXPipeline(DiffusionPipeline): self._num_timesteps = len(timesteps) # 5. Prepare latents. - latent_channels = self.transformer.config.in_channels + latent_channels = self.vae.config.latent_channels if latents is None and num_frames == t_tile_length: num_frames += 1 @@ -443,20 +443,24 @@ class CogVideoXPipeline(DiffusionPipeline): latents ) latents = latents.to(self.transformer.dtype) + print("latents", latents.shape) # 5.5. if image_cond_latents is not None: - image_cond_latents = torch.cat(image_cond_latents, dim=0).to(self.transformer.dtype)#.permute(0, 2, 1, 3, 4) # [B, F, C, H, W] + print("image_cond_latents", image_cond_latents.shape) + #image_cond_latents = torch.cat(image_cond_latents, dim=0).to(self.transformer.dtype)#.permute(0, 2, 1, 3, 4) # [B, F, C, H, W] padding_shape = ( batch_size, - num_frames - 1, - latent_channels, + (latents.shape[1] - 1), + self.vae.config.latent_channels, height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial, ) + print("padding_shape", padding_shape) latent_padding = torch.zeros(padding_shape, device=device, dtype=self.transformer.dtype) - image_latents = torch.cat([image_latents, latent_padding], dim=1) + image_cond_latents = torch.cat([image_cond_latents, latent_padding], dim=1) + print("image_cond_latents", image_cond_latents.shape) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) @@ -598,7 +602,11 @@ class CogVideoXPipeline(DiffusionPipeline): latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) if image_cond_latents is not None: + + latent_image_input = torch.cat([image_cond_latents] * 2) if do_classifier_free_guidance else image_cond_latents + print("latent_model_input",latent_model_input.shape) + print("image_cond_latents",image_cond_latents.shape) latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=2) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML