Add context schedules for the control pipeline

This commit is contained in:
kijai 2024-10-05 16:15:34 +03:00
parent 033ec61d86
commit 453ee1606e
4 changed files with 390 additions and 45 deletions

184
cogvideox_fun/context.py Normal file
View File

@ -0,0 +1,184 @@
import numpy as np
from typing import Callable, Optional, List
def ordered_halving(val):
bin_str = f"{val:064b}"
bin_flip = bin_str[::-1]
as_int = int(bin_flip, 2)
return as_int / (1 << 64)
def does_window_roll_over(window: list[int], num_frames: int) -> tuple[bool, int]:
prev_val = -1
for i, val in enumerate(window):
val = val % num_frames
if val < prev_val:
return True, i
prev_val = val
return False, -1
def shift_window_to_start(window: list[int], num_frames: int):
start_val = window[0]
for i in range(len(window)):
# 1) subtract each element by start_val to move vals relative to the start of all frames
# 2) add num_frames and take modulus to get adjusted vals
window[i] = ((window[i] - start_val) + num_frames) % num_frames
def shift_window_to_end(window: list[int], num_frames: int):
# 1) shift window to start
shift_window_to_start(window, num_frames)
end_val = window[-1]
end_delta = num_frames - end_val - 1
for i in range(len(window)):
# 2) add end_delta to each val to slide windows to end
window[i] = window[i] + end_delta
def get_missing_indexes(windows: list[list[int]], num_frames: int) -> list[int]:
all_indexes = list(range(num_frames))
for w in windows:
for val in w:
try:
all_indexes.remove(val)
except ValueError:
pass
return all_indexes
def uniform_looped(
step: int = ...,
num_steps: Optional[int] = None,
num_frames: int = ...,
context_size: Optional[int] = None,
context_stride: int = 3,
context_overlap: int = 4,
closed_loop: bool = True,
):
if num_frames <= context_size:
yield list(range(num_frames))
return
context_stride = min(context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1)
for context_step in 1 << np.arange(context_stride):
pad = int(round(num_frames * ordered_halving(step)))
for j in range(
int(ordered_halving(step) * context_step) + pad,
num_frames + pad + (0 if closed_loop else -context_overlap),
(context_size * context_step - context_overlap),
):
yield [e % num_frames for e in range(j, j + context_size * context_step, context_step)]
#from AnimateDiff-Evolved by Kosinkadink (https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved)
def uniform_standard(
step: int = ...,
num_steps: Optional[int] = None,
num_frames: int = ...,
context_size: Optional[int] = None,
context_stride: int = 3,
context_overlap: int = 4,
closed_loop: bool = True,
):
windows = []
if num_frames <= context_size:
windows.append(list(range(num_frames)))
return windows
context_stride = min(context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1)
for context_step in 1 << np.arange(context_stride):
pad = int(round(num_frames * ordered_halving(step)))
for j in range(
int(ordered_halving(step) * context_step) + pad,
num_frames + pad + (0 if closed_loop else -context_overlap),
(context_size * context_step - context_overlap),
):
windows.append([e % num_frames for e in range(j, j + context_size * context_step, context_step)])
# now that windows are created, shift any windows that loop, and delete duplicate windows
delete_idxs = []
win_i = 0
while win_i < len(windows):
# if window is rolls over itself, need to shift it
is_roll, roll_idx = does_window_roll_over(windows[win_i], num_frames)
if is_roll:
roll_val = windows[win_i][roll_idx] # roll_val might not be 0 for windows of higher strides
shift_window_to_end(windows[win_i], num_frames=num_frames)
# check if next window (cyclical) is missing roll_val
if roll_val not in windows[(win_i+1) % len(windows)]:
# need to insert new window here - just insert window starting at roll_val
windows.insert(win_i+1, list(range(roll_val, roll_val + context_size)))
# delete window if it's not unique
for pre_i in range(0, win_i):
if windows[win_i] == windows[pre_i]:
delete_idxs.append(win_i)
break
win_i += 1
# reverse delete_idxs so that they will be deleted in an order that doesn't break idx correlation
delete_idxs.reverse()
for i in delete_idxs:
windows.pop(i)
return windows
def static_standard(
step: int = ...,
num_steps: Optional[int] = None,
num_frames: int = ...,
context_size: Optional[int] = None,
context_stride: int = 3,
context_overlap: int = 4,
closed_loop: bool = True,
):
windows = []
if num_frames <= context_size:
windows.append(list(range(num_frames)))
return windows
# always return the same set of windows
delta = context_size - context_overlap
for start_idx in range(0, num_frames, delta):
# if past the end of frames, move start_idx back to allow same context_length
ending = start_idx + context_size
if ending >= num_frames:
final_delta = ending - num_frames
final_start_idx = start_idx - final_delta
windows.append(list(range(final_start_idx, final_start_idx + context_size)))
break
windows.append(list(range(start_idx, start_idx + context_size)))
return windows
def get_context_scheduler(name: str) -> Callable:
if name == "uniform_looped":
return uniform_looped
elif name == "uniform_standard":
return uniform_standard
elif name == "static_standard":
return static_standard
else:
raise ValueError(f"Unknown context_overlap policy {name}")
def get_total_steps(
scheduler,
timesteps: List[int],
num_steps: Optional[int] = None,
num_frames: int = ...,
context_size: Optional[int] = None,
context_stride: int = 3,
context_overlap: int = 4,
closed_loop: bool = True,
):
return sum(
len(
list(
scheduler(
i,
num_steps,
num_frames,
context_size,
context_stride,
context_overlap,
)
)
)
for i in range(len(timesteps))
)

