From e047e6f07fab20608587ca354b9c404fbbcba8dd Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Tue, 8 Oct 2024 17:56:30 +0300 Subject: [PATCH] controlnet with context windowing --- cogvideox_fun/pipeline_cogvideox_control.py | 13 ----- cogvideox_fun/pipeline_cogvideox_inpaint.py | 14 ----- .../cogvideox_2b_controlnet_example_01.json | 4 +- pipeline_cogvideox.py | 53 +++++++++++++------ 4 files changed, 39 insertions(+), 45 deletions(-) diff --git a/cogvideox_fun/pipeline_cogvideox_control.py b/cogvideox_fun/pipeline_cogvideox_control.py index 966e0ee..545e084 100644 --- a/cogvideox_fun/pipeline_cogvideox_control.py +++ b/cogvideox_fun/pipeline_cogvideox_control.py @@ -829,8 +829,6 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline): )) counter = torch.zeros_like(latent_model_input) noise_pred = torch.zeros_like(latent_model_input) - if do_classifier_free_guidance: - noise_uncond = torch.zeros_like(latent_model_input) image_rotary_emb = ( self._prepare_rotary_positional_embeddings(height, width, context_frames, device) @@ -851,17 +849,6 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline): return_dict=False, control_latents=partial_control_latents, )[0] - - # uncond - if do_classifier_free_guidance: - noise_uncond[:, c, :, :, :] += self.transformer( - hidden_states=partial_latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - image_rotary_emb=image_rotary_emb, - return_dict=False, - control_latents=partial_control_latents, - )[0] counter[:, c, :, :, :] += 1 noise_pred = noise_pred.float() diff --git a/cogvideox_fun/pipeline_cogvideox_inpaint.py b/cogvideox_fun/pipeline_cogvideox_inpaint.py index 4c9d505..459e845 100644 --- a/cogvideox_fun/pipeline_cogvideox_inpaint.py +++ b/cogvideox_fun/pipeline_cogvideox_inpaint.py @@ -984,9 +984,6 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline): latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - # Calculate the current step percentage - current_step_percentage = i / num_inference_steps - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) @@ -995,8 +992,6 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline): )) counter = torch.zeros_like(latent_model_input) noise_pred = torch.zeros_like(latent_model_input) - if do_classifier_free_guidance: - noise_uncond = torch.zeros_like(latent_model_input) image_rotary_emb = ( self._prepare_rotary_positional_embeddings(height, width, context_frames, device) @@ -1020,15 +1015,6 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline): )[0] counter[:, c, :, :, :] += 1 - if do_classifier_free_guidance: - noise_uncond[:, c, :, :, :] += self.transformer( - hidden_states=partial_latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - image_rotary_emb=image_rotary_emb, - return_dict=False, - inpaint_latents=partial_inpaint_latents, - )[0] noise_pred = noise_pred.float() diff --git a/examples/cogvideox_2b_controlnet_example_01.json b/examples/cogvideox_2b_controlnet_example_01.json index cd5ccef..a73517d 100644 --- a/examples/cogvideox_2b_controlnet_example_01.json +++ b/examples/cogvideox_2b_controlnet_example_01.json @@ -737,7 +737,7 @@ "widgets_values": { "frame_rate": 8, "loop_count": 0, - "filename_prefix": "CogVideoX5B", + "filename_prefix": "CogVideoX2B_controlnet", "format": "video/h264-mp4", "pix_fmt": "yuv420p", "crf": 19, @@ -748,7 +748,7 @@ "hidden": false, "paused": false, "params": { - "filename": "CogVideoX5B_00007.mp4", + "filename": "CogVideoX2B_00007.mp4", "subfolder": "", "type": "temp", "format": "video/h264-mp4", diff --git a/pipeline_cogvideox.py b/pipeline_cogvideox.py index 10d3425..3ec28a2 100644 --- a/pipeline_cogvideox.py +++ b/pipeline_cogvideox.py @@ -678,8 +678,6 @@ class CogVideoXPipeline(VideoSysPipeline): latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) counter = torch.zeros_like(latent_model_input) noise_pred = torch.zeros_like(latent_model_input) - if do_classifier_free_guidance: - noise_uncond = torch.zeros_like(latent_model_input) if image_cond_latents is not None: latent_image_input = torch.cat([image_cond_latents] * 2) if do_classifier_free_guidance else image_cond_latents @@ -688,18 +686,49 @@ class CogVideoXPipeline(VideoSysPipeline): # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) + current_step_percentage = i / num_inference_steps + context_queue = list(context( i, num_inference_steps, latents.shape[1], context_frames, context_stride, context_overlap, - )) + )) + # controlnet frames are not temporally compressed, so try to match the context frames that are + control_context_queue = list(context( + i, + num_inference_steps, + control_frames.shape[1], + context_frames * self.vae_scale_factor_temporal, + context_stride * self.vae_scale_factor_temporal, + context_overlap * self.vae_scale_factor_temporal, + )) + # use same rotary embeddings for all context windows image_rotary_emb = ( self._prepare_rotary_positional_embeddings(height, width, context_frames, device) if self.transformer.config.use_rotary_positional_embeddings else None ) - for c in context_queue: + for c, control_c in zip(context_queue, control_context_queue): partial_latent_model_input = latent_model_input[:, c, :, :, :] + partial_control_frames = control_frames[:, control_c, :, :, :] + + controlnet_states = None + + if (control_start <= current_step_percentage <= control_end): + # extract controlnet hidden state + controlnet_states = self.controlnet( + hidden_states=partial_latent_model_input, + encoder_hidden_states=prompt_embeds, + image_rotary_emb=image_rotary_emb, + controlnet_states=partial_control_frames, + timestep=timestep, + return_dict=False, + )[0] + if isinstance(controlnet_states, (tuple, list)): + controlnet_states = [x.to(dtype=self.vae.dtype) for x in controlnet_states] + else: + controlnet_states = controlnet_states.to(dtype=self.vae.dtype) + # predict noise model_output noise_pred[:, c, :, :, :] += self.transformer( hidden_states=partial_latent_model_input, @@ -707,18 +736,10 @@ class CogVideoXPipeline(VideoSysPipeline): timestep=timestep, image_rotary_emb=image_rotary_emb, return_dict=False, + controlnet_states=controlnet_states, + controlnet_weights=control_strength, )[0] - # uncond - if do_classifier_free_guidance: - noise_uncond[:, c, :, :, :] += self.transformer( - hidden_states=partial_latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - image_rotary_emb=image_rotary_emb, - return_dict=False, - )[0] - counter[:, c, :, :, :] += 1 noise_pred = noise_pred.float() @@ -757,10 +778,10 @@ class CogVideoXPipeline(VideoSysPipeline): # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) - current_sampling_percent = i / len(timesteps) + current_step_percentage = i / num_inference_steps controlnet_states = None - if (control_start < current_sampling_percent < control_end): + if (control_start <= current_step_percentage <= control_end): # extract controlnet hidden state controlnet_states = self.controlnet( hidden_states=latent_model_input,