diff --git a/cogvideox_fun/context.py b/cogvideox_fun/context.py new file mode 100644 index 0000000..6a30fed --- /dev/null +++ b/cogvideox_fun/context.py @@ -0,0 +1,184 @@ +import numpy as np +from typing import Callable, Optional, List + + +def ordered_halving(val): + bin_str = f"{val:064b}" + bin_flip = bin_str[::-1] + as_int = int(bin_flip, 2) + + return as_int / (1 << 64) + +def does_window_roll_over(window: list[int], num_frames: int) -> tuple[bool, int]: + prev_val = -1 + for i, val in enumerate(window): + val = val % num_frames + if val < prev_val: + return True, i + prev_val = val + return False, -1 + +def shift_window_to_start(window: list[int], num_frames: int): + start_val = window[0] + for i in range(len(window)): + # 1) subtract each element by start_val to move vals relative to the start of all frames + # 2) add num_frames and take modulus to get adjusted vals + window[i] = ((window[i] - start_val) + num_frames) % num_frames + +def shift_window_to_end(window: list[int], num_frames: int): + # 1) shift window to start + shift_window_to_start(window, num_frames) + end_val = window[-1] + end_delta = num_frames - end_val - 1 + for i in range(len(window)): + # 2) add end_delta to each val to slide windows to end + window[i] = window[i] + end_delta + +def get_missing_indexes(windows: list[list[int]], num_frames: int) -> list[int]: + all_indexes = list(range(num_frames)) + for w in windows: + for val in w: + try: + all_indexes.remove(val) + except ValueError: + pass + return all_indexes + +def uniform_looped( + step: int = ..., + num_steps: Optional[int] = None, + num_frames: int = ..., + context_size: Optional[int] = None, + context_stride: int = 3, + context_overlap: int = 4, + closed_loop: bool = True, +): + if num_frames <= context_size: + yield list(range(num_frames)) + return + + context_stride = min(context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1) + + for context_step in 1 << np.arange(context_stride): + pad = int(round(num_frames * ordered_halving(step))) + for j in range( + int(ordered_halving(step) * context_step) + pad, + num_frames + pad + (0 if closed_loop else -context_overlap), + (context_size * context_step - context_overlap), + ): + yield [e % num_frames for e in range(j, j + context_size * context_step, context_step)] + +#from AnimateDiff-Evolved by Kosinkadink (https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved) +def uniform_standard( + step: int = ..., + num_steps: Optional[int] = None, + num_frames: int = ..., + context_size: Optional[int] = None, + context_stride: int = 3, + context_overlap: int = 4, + closed_loop: bool = True, +): + windows = [] + if num_frames <= context_size: + windows.append(list(range(num_frames))) + return windows + + context_stride = min(context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1) + + for context_step in 1 << np.arange(context_stride): + pad = int(round(num_frames * ordered_halving(step))) + for j in range( + int(ordered_halving(step) * context_step) + pad, + num_frames + pad + (0 if closed_loop else -context_overlap), + (context_size * context_step - context_overlap), + ): + windows.append([e % num_frames for e in range(j, j + context_size * context_step, context_step)]) + + # now that windows are created, shift any windows that loop, and delete duplicate windows + delete_idxs = [] + win_i = 0 + while win_i < len(windows): + # if window is rolls over itself, need to shift it + is_roll, roll_idx = does_window_roll_over(windows[win_i], num_frames) + if is_roll: + roll_val = windows[win_i][roll_idx] # roll_val might not be 0 for windows of higher strides + shift_window_to_end(windows[win_i], num_frames=num_frames) + # check if next window (cyclical) is missing roll_val + if roll_val not in windows[(win_i+1) % len(windows)]: + # need to insert new window here - just insert window starting at roll_val + windows.insert(win_i+1, list(range(roll_val, roll_val + context_size))) + # delete window if it's not unique + for pre_i in range(0, win_i): + if windows[win_i] == windows[pre_i]: + delete_idxs.append(win_i) + break + win_i += 1 + + # reverse delete_idxs so that they will be deleted in an order that doesn't break idx correlation + delete_idxs.reverse() + for i in delete_idxs: + windows.pop(i) + return windows + +def static_standard( + step: int = ..., + num_steps: Optional[int] = None, + num_frames: int = ..., + context_size: Optional[int] = None, + context_stride: int = 3, + context_overlap: int = 4, + closed_loop: bool = True, +): + windows = [] + if num_frames <= context_size: + windows.append(list(range(num_frames))) + return windows + # always return the same set of windows + delta = context_size - context_overlap + for start_idx in range(0, num_frames, delta): + # if past the end of frames, move start_idx back to allow same context_length + ending = start_idx + context_size + if ending >= num_frames: + final_delta = ending - num_frames + final_start_idx = start_idx - final_delta + windows.append(list(range(final_start_idx, final_start_idx + context_size))) + break + windows.append(list(range(start_idx, start_idx + context_size))) + return windows + +def get_context_scheduler(name: str) -> Callable: + if name == "uniform_looped": + return uniform_looped + elif name == "uniform_standard": + return uniform_standard + elif name == "static_standard": + return static_standard + else: + raise ValueError(f"Unknown context_overlap policy {name}") + + +def get_total_steps( + scheduler, + timesteps: List[int], + num_steps: Optional[int] = None, + num_frames: int = ..., + context_size: Optional[int] = None, + context_stride: int = 3, + context_overlap: int = 4, + closed_loop: bool = True, +): + return sum( + len( + list( + scheduler( + i, + num_steps, + num_frames, + context_size, + context_stride, + context_overlap, + ) + ) + ) + for i in range(len(timesteps)) + ) diff --git a/cogvideox_fun/pipeline_cogvideox_control.py b/cogvideox_fun/pipeline_cogvideox_control.py index 9edc283..cad1fad 100644 --- a/cogvideox_fun/pipeline_cogvideox_control.py +++ b/cogvideox_fun/pipeline_cogvideox_control.py @@ -395,8 +395,9 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline): width: int, num_frames: int, device: torch.device, - start_frame: int = None, - end_frame: int = None, + start_frame: Optional[int] = None, + end_frame: Optional[int] = None, + context_frames: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) @@ -414,12 +415,15 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline): use_real=True, ) - if start_frame is not None: + if start_frame is not None or context_frames is not None: freqs_cos = freqs_cos.view(num_frames, grid_height * grid_width, -1) freqs_sin = freqs_sin.view(num_frames, grid_height * grid_width, -1) - - freqs_cos = freqs_cos[start_frame:end_frame] - freqs_sin = freqs_sin[start_frame:end_frame] + if context_frames is not None: + freqs_cos = freqs_cos[context_frames] + freqs_sin = freqs_sin[context_frames] + else: + freqs_cos = freqs_cos[start_frame:end_frame] + freqs_sin = freqs_sin[start_frame:end_frame] freqs_cos = freqs_cos.view(-1, freqs_cos.shape[-1]) freqs_sin = freqs_sin.view(-1, freqs_sin.shape[-1]) @@ -483,9 +487,11 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline): control_strength: float = 1.0, control_start_percent: float = 0.0, control_end_percent: float = 1.0, - t_tile_length: int = 12, - t_tile_overlap: int = 4, scheduler_name: str = "DPM", + context_schedule: Optional[str] = None, + context_frames: Optional[int] = None, + context_stride: Optional[int] = None, + context_overlap: Optional[int] = None, ) -> Union[CogVideoX_Fun_PipelineOutput, Tuple]: """ Function invoked when calling the pipeline for generation. @@ -652,23 +658,33 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline): num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) # 8.5. Temporal tiling prep - if "tiled" in scheduler_name: - t_tile_length = t_tile_length // 4 - t_tile_overlap = t_tile_overlap // 4 + if context_schedule is not None and context_schedule == "temporal_tiling": + t_tile_length = context_frames // 4 + t_tile_overlap = context_overlap // 4 t_tile_weights = self._gaussian_weights(t_tile_length=t_tile_length, t_batch_size=1).to(latents.device).to(self.vae.dtype) - temporal_tiling = True - + use_temporal_tiling = True print("Temporal tiling enabled") + elif context_schedule is not None: + print(f"Context schedule enabled: {context_frames} frames, {context_stride} stride, {context_overlap} overlap") + use_temporal_tiling = False + use_context_schedule = True + context_frames = context_frames // 4 + context_stride = context_stride // 4 + context_overlap = context_overlap // 4 + from .context import get_context_scheduler + context = get_context_scheduler(context_schedule) + else: - temporal_tiling = False - print("Temporal tiling disabled") + use_temporal_tiling = False + use_context_schedule = False + print("Temporal tiling and context schedule disabled") # 7. Create rotary embeds if required image_rotary_emb = ( self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) if self.transformer.config.use_rotary_positional_embeddings else None ) - #print("latents.shape", latents.shape) + with self.progress_bar(total=num_inference_steps) as progress_bar: # for DPM-solver++ @@ -677,7 +693,7 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline): if self.interrupt: continue - if temporal_tiling and isinstance(self.scheduler, CogVideoXDDIMScheduler): + if use_temporal_tiling and isinstance(self.scheduler, CogVideoXDDIMScheduler): #temporal tiling code based on https://github.com/mayuelala/FollowYourEmoji/blob/main/models/video_pipeline.py # ===================================================== grid_ts = 0 @@ -757,6 +773,96 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline): progress_bar.update() pbar.update(1) # ========================================== + elif use_context_schedule: + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # Calculate the current step percentage + current_step_percentage = i / num_inference_steps + + # Determine if control_latents should be applied + apply_control = control_start_percent <= current_step_percentage <= control_end_percent + current_control_latents = control_latents if apply_control else torch.zeros_like(control_latents) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + context_queue = list(context( + i, num_inference_steps, latents.shape[1], context_frames, context_stride, context_overlap, + )) + counter = torch.zeros_like(latent_model_input) + noise_pred = torch.zeros_like(latent_model_input) + if do_classifier_free_guidance: + noise_uncond = torch.zeros_like(latent_model_input) + + for c in context_queue: + partial_latent_model_input = latent_model_input[:, c, :, :, :] + partial_control_latents = current_control_latents[:, c, :, :, :] + image_rotary_emb = ( + self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device, context_frames=c) + if self.transformer.config.use_rotary_positional_embeddings + else None + ) + + # predict noise model_output + noise_pred[:, c, :, :, :] += self.transformer( + hidden_states=partial_latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + return_dict=False, + control_latents=partial_control_latents, + )[0] + + counter[:, c, :, :, :] += 1 + if do_classifier_free_guidance: + noise_uncond[:, c, :, :, :] += self.transformer( + hidden_states=partial_latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + return_dict=False, + control_latents=partial_control_latents, + )[0] + + noise_pred = noise_pred.float() + + noise_pred /= counter + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if not isinstance(self.scheduler, CogVideoXDPMScheduler): + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + else: + latents, old_pred_original_sample = self.scheduler.step( + noise_pred, + old_pred_original_sample, + t, + timesteps[i - 1] if i > 0 else None, + latents, + **extra_step_kwargs, + return_dict=False, + ) + latents = latents.to(prompt_embeds.dtype) + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if comfyui_progressbar: + pbar.update(1) else: latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) diff --git a/examples/cogvideox_fun_pose_example_01.json b/examples/cogvideox_fun_pose_example_01.json index 9933267..da8750d 100644 --- a/examples/cogvideox_fun_pose_example_01.json +++ b/examples/cogvideox_fun_pose_example_01.json @@ -682,7 +682,7 @@ }, "size": { "0": 367.79998779296875, - "1": 102 + "1": 146 }, "flags": {}, "order": 10, @@ -706,8 +706,20 @@ "links": [ 182 ], - "shape": 3, - "slot_index": 0 + "slot_index": 0, + "shape": 3 + }, + { + "name": "width", + "type": "INT", + "links": null, + "shape": 3 + }, + { + "name": "height", + "type": "INT", + "links": null, + "shape": 3 } ], "properties": { @@ -715,7 +727,8 @@ }, "widgets_values": [ 512, - false + false, + 0 ] }, { @@ -725,10 +738,10 @@ "0": 1085, "1": 312 }, - "size": [ - 311.22059416191496, - 286 - ], + "size": { + "0": 311.2205810546875, + "1": 350 + }, "flags": {}, "order": 11, "mode": 0, @@ -752,6 +765,16 @@ "name": "control_latents", "type": "COGCONTROL_LATENTS", "link": 182 + }, + { + "name": "samples", + "type": "LATENT", + "link": null + }, + { + "name": "context_options", + "type": "COGCONTEXT", + "link": null } ], "outputs": [ @@ -783,6 +806,7 @@ "CogVideoXDPMScheduler", 0.7000000000000001, 0, + 1, 1 ] }, @@ -1034,10 +1058,10 @@ "config": {}, "extra": { "ds": { - "scale": 0.6934334949441758, + "scale": 0.6934334949442492, "offset": [ - 364.1021432696588, - 27.28260472943026 + 39.55130702561554, + 104.54407751572876 ] } }, diff --git a/nodes.py b/nodes.py index bc3d188..480f54b 100644 --- a/nodes.py +++ b/nodes.py @@ -55,7 +55,7 @@ from .cogvideox_fun.autoencoder_magvit import AutoencoderKLCogVideoX as Autoenco from .cogvideox_fun.utils import get_image_to_video_latent, get_video_to_video_latent, ASPECT_RATIO_512, get_closest_ratio, to_pil from .cogvideox_fun.pipeline_cogvideox_inpaint import CogVideoX_Fun_Pipeline_Inpaint from .cogvideox_fun.pipeline_cogvideox_control import CogVideoX_Fun_Pipeline_Control -from .cogvideox_fun.lora_utils import merge_lora, unmerge_lora +from .cogvideox_fun.lora_utils import merge_lora from PIL import Image import numpy as np import json @@ -342,12 +342,12 @@ class DownloadAndLoadCogVideoModel: transformer = transformer.to(dtype).to(offload_device) if lora is not None: - if lora['strength'] > 0: - logging.info(f"Merging LoRA weights from {lora['path']} with strength {lora['strength']}") + logging.info(f"Merging LoRA weights from {lora['path']} with strength {lora['strength']}") + if "fun" in model.lower(): transformer = merge_lora(transformer, lora["path"], lora["strength"]) else: - logging.info(f"Removing LoRA weights from {lora['path']} with strength {lora['strength']}") - transformer = unmerge_lora(transformer, lora["path"], lora["strength"]) + raise NotImplementedError("LoRA merging is currently only supported for Fun models") + if block_edit is not None: transformer = remove_specific_blocks(transformer, block_edit) @@ -381,9 +381,7 @@ class DownloadAndLoadCogVideoModel: pipe = CogVideoX_Fun_Pipeline_Inpaint(vae, transformer, scheduler, pab_config=pab_config) else: vae = AutoencoderKLCogVideoX.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device) - pipe = CogVideoXPipeline(vae, transformer, scheduler, pab_config=pab_config) - - + pipe = CogVideoXPipeline(vae, transformer, scheduler, pab_config=pab_config) if enable_sequential_cpu_offload: pipe.enable_sequential_cpu_offload() @@ -1274,7 +1272,34 @@ class CogVideoControlImageEncode: } return (control_latents, width, height) - + + +class CogVideoContextOptions: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "context_schedule": (["uniform_standard", "uniform_looped", "static_standard", "temporal_tiling"],), + "context_frames": ("INT", {"default": 12, "min": 2, "max": 100, "step": 1, "tooltip": "Number of pixel frames in the context, NOTE: the latent space has 4 frames in 1"} ), + "context_stride": ("INT", {"default": 4, "min": 4, "max": 10, "step": 1, "tooltip": "Context stride as pixel frames, NOTE: the latent space has 4 frames in 1"} ), + "context_overlap": ("INT", {"default": 4, "min": 4, "max": 10, "step": 1, "tooltip": "Context overlap as pixel frames, NOTE: the latent space has 4 frames in 1"} ), + } + } + + RETURN_TYPES = ("COGCONTEXT", ) + RETURN_NAMES = ("context_options",) + FUNCTION = "process" + CATEGORY = "CogVideoWrapper" + + def process(self, context_schedule, context_frames, context_stride, context_overlap): + context_options = { + "context_schedule":context_schedule, + "context_frames":context_frames, + "context_stride":context_stride, + "context_overlap":context_overlap + } + + return (context_options,) + class CogVideoXFunControlSampler: @classmethod def INPUT_TYPES(s): @@ -1300,7 +1325,6 @@ class CogVideoXFunControlSampler: "DEISMultistepScheduler", "CogVideoXDDIM", "CogVideoXDPMScheduler", - "DDIM_tiled", ], { "default": 'DDIM' @@ -1309,12 +1333,11 @@ class CogVideoXFunControlSampler: "control_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), "control_start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}), "control_end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), - "t_tile_length": ("INT", {"default": 48, "min": 2, "max": 128, "step": 1, "tooltip": "Length of temporal tiles for extending generations, only in effect with the tiled samplers"}), - "t_tile_overlap": ("INT", {"default": 8, "min": 2, "max": 128, "step": 1, "tooltip": "Overlap of temporal tiling"}), }, "optional": { "samples": ("LATENT", ), "denoise_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + "context_options": ("COGCONTEXT", ), }, } @@ -1325,7 +1348,7 @@ class CogVideoXFunControlSampler: def process(self, pipeline, positive, negative, seed, steps, cfg, scheduler, control_latents, control_strength=1.0, control_start_percent=0.0, control_end_percent=1.0, t_tile_length=16, t_tile_overlap=8, - samples=None, denoise_strength=1.0): + samples=None, denoise_strength=1.0, context_options=None): device = mm.get_torch_device() offload_device = mm.unet_offload_device() pipe = pipeline["pipe"] @@ -1341,6 +1364,9 @@ class CogVideoXFunControlSampler: # Load Sampler scheduler_config = pipeline["scheduler_config"] + if context_options is not None and context_options["context_schedule"] == "temporal_tiling": + logging.info("Temporal tiling enabled, changing scheduler to DDIM_tiled") + scheduler="DDIM_tiled" if scheduler in scheduler_mapping: noise_scheduler = scheduler_mapping[scheduler].from_config(scheduler_config) pipe.scheduler = noise_scheduler @@ -1371,11 +1397,14 @@ class CogVideoXFunControlSampler: control_strength=control_strength, control_start_percent=control_start_percent, control_end_percent=control_end_percent, - t_tile_length=t_tile_length, - t_tile_overlap=t_tile_overlap, scheduler_name=scheduler, latents=samples["samples"] if samples is not None else None, denoise_strength=denoise_strength, + context_schedule=context_options["context_schedule"] if context_options is not None else None, + context_frames=context_options["context_frames"] if context_options is not None else None, + context_stride=context_options["context_stride"] if context_options is not None else None, + context_overlap=context_options["context_overlap"] if context_options is not None else None + ) return (pipeline, {"samples": latents}) @@ -1395,7 +1424,8 @@ NODE_CLASS_MAPPINGS = { "CogVideoPABConfig": CogVideoPABConfig, "CogVideoTransformerEdit": CogVideoTransformerEdit, "CogVideoControlImageEncode": CogVideoControlImageEncode, - "CogVideoLoraSelect": CogVideoLoraSelect + "CogVideoLoraSelect": CogVideoLoraSelect, + "CogVideoContextOptions": CogVideoContextOptions } NODE_DISPLAY_NAME_MAPPINGS = { "DownloadAndLoadCogVideoModel": "(Down)load CogVideo Model", @@ -1412,5 +1442,6 @@ NODE_DISPLAY_NAME_MAPPINGS = { "CogVideoPABConfig": "CogVideo PABConfig", "CogVideoTransformerEdit": "CogVideo TransformerEdit", "CogVideoControlImageEncode": "CogVideo Control ImageEncode", - "CogVideoLoraSelect": "CogVideo LoraSelect" + "CogVideoLoraSelect": "CogVideo LoraSelect", + "CogVideoContextOptions": "CogVideo Context Options" }