View File

@ -395,8 +395,9 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
width: int,
num_frames: int,
device: torch.device,
start_frame: int = None,
end_frame: int = None,
start_frame: Optional[int] = None,
end_frame: Optional[int] = None,
context_frames: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
@ -414,12 +415,15 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
use_real=True,
)
if start_frame is not None:
if start_frame is not None or context_frames is not None:
freqs_cos = freqs_cos.view(num_frames, grid_height * grid_width, -1)
freqs_sin = freqs_sin.view(num_frames, grid_height * grid_width, -1)
freqs_cos = freqs_cos[start_frame:end_frame]
freqs_sin = freqs_sin[start_frame:end_frame]
if context_frames is not None:
freqs_cos = freqs_cos[context_frames]
freqs_sin = freqs_sin[context_frames]
else:
freqs_cos = freqs_cos[start_frame:end_frame]
freqs_sin = freqs_sin[start_frame:end_frame]
freqs_cos = freqs_cos.view(-1, freqs_cos.shape[-1])
freqs_sin = freqs_sin.view(-1, freqs_sin.shape[-1])
@ -483,9 +487,11 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
control_strength: float = 1.0,
control_start_percent: float = 0.0,
control_end_percent: float = 1.0,
t_tile_length: int = 12,
t_tile_overlap: int = 4,
scheduler_name: str = "DPM",
context_schedule: Optional[str] = None,
context_frames: Optional[int] = None,
context_stride: Optional[int] = None,
context_overlap: Optional[int] = None,
) -> Union[CogVideoX_Fun_PipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
@ -652,23 +658,33 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
# 8.5. Temporal tiling prep
if "tiled" in scheduler_name:
t_tile_length = t_tile_length // 4
t_tile_overlap = t_tile_overlap // 4
if context_schedule is not None and context_schedule == "temporal_tiling":
t_tile_length = context_frames // 4
t_tile_overlap = context_overlap // 4
t_tile_weights = self._gaussian_weights(t_tile_length=t_tile_length, t_batch_size=1).to(latents.device).to(self.vae.dtype)
temporal_tiling = True
use_temporal_tiling = True
print("Temporal tiling enabled")
elif context_schedule is not None:
print(f"Context schedule enabled: {context_frames} frames, {context_stride} stride, {context_overlap} overlap")
use_temporal_tiling = False
use_context_schedule = True
context_frames = context_frames // 4
context_stride = context_stride // 4
context_overlap = context_overlap // 4
from .context import get_context_scheduler
context = get_context_scheduler(context_schedule)
else:
temporal_tiling = False
print("Temporal tiling disabled")
use_temporal_tiling = False
use_context_schedule = False
print("Temporal tiling and context schedule disabled")
# 7. Create rotary embeds if required
image_rotary_emb = (
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
if self.transformer.config.use_rotary_positional_embeddings
else None
)
#print("latents.shape", latents.shape)
with self.progress_bar(total=num_inference_steps) as progress_bar:
# for DPM-solver++
@ -677,7 +693,7 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
if self.interrupt:
continue
if temporal_tiling and isinstance(self.scheduler, CogVideoXDDIMScheduler):
if use_temporal_tiling and isinstance(self.scheduler, CogVideoXDDIMScheduler):
#temporal tiling code based on https://github.com/mayuelala/FollowYourEmoji/blob/main/models/video_pipeline.py
# =====================================================
grid_ts = 0
@ -757,6 +773,96 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
progress_bar.update()
pbar.update(1)
# ==========================================
elif use_context_schedule:
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# Calculate the current step percentage
current_step_percentage = i / num_inference_steps
# Determine if control_latents should be applied
apply_control = control_start_percent <= current_step_percentage <= control_end_percent
current_control_latents = control_latents if apply_control else torch.zeros_like(control_latents)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
context_queue = list(context(
i, num_inference_steps, latents.shape[1], context_frames, context_stride, context_overlap,
))
counter = torch.zeros_like(latent_model_input)
noise_pred = torch.zeros_like(latent_model_input)
if do_classifier_free_guidance:
noise_uncond = torch.zeros_like(latent_model_input)
for c in context_queue:
partial_latent_model_input = latent_model_input[:, c, :, :, :]
partial_control_latents = current_control_latents[:, c, :, :, :]
image_rotary_emb = (
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device, context_frames=c)
if self.transformer.config.use_rotary_positional_embeddings
else None
)
# predict noise model_output
noise_pred[:, c, :, :, :] += self.transformer(
hidden_states=partial_latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
image_rotary_emb=image_rotary_emb,
return_dict=False,
control_latents=partial_control_latents,
)[0]
counter[:, c, :, :, :] += 1
if do_classifier_free_guidance:
noise_uncond[:, c, :, :, :] += self.transformer(
hidden_states=partial_latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
image_rotary_emb=image_rotary_emb,
return_dict=False,
control_latents=partial_control_latents,
)[0]
noise_pred = noise_pred.float()
noise_pred /= counter
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
else:
latents, old_pred_original_sample = self.scheduler.step(
noise_pred,
old_pred_original_sample,
t,
timesteps[i - 1] if i > 0 else None,
latents,
**extra_step_kwargs,
return_dict=False,
)
latents = latents.to(prompt_embeds.dtype)
# call the callback, if provided
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if comfyui_progressbar:
pbar.update(1)
else:
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

