mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-08 20:34:23 +08:00
temporal tiling for the control pipe
This commit is contained in:
parent
bb6ea6b3a4
commit
f9f06d595e
@ -300,6 +300,16 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
def _gaussian_weights(self, t_tile_length, t_batch_size):
|
||||
from numpy import pi, exp, sqrt
|
||||
|
||||
var = 0.01
|
||||
midpoint = (t_tile_length - 1) / 2 # -1 because index goes from 0 to latent_width - 1
|
||||
t_probs = [exp(-(t-midpoint)*(t-midpoint)/(t_tile_length*t_tile_length)/(2*var)) / sqrt(2*pi*var) for t in range(t_tile_length)]
|
||||
weights = torch.tensor(t_probs)
|
||||
weights = weights.unsqueeze(0).unsqueeze(2).unsqueeze(3).unsqueeze(4).repeat(1, t_batch_size,1, 1, 1)
|
||||
return weights
|
||||
|
||||
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
|
||||
def check_inputs(
|
||||
@ -372,7 +382,9 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
|
||||
width: int,
|
||||
num_frames: int,
|
||||
device: torch.device,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
start_frame: int = None,
|
||||
end_frame: int = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
||||
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
||||
base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
||||
@ -388,6 +400,16 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
|
||||
temporal_size=num_frames,
|
||||
use_real=True,
|
||||
)
|
||||
|
||||
if start_frame is not None:
|
||||
freqs_cos = freqs_cos.view(num_frames, grid_height * grid_width, -1)
|
||||
freqs_sin = freqs_sin.view(num_frames, grid_height * grid_width, -1)
|
||||
|
||||
freqs_cos = freqs_cos[start_frame:end_frame]
|
||||
freqs_sin = freqs_sin[start_frame:end_frame]
|
||||
|
||||
freqs_cos = freqs_cos.view(-1, freqs_cos.shape[-1])
|
||||
freqs_sin = freqs_sin.view(-1, freqs_sin.shape[-1])
|
||||
|
||||
freqs_cos = freqs_cos.to(device=device)
|
||||
freqs_sin = freqs_sin.to(device=device)
|
||||
@ -447,6 +469,9 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
|
||||
control_strength: float = 1.0,
|
||||
control_start_percent: float = 0.0,
|
||||
control_end_percent: float = 1.0,
|
||||
t_tile_length: int = 12,
|
||||
t_tile_overlap: int = 4,
|
||||
scheduler_name: str = "DPM",
|
||||
) -> Union[CogVideoX_Fun_PipelineOutput, Tuple]:
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
@ -524,10 +549,10 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
|
||||
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
if num_frames > 49:
|
||||
raise ValueError(
|
||||
"The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
|
||||
)
|
||||
# if num_frames > 49:
|
||||
# raise ValueError(
|
||||
# "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
|
||||
# )
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
@ -590,26 +615,6 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
|
||||
if comfyui_progressbar:
|
||||
pbar.update(1)
|
||||
|
||||
# if control_video is not None:
|
||||
# video_length = control_video.shape[2]
|
||||
# control_video = self.image_processor.preprocess(rearrange(control_video, "b c f h w -> (b f) c h w"), height=height, width=width)
|
||||
# control_video = control_video.to(dtype=torch.float32)
|
||||
# control_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length)
|
||||
# else:
|
||||
# control_video = None
|
||||
|
||||
# control_video_latents = self.prepare_control_latents(
|
||||
# None,
|
||||
# control_video,
|
||||
# batch_size,
|
||||
# height,
|
||||
# width,
|
||||
# self.vae.dtype,
|
||||
# device,
|
||||
# generator,
|
||||
# do_classifier_free_guidance
|
||||
# )[1]
|
||||
|
||||
|
||||
control_video_latents_input = (
|
||||
torch.cat([control_video] * 2) if do_classifier_free_guidance else control_video
|
||||
@ -624,16 +629,30 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Create rotary embeds if required
|
||||
image_rotary_emb = (
|
||||
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
|
||||
if self.transformer.config.use_rotary_positional_embeddings
|
||||
else None
|
||||
)
|
||||
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
# 8.5. Temporal tiling prep
|
||||
if "tiled" in scheduler_name:
|
||||
t_tile_length = t_tile_length // 4
|
||||
t_tile_overlap = t_tile_overlap // 4
|
||||
t_tile_weights = self._gaussian_weights(t_tile_length=t_tile_length, t_batch_size=1).to(latents.device).to(self.vae.dtype)
|
||||
temporal_tiling = True
|
||||
|
||||
print("Temporal tiling enabled")
|
||||
else:
|
||||
temporal_tiling = False
|
||||
print("Temporal tiling disabled")
|
||||
# 7. Create rotary embeds if required
|
||||
image_rotary_emb = (
|
||||
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
|
||||
if self.transformer.config.use_rotary_positional_embeddings
|
||||
else None
|
||||
)
|
||||
#print("latents.shape", latents.shape)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
# for DPM-solver++
|
||||
old_pred_original_sample = None
|
||||
@ -641,69 +660,149 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
if temporal_tiling and isinstance(self.scheduler, CogVideoXDDIMScheduler):
|
||||
#temporal tiling code based on https://github.com/mayuelala/FollowYourEmoji/blob/main/models/video_pipeline.py
|
||||
# =====================================================
|
||||
grid_ts = 0
|
||||
cur_t = 0
|
||||
while cur_t < latents.shape[1]:
|
||||
cur_t = max(grid_ts * t_tile_length - t_tile_overlap * grid_ts, 0) + t_tile_length
|
||||
grid_ts += 1
|
||||
|
||||
# Calculate the current step percentage
|
||||
current_step_percentage = i / num_inference_steps
|
||||
all_t = latents.shape[1]
|
||||
latents_all_list = []
|
||||
# =====================================================
|
||||
|
||||
# Determine if control_latents should be applied
|
||||
apply_control = control_start_percent <= current_step_percentage <= control_end_percent
|
||||
current_control_latents = control_latents if apply_control else torch.zeros_like(control_latents)
|
||||
for t_i in range(grid_ts):
|
||||
if t_i < grid_ts - 1:
|
||||
ofs_t = max(t_i * t_tile_length - t_tile_overlap * t_i, 0)
|
||||
if t_i == grid_ts - 1:
|
||||
ofs_t = all_t - t_tile_length
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
input_start_t = ofs_t
|
||||
input_end_t = ofs_t + t_tile_length
|
||||
|
||||
# predict noise model_output
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
return_dict=False,
|
||||
control_latents=current_control_latents,
|
||||
)[0]
|
||||
noise_pred = noise_pred.float()
|
||||
image_rotary_emb = (
|
||||
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device, input_start_t, input_end_t)
|
||||
if self.transformer.config.use_rotary_positional_embeddings
|
||||
else None
|
||||
)
|
||||
|
||||
# perform guidance
|
||||
if use_dynamic_cfg:
|
||||
self._guidance_scale = 1 + guidance_scale * (
|
||||
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
|
||||
)
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
latents_tile = latents[:, input_start_t:input_end_t,:, :, :]
|
||||
print("latents_tile.shape", latents_tile.shape)
|
||||
control_latents_tile = control_latents[:, input_start_t:input_end_t, :, :, :]
|
||||
print("control_latents_tile.shape", control_latents_tile.shape)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
latent_model_input_tile = torch.cat([latents_tile] * 2) if do_classifier_free_guidance else latents_tile
|
||||
latent_model_input_tile = self.scheduler.scale_model_input(latent_model_input_tile, t)
|
||||
|
||||
#t_input = t[None].to(device)
|
||||
t_input = t.expand(latent_model_input_tile.shape[0]) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
|
||||
# predict noise model_output
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input_tile,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=t_input,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
return_dict=False,
|
||||
control_latents=control_latents_tile,
|
||||
)[0]
|
||||
noise_pred = noise_pred.float()
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self._guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_tile = self.scheduler.step(noise_pred, t, latents_tile.to(self.vae.dtype), **extra_step_kwargs, return_dict=False)[0]
|
||||
latents_all_list.append(latents_tile)
|
||||
|
||||
# ==========================================
|
||||
latents_all = torch.zeros(latents.shape, device=latents.device, dtype=self.vae.dtype)
|
||||
contributors = torch.zeros(latents.shape, device=latents.device, dtype=self.vae.dtype)
|
||||
# Add each tile contribution to overall latents
|
||||
for t_i in range(grid_ts):
|
||||
if t_i < grid_ts - 1:
|
||||
ofs_t = max(t_i * t_tile_length - t_tile_overlap * t_i, 0)
|
||||
if t_i == grid_ts - 1:
|
||||
ofs_t = all_t - t_tile_length
|
||||
|
||||
input_start_t = ofs_t
|
||||
input_end_t = ofs_t + t_tile_length
|
||||
print("input_start_t, input_end_t", input_start_t, input_end_t, latents_all.shape)
|
||||
print("t_tile_weights.shape", t_tile_weights.shape)
|
||||
|
||||
latents_all[:, input_start_t:input_end_t,:, :, :] += latents_all_list[t_i] * t_tile_weights
|
||||
contributors[:, input_start_t:input_end_t,:, :, :] += t_tile_weights
|
||||
|
||||
latents_all /= contributors
|
||||
|
||||
latents = latents_all
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
pbar.update(1)
|
||||
# ==========================================
|
||||
else:
|
||||
latents, old_pred_original_sample = self.scheduler.step(
|
||||
noise_pred,
|
||||
old_pred_original_sample,
|
||||
t,
|
||||
timesteps[i - 1] if i > 0 else None,
|
||||
latents,
|
||||
**extra_step_kwargs,
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# Calculate the current step percentage
|
||||
current_step_percentage = i / num_inference_steps
|
||||
|
||||
# Determine if control_latents should be applied
|
||||
apply_control = control_start_percent <= current_step_percentage <= control_end_percent
|
||||
current_control_latents = control_latents if apply_control else torch.zeros_like(control_latents)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
|
||||
# predict noise model_output
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
return_dict=False,
|
||||
)
|
||||
latents = latents.to(prompt_embeds.dtype)
|
||||
control_latents=current_control_latents,
|
||||
)[0]
|
||||
noise_pred = noise_pred.float()
|
||||
|
||||
# call the callback, if provided
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
else:
|
||||
latents, old_pred_original_sample = self.scheduler.step(
|
||||
noise_pred,
|
||||
old_pred_original_sample,
|
||||
t,
|
||||
timesteps[i - 1] if i > 0 else None,
|
||||
latents,
|
||||
**extra_step_kwargs,
|
||||
return_dict=False,
|
||||
)
|
||||
latents = latents.to(prompt_embeds.dtype)
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if comfyui_progressbar:
|
||||
pbar.update(1)
|
||||
# call the callback, if provided
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if comfyui_progressbar:
|
||||
pbar.update(1)
|
||||
|
||||
# if output_type == "numpy":
|
||||
# video = self.decode_latents(latents)
|
||||
|
||||
12
nodes.py
12
nodes.py
@ -35,6 +35,7 @@ scheduler_mapping = {
|
||||
"Euler A": EulerAncestralDiscreteScheduler,
|
||||
"PNDM": PNDMScheduler,
|
||||
"DDIM": DDIMScheduler,
|
||||
"DDIM_tiled": CogVideoXDDIMScheduler,
|
||||
"CogVideoXDDIM": CogVideoXDDIMScheduler,
|
||||
"CogVideoXDPMScheduler": CogVideoXDPMScheduler,
|
||||
"SASolverScheduler": SASolverScheduler,
|
||||
@ -1244,6 +1245,7 @@ class CogVideoXFunControlSampler:
|
||||
"DEISMultistepScheduler",
|
||||
"CogVideoXDDIM",
|
||||
"CogVideoXDPMScheduler",
|
||||
"DDIM_tiled",
|
||||
],
|
||||
{
|
||||
"default": 'DDIM'
|
||||
@ -1252,6 +1254,9 @@ class CogVideoXFunControlSampler:
|
||||
"control_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"control_start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
"control_end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
"t_tile_length": ("INT", {"default": 48, "min": 2, "max": 128, "step": 1, "tooltip": "Length of temporal tiles for extending generations, only in effect with the tiled samplers"}),
|
||||
"t_tile_overlap": ("INT", {"default": 8, "min": 2, "max": 128, "step": 1, "tooltip": "Overlap of temporal tiling"}),
|
||||
|
||||
},
|
||||
}
|
||||
|
||||
@ -1261,7 +1266,7 @@ class CogVideoXFunControlSampler:
|
||||
CATEGORY = "CogVideoWrapper"
|
||||
|
||||
def process(self, pipeline, positive, negative, seed, steps, cfg, scheduler,
|
||||
control_latents, control_strength=1.0, control_start_percent=0.0, control_end_percent=1.0):
|
||||
control_latents, control_strength=1.0, control_start_percent=0.0, control_end_percent=1.0, t_tile_length=16, t_tile_overlap=8,):
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
pipe = pipeline["pipe"]
|
||||
@ -1309,7 +1314,10 @@ class CogVideoXFunControlSampler:
|
||||
control_video=control_latents["latents"],
|
||||
control_strength=control_strength,
|
||||
control_start_percent=control_start_percent,
|
||||
control_end_percent=control_end_percent
|
||||
control_end_percent=control_end_percent,
|
||||
t_tile_length=t_tile_length,
|
||||
t_tile_overlap=t_tile_overlap,
|
||||
scheduler_name=scheduler
|
||||
)
|
||||
|
||||
# for _lora_path, _lora_weight in zip(cogvideoxfun_model.get("loras", []), cogvideoxfun_model.get("strength_model", [])):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user