From dccc8bdcb723bb9b9ab632a1abbabb7da3497ad0 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Tue, 29 Oct 2024 19:51:36 +0200 Subject: [PATCH] Tora for fun context scheduling --- cogvideox_fun/pipeline_cogvideox_inpaint.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/cogvideox_fun/pipeline_cogvideox_inpaint.py b/cogvideox_fun/pipeline_cogvideox_inpaint.py index ba95388..f2a8680 100644 --- a/cogvideox_fun/pipeline_cogvideox_inpaint.py +++ b/cogvideox_fun/pipeline_cogvideox_inpaint.py @@ -988,7 +988,6 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline): pbar.update(1) # ========================================== elif use_context_schedule: - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -1001,6 +1000,8 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline): counter = torch.zeros_like(latent_model_input) noise_pred = torch.zeros_like(latent_model_input) + current_step_percentage = i / num_inference_steps + image_rotary_emb = ( self._prepare_rotary_positional_embeddings(height, width, context_frames, device) if self.transformer.config.use_rotary_positional_embeddings @@ -1011,6 +1012,13 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline): partial_latent_model_input = latent_model_input[:, c, :, :, :] partial_inpaint_latents = inpaint_latents[:, c, :, :, :] partial_inpaint_latents[:, 0, :, :, :] = inpaint_latents[:, 0, :, :, :] + if (tora is not None and tora["start_percent"] <= current_step_percentage <= tora["end_percent"]): + if do_classifier_free_guidance: + partial_video_flow_features = tora["video_flow_features"][:, c, :, :, :].repeat(1, 2, 1, 1, 1).contiguous() + else: + partial_video_flow_features = tora["video_flow_features"][:, c, :, :, :] + else: + partial_video_flow_features = None # predict noise model_output noise_pred[:, c, :, :, :] += self.transformer( @@ -1020,6 +1028,7 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline): image_rotary_emb=image_rotary_emb, return_dict=False, inpaint_latents=partial_inpaint_latents, + video_flow_features=partial_video_flow_features )[0] counter[:, c, :, :, :] += 1