View File

@ -682,7 +682,7 @@
},
"size": {
"0": 367.79998779296875,
"1": 102
"1": 146
},
"flags": {},
"order": 10,
@ -706,8 +706,20 @@
"links": [
182
],
"shape": 3,
"slot_index": 0
"slot_index": 0,
"shape": 3
},
{
"name": "width",
"type": "INT",
"links": null,
"shape": 3
},
{
"name": "height",
"type": "INT",
"links": null,
"shape": 3
}
],
"properties": {
@ -715,7 +727,8 @@
},
"widgets_values": [
512,
false
false,
0
]
},
{
@ -725,10 +738,10 @@
"0": 1085,
"1": 312
},
"size": [
311.22059416191496,
286
],
"size": {
"0": 311.2205810546875,
"1": 350
},
"flags": {},
"order": 11,
"mode": 0,
@ -752,6 +765,16 @@
"name": "control_latents",
"type": "COGCONTROL_LATENTS",
"link": 182
},
{
"name": "samples",
"type": "LATENT",
"link": null
},
{
"name": "context_options",
"type": "COGCONTEXT",
"link": null
}
],
"outputs": [
@ -783,6 +806,7 @@
"CogVideoXDPMScheduler",
0.7000000000000001,
0,
1,
1
]
},
@ -1034,10 +1058,10 @@
"config": {},
"extra": {
"ds": {
"scale": 0.6934334949441758,
"scale": 0.6934334949442492,
"offset": [
364.1021432696588,
27.28260472943026
39.55130702561554,
104.54407751572876
]
}
},

