This commit is contained in:
kijai 2024-10-08 22:47:12 +03:00
parent 032a849bc6
commit ac5daa7148

View File

@ -587,6 +587,9 @@ class CogVideoXPipeline(VideoSysPipeline):
print("Controlnet enabled with weights: ", control_weights) print("Controlnet enabled with weights: ", control_weights)
control_start = controlnet["control_start"] control_start = controlnet["control_start"]
control_end = controlnet["control_end"] control_end = controlnet["control_end"]
else:
controlnet_states = None
control_weights= None
# 10. Denoising loop # 10. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
@ -702,9 +705,18 @@ class CogVideoXPipeline(VideoSysPipeline):
current_step_percentage = i / num_inference_steps current_step_percentage = i / num_inference_steps
# use same rotary embeddings for all context windows
image_rotary_emb = (
self._prepare_rotary_positional_embeddings(height, width, context_frames, device)
if self.transformer.config.use_rotary_positional_embeddings
else None
)
context_queue = list(context( context_queue = list(context(
i, num_inference_steps, latents.shape[1], context_frames, context_stride, context_overlap, i, num_inference_steps, latents.shape[1], context_frames, context_stride, context_overlap,
)) ))
if controlnet is not None:
# controlnet frames are not temporally compressed, so try to match the context frames that are # controlnet frames are not temporally compressed, so try to match the context frames that are
control_context_queue = list(context( control_context_queue = list(context(
i, i,
@ -715,13 +727,6 @@ class CogVideoXPipeline(VideoSysPipeline):
context_overlap * self.vae_scale_factor_temporal, context_overlap * self.vae_scale_factor_temporal,
)) ))
# use same rotary embeddings for all context windows
image_rotary_emb = (
self._prepare_rotary_positional_embeddings(height, width, context_frames, device)
if self.transformer.config.use_rotary_positional_embeddings
else None
)
for c, control_c in zip(context_queue, control_context_queue): for c, control_c in zip(context_queue, control_context_queue):
partial_latent_model_input = latent_model_input[:, c, :, :, :] partial_latent_model_input = latent_model_input[:, c, :, :, :]
partial_control_frames = control_frames[:, control_c, :, :, :] partial_control_frames = control_frames[:, control_c, :, :, :]
@ -754,6 +759,21 @@ class CogVideoXPipeline(VideoSysPipeline):
controlnet_weights=control_weights, controlnet_weights=control_weights,
)[0] )[0]
counter[:, c, :, :, :] += 1
noise_pred = noise_pred.float()
else:
for c in context_queue:
partial_latent_model_input = latent_model_input[:, c, :, :, :]
# predict noise model_output
noise_pred[:, c, :, :, :] += self.transformer(
hidden_states=partial_latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
image_rotary_emb=image_rotary_emb,
return_dict=False
)[0]
counter[:, c, :, :, :] += 1 counter[:, c, :, :, :] += 1
noise_pred = noise_pred.float() noise_pred = noise_pred.float()
@ -794,6 +814,7 @@ class CogVideoXPipeline(VideoSysPipeline):
current_step_percentage = i / num_inference_steps current_step_percentage = i / num_inference_steps
if controlnet is not None:
controlnet_states = None controlnet_states = None
if (control_start <= current_step_percentage <= control_end): if (control_start <= current_step_percentage <= control_end):
# extract controlnet hidden state # extract controlnet hidden state
@ -810,7 +831,6 @@ class CogVideoXPipeline(VideoSysPipeline):
else: else:
controlnet_states = controlnet_states.to(dtype=self.vae.dtype) controlnet_states = controlnet_states.to(dtype=self.vae.dtype)
# predict noise model_output # predict noise model_output
noise_pred = self.transformer( noise_pred = self.transformer(
hidden_states=latent_model_input, hidden_states=latent_model_input,