separate control image encoder and add tiled encode support for it

tiled encoding code thanks to MinusZoneML:
https://github.com/MinusZoneAI/ComfyUI-CogVideoX-MZ/blob/main/mz_enable_vae_encode_tiling.py
This commit is contained in:
kijai 2024-10-02 18:08:11 +03:00
parent 3de0113927
commit 627af9341c
3 changed files with 294 additions and 45 deletions

View File

@ -590,26 +590,29 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
if comfyui_progressbar:
pbar.update(1)
if control_video is not None:
video_length = control_video.shape[2]
control_video = self.image_processor.preprocess(rearrange(control_video, "b c f h w -> (b f) c h w"), height=height, width=width)
control_video = control_video.to(dtype=torch.float32)
control_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length)
else:
control_video = None
control_video_latents = self.prepare_control_latents(
None,
control_video,
batch_size,
height,
width,
self.vae.dtype,
device,
generator,
do_classifier_free_guidance
)[1]
# if control_video is not None:
# video_length = control_video.shape[2]
# control_video = self.image_processor.preprocess(rearrange(control_video, "b c f h w -> (b f) c h w"), height=height, width=width)
# control_video = control_video.to(dtype=torch.float32)
# control_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length)
# else:
# control_video = None
# control_video_latents = self.prepare_control_latents(
# None,
# control_video,
# batch_size,
# height,
# width,
# self.vae.dtype,
# device,
# generator,
# do_classifier_free_guidance
# )[1]
control_video_latents_input = (
torch.cat([control_video_latents] * 2) if do_classifier_free_guidance else control_video_latents
torch.cat([control_video] * 2) if do_classifier_free_guidance else control_video
)
control_latents = rearrange(control_video_latents_input, "b c f h w -> b f c h w")

View File

