From ac5daa714808da1819dbaabf487bb159e67243db Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Tue, 8 Oct 2024 22:47:12 +0300 Subject: [PATCH] fixes --- pipeline_cogvideox.py | 132 ++++++++++++++++++++++++------------------ 1 file changed, 76 insertions(+), 56 deletions(-) diff --git a/pipeline_cogvideox.py b/pipeline_cogvideox.py index be97777..64208f0 100644 --- a/pipeline_cogvideox.py +++ b/pipeline_cogvideox.py @@ -587,6 +587,9 @@ class CogVideoXPipeline(VideoSysPipeline): print("Controlnet enabled with weights: ", control_weights) control_start = controlnet["control_start"] control_end = controlnet["control_end"] + else: + controlnet_states = None + control_weights= None # 10. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -702,19 +705,6 @@ class CogVideoXPipeline(VideoSysPipeline): current_step_percentage = i / num_inference_steps - context_queue = list(context( - i, num_inference_steps, latents.shape[1], context_frames, context_stride, context_overlap, - )) - # controlnet frames are not temporally compressed, so try to match the context frames that are - control_context_queue = list(context( - i, - num_inference_steps, - control_frames.shape[1], - context_frames * self.vae_scale_factor_temporal, - context_stride * self.vae_scale_factor_temporal, - context_overlap * self.vae_scale_factor_temporal, - )) - # use same rotary embeddings for all context windows image_rotary_emb = ( self._prepare_rotary_positional_embeddings(height, width, context_frames, device) @@ -722,40 +712,70 @@ class CogVideoXPipeline(VideoSysPipeline): else None ) - for c, control_c in zip(context_queue, control_context_queue): - partial_latent_model_input = latent_model_input[:, c, :, :, :] - partial_control_frames = control_frames[:, control_c, :, :, :] + context_queue = list(context( + i, num_inference_steps, latents.shape[1], context_frames, context_stride, context_overlap, + )) - controlnet_states = None - - if (control_start <= current_step_percentage <= control_end): - # extract controlnet hidden state - controlnet_states = self.controlnet( + if controlnet is not None: + # controlnet frames are not temporally compressed, so try to match the context frames that are + control_context_queue = list(context( + i, + num_inference_steps, + control_frames.shape[1], + context_frames * self.vae_scale_factor_temporal, + context_stride * self.vae_scale_factor_temporal, + context_overlap * self.vae_scale_factor_temporal, + )) + + for c, control_c in zip(context_queue, control_context_queue): + partial_latent_model_input = latent_model_input[:, c, :, :, :] + partial_control_frames = control_frames[:, control_c, :, :, :] + + controlnet_states = None + + if (control_start <= current_step_percentage <= control_end): + # extract controlnet hidden state + controlnet_states = self.controlnet( + hidden_states=partial_latent_model_input, + encoder_hidden_states=prompt_embeds, + image_rotary_emb=image_rotary_emb, + controlnet_states=partial_control_frames, + timestep=timestep, + return_dict=False, + )[0] + if isinstance(controlnet_states, (tuple, list)): + controlnet_states = [x.to(dtype=self.controlnet.dtype) for x in controlnet_states] + else: + controlnet_states = controlnet_states.to(dtype=self.controlnet.dtype) + + # predict noise model_output + noise_pred[:, c, :, :, :] += self.transformer( hidden_states=partial_latent_model_input, encoder_hidden_states=prompt_embeds, - image_rotary_emb=image_rotary_emb, - controlnet_states=partial_control_frames, timestep=timestep, + image_rotary_emb=image_rotary_emb, return_dict=False, + controlnet_states=controlnet_states, + controlnet_weights=control_weights, )[0] - if isinstance(controlnet_states, (tuple, list)): - controlnet_states = [x.to(dtype=self.controlnet.dtype) for x in controlnet_states] - else: - controlnet_states = controlnet_states.to(dtype=self.controlnet.dtype) - - # predict noise model_output - noise_pred[:, c, :, :, :] += self.transformer( - hidden_states=partial_latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - image_rotary_emb=image_rotary_emb, - return_dict=False, - controlnet_states=controlnet_states, - controlnet_weights=control_weights, - )[0] - counter[:, c, :, :, :] += 1 - noise_pred = noise_pred.float() + counter[:, c, :, :, :] += 1 + noise_pred = noise_pred.float() + else: + for c in context_queue: + partial_latent_model_input = latent_model_input[:, c, :, :, :] + + # predict noise model_output + noise_pred[:, c, :, :, :] += self.transformer( + hidden_states=partial_latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + return_dict=False + )[0] + + counter[:, c, :, :, :] += 1 + noise_pred = noise_pred.float() noise_pred /= counter if do_classifier_free_guidance: @@ -794,23 +814,23 @@ class CogVideoXPipeline(VideoSysPipeline): current_step_percentage = i / num_inference_steps - controlnet_states = None - if (control_start <= current_step_percentage <= control_end): - # extract controlnet hidden state - controlnet_states = self.controlnet( - hidden_states=latent_model_input, - encoder_hidden_states=prompt_embeds, - image_rotary_emb=image_rotary_emb, - controlnet_states=control_frames, - timestep=timestep, - return_dict=False, - )[0] - if isinstance(controlnet_states, (tuple, list)): - controlnet_states = [x.to(dtype=self.vae.dtype) for x in controlnet_states] - else: - controlnet_states = controlnet_states.to(dtype=self.vae.dtype) + if controlnet is not None: + controlnet_states = None + if (control_start <= current_step_percentage <= control_end): + # extract controlnet hidden state + controlnet_states = self.controlnet( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + image_rotary_emb=image_rotary_emb, + controlnet_states=control_frames, + timestep=timestep, + return_dict=False, + )[0] + if isinstance(controlnet_states, (tuple, list)): + controlnet_states = [x.to(dtype=self.vae.dtype) for x in controlnet_states] + else: + controlnet_states = controlnet_states.to(dtype=self.vae.dtype) - # predict noise model_output noise_pred = self.transformer( hidden_states=latent_model_input,