mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2026-01-23 19:54:24 +08:00
fixes
This commit is contained in:
parent
032a849bc6
commit
ac5daa7148
@ -587,6 +587,9 @@ class CogVideoXPipeline(VideoSysPipeline):
|
||||
print("Controlnet enabled with weights: ", control_weights)
|
||||
control_start = controlnet["control_start"]
|
||||
control_end = controlnet["control_end"]
|
||||
else:
|
||||
controlnet_states = None
|
||||
control_weights= None
|
||||
|
||||
# 10. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
@ -702,19 +705,6 @@ class CogVideoXPipeline(VideoSysPipeline):
|
||||
|
||||
current_step_percentage = i / num_inference_steps
|
||||
|
||||
context_queue = list(context(
|
||||
i, num_inference_steps, latents.shape[1], context_frames, context_stride, context_overlap,
|
||||
))
|
||||
# controlnet frames are not temporally compressed, so try to match the context frames that are
|
||||
control_context_queue = list(context(
|
||||
i,
|
||||
num_inference_steps,
|
||||
control_frames.shape[1],
|
||||
context_frames * self.vae_scale_factor_temporal,
|
||||
context_stride * self.vae_scale_factor_temporal,
|
||||
context_overlap * self.vae_scale_factor_temporal,
|
||||
))
|
||||
|
||||
# use same rotary embeddings for all context windows
|
||||
image_rotary_emb = (
|
||||
self._prepare_rotary_positional_embeddings(height, width, context_frames, device)
|
||||
@ -722,40 +712,70 @@ class CogVideoXPipeline(VideoSysPipeline):
|
||||
else None
|
||||
)
|
||||
|
||||
for c, control_c in zip(context_queue, control_context_queue):
|
||||
partial_latent_model_input = latent_model_input[:, c, :, :, :]
|
||||
partial_control_frames = control_frames[:, control_c, :, :, :]
|
||||
context_queue = list(context(
|
||||
i, num_inference_steps, latents.shape[1], context_frames, context_stride, context_overlap,
|
||||
))
|
||||
|
||||
controlnet_states = None
|
||||
|
||||
if (control_start <= current_step_percentage <= control_end):
|
||||
# extract controlnet hidden state
|
||||
controlnet_states = self.controlnet(
|
||||
if controlnet is not None:
|
||||
# controlnet frames are not temporally compressed, so try to match the context frames that are
|
||||
control_context_queue = list(context(
|
||||
i,
|
||||
num_inference_steps,
|
||||
control_frames.shape[1],
|
||||
context_frames * self.vae_scale_factor_temporal,
|
||||
context_stride * self.vae_scale_factor_temporal,
|
||||
context_overlap * self.vae_scale_factor_temporal,
|
||||
))
|
||||
|
||||
for c, control_c in zip(context_queue, control_context_queue):
|
||||
partial_latent_model_input = latent_model_input[:, c, :, :, :]
|
||||
partial_control_frames = control_frames[:, control_c, :, :, :]
|
||||
|
||||
controlnet_states = None
|
||||
|
||||
if (control_start <= current_step_percentage <= control_end):
|
||||
# extract controlnet hidden state
|
||||
controlnet_states = self.controlnet(
|
||||
hidden_states=partial_latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
controlnet_states=partial_control_frames,
|
||||
timestep=timestep,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
if isinstance(controlnet_states, (tuple, list)):
|
||||
controlnet_states = [x.to(dtype=self.controlnet.dtype) for x in controlnet_states]
|
||||
else:
|
||||
controlnet_states = controlnet_states.to(dtype=self.controlnet.dtype)
|
||||
|
||||
# predict noise model_output
|
||||
noise_pred[:, c, :, :, :] += self.transformer(
|
||||
hidden_states=partial_latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
controlnet_states=partial_control_frames,
|
||||
timestep=timestep,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
return_dict=False,
|
||||
controlnet_states=controlnet_states,
|
||||
controlnet_weights=control_weights,
|
||||
)[0]
|
||||
if isinstance(controlnet_states, (tuple, list)):
|
||||
controlnet_states = [x.to(dtype=self.controlnet.dtype) for x in controlnet_states]
|
||||
else:
|
||||
controlnet_states = controlnet_states.to(dtype=self.controlnet.dtype)
|
||||
|
||||
# predict noise model_output
|
||||
noise_pred[:, c, :, :, :] += self.transformer(
|
||||
hidden_states=partial_latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
return_dict=False,
|
||||
controlnet_states=controlnet_states,
|
||||
controlnet_weights=control_weights,
|
||||
)[0]
|
||||
|
||||
counter[:, c, :, :, :] += 1
|
||||
noise_pred = noise_pred.float()
|
||||
counter[:, c, :, :, :] += 1
|
||||
noise_pred = noise_pred.float()
|
||||
else:
|
||||
for c in context_queue:
|
||||
partial_latent_model_input = latent_model_input[:, c, :, :, :]
|
||||
|
||||
# predict noise model_output
|
||||
noise_pred[:, c, :, :, :] += self.transformer(
|
||||
hidden_states=partial_latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
return_dict=False
|
||||
)[0]
|
||||
|
||||
counter[:, c, :, :, :] += 1
|
||||
noise_pred = noise_pred.float()
|
||||
|
||||
noise_pred /= counter
|
||||
if do_classifier_free_guidance:
|
||||
@ -794,23 +814,23 @@ class CogVideoXPipeline(VideoSysPipeline):
|
||||
|
||||
current_step_percentage = i / num_inference_steps
|
||||
|
||||
controlnet_states = None
|
||||
if (control_start <= current_step_percentage <= control_end):
|
||||
# extract controlnet hidden state
|
||||
controlnet_states = self.controlnet(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
controlnet_states=control_frames,
|
||||
timestep=timestep,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
if isinstance(controlnet_states, (tuple, list)):
|
||||
controlnet_states = [x.to(dtype=self.vae.dtype) for x in controlnet_states]
|
||||
else:
|
||||
controlnet_states = controlnet_states.to(dtype=self.vae.dtype)
|
||||
if controlnet is not None:
|
||||
controlnet_states = None
|
||||
if (control_start <= current_step_percentage <= control_end):
|
||||
# extract controlnet hidden state
|
||||
controlnet_states = self.controlnet(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
controlnet_states=control_frames,
|
||||
timestep=timestep,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
if isinstance(controlnet_states, (tuple, list)):
|
||||
controlnet_states = [x.to(dtype=self.vae.dtype) for x in controlnet_states]
|
||||
else:
|
||||
controlnet_states = controlnet_states.to(dtype=self.vae.dtype)
|
||||
|
||||
|
||||
# predict noise model_output
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user