@ -0,0 +1,188 @@
# thanks to MinusZoneAI: https://github.com/MinusZoneAI/ComfyUI-CogVideoX-MZ/blob/b98b98bd04621e4c85547866c12de2ec723ae98a/mz_enable_vae_encode_tiling.py
from typing import Optional
import torch
from diffusers.utils.accelerate_utils import apply_forward_hook
from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution
from diffusers.models.modeling_outputs import AutoencoderKLOutput
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
):
"""
Encode a batch of images into latents.
Args:
x (`torch.Tensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
The latent representations of the encoded videos. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self._encode(x)
posterior = DiagonalGaussianDistribution(h)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
r"""Encode a batch of images using a tiled encoder.
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
output, but they should be much less noticeable.
Args:
x (`torch.Tensor`): Input batch of videos.
Returns:
`torch.Tensor`:
The latent representation of the encoded videos.
"""
# For a rough memory estimate, take a look at the `tiled_decode` method.
batch_size, num_channels, num_frames, height, width = x.shape
overlap_height = int(self.tile_sample_min_height *
(1 - self.tile_overlap_factor_height))
overlap_width = int(self.tile_sample_min_width *
(1 - self.tile_overlap_factor_width))
blend_extent_height = int(
self.tile_latent_min_height * self.tile_overlap_factor_height)
blend_extent_width = int(
self.tile_latent_min_width * self.tile_overlap_factor_width)
row_limit_height = self.tile_latent_min_height - blend_extent_height
row_limit_width = self.tile_latent_min_width - blend_extent_width
frame_batch_size = 4
# Split x into overlapping tiles and encode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, height, overlap_height):
row = []
for j in range(0, width, overlap_width):
# Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
num_batches = num_frames // frame_batch_size if num_frames > 1 else 1
time = []
for k in range(num_batches):
remaining_frames = num_frames % frame_batch_size
start_frame = frame_batch_size * k + \
(0 if k == 0 else remaining_frames)
end_frame = frame_batch_size * (k + 1) + remaining_frames
tile = x[
:,
:,
start_frame:end_frame,
i: i + self.tile_sample_min_height,
j: j + self.tile_sample_min_width,
]
tile = self.encoder(tile)
if self.quant_conv is not None:
tile = self.quant_conv(tile)
time.append(tile)
self._clear_fake_context_parallel_cache()
row.append(torch.cat(time, dim=2))
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(
rows[i - 1][j], tile, blend_extent_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent_width)
result_row.append(
tile[:, :, :, :row_limit_height, :row_limit_width])
result_rows.append(torch.cat(result_row, dim=4))
enc = torch.cat(result_rows, dim=3)
return enc
def _encode(
self, x: torch.Tensor, return_dict: bool = True
):
batch_size, num_channels, num_frames, height, width = x.shape
if self.use_encode_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
return self.tiled_encode(x)
if num_frames == 1:
h = self.encoder(x)
if self.quant_conv is not None:
h = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(h)
else:
frame_batch_size = 4
h = []
for i in range(num_frames // frame_batch_size):
remaining_frames = num_frames % frame_batch_size
start_frame = frame_batch_size * i + \
(0 if i == 0 else remaining_frames)
end_frame = frame_batch_size * (i + 1) + remaining_frames
z_intermediate = x[:, :, start_frame:end_frame]
z_intermediate = self.encoder(z_intermediate)
if self.quant_conv is not None:
z_intermediate = self.quant_conv(z_intermediate)
h.append(z_intermediate)
self._clear_fake_context_parallel_cache()
h = torch.cat(h, dim=2)
return h
def enable_encode_tiling(
self,
tile_sample_min_height: Optional[int] = None,
tile_sample_min_width: Optional[int] = None,
tile_overlap_factor_height: Optional[float] = None,
tile_overlap_factor_width: Optional[float] = None,
) -> None:
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
Args:
tile_sample_min_height (`int`, *optional*):
The minimum height required for a sample to be separated into tiles across the height dimension.
tile_sample_min_width (`int`, *optional*):
The minimum width required for a sample to be separated into tiles across the width dimension.
tile_overlap_factor_height (`int`, *optional*):
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher
value might cause more tiles to be processed leading to slow down of the decoding process.
tile_overlap_factor_width (`int`, *optional*):
The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there
are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher
value might cause more tiles to be processed leading to slow down of the decoding process.
"""
self.use_encode_tiling = True
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
self.tile_latent_min_height = int(
self.tile_sample_min_height /
(2 ** (len(self.config.block_out_channels) - 1))
)
self.tile_latent_min_width = int(
self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height
self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width
from types import MethodType
def enable_vae_encode_tiling(vae):
vae.encode = MethodType(encode, vae)
setattr(vae, "_encode", MethodType(_encode, vae))
setattr(vae, "tiled_encode", MethodType(tiled_encode, vae))
setattr(vae, "use_encode_tiling", True)
setattr(vae, "enable_encode_tiling", MethodType(enable_encode_tiling, vae))
vae.enable_encode_tiling()
return vae

110
nodes.py
View File

@ -3,7 +3,7 @@ import torch
import folder_paths
import comfy.model_management as mm
from comfy.utils import ProgressBar, load_torch_file
from einops import rearrange
import importlib.metadata
def check_diffusers_version():
@ -1160,7 +1160,78 @@ class CogVideoXFunVid2VidSampler:
# for _lora_path, _lora_weight in zip(cogvideoxfun_model.get("loras", []), cogvideoxfun_model.get("strength_model", [])):
# pipeline = unmerge_lora(pipeline, _lora_path, _lora_weight)
return (pipeline, {"samples": latents})
class CogVideoControlImageEncode:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"pipeline": ("COGVIDEOPIPE",),
"control_video": ("IMAGE", ),
"base_resolution": ("INT", {"min": 256, "max": 1280, "step": 64, "default": 512, "tooltip": "Base resolution, closest training data bucket resolution is chosen based on the selection."}),
"enable_tiling": ("BOOLEAN", {"default": False, "tooltip": "Enable tiling for the VAE to reduce memory usage"}),
},
}
RETURN_TYPES = ("COGCONTROL_LATENTS",)
RETURN_NAMES = ("control_latents",)
FUNCTION = "encode"
CATEGORY = "CogVideoWrapper"
def encode(self, pipeline, control_video, base_resolution, enable_tiling):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
B, H, W, C = control_video.shape
vae = pipeline["pipe"].vae
vae.enable_slicing()
if enable_tiling:
from .mz_enable_vae_encode_tiling import enable_vae_encode_tiling
enable_vae_encode_tiling(vae)
if not pipeline["cpu_offloading"]:
vae.to(device)
# Count most suitable height and width
aspect_ratio_sample_size = {key : [x / 512 * base_resolution for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()}
control_video = np.array(control_video.cpu().numpy() * 255, np.uint8)
original_width, original_height = Image.fromarray(control_video[0]).size
closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size)
height, width = [int(x / 16) * 16 for x in closest_size]
video_length = int((B - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if B != 1 else 1
input_video, input_video_mask, clip_image = get_video_to_video_latent(control_video, video_length=video_length, sample_size=(height, width))
control_video = pipeline["pipe"].image_processor.preprocess(rearrange(input_video, "b c f h w -> (b f) c h w"), height=height, width=width)
control_video = control_video.to(dtype=torch.float32)
control_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length)
masked_image = control_video.to(device=device, dtype=vae.dtype)
bs = 1
new_mask_pixel_values = []
for i in range(0, masked_image.shape[0], bs):
mask_pixel_values_bs = masked_image[i : i + bs]
mask_pixel_values_bs = vae.encode(mask_pixel_values_bs)[0]
mask_pixel_values_bs = mask_pixel_values_bs.mode()
new_mask_pixel_values.append(mask_pixel_values_bs)
masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
masked_image_latents = masked_image_latents * vae.config.scaling_factor
vae.to(offload_device)
control_latents = {
"latents": masked_image_latents,
"num_frames" : B,
"height" : height,
"width" : width,
}
return (control_latents, )
class CogVideoXFunControlSampler:
@classmethod
def INPUT_TYPES(s):
@ -1169,10 +1240,7 @@ class CogVideoXFunControlSampler:
"pipeline": ("COGVIDEOPIPE",),
"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"video_length": ("INT", {"default": 49, "min": 5, "max": 49, "step": 4}),
"base_resolution": (
[256,320,384,448,512,768,960,1024,], {"default": 512}
),
"control_latents": ("COGCONTROL_LATENTS",),
"seed": ("INT", {"default": 42, "min": 0, "max": 0xffffffffffffffff}),
"steps": ("INT", {"default": 25, "min": 1, "max": 200, "step": 1}),
"cfg": ("FLOAT", {"default": 6.0, "min": 1.0, "max": 20.0, "step": 0.01}),
@ -1194,7 +1262,6 @@ class CogVideoXFunControlSampler:
"default": 'DDIM'
}
),
"control_video": ("IMAGE",),
"control_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"control_start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
"control_end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
@ -1206,8 +1273,8 @@ class CogVideoXFunControlSampler:
FUNCTION = "process"
CATEGORY = "CogVideoWrapper"
def process(self, pipeline, positive, negative, video_length, base_resolution, seed, steps, cfg, scheduler,
control_video=None, control_strength=1.0, control_start_percent=0.0, control_end_percent=1.0):
def process(self, pipeline, positive, negative, seed, steps, cfg, scheduler,
control_latents, control_strength=1.0, control_start_percent=0.0, control_end_percent=1.0):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
pipe = pipeline["pipe"]
@ -1221,15 +1288,6 @@ class CogVideoXFunControlSampler:
mm.soft_empty_cache()
# Count most suitable height and width
aspect_ratio_sample_size = {key : [x / 512 * base_resolution for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()}
control_video = np.array(control_video.cpu().numpy() * 255, np.uint8)
original_width, original_height = Image.fromarray(control_video[0]).size
closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size)
height, width = [int(x / 16) * 16 for x in closest_size]
# Load Sampler
scheduler_config = pipeline["scheduler_config"]
if scheduler in scheduler_mapping:
@ -1243,8 +1301,6 @@ class CogVideoXFunControlSampler:
autocastcondition = not pipeline["onediff"]
autocast_context = torch.autocast(mm.get_autocast_device(device)) if autocastcondition else nullcontext()
with autocast_context:
video_length = int((video_length - 1) // pipe.vae.config.temporal_compression_ratio * pipe.vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1
input_video, input_video_mask, clip_image = get_video_to_video_latent(control_video, video_length=video_length, sample_size=(height, width))
# for _lora_path, _lora_weight in zip(cogvideoxfun_model.get("loras", []), cogvideoxfun_model.get("strength_model", [])):
# pipeline = merge_lora(pipeline, _lora_path, _lora_weight)
@ -1252,9 +1308,9 @@ class CogVideoXFunControlSampler:
common_params = {
"prompt_embeds": positive.to(dtype).to(device),
"negative_prompt_embeds": negative.to(dtype).to(device),
"num_frames": video_length,
"height": height,
"width": width,
"num_frames": control_latents["num_frames"],
"height": control_latents["height"],
"width": control_latents["width"],
"generator": generator,
"guidance_scale": cfg,
"num_inference_steps": steps,
@ -1263,7 +1319,7 @@ class CogVideoXFunControlSampler:
latents = pipe(
**common_params,
control_video=input_video,
control_video=control_latents["latents"],
control_strength=control_strength,
control_start_percent=control_start_percent,
control_end_percent=control_end_percent
@ -1286,7 +1342,8 @@ NODE_CLASS_MAPPINGS = {
"CogVideoTextEncodeCombine": CogVideoTextEncodeCombine,
"DownloadAndLoadCogVideoGGUFModel": DownloadAndLoadCogVideoGGUFModel,
"CogVideoPABConfig": CogVideoPABConfig,
"CogVideoTransformerEdit": CogVideoTransformerEdit
"CogVideoTransformerEdit": CogVideoTransformerEdit,
"CogVideoControlImageEncode": CogVideoControlImageEncode
}
NODE_DISPLAY_NAME_MAPPINGS = {
"DownloadAndLoadCogVideoModel": "(Down)load CogVideo Model",
@ -1301,5 +1358,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"CogVideoTextEncodeCombine": "CogVideo TextEncode Combine",
"DownloadAndLoadCogVideoGGUFModel": "(Down)load CogVideo GGUF Model",
"CogVideoPABConfig": "CogVideo PABConfig",
"CogVideoTransformerEdit": "CogVideo TransformerEdit"
"CogVideoTransformerEdit": "CogVideo TransformerEdit",
"CogVideoControlImageEncode": "CogVideo Control ImageEncode"
}