View File

@ -55,7 +55,7 @@ from .cogvideox_fun.autoencoder_magvit import AutoencoderKLCogVideoX as Autoenco
from .cogvideox_fun.utils import get_image_to_video_latent, get_video_to_video_latent, ASPECT_RATIO_512, get_closest_ratio, to_pil
from .cogvideox_fun.pipeline_cogvideox_inpaint import CogVideoX_Fun_Pipeline_Inpaint
from .cogvideox_fun.pipeline_cogvideox_control import CogVideoX_Fun_Pipeline_Control
from .cogvideox_fun.lora_utils import merge_lora, unmerge_lora
from .cogvideox_fun.lora_utils import merge_lora
from PIL import Image
import numpy as np
import json
@ -342,12 +342,12 @@ class DownloadAndLoadCogVideoModel:
transformer = transformer.to(dtype).to(offload_device)
if lora is not None:
if lora['strength'] > 0:
logging.info(f"Merging LoRA weights from {lora['path']} with strength {lora['strength']}")
logging.info(f"Merging LoRA weights from {lora['path']} with strength {lora['strength']}")
if "fun" in model.lower():
transformer = merge_lora(transformer, lora["path"], lora["strength"])
else:
logging.info(f"Removing LoRA weights from {lora['path']} with strength {lora['strength']}")
transformer = unmerge_lora(transformer, lora["path"], lora["strength"])
raise NotImplementedError("LoRA merging is currently only supported for Fun models")
if block_edit is not None:
transformer = remove_specific_blocks(transformer, block_edit)
@ -381,9 +381,7 @@ class DownloadAndLoadCogVideoModel:
pipe = CogVideoX_Fun_Pipeline_Inpaint(vae, transformer, scheduler, pab_config=pab_config)
else:
vae = AutoencoderKLCogVideoX.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device)
pipe = CogVideoXPipeline(vae, transformer, scheduler, pab_config=pab_config)
pipe = CogVideoXPipeline(vae, transformer, scheduler, pab_config=pab_config)
if enable_sequential_cpu_offload:
pipe.enable_sequential_cpu_offload()
@ -1274,7 +1272,34 @@ class CogVideoControlImageEncode:
}
return (control_latents, width, height)
class CogVideoContextOptions:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"context_schedule": (["uniform_standard", "uniform_looped", "static_standard", "temporal_tiling"],),
"context_frames": ("INT", {"default": 12, "min": 2, "max": 100, "step": 1, "tooltip": "Number of pixel frames in the context, NOTE: the latent space has 4 frames in 1"} ),
"context_stride": ("INT", {"default": 4, "min": 4, "max": 10, "step": 1, "tooltip": "Context stride as pixel frames, NOTE: the latent space has 4 frames in 1"} ),
"context_overlap": ("INT", {"default": 4, "min": 4, "max": 10, "step": 1, "tooltip": "Context overlap as pixel frames, NOTE: the latent space has 4 frames in 1"} ),
}
}
RETURN_TYPES = ("COGCONTEXT", )
RETURN_NAMES = ("context_options",)
FUNCTION = "process"
CATEGORY = "CogVideoWrapper"
def process(self, context_schedule, context_frames, context_stride, context_overlap):
context_options = {
"context_schedule":context_schedule,
"context_frames":context_frames,
"context_stride":context_stride,
"context_overlap":context_overlap
}
return (context_options,)
class CogVideoXFunControlSampler:
@classmethod
def INPUT_TYPES(s):
@ -1300,7 +1325,6 @@ class CogVideoXFunControlSampler:
"DEISMultistepScheduler",
"CogVideoXDDIM",
"CogVideoXDPMScheduler",
"DDIM_tiled",
],
{
"default": 'DDIM'
@ -1309,12 +1333,11 @@ class CogVideoXFunControlSampler:
"control_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"control_start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
"control_end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
"t_tile_length": ("INT", {"default": 48, "min": 2, "max": 128, "step": 1, "tooltip": "Length of temporal tiles for extending generations, only in effect with the tiled samplers"}),
"t_tile_overlap": ("INT", {"default": 8, "min": 2, "max": 128, "step": 1, "tooltip": "Overlap of temporal tiling"}),
},
"optional": {
"samples": ("LATENT", ),
"denoise_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
"context_options": ("COGCONTEXT", ),
},
}
@ -1325,7 +1348,7 @@ class CogVideoXFunControlSampler:
def process(self, pipeline, positive, negative, seed, steps, cfg, scheduler, control_latents,
control_strength=1.0, control_start_percent=0.0, control_end_percent=1.0, t_tile_length=16, t_tile_overlap=8,
samples=None, denoise_strength=1.0):
samples=None, denoise_strength=1.0, context_options=None):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
pipe = pipeline["pipe"]
@ -1341,6 +1364,9 @@ class CogVideoXFunControlSampler:
# Load Sampler
scheduler_config = pipeline["scheduler_config"]
if context_options is not None and context_options["context_schedule"] == "temporal_tiling":
logging.info("Temporal tiling enabled, changing scheduler to DDIM_tiled")
scheduler="DDIM_tiled"
if scheduler in scheduler_mapping:
noise_scheduler = scheduler_mapping[scheduler].from_config(scheduler_config)
pipe.scheduler = noise_scheduler
@ -1371,11 +1397,14 @@ class CogVideoXFunControlSampler:
control_strength=control_strength,
control_start_percent=control_start_percent,
control_end_percent=control_end_percent,
t_tile_length=t_tile_length,
t_tile_overlap=t_tile_overlap,
scheduler_name=scheduler,
latents=samples["samples"] if samples is not None else None,
denoise_strength=denoise_strength,
context_schedule=context_options["context_schedule"] if context_options is not None else None,
context_frames=context_options["context_frames"] if context_options is not None else None,
context_stride=context_options["context_stride"] if context_options is not None else None,
context_overlap=context_options["context_overlap"] if context_options is not None else None
)
return (pipeline, {"samples": latents})
@ -1395,7 +1424,8 @@ NODE_CLASS_MAPPINGS = {
"CogVideoPABConfig": CogVideoPABConfig,
"CogVideoTransformerEdit": CogVideoTransformerEdit,
"CogVideoControlImageEncode": CogVideoControlImageEncode,
"CogVideoLoraSelect": CogVideoLoraSelect
"CogVideoLoraSelect": CogVideoLoraSelect,
"CogVideoContextOptions": CogVideoContextOptions
}
NODE_DISPLAY_NAME_MAPPINGS = {
"DownloadAndLoadCogVideoModel": "(Down)load CogVideo Model",
@ -1412,5 +1442,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"CogVideoPABConfig": "CogVideo PABConfig",
"CogVideoTransformerEdit": "CogVideo TransformerEdit",
"CogVideoControlImageEncode": "CogVideo Control ImageEncode",
"CogVideoLoraSelect": "CogVideo LoraSelect"
"CogVideoLoraSelect": "CogVideo LoraSelect",
"CogVideoContextOptions": "CogVideo Context Options"
}