mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-09 04:44:22 +08:00
controlnet with context windowing
This commit is contained in:
parent
d76229c49b
commit
e047e6f07f
@ -829,8 +829,6 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
|
|||||||
))
|
))
|
||||||
counter = torch.zeros_like(latent_model_input)
|
counter = torch.zeros_like(latent_model_input)
|
||||||
noise_pred = torch.zeros_like(latent_model_input)
|
noise_pred = torch.zeros_like(latent_model_input)
|
||||||
if do_classifier_free_guidance:
|
|
||||||
noise_uncond = torch.zeros_like(latent_model_input)
|
|
||||||
|
|
||||||
image_rotary_emb = (
|
image_rotary_emb = (
|
||||||
self._prepare_rotary_positional_embeddings(height, width, context_frames, device)
|
self._prepare_rotary_positional_embeddings(height, width, context_frames, device)
|
||||||
@ -851,17 +849,6 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
|
|||||||
return_dict=False,
|
return_dict=False,
|
||||||
control_latents=partial_control_latents,
|
control_latents=partial_control_latents,
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
# uncond
|
|
||||||
if do_classifier_free_guidance:
|
|
||||||
noise_uncond[:, c, :, :, :] += self.transformer(
|
|
||||||
hidden_states=partial_latent_model_input,
|
|
||||||
encoder_hidden_states=prompt_embeds,
|
|
||||||
timestep=timestep,
|
|
||||||
image_rotary_emb=image_rotary_emb,
|
|
||||||
return_dict=False,
|
|
||||||
control_latents=partial_control_latents,
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
counter[:, c, :, :, :] += 1
|
counter[:, c, :, :, :] += 1
|
||||||
noise_pred = noise_pred.float()
|
noise_pred = noise_pred.float()
|
||||||
|
|||||||
@ -984,9 +984,6 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
|
|||||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||||
|
|
||||||
# Calculate the current step percentage
|
|
||||||
current_step_percentage = i / num_inference_steps
|
|
||||||
|
|
||||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||||
timestep = t.expand(latent_model_input.shape[0])
|
timestep = t.expand(latent_model_input.shape[0])
|
||||||
|
|
||||||
@ -995,8 +992,6 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
|
|||||||
))
|
))
|
||||||
counter = torch.zeros_like(latent_model_input)
|
counter = torch.zeros_like(latent_model_input)
|
||||||
noise_pred = torch.zeros_like(latent_model_input)
|
noise_pred = torch.zeros_like(latent_model_input)
|
||||||
if do_classifier_free_guidance:
|
|
||||||
noise_uncond = torch.zeros_like(latent_model_input)
|
|
||||||
|
|
||||||
image_rotary_emb = (
|
image_rotary_emb = (
|
||||||
self._prepare_rotary_positional_embeddings(height, width, context_frames, device)
|
self._prepare_rotary_positional_embeddings(height, width, context_frames, device)
|
||||||
@ -1020,15 +1015,6 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
|
|||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
counter[:, c, :, :, :] += 1
|
counter[:, c, :, :, :] += 1
|
||||||
if do_classifier_free_guidance:
|
|
||||||
noise_uncond[:, c, :, :, :] += self.transformer(
|
|
||||||
hidden_states=partial_latent_model_input,
|
|
||||||
encoder_hidden_states=prompt_embeds,
|
|
||||||
timestep=timestep,
|
|
||||||
image_rotary_emb=image_rotary_emb,
|
|
||||||
return_dict=False,
|
|
||||||
inpaint_latents=partial_inpaint_latents,
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
noise_pred = noise_pred.float()
|
noise_pred = noise_pred.float()
|
||||||
|
|
||||||
|
|||||||
@ -737,7 +737,7 @@
|
|||||||
"widgets_values": {
|
"widgets_values": {
|
||||||
"frame_rate": 8,
|
"frame_rate": 8,
|
||||||
"loop_count": 0,
|
"loop_count": 0,
|
||||||
"filename_prefix": "CogVideoX5B",
|
"filename_prefix": "CogVideoX2B_controlnet",
|
||||||
"format": "video/h264-mp4",
|
"format": "video/h264-mp4",
|
||||||
"pix_fmt": "yuv420p",
|
"pix_fmt": "yuv420p",
|
||||||
"crf": 19,
|
"crf": 19,
|
||||||
@ -748,7 +748,7 @@
|
|||||||
"hidden": false,
|
"hidden": false,
|
||||||
"paused": false,
|
"paused": false,
|
||||||
"params": {
|
"params": {
|
||||||
"filename": "CogVideoX5B_00007.mp4",
|
"filename": "CogVideoX2B_00007.mp4",
|
||||||
"subfolder": "",
|
"subfolder": "",
|
||||||
"type": "temp",
|
"type": "temp",
|
||||||
"format": "video/h264-mp4",
|
"format": "video/h264-mp4",
|
||||||
|
|||||||
@ -678,8 +678,6 @@ class CogVideoXPipeline(VideoSysPipeline):
|
|||||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||||
counter = torch.zeros_like(latent_model_input)
|
counter = torch.zeros_like(latent_model_input)
|
||||||
noise_pred = torch.zeros_like(latent_model_input)
|
noise_pred = torch.zeros_like(latent_model_input)
|
||||||
if do_classifier_free_guidance:
|
|
||||||
noise_uncond = torch.zeros_like(latent_model_input)
|
|
||||||
|
|
||||||
if image_cond_latents is not None:
|
if image_cond_latents is not None:
|
||||||
latent_image_input = torch.cat([image_cond_latents] * 2) if do_classifier_free_guidance else image_cond_latents
|
latent_image_input = torch.cat([image_cond_latents] * 2) if do_classifier_free_guidance else image_cond_latents
|
||||||
@ -688,18 +686,49 @@ class CogVideoXPipeline(VideoSysPipeline):
|
|||||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||||
timestep = t.expand(latent_model_input.shape[0])
|
timestep = t.expand(latent_model_input.shape[0])
|
||||||
|
|
||||||
|
current_step_percentage = i / num_inference_steps
|
||||||
|
|
||||||
context_queue = list(context(
|
context_queue = list(context(
|
||||||
i, num_inference_steps, latents.shape[1], context_frames, context_stride, context_overlap,
|
i, num_inference_steps, latents.shape[1], context_frames, context_stride, context_overlap,
|
||||||
))
|
))
|
||||||
|
# controlnet frames are not temporally compressed, so try to match the context frames that are
|
||||||
|
control_context_queue = list(context(
|
||||||
|
i,
|
||||||
|
num_inference_steps,
|
||||||
|
control_frames.shape[1],
|
||||||
|
context_frames * self.vae_scale_factor_temporal,
|
||||||
|
context_stride * self.vae_scale_factor_temporal,
|
||||||
|
context_overlap * self.vae_scale_factor_temporal,
|
||||||
|
))
|
||||||
|
|
||||||
|
# use same rotary embeddings for all context windows
|
||||||
image_rotary_emb = (
|
image_rotary_emb = (
|
||||||
self._prepare_rotary_positional_embeddings(height, width, context_frames, device)
|
self._prepare_rotary_positional_embeddings(height, width, context_frames, device)
|
||||||
if self.transformer.config.use_rotary_positional_embeddings
|
if self.transformer.config.use_rotary_positional_embeddings
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
for c in context_queue:
|
for c, control_c in zip(context_queue, control_context_queue):
|
||||||
partial_latent_model_input = latent_model_input[:, c, :, :, :]
|
partial_latent_model_input = latent_model_input[:, c, :, :, :]
|
||||||
|
partial_control_frames = control_frames[:, control_c, :, :, :]
|
||||||
|
|
||||||
|
controlnet_states = None
|
||||||
|
|
||||||
|
if (control_start <= current_step_percentage <= control_end):
|
||||||
|
# extract controlnet hidden state
|
||||||
|
controlnet_states = self.controlnet(
|
||||||
|
hidden_states=partial_latent_model_input,
|
||||||
|
encoder_hidden_states=prompt_embeds,
|
||||||
|
image_rotary_emb=image_rotary_emb,
|
||||||
|
controlnet_states=partial_control_frames,
|
||||||
|
timestep=timestep,
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
if isinstance(controlnet_states, (tuple, list)):
|
||||||
|
controlnet_states = [x.to(dtype=self.vae.dtype) for x in controlnet_states]
|
||||||
|
else:
|
||||||
|
controlnet_states = controlnet_states.to(dtype=self.vae.dtype)
|
||||||
|
|
||||||
# predict noise model_output
|
# predict noise model_output
|
||||||
noise_pred[:, c, :, :, :] += self.transformer(
|
noise_pred[:, c, :, :, :] += self.transformer(
|
||||||
hidden_states=partial_latent_model_input,
|
hidden_states=partial_latent_model_input,
|
||||||
@ -707,18 +736,10 @@ class CogVideoXPipeline(VideoSysPipeline):
|
|||||||
timestep=timestep,
|
timestep=timestep,
|
||||||
image_rotary_emb=image_rotary_emb,
|
image_rotary_emb=image_rotary_emb,
|
||||||
return_dict=False,
|
return_dict=False,
|
||||||
|
controlnet_states=controlnet_states,
|
||||||
|
controlnet_weights=control_strength,
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
# uncond
|
|
||||||
if do_classifier_free_guidance:
|
|
||||||
noise_uncond[:, c, :, :, :] += self.transformer(
|
|
||||||
hidden_states=partial_latent_model_input,
|
|
||||||
encoder_hidden_states=prompt_embeds,
|
|
||||||
timestep=timestep,
|
|
||||||
image_rotary_emb=image_rotary_emb,
|
|
||||||
return_dict=False,
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
counter[:, c, :, :, :] += 1
|
counter[:, c, :, :, :] += 1
|
||||||
noise_pred = noise_pred.float()
|
noise_pred = noise_pred.float()
|
||||||
|
|
||||||
@ -757,10 +778,10 @@ class CogVideoXPipeline(VideoSysPipeline):
|
|||||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||||
timestep = t.expand(latent_model_input.shape[0])
|
timestep = t.expand(latent_model_input.shape[0])
|
||||||
|
|
||||||
current_sampling_percent = i / len(timesteps)
|
current_step_percentage = i / num_inference_steps
|
||||||
|
|
||||||
controlnet_states = None
|
controlnet_states = None
|
||||||
if (control_start < current_sampling_percent < control_end):
|
if (control_start <= current_step_percentage <= control_end):
|
||||||
# extract controlnet hidden state
|
# extract controlnet hidden state
|
||||||
controlnet_states = self.controlnet(
|
controlnet_states = self.controlnet(
|
||||||
hidden_states=latent_model_input,
|
hidden_states=latent_model_input,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user