diff --git a/pipeline_cogvideox.py b/pipeline_cogvideox.py index e46759d..877a9f7 100644 --- a/pipeline_cogvideox.py +++ b/pipeline_cogvideox.py @@ -267,7 +267,9 @@ class CogVideoXPipeline(DiffusionPipeline): width: int, num_frames: int, device: torch.device, - ) -> Tuple[torch.Tensor, torch.Tensor]: + start_frame: int = None, + end_frame: int = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) @@ -284,14 +286,20 @@ class CogVideoXPipeline(DiffusionPipeline): use_real=True, ) + freqs_cos = freqs_cos.view(num_frames, grid_height * grid_width, -1) + freqs_sin = freqs_sin.view(num_frames, grid_height * grid_width, -1) + + if start_frame is not None: + freqs_cos = freqs_cos[start_frame:end_frame] + freqs_sin = freqs_sin[start_frame:end_frame] + + freqs_cos = freqs_cos.view(-1, freqs_cos.shape[-1]) + freqs_sin = freqs_sin.view(-1, freqs_sin.shape[-1]) + freqs_cos = freqs_cos.to(device=device) freqs_sin = freqs_sin.to(device=device) return freqs_cos, freqs_sin - @property - def guidance_scale(self): - return self._guidance_scale - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @@ -436,12 +444,12 @@ class CogVideoXPipeline(DiffusionPipeline): print("latents.device", latents.device) - # 6.5. Create rotary embeds if required - image_rotary_emb = ( - self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) - if self.transformer.config.use_rotary_positional_embeddings - else None - ) + # # 6.5. Create rotary embeds if required + # image_rotary_emb = ( + # self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) + # if self.transformer.config.use_rotary_positional_embeddings + # else None + # ) # 7. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) @@ -477,6 +485,12 @@ class CogVideoXPipeline(DiffusionPipeline): #latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents #latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + image_rotary_emb = ( + self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device, input_start_t, input_end_t) + if self.transformer.config.use_rotary_positional_embeddings + else None + ) + latents_tile = latents[:, input_start_t:input_end_t,:, :, :] latent_model_input_tile = torch.cat([latents_tile] * 2) if do_classifier_free_guidance else latents_tile latent_model_input_tile = self.scheduler.scale_model_input(latent_model_input_tile, t) @@ -496,7 +510,7 @@ class CogVideoXPipeline(DiffusionPipeline): if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = noise_pred_uncond + self._guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents_tile = self.scheduler.step(noise_pred, t, latents_tile, **extra_step_kwargs, return_dict=False)[0] @@ -527,6 +541,11 @@ class CogVideoXPipeline(DiffusionPipeline): comfy_pbar.update(1) # ========================================== else: + image_rotary_emb = ( + self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) + if self.transformer.config.use_rotary_positional_embeddings + else None + ) latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)