From 627af9341c4760878bc9399c5f9d0706353974d2 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Wed, 2 Oct 2024 18:08:11 +0300 Subject: [PATCH] separate control image encoder and add tiled encode support for it tiled encoding code thanks to MinusZoneML: https://github.com/MinusZoneAI/ComfyUI-CogVideoX-MZ/blob/main/mz_enable_vae_encode_tiling.py --- cogvideox_fun/pipeline_cogvideox_control.py | 41 +++-- mz_enable_vae_encode_tiling.py | 188 ++++++++++++++++++++ nodes.py | 110 +++++++++--- 3 files changed, 294 insertions(+), 45 deletions(-) create mode 100644 mz_enable_vae_encode_tiling.py diff --git a/cogvideox_fun/pipeline_cogvideox_control.py b/cogvideox_fun/pipeline_cogvideox_control.py index 5dcda4a..2a204cf 100644 --- a/cogvideox_fun/pipeline_cogvideox_control.py +++ b/cogvideox_fun/pipeline_cogvideox_control.py @@ -590,26 +590,29 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline): if comfyui_progressbar: pbar.update(1) - if control_video is not None: - video_length = control_video.shape[2] - control_video = self.image_processor.preprocess(rearrange(control_video, "b c f h w -> (b f) c h w"), height=height, width=width) - control_video = control_video.to(dtype=torch.float32) - control_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length) - else: - control_video = None - control_video_latents = self.prepare_control_latents( - None, - control_video, - batch_size, - height, - width, - self.vae.dtype, - device, - generator, - do_classifier_free_guidance - )[1] + # if control_video is not None: + # video_length = control_video.shape[2] + # control_video = self.image_processor.preprocess(rearrange(control_video, "b c f h w -> (b f) c h w"), height=height, width=width) + # control_video = control_video.to(dtype=torch.float32) + # control_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length) + # else: + # control_video = None + + # control_video_latents = self.prepare_control_latents( + # None, + # control_video, + # batch_size, + # height, + # width, + # self.vae.dtype, + # device, + # generator, + # do_classifier_free_guidance + # )[1] + + control_video_latents_input = ( - torch.cat([control_video_latents] * 2) if do_classifier_free_guidance else control_video_latents + torch.cat([control_video] * 2) if do_classifier_free_guidance else control_video ) control_latents = rearrange(control_video_latents_input, "b c f h w -> b f c h w") diff --git a/mz_enable_vae_encode_tiling.py b/mz_enable_vae_encode_tiling.py new file mode 100644 index 0000000..90b1d7d --- /dev/null +++ b/mz_enable_vae_encode_tiling.py @@ -0,0 +1,188 @@ +# thanks to MinusZoneAI: https://github.com/MinusZoneAI/ComfyUI-CogVideoX-MZ/blob/b98b98bd04621e4c85547866c12de2ec723ae98a/mz_enable_vae_encode_tiling.py +from typing import Optional +import torch +from diffusers.utils.accelerate_utils import apply_forward_hook +from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution +from diffusers.models.modeling_outputs import AutoencoderKLOutput + + +@apply_forward_hook +def encode( + self, x: torch.Tensor, return_dict: bool = True +): + """ + Encode a batch of images into latents. + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + Returns: + The latent representations of the encoded videos. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + +def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + r"""Encode a batch of images using a tiled encoder. + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is + different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the + output, but they should be much less noticeable. + Args: + x (`torch.Tensor`): Input batch of videos. + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + # For a rough memory estimate, take a look at the `tiled_decode` method. + batch_size, num_channels, num_frames, height, width = x.shape + overlap_height = int(self.tile_sample_min_height * + (1 - self.tile_overlap_factor_height)) + overlap_width = int(self.tile_sample_min_width * + (1 - self.tile_overlap_factor_width)) + blend_extent_height = int( + self.tile_latent_min_height * self.tile_overlap_factor_height) + blend_extent_width = int( + self.tile_latent_min_width * self.tile_overlap_factor_width) + row_limit_height = self.tile_latent_min_height - blend_extent_height + row_limit_width = self.tile_latent_min_width - blend_extent_width + frame_batch_size = 4 + # Split x into overlapping tiles and encode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, overlap_height): + row = [] + for j in range(0, width, overlap_width): + # Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k. + num_batches = num_frames // frame_batch_size if num_frames > 1 else 1 + time = [] + for k in range(num_batches): + remaining_frames = num_frames % frame_batch_size + start_frame = frame_batch_size * k + \ + (0 if k == 0 else remaining_frames) + end_frame = frame_batch_size * (k + 1) + remaining_frames + tile = x[ + :, + :, + start_frame:end_frame, + i: i + self.tile_sample_min_height, + j: j + self.tile_sample_min_width, + ] + tile = self.encoder(tile) + if self.quant_conv is not None: + tile = self.quant_conv(tile) + time.append(tile) + self._clear_fake_context_parallel_cache() + row.append(torch.cat(time, dim=2)) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v( + rows[i - 1][j], tile, blend_extent_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent_width) + result_row.append( + tile[:, :, :, :row_limit_height, :row_limit_width]) + result_rows.append(torch.cat(result_row, dim=4)) + enc = torch.cat(result_rows, dim=3) + return enc + + +def _encode( + self, x: torch.Tensor, return_dict: bool = True +): + batch_size, num_channels, num_frames, height, width = x.shape + + if self.use_encode_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): + return self.tiled_encode(x) + + if num_frames == 1: + h = self.encoder(x) + if self.quant_conv is not None: + h = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(h) + else: + frame_batch_size = 4 + h = [] + for i in range(num_frames // frame_batch_size): + remaining_frames = num_frames % frame_batch_size + start_frame = frame_batch_size * i + \ + (0 if i == 0 else remaining_frames) + end_frame = frame_batch_size * (i + 1) + remaining_frames + z_intermediate = x[:, :, start_frame:end_frame] + z_intermediate = self.encoder(z_intermediate) + if self.quant_conv is not None: + z_intermediate = self.quant_conv(z_intermediate) + h.append(z_intermediate) + self._clear_fake_context_parallel_cache() + h = torch.cat(h, dim=2) + return h + + +def enable_encode_tiling( + self, + tile_sample_min_height: Optional[int] = None, + tile_sample_min_width: Optional[int] = None, + tile_overlap_factor_height: Optional[float] = None, + tile_overlap_factor_width: Optional[float] = None, +) -> None: + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + + Args: + tile_sample_min_height (`int`, *optional*): + The minimum height required for a sample to be separated into tiles across the height dimension. + tile_sample_min_width (`int`, *optional*): + The minimum width required for a sample to be separated into tiles across the width dimension. + tile_overlap_factor_height (`int`, *optional*): + The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are + no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher + value might cause more tiles to be processed leading to slow down of the decoding process. + tile_overlap_factor_width (`int`, *optional*): + The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there + are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher + value might cause more tiles to be processed leading to slow down of the decoding process. + """ + self.use_encode_tiling = True + self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height + self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width + self.tile_latent_min_height = int( + self.tile_sample_min_height / + (2 ** (len(self.config.block_out_channels) - 1)) + ) + self.tile_latent_min_width = int( + self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1))) + self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height + self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width + + +from types import MethodType + + +def enable_vae_encode_tiling(vae): + vae.encode = MethodType(encode, vae) + setattr(vae, "_encode", MethodType(_encode, vae)) + setattr(vae, "tiled_encode", MethodType(tiled_encode, vae)) + setattr(vae, "use_encode_tiling", True) + + setattr(vae, "enable_encode_tiling", MethodType(enable_encode_tiling, vae)) + vae.enable_encode_tiling() + return vae diff --git a/nodes.py b/nodes.py index 95811ff..7e392dc 100644 --- a/nodes.py +++ b/nodes.py @@ -3,7 +3,7 @@ import torch import folder_paths import comfy.model_management as mm from comfy.utils import ProgressBar, load_torch_file - +from einops import rearrange import importlib.metadata def check_diffusers_version(): @@ -1160,7 +1160,78 @@ class CogVideoXFunVid2VidSampler: # for _lora_path, _lora_weight in zip(cogvideoxfun_model.get("loras", []), cogvideoxfun_model.get("strength_model", [])): # pipeline = unmerge_lora(pipeline, _lora_path, _lora_weight) return (pipeline, {"samples": latents}) - + + +class CogVideoControlImageEncode: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "pipeline": ("COGVIDEOPIPE",), + "control_video": ("IMAGE", ), + "base_resolution": ("INT", {"min": 256, "max": 1280, "step": 64, "default": 512, "tooltip": "Base resolution, closest training data bucket resolution is chosen based on the selection."}), + "enable_tiling": ("BOOLEAN", {"default": False, "tooltip": "Enable tiling for the VAE to reduce memory usage"}), + }, + } + + RETURN_TYPES = ("COGCONTROL_LATENTS",) + RETURN_NAMES = ("control_latents",) + FUNCTION = "encode" + CATEGORY = "CogVideoWrapper" + + def encode(self, pipeline, control_video, base_resolution, enable_tiling): + device = mm.get_torch_device() + offload_device = mm.unet_offload_device() + + B, H, W, C = control_video.shape + + vae = pipeline["pipe"].vae + vae.enable_slicing() + + if enable_tiling: + from .mz_enable_vae_encode_tiling import enable_vae_encode_tiling + enable_vae_encode_tiling(vae) + + if not pipeline["cpu_offloading"]: + vae.to(device) + + # Count most suitable height and width + aspect_ratio_sample_size = {key : [x / 512 * base_resolution for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()} + + control_video = np.array(control_video.cpu().numpy() * 255, np.uint8) + original_width, original_height = Image.fromarray(control_video[0]).size + + closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size) + height, width = [int(x / 16) * 16 for x in closest_size] + + video_length = int((B - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if B != 1 else 1 + input_video, input_video_mask, clip_image = get_video_to_video_latent(control_video, video_length=video_length, sample_size=(height, width)) + + control_video = pipeline["pipe"].image_processor.preprocess(rearrange(input_video, "b c f h w -> (b f) c h w"), height=height, width=width) + control_video = control_video.to(dtype=torch.float32) + control_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length) + + masked_image = control_video.to(device=device, dtype=vae.dtype) + bs = 1 + new_mask_pixel_values = [] + for i in range(0, masked_image.shape[0], bs): + mask_pixel_values_bs = masked_image[i : i + bs] + mask_pixel_values_bs = vae.encode(mask_pixel_values_bs)[0] + mask_pixel_values_bs = mask_pixel_values_bs.mode() + new_mask_pixel_values.append(mask_pixel_values_bs) + masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0) + masked_image_latents = masked_image_latents * vae.config.scaling_factor + + vae.to(offload_device) + + control_latents = { + "latents": masked_image_latents, + "num_frames" : B, + "height" : height, + "width" : width, + } + + return (control_latents, ) + class CogVideoXFunControlSampler: @classmethod def INPUT_TYPES(s): @@ -1169,10 +1240,7 @@ class CogVideoXFunControlSampler: "pipeline": ("COGVIDEOPIPE",), "positive": ("CONDITIONING", ), "negative": ("CONDITIONING", ), - "video_length": ("INT", {"default": 49, "min": 5, "max": 49, "step": 4}), - "base_resolution": ( - [256,320,384,448,512,768,960,1024,], {"default": 512} - ), + "control_latents": ("COGCONTROL_LATENTS",), "seed": ("INT", {"default": 42, "min": 0, "max": 0xffffffffffffffff}), "steps": ("INT", {"default": 25, "min": 1, "max": 200, "step": 1}), "cfg": ("FLOAT", {"default": 6.0, "min": 1.0, "max": 20.0, "step": 0.01}), @@ -1194,7 +1262,6 @@ class CogVideoXFunControlSampler: "default": 'DDIM' } ), - "control_video": ("IMAGE",), "control_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), "control_start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}), "control_end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), @@ -1206,8 +1273,8 @@ class CogVideoXFunControlSampler: FUNCTION = "process" CATEGORY = "CogVideoWrapper" - def process(self, pipeline, positive, negative, video_length, base_resolution, seed, steps, cfg, scheduler, - control_video=None, control_strength=1.0, control_start_percent=0.0, control_end_percent=1.0): + def process(self, pipeline, positive, negative, seed, steps, cfg, scheduler, + control_latents, control_strength=1.0, control_start_percent=0.0, control_end_percent=1.0): device = mm.get_torch_device() offload_device = mm.unet_offload_device() pipe = pipeline["pipe"] @@ -1221,15 +1288,6 @@ class CogVideoXFunControlSampler: mm.soft_empty_cache() - # Count most suitable height and width - aspect_ratio_sample_size = {key : [x / 512 * base_resolution for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()} - - control_video = np.array(control_video.cpu().numpy() * 255, np.uint8) - original_width, original_height = Image.fromarray(control_video[0]).size - - closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size) - height, width = [int(x / 16) * 16 for x in closest_size] - # Load Sampler scheduler_config = pipeline["scheduler_config"] if scheduler in scheduler_mapping: @@ -1243,8 +1301,6 @@ class CogVideoXFunControlSampler: autocastcondition = not pipeline["onediff"] autocast_context = torch.autocast(mm.get_autocast_device(device)) if autocastcondition else nullcontext() with autocast_context: - video_length = int((video_length - 1) // pipe.vae.config.temporal_compression_ratio * pipe.vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1 - input_video, input_video_mask, clip_image = get_video_to_video_latent(control_video, video_length=video_length, sample_size=(height, width)) # for _lora_path, _lora_weight in zip(cogvideoxfun_model.get("loras", []), cogvideoxfun_model.get("strength_model", [])): # pipeline = merge_lora(pipeline, _lora_path, _lora_weight) @@ -1252,9 +1308,9 @@ class CogVideoXFunControlSampler: common_params = { "prompt_embeds": positive.to(dtype).to(device), "negative_prompt_embeds": negative.to(dtype).to(device), - "num_frames": video_length, - "height": height, - "width": width, + "num_frames": control_latents["num_frames"], + "height": control_latents["height"], + "width": control_latents["width"], "generator": generator, "guidance_scale": cfg, "num_inference_steps": steps, @@ -1263,7 +1319,7 @@ class CogVideoXFunControlSampler: latents = pipe( **common_params, - control_video=input_video, + control_video=control_latents["latents"], control_strength=control_strength, control_start_percent=control_start_percent, control_end_percent=control_end_percent @@ -1286,7 +1342,8 @@ NODE_CLASS_MAPPINGS = { "CogVideoTextEncodeCombine": CogVideoTextEncodeCombine, "DownloadAndLoadCogVideoGGUFModel": DownloadAndLoadCogVideoGGUFModel, "CogVideoPABConfig": CogVideoPABConfig, - "CogVideoTransformerEdit": CogVideoTransformerEdit + "CogVideoTransformerEdit": CogVideoTransformerEdit, + "CogVideoControlImageEncode": CogVideoControlImageEncode } NODE_DISPLAY_NAME_MAPPINGS = { "DownloadAndLoadCogVideoModel": "(Down)load CogVideo Model", @@ -1301,5 +1358,6 @@ NODE_DISPLAY_NAME_MAPPINGS = { "CogVideoTextEncodeCombine": "CogVideo TextEncode Combine", "DownloadAndLoadCogVideoGGUFModel": "(Down)load CogVideo GGUF Model", "CogVideoPABConfig": "CogVideo PABConfig", - "CogVideoTransformerEdit": "CogVideo TransformerEdit" + "CogVideoTransformerEdit": "CogVideo TransformerEdit", + "CogVideoControlImageEncode": "CogVideo Control ImageEncode" }