From 00d38f9a22e8f40d0caae7e7c363717125268fa3 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sun, 6 Oct 2024 01:28:20 +0300 Subject: [PATCH] context schedule and all samplers for the non-Fun sampler as well --- cogvideox_fun/pipeline_cogvideox_control.py | 10 +- nodes.py | 87 +++++------ pipeline_cogvideox.py | 156 ++++++++++++++++++-- 3 files changed, 179 insertions(+), 74 deletions(-) diff --git a/cogvideox_fun/pipeline_cogvideox_control.py b/cogvideox_fun/pipeline_cogvideox_control.py index 93983a8..966e0ee 100644 --- a/cogvideox_fun/pipeline_cogvideox_control.py +++ b/cogvideox_fun/pipeline_cogvideox_control.py @@ -841,11 +841,6 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline): for c in context_queue: partial_latent_model_input = latent_model_input[:, c, :, :, :] partial_control_latents = current_control_latents[:, c, :, :, :] - # image_rotary_emb = ( - # self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device, context_frames=c) - # if self.transformer.config.use_rotary_positional_embeddings - # else None - # ) # predict noise model_output noise_pred[:, c, :, :, :] += self.transformer( @@ -857,7 +852,7 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline): control_latents=partial_control_latents, )[0] - counter[:, c, :, :, :] += 1 + # uncond if do_classifier_free_guidance: noise_uncond[:, c, :, :, :] += self.transformer( hidden_states=partial_latent_model_input, @@ -867,7 +862,8 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline): return_dict=False, control_latents=partial_control_latents, )[0] - + + counter[:, c, :, :, :] += 1 noise_pred = noise_pred.float() noise_pred /= counter diff --git a/nodes.py b/nodes.py index 24b8ce0..f232dba 100644 --- a/nodes.py +++ b/nodes.py @@ -27,6 +27,7 @@ from diffusers.schedulers import ( HeunDiscreteScheduler, SASolverScheduler, DEISMultistepScheduler, + LCMScheduler ) scheduler_mapping = { @@ -41,8 +42,11 @@ scheduler_mapping = { "SASolverScheduler": SASolverScheduler, "UniPCMultistepScheduler": UniPCMultistepScheduler, "HeunDiscreteScheduler": HeunDiscreteScheduler, - "DEISMultistepScheduler": DEISMultistepScheduler + "DEISMultistepScheduler": DEISMultistepScheduler, + "LCMScheduler": LCMScheduler } +available_schedulers = list(scheduler_mapping.keys()) + from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel from .pipeline_cogvideox import CogVideoXPipeline @@ -833,14 +837,16 @@ class CogVideoSampler: "steps": ("INT", {"default": 50, "min": 1}), "cfg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 30.0, "step": 0.01}), "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), - "scheduler": (["DDIM", "DPM", "DDIM_tiled"], {"tooltip": "5B likes DPM, but it doesn't support temporal tiling"}), - "t_tile_length": ("INT", {"default": 16, "min": 2, "max": 128, "step": 1, "tooltip": "Length of temporal tiling, use same alue as num_frames to disable, disabled automatically for DPM"}), - "t_tile_overlap": ("INT", {"default": 8, "min": 2, "max": 128, "step": 1, "tooltip": "Overlap of temporal tiling"}), + "scheduler": (available_schedulers, + { + "default": 'CogVideoXDDIM' + }), }, "optional": { "samples": ("LATENT", ), "denoise_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), "image_cond_latents": ("LATENT", ), + "context_options": ("COGCONTEXT", ), } } @@ -849,17 +855,13 @@ class CogVideoSampler: FUNCTION = "process" CATEGORY = "CogVideoWrapper" - def process(self, pipeline, positive, negative, steps, cfg, seed, height, width, num_frames, scheduler, t_tile_length, t_tile_overlap, samples=None, - denoise_strength=1.0, image_cond_latents=None): + def process(self, pipeline, positive, negative, steps, cfg, seed, height, width, num_frames, scheduler, samples=None, + denoise_strength=1.0, image_cond_latents=None, context_options=None): mm.soft_empty_cache() base_path = pipeline["base_path"] assert "fun" not in base_path.lower(), "'Fun' models not supported in 'CogVideoSampler', use the 'CogVideoXFunSampler'" - assert t_tile_length > t_tile_overlap, "t_tile_length must be greater than t_tile_overlap" - assert t_tile_length <= num_frames, "t_tile_length must be equal or less than num_frames" - t_tile_length = t_tile_length // 4 - t_tile_overlap = t_tile_overlap // 4 device = mm.get_torch_device() offload_device = mm.unet_offload_device() @@ -869,12 +871,20 @@ class CogVideoSampler: if not pipeline["cpu_offloading"]: pipe.transformer.to(device) - generator = torch.Generator(device=device).manual_seed(seed) + generator = torch.Generator(device=torch.device("cpu")).manual_seed(seed) - if scheduler == "DDIM" or scheduler == "DDIM_tiled": - pipe.scheduler = CogVideoXDDIMScheduler.from_config(scheduler_config) - elif scheduler == "DPM": - pipe.scheduler = CogVideoXDPMScheduler.from_config(scheduler_config) + if scheduler in scheduler_mapping: + noise_scheduler = scheduler_mapping[scheduler].from_config(scheduler_config) + pipe.scheduler = noise_scheduler + else: + raise ValueError(f"Unknown scheduler: {scheduler}") + + if context_options is not None: + context_frames = context_options["context_frames"] // 4 + context_stride = context_options["context_stride"] // 4 + context_overlap = context_options["context_overlap"] // 4 + else: + context_frames, context_stride, context_overlap = None, None, None if negative.shape[1] < positive.shape[1]: target_length = positive.shape[1] @@ -889,8 +899,6 @@ class CogVideoSampler: height = height, width = width, num_frames = num_frames, - t_tile_length = t_tile_length, - t_tile_overlap = t_tile_overlap, guidance_scale=cfg, latents=samples["samples"] if samples is not None else None, image_cond_latents=image_cond_latents["samples"] if image_cond_latents is not None else None, @@ -899,7 +907,12 @@ class CogVideoSampler: negative_prompt_embeds=negative.to(dtype).to(device), generator=generator, device=device, - scheduler_name=scheduler + scheduler_name=scheduler, + context_schedule=context_options["context_schedule"] if context_options is not None else None, + context_frames=context_frames, + context_stride= context_stride, + context_overlap= context_overlap, + freenoise=context_options["freenoise"] if context_options is not None else None ) if not pipeline["cpu_offloading"]: pipe.transformer.to(offload_device) @@ -979,24 +992,7 @@ class CogVideoXFunSampler: "seed": ("INT", {"default": 43, "min": 0, "max": 0xffffffffffffffff}), "steps": ("INT", {"default": 50, "min": 1, "max": 200, "step": 1}), "cfg": ("FLOAT", {"default": 6.0, "min": 1.0, "max": 20.0, "step": 0.01}), - "scheduler": ( - [ - "Euler", - "Euler A", - "DPM++", - "PNDM", - "DDIM", - "SASolverScheduler", - "UniPCMultistepScheduler", - "HeunDiscreteScheduler", - "DEISMultistepScheduler", - "CogVideoXDDIM", - "CogVideoXDPMScheduler", - ], - { - "default": 'DDIM' - } - ) + "scheduler": (available_schedulers, {"default": 'DDIM'}) }, "optional":{ "start_img": ("IMAGE",), @@ -1331,24 +1327,7 @@ class CogVideoXFunControlSampler: "seed": ("INT", {"default": 42, "min": 0, "max": 0xffffffffffffffff}), "steps": ("INT", {"default": 25, "min": 1, "max": 200, "step": 1}), "cfg": ("FLOAT", {"default": 6.0, "min": 1.0, "max": 20.0, "step": 0.01}), - "scheduler": ( - [ - "Euler", - "Euler A", - "DPM++", - "PNDM", - "DDIM", - "SASolverScheduler", - "UniPCMultistepScheduler", - "HeunDiscreteScheduler", - "DEISMultistepScheduler", - "CogVideoXDDIM", - "CogVideoXDPMScheduler", - ], - { - "default": 'DDIM' - } - ), + "scheduler": (available_schedulers, {"default": 'DDIM'}), "control_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), "control_start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}), "control_end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), diff --git a/pipeline_cogvideox.py b/pipeline_cogvideox.py index 57e2172..2abd11f 100644 --- a/pipeline_cogvideox.py +++ b/pipeline_cogvideox.py @@ -21,7 +21,6 @@ import torch.nn.functional as F import math from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel -from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler from diffusers.utils import logging from diffusers.utils.torch_utils import randn_tensor @@ -164,7 +163,8 @@ class CogVideoXPipeline(VideoSysPipeline): set_pab_manager(pab_config) def prepare_latents( - self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, timesteps, denoise_strength, num_inference_steps, latents=None, + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, timesteps, denoise_strength, + num_inference_steps, latents=None, freenoise=True, context_size=None, context_overlap=None ): shape = ( batch_size, @@ -178,9 +178,43 @@ class CogVideoXPipeline(VideoSysPipeline): f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - noise = randn_tensor(shape, generator=generator, device=device, dtype=self.vae.dtype) + noise = randn_tensor(shape, generator=generator, device=torch.device("cpu"), dtype=self.vae.dtype) + if freenoise: + print("Applying FreeNoise") + # code and comments from AnimateDiff-Evolved by Kosinkadink (https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved) + video_length = num_frames // 4 + delta = context_size - context_overlap + for start_idx in range(0, video_length-context_size, delta): + # start_idx corresponds to the beginning of a context window + # goal: place shuffled in the delta region right after the end of the context window + # if space after context window is not enough to place the noise, adjust and finish + place_idx = start_idx + context_size + # if place_idx is outside the valid indexes, we are already finished + if place_idx >= video_length: + break + end_idx = place_idx - 1 + #print("video_length:", video_length, "start_idx:", start_idx, "end_idx:", end_idx, "place_idx:", place_idx, "delta:", delta) + + # if there is not enough room to copy delta amount of indexes, copy limited amount and finish + if end_idx + delta >= video_length: + final_delta = video_length - place_idx + # generate list of indexes in final delta region + list_idx = torch.tensor(list(range(start_idx,start_idx+final_delta)), device=torch.device("cpu"), dtype=torch.long) + # shuffle list + list_idx = list_idx[torch.randperm(final_delta, generator=generator)] + # apply shuffled indexes + noise[:, place_idx:place_idx + final_delta, :, :, :] = noise[:, list_idx, :, :, :] + break + # otherwise, do normal behavior + # generate list of indexes in delta region + list_idx = torch.tensor(list(range(start_idx,start_idx+delta)), device=torch.device("cpu"), dtype=torch.long) + # shuffle list + list_idx = list_idx[torch.randperm(delta, generator=generator)] + # apply shuffled indexes + #print("place_idx:", place_idx, "delta:", delta, "list_idx:", list_idx) + noise[:, place_idx:place_idx + delta, :, :, :] = noise[:, list_idx, :, :, :] if latents is None: - latents = noise + latents = noise.to(device) else: latents = latents.to(device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, denoise_strength, device) @@ -346,6 +380,11 @@ class CogVideoXPipeline(VideoSysPipeline): negative_prompt_embeds: Optional[torch.Tensor] = None, device = torch.device("cuda"), scheduler_name: str = "DPM", + context_schedule: Optional[str] = None, + context_frames: Optional[int] = None, + context_stride: Optional[int] = None, + context_overlap: Optional[int] = None, + freenoise: Optional[bool] = True, ): """ Function invoked when calling the pipeline for generation. @@ -448,7 +487,10 @@ class CogVideoXPipeline(VideoSysPipeline): timesteps, denoise_strength, num_inference_steps, - latents + latents, + context_size=context_frames, + context_overlap=context_overlap, + freenoise=freenoise, ) latents = latents.to(self.vae.dtype) #print("latents", latents.shape) @@ -492,22 +534,37 @@ class CogVideoXPipeline(VideoSysPipeline): num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) comfy_pbar = ProgressBar(num_inference_steps) - # 8. Temporal tiling prep - if "tiled" in scheduler_name: + # 8.5. Temporal tiling prep + if context_schedule is not None and context_schedule == "temporal_tiling": + t_tile_length = context_frames + t_tile_overlap = context_overlap t_tile_weights = self._gaussian_weights(t_tile_length=t_tile_length, t_batch_size=1).to(latents.device).to(self.vae.dtype) - temporal_tiling = True + use_temporal_tiling = True print("Temporal tiling enabled") + elif context_schedule is not None: + print(f"Context schedule enabled: {context_frames} frames, {context_stride} stride, {context_overlap} overlap") + use_temporal_tiling = False + use_context_schedule = True + from .cogvideox_fun.context import get_context_scheduler + context = get_context_scheduler(context_schedule) + else: - temporal_tiling = False - print("Temporal tiling disabled") - #print("latents.shape", latents.shape) + use_temporal_tiling = False + use_context_schedule = False + print("Temporal tiling and context schedule disabled") + # 7. Create rotary embeds if required + image_rotary_emb = ( + self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) + if self.transformer.config.use_rotary_positional_embeddings + else None + ) with self.progress_bar(total=num_inference_steps) as progress_bar: old_pred_original_sample = None # for DPM-solver++ for i, t in enumerate(timesteps): if self.interrupt: continue - if temporal_tiling and isinstance(self.scheduler, CogVideoXDDIMScheduler): + if use_temporal_tiling and isinstance(self.scheduler, CogVideoXDDIMScheduler): #temporal tiling code based on https://github.com/mayuelala/FollowYourEmoji/blob/main/models/video_pipeline.py # ===================================================== grid_ts = 0 @@ -533,7 +590,7 @@ class CogVideoXPipeline(VideoSysPipeline): #latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) image_rotary_emb = ( - self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device, input_start_t, input_end_t) + self._prepare_rotary_positional_embeddings(height, width, t_tile_length, device) if self.transformer.config.use_rotary_positional_embeddings else None ) @@ -600,6 +657,79 @@ class CogVideoXPipeline(VideoSysPipeline): progress_bar.update() comfy_pbar.update(1) # ========================================== + elif use_context_schedule: + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + counter = torch.zeros_like(latent_model_input) + noise_pred = torch.zeros_like(latent_model_input) + if do_classifier_free_guidance: + noise_uncond = torch.zeros_like(latent_model_input) + + if image_cond_latents is not None: + latent_image_input = torch.cat([image_cond_latents] * 2) if do_classifier_free_guidance else image_cond_latents + latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=2) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + context_queue = list(context( + i, num_inference_steps, latents.shape[1], context_frames, context_stride, context_overlap, + )) + + image_rotary_emb = ( + self._prepare_rotary_positional_embeddings(height, width, context_frames, device) + if self.transformer.config.use_rotary_positional_embeddings + else None + ) + + for c in context_queue: + partial_latent_model_input = latent_model_input[:, c, :, :, :] + + # predict noise model_output + noise_pred[:, c, :, :, :] += self.transformer( + hidden_states=partial_latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + return_dict=False, + )[0] + + # uncond + if do_classifier_free_guidance: + noise_uncond[:, c, :, :, :] += self.transformer( + hidden_states=partial_latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + return_dict=False, + )[0] + + counter[:, c, :, :, :] += 1 + noise_pred = noise_pred.float() + + noise_pred /= counter + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self._guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if not isinstance(self.scheduler, CogVideoXDPMScheduler): + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + else: + latents, old_pred_original_sample = self.scheduler.step( + noise_pred, + old_pred_original_sample, + t, + timesteps[i - 1] if i > 0 else None, + latents, + **extra_step_kwargs, + return_dict=False, + ) + latents = latents.to(prompt_embeds.dtype) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + else: latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)