diff --git a/cogvideox_fun/pipeline_cogvideox_control.py b/cogvideox_fun/pipeline_cogvideox_control.py index 1ec87e1..93983a8 100644 --- a/cogvideox_fun/pipeline_cogvideox_control.py +++ b/cogvideox_fun/pipeline_cogvideox_control.py @@ -742,6 +742,12 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline): latents_all_list = [] # ===================================================== + image_rotary_emb = ( + self._prepare_rotary_positional_embeddings(height, width, context_frames, device) + if self.transformer.config.use_rotary_positional_embeddings + else None + ) + for t_i in range(grid_ts): if t_i < grid_ts - 1: ofs_t = max(t_i * t_tile_length - t_tile_overlap * t_i, 0) @@ -751,12 +757,6 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline): input_start_t = ofs_t input_end_t = ofs_t + t_tile_length - image_rotary_emb = ( - self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device, input_start_t, input_end_t) - if self.transformer.config.use_rotary_positional_embeddings - else None - ) - latents_tile = latents[:, input_start_t:input_end_t,:, :, :] control_latents_tile = control_latents[:, input_start_t:input_end_t, :, :, :] @@ -832,15 +832,21 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline): if do_classifier_free_guidance: noise_uncond = torch.zeros_like(latent_model_input) - for c in context_queue: - partial_latent_model_input = latent_model_input[:, c, :, :, :] - partial_control_latents = current_control_latents[:, c, :, :, :] - image_rotary_emb = ( - self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device, context_frames=c) + image_rotary_emb = ( + self._prepare_rotary_positional_embeddings(height, width, context_frames, device) if self.transformer.config.use_rotary_positional_embeddings else None ) + for c in context_queue: + partial_latent_model_input = latent_model_input[:, c, :, :, :] + partial_control_latents = current_control_latents[:, c, :, :, :] + # image_rotary_emb = ( + # self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device, context_frames=c) + # if self.transformer.config.use_rotary_positional_embeddings + # else None + # ) + # predict noise model_output noise_pred[:, c, :, :, :] += self.transformer( hidden_states=partial_latent_model_input, diff --git a/cogvideox_fun/pipeline_cogvideox_inpaint.py b/cogvideox_fun/pipeline_cogvideox_inpaint.py index 440bd3e..e9aff57 100644 --- a/cogvideox_fun/pipeline_cogvideox_inpaint.py +++ b/cogvideox_fun/pipeline_cogvideox_inpaint.py @@ -903,6 +903,12 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline): latents_all_list = [] # ===================================================== + image_rotary_emb = ( + self._prepare_rotary_positional_embeddings(height, width, t_tile_length, device) + if self.transformer.config.use_rotary_positional_embeddings + else None + ) + for t_i in range(grid_ts): if t_i < grid_ts - 1: ofs_t = max(t_i * t_tile_length - t_tile_overlap * t_i, 0) @@ -912,12 +918,6 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline): input_start_t = ofs_t input_end_t = ofs_t + t_tile_length - image_rotary_emb = ( - self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device, input_start_t, input_end_t) - if self.transformer.config.use_rotary_positional_embeddings - else None - ) - latents_tile = latents[:, input_start_t:input_end_t,:, :, :] inpaint_latents_tile = inpaint_latents[:, input_start_t:input_end_t, :, :, :] @@ -989,17 +989,17 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline): if do_classifier_free_guidance: noise_uncond = torch.zeros_like(latent_model_input) + image_rotary_emb = ( + self._prepare_rotary_positional_embeddings(height, width, context_frames, device) + if self.transformer.config.use_rotary_positional_embeddings + else None + ) + for c in context_queue: partial_latent_model_input = latent_model_input[:, c, :, :, :] partial_inpaint_latents = inpaint_latents[:, c, :, :, :] partial_inpaint_latents[:, 0, :, :, :] = inpaint_latents[:, 0, :, :, :] - image_rotary_emb = ( - self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device, context_frames=c) - if self.transformer.config.use_rotary_positional_embeddings - else None - ) - # predict noise model_output noise_pred[:, c, :, :, :] += self.transformer( hidden_states=partial_latent_model_input,