From 7af6666c67f855e2e1533c7fbc3091e97b7fbc11 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Fri, 4 Oct 2024 13:14:54 +0300 Subject: [PATCH] latent input for the control sampler, allows vid2vid along with pose input --- cogvideox_fun/pipeline_cogvideox_control.py | 31 ++++++++++++++----- nodes.py | 34 +++++++++++++-------- 2 files changed, 45 insertions(+), 20 deletions(-) diff --git a/cogvideox_fun/pipeline_cogvideox_control.py b/cogvideox_fun/pipeline_cogvideox_control.py index e5578ad..9edc283 100644 --- a/cogvideox_fun/pipeline_cogvideox_control.py +++ b/cogvideox_fun/pipeline_cogvideox_control.py @@ -214,7 +214,7 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline): set_pab_manager(pab_config) def prepare_latents( - self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, timesteps, denoise_strength, num_inference_steps, latents=None, ): shape = ( batch_size, @@ -228,15 +228,28 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline): f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - + noise = randn_tensor(shape, generator=generator, device=device, dtype=self.vae.dtype) if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = noise else: latents = latents.to(device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, denoise_strength, device) + latent_timestep = timesteps[:1] + + noise = randn_tensor(shape, generator=generator, device=device, dtype=self.vae.dtype) + frames_needed = noise.shape[1] + current_frames = latents.shape[1] + + if frames_needed > current_frames: + repeat_factor = frames_needed // current_frames + additional_frame = torch.randn((latents.size(0), repeat_factor, latents.size(2), latents.size(3), latents.size(4)), dtype=latents.dtype, device=latents.device) + latents = torch.cat((latents, additional_frame), dim=1) + elif frames_needed < current_frames: + latents = latents[:, :frames_needed, :, :, :] - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - return latents + latents = self.scheduler.add_noise(latents, noise, latent_timestep) + latents = latents * self.scheduler.init_noise_sigma # scale the initial noise by the standard deviation required by the scheduler + return latents, timesteps, noise def prepare_control_latents( self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance @@ -452,6 +465,7 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline): timesteps: Optional[List[int]] = None, guidance_scale: float = 6, use_dynamic_cfg: bool = False, + denoise_strength: float = 1.0, num_videos_per_prompt: int = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, @@ -601,7 +615,7 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline): # 5. Prepare latents. latent_channels = self.vae.config.latent_channels - latents = self.prepare_latents( + latents, timesteps, noise = self.prepare_latents( batch_size * num_videos_per_prompt, latent_channels, num_frames, @@ -610,6 +624,9 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline): self.vae.dtype, device, generator, + timesteps, + denoise_strength, + num_inference_steps, latents, ) if comfyui_progressbar: diff --git a/nodes.py b/nodes.py index 24897a5..36a50e8 100644 --- a/nodes.py +++ b/nodes.py @@ -746,7 +746,7 @@ class CogVideoImageEncode: }, "optional": { "chunk_size": ("INT", {"default": 16, "min": 1}), - "enable_vae_slicing": ("BOOLEAN", {"default": True, "tooltip": "VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes."}), + "enable_tiling": ("BOOLEAN", {"default": False, "tooltip": "Enable tiling for the VAE to reduce memory usage"}), "mask": ("MASK", ), }, } @@ -756,7 +756,7 @@ class CogVideoImageEncode: FUNCTION = "encode" CATEGORY = "CogVideoWrapper" - def encode(self, pipeline, image, chunk_size=8, enable_vae_slicing=True, mask=None): + def encode(self, pipeline, image, chunk_size=8, enable_tiling=False, mask=None): device = mm.get_torch_device() offload_device = mm.unet_offload_device() generator = torch.Generator(device=device).manual_seed(0) @@ -764,14 +764,16 @@ class CogVideoImageEncode: B, H, W, C = image.shape vae = pipeline["pipe"].vae + vae.enable_slicing() - if enable_vae_slicing: - vae.enable_slicing() - else: - vae.disable_slicing() + if enable_tiling: + from .mz_enable_vae_encode_tiling import enable_vae_encode_tiling + enable_vae_encode_tiling(vae) if not pipeline["cpu_offloading"]: vae.to(device) + + vae._clear_fake_context_parallel_cache() input_image = image.clone() if mask is not None: @@ -1211,8 +1213,8 @@ class CogVideoControlImageEncode: }, } - RETURN_TYPES = ("COGCONTROL_LATENTS",) - RETURN_NAMES = ("control_latents",) + RETURN_TYPES = ("COGCONTROL_LATENTS", "INT", "INT",) + RETURN_NAMES = ("control_latents", "width", "height") FUNCTION = "encode" CATEGORY = "CogVideoWrapper" @@ -1271,7 +1273,7 @@ class CogVideoControlImageEncode: "width" : width, } - return (control_latents, ) + return (control_latents, width, height) class CogVideoXFunControlSampler: @classmethod @@ -1309,7 +1311,10 @@ class CogVideoXFunControlSampler: "control_end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), "t_tile_length": ("INT", {"default": 48, "min": 2, "max": 128, "step": 1, "tooltip": "Length of temporal tiles for extending generations, only in effect with the tiled samplers"}), "t_tile_overlap": ("INT", {"default": 8, "min": 2, "max": 128, "step": 1, "tooltip": "Overlap of temporal tiling"}), - + }, + "optional": { + "samples": ("LATENT", ), + "denoise_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), }, } @@ -1318,8 +1323,9 @@ class CogVideoXFunControlSampler: FUNCTION = "process" CATEGORY = "CogVideoWrapper" - def process(self, pipeline, positive, negative, seed, steps, cfg, scheduler, - control_latents, control_strength=1.0, control_start_percent=0.0, control_end_percent=1.0, t_tile_length=16, t_tile_overlap=8,): + def process(self, pipeline, positive, negative, seed, steps, cfg, scheduler, control_latents, + control_strength=1.0, control_start_percent=0.0, control_end_percent=1.0, t_tile_length=16, t_tile_overlap=8, + samples=None, denoise_strength=1.0): device = mm.get_torch_device() offload_device = mm.unet_offload_device() pipe = pipeline["pipe"] @@ -1367,7 +1373,9 @@ class CogVideoXFunControlSampler: control_end_percent=control_end_percent, t_tile_length=t_tile_length, t_tile_overlap=t_tile_overlap, - scheduler_name=scheduler + scheduler_name=scheduler, + latents=samples["samples"] if samples is not None else None, + denoise_strength=denoise_strength, ) return (pipeline, {"samples": latents})