mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-09 21:04:23 +08:00
separate control image encoder and add tiled encode support for it
tiled encoding code thanks to MinusZoneML: https://github.com/MinusZoneAI/ComfyUI-CogVideoX-MZ/blob/main/mz_enable_vae_encode_tiling.py
This commit is contained in:
parent
3de0113927
commit
627af9341c
@ -590,26 +590,29 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
|
||||
if comfyui_progressbar:
|
||||
pbar.update(1)
|
||||
|
||||
if control_video is not None:
|
||||
video_length = control_video.shape[2]
|
||||
control_video = self.image_processor.preprocess(rearrange(control_video, "b c f h w -> (b f) c h w"), height=height, width=width)
|
||||
control_video = control_video.to(dtype=torch.float32)
|
||||
control_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length)
|
||||
else:
|
||||
control_video = None
|
||||
control_video_latents = self.prepare_control_latents(
|
||||
None,
|
||||
control_video,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
self.vae.dtype,
|
||||
device,
|
||||
generator,
|
||||
do_classifier_free_guidance
|
||||
)[1]
|
||||
# if control_video is not None:
|
||||
# video_length = control_video.shape[2]
|
||||
# control_video = self.image_processor.preprocess(rearrange(control_video, "b c f h w -> (b f) c h w"), height=height, width=width)
|
||||
# control_video = control_video.to(dtype=torch.float32)
|
||||
# control_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length)
|
||||
# else:
|
||||
# control_video = None
|
||||
|
||||
# control_video_latents = self.prepare_control_latents(
|
||||
# None,
|
||||
# control_video,
|
||||
# batch_size,
|
||||
# height,
|
||||
# width,
|
||||
# self.vae.dtype,
|
||||
# device,
|
||||
# generator,
|
||||
# do_classifier_free_guidance
|
||||
# )[1]
|
||||
|
||||
|
||||
control_video_latents_input = (
|
||||
torch.cat([control_video_latents] * 2) if do_classifier_free_guidance else control_video_latents
|
||||
torch.cat([control_video] * 2) if do_classifier_free_guidance else control_video
|
||||
)
|
||||
control_latents = rearrange(control_video_latents_input, "b c f h w -> b f c h w")
|
||||
|
||||
|
||||
188
mz_enable_vae_encode_tiling.py
Normal file
188
mz_enable_vae_encode_tiling.py
Normal file
@ -0,0 +1,188 @@
|
||||
# thanks to MinusZoneAI: https://github.com/MinusZoneAI/ComfyUI-CogVideoX-MZ/blob/b98b98bd04621e4c85547866c12de2ec723ae98a/mz_enable_vae_encode_tiling.py
|
||||
from typing import Optional
|
||||
import torch
|
||||
from diffusers.utils.accelerate_utils import apply_forward_hook
|
||||
from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution
|
||||
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
||||
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_dict: bool = True
|
||||
):
|
||||
"""
|
||||
Encode a batch of images into latents.
|
||||
Args:
|
||||
x (`torch.Tensor`): Input batch of images.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
||||
Returns:
|
||||
The latent representations of the encoded videos. If `return_dict` is True, a
|
||||
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
||||
"""
|
||||
if self.use_slicing and x.shape[0] > 1:
|
||||
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
|
||||
h = torch.cat(encoded_slices)
|
||||
else:
|
||||
h = self._encode(x)
|
||||
posterior = DiagonalGaussianDistribution(h)
|
||||
|
||||
if not return_dict:
|
||||
return (posterior,)
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
|
||||
def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
r"""Encode a batch of images using a tiled encoder.
|
||||
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
||||
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
|
||||
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
|
||||
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
|
||||
output, but they should be much less noticeable.
|
||||
Args:
|
||||
x (`torch.Tensor`): Input batch of videos.
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The latent representation of the encoded videos.
|
||||
"""
|
||||
# For a rough memory estimate, take a look at the `tiled_decode` method.
|
||||
batch_size, num_channels, num_frames, height, width = x.shape
|
||||
overlap_height = int(self.tile_sample_min_height *
|
||||
(1 - self.tile_overlap_factor_height))
|
||||
overlap_width = int(self.tile_sample_min_width *
|
||||
(1 - self.tile_overlap_factor_width))
|
||||
blend_extent_height = int(
|
||||
self.tile_latent_min_height * self.tile_overlap_factor_height)
|
||||
blend_extent_width = int(
|
||||
self.tile_latent_min_width * self.tile_overlap_factor_width)
|
||||
row_limit_height = self.tile_latent_min_height - blend_extent_height
|
||||
row_limit_width = self.tile_latent_min_width - blend_extent_width
|
||||
frame_batch_size = 4
|
||||
# Split x into overlapping tiles and encode them separately.
|
||||
# The tiles have an overlap to avoid seams between tiles.
|
||||
rows = []
|
||||
for i in range(0, height, overlap_height):
|
||||
row = []
|
||||
for j in range(0, width, overlap_width):
|
||||
# Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
|
||||
num_batches = num_frames // frame_batch_size if num_frames > 1 else 1
|
||||
time = []
|
||||
for k in range(num_batches):
|
||||
remaining_frames = num_frames % frame_batch_size
|
||||
start_frame = frame_batch_size * k + \
|
||||
(0 if k == 0 else remaining_frames)
|
||||
end_frame = frame_batch_size * (k + 1) + remaining_frames
|
||||
tile = x[
|
||||
:,
|
||||
:,
|
||||
start_frame:end_frame,
|
||||
i: i + self.tile_sample_min_height,
|
||||
j: j + self.tile_sample_min_width,
|
||||
]
|
||||
tile = self.encoder(tile)
|
||||
if self.quant_conv is not None:
|
||||
tile = self.quant_conv(tile)
|
||||
time.append(tile)
|
||||
self._clear_fake_context_parallel_cache()
|
||||
row.append(torch.cat(time, dim=2))
|
||||
rows.append(row)
|
||||
result_rows = []
|
||||
for i, row in enumerate(rows):
|
||||
result_row = []
|
||||
for j, tile in enumerate(row):
|
||||
# blend the above tile and the left tile
|
||||
# to the current tile and add the current tile to the result row
|
||||
if i > 0:
|
||||
tile = self.blend_v(
|
||||
rows[i - 1][j], tile, blend_extent_height)
|
||||
if j > 0:
|
||||
tile = self.blend_h(row[j - 1], tile, blend_extent_width)
|
||||
result_row.append(
|
||||
tile[:, :, :, :row_limit_height, :row_limit_width])
|
||||
result_rows.append(torch.cat(result_row, dim=4))
|
||||
enc = torch.cat(result_rows, dim=3)
|
||||
return enc
|
||||
|
||||
|
||||
def _encode(
|
||||
self, x: torch.Tensor, return_dict: bool = True
|
||||
):
|
||||
batch_size, num_channels, num_frames, height, width = x.shape
|
||||
|
||||
if self.use_encode_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
|
||||
return self.tiled_encode(x)
|
||||
|
||||
if num_frames == 1:
|
||||
h = self.encoder(x)
|
||||
if self.quant_conv is not None:
|
||||
h = self.quant_conv(h)
|
||||
posterior = DiagonalGaussianDistribution(h)
|
||||
else:
|
||||
frame_batch_size = 4
|
||||
h = []
|
||||
for i in range(num_frames // frame_batch_size):
|
||||
remaining_frames = num_frames % frame_batch_size
|
||||
start_frame = frame_batch_size * i + \
|
||||
(0 if i == 0 else remaining_frames)
|
||||
end_frame = frame_batch_size * (i + 1) + remaining_frames
|
||||
z_intermediate = x[:, :, start_frame:end_frame]
|
||||
z_intermediate = self.encoder(z_intermediate)
|
||||
if self.quant_conv is not None:
|
||||
z_intermediate = self.quant_conv(z_intermediate)
|
||||
h.append(z_intermediate)
|
||||
self._clear_fake_context_parallel_cache()
|
||||
h = torch.cat(h, dim=2)
|
||||
return h
|
||||
|
||||
|
||||
def enable_encode_tiling(
|
||||
self,
|
||||
tile_sample_min_height: Optional[int] = None,
|
||||
tile_sample_min_width: Optional[int] = None,
|
||||
tile_overlap_factor_height: Optional[float] = None,
|
||||
tile_overlap_factor_width: Optional[float] = None,
|
||||
) -> None:
|
||||
r"""
|
||||
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
||||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
|
||||
Args:
|
||||
tile_sample_min_height (`int`, *optional*):
|
||||
The minimum height required for a sample to be separated into tiles across the height dimension.
|
||||
tile_sample_min_width (`int`, *optional*):
|
||||
The minimum width required for a sample to be separated into tiles across the width dimension.
|
||||
tile_overlap_factor_height (`int`, *optional*):
|
||||
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
|
||||
no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher
|
||||
value might cause more tiles to be processed leading to slow down of the decoding process.
|
||||
tile_overlap_factor_width (`int`, *optional*):
|
||||
The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there
|
||||
are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher
|
||||
value might cause more tiles to be processed leading to slow down of the decoding process.
|
||||
"""
|
||||
self.use_encode_tiling = True
|
||||
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
|
||||
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
|
||||
self.tile_latent_min_height = int(
|
||||
self.tile_sample_min_height /
|
||||
(2 ** (len(self.config.block_out_channels) - 1))
|
||||
)
|
||||
self.tile_latent_min_width = int(
|
||||
self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
|
||||
self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height
|
||||
self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width
|
||||
|
||||
|
||||
from types import MethodType
|
||||
|
||||
|
||||
def enable_vae_encode_tiling(vae):
|
||||
vae.encode = MethodType(encode, vae)
|
||||
setattr(vae, "_encode", MethodType(_encode, vae))
|
||||
setattr(vae, "tiled_encode", MethodType(tiled_encode, vae))
|
||||
setattr(vae, "use_encode_tiling", True)
|
||||
|
||||
setattr(vae, "enable_encode_tiling", MethodType(enable_encode_tiling, vae))
|
||||
vae.enable_encode_tiling()
|
||||
return vae
|
||||
110
nodes.py
110
nodes.py
@ -3,7 +3,7 @@ import torch
|
||||
import folder_paths
|
||||
import comfy.model_management as mm
|
||||
from comfy.utils import ProgressBar, load_torch_file
|
||||
|
||||
from einops import rearrange
|
||||
import importlib.metadata
|
||||
|
||||
def check_diffusers_version():
|
||||
@ -1160,7 +1160,78 @@ class CogVideoXFunVid2VidSampler:
|
||||
# for _lora_path, _lora_weight in zip(cogvideoxfun_model.get("loras", []), cogvideoxfun_model.get("strength_model", [])):
|
||||
# pipeline = unmerge_lora(pipeline, _lora_path, _lora_weight)
|
||||
return (pipeline, {"samples": latents})
|
||||
|
||||
|
||||
|
||||
class CogVideoControlImageEncode:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"pipeline": ("COGVIDEOPIPE",),
|
||||
"control_video": ("IMAGE", ),
|
||||
"base_resolution": ("INT", {"min": 256, "max": 1280, "step": 64, "default": 512, "tooltip": "Base resolution, closest training data bucket resolution is chosen based on the selection."}),
|
||||
"enable_tiling": ("BOOLEAN", {"default": False, "tooltip": "Enable tiling for the VAE to reduce memory usage"}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("COGCONTROL_LATENTS",)
|
||||
RETURN_NAMES = ("control_latents",)
|
||||
FUNCTION = "encode"
|
||||
CATEGORY = "CogVideoWrapper"
|
||||
|
||||
def encode(self, pipeline, control_video, base_resolution, enable_tiling):
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
|
||||
B, H, W, C = control_video.shape
|
||||
|
||||
vae = pipeline["pipe"].vae
|
||||
vae.enable_slicing()
|
||||
|
||||
if enable_tiling:
|
||||
from .mz_enable_vae_encode_tiling import enable_vae_encode_tiling
|
||||
enable_vae_encode_tiling(vae)
|
||||
|
||||
if not pipeline["cpu_offloading"]:
|
||||
vae.to(device)
|
||||
|
||||
# Count most suitable height and width
|
||||
aspect_ratio_sample_size = {key : [x / 512 * base_resolution for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()}
|
||||
|
||||
control_video = np.array(control_video.cpu().numpy() * 255, np.uint8)
|
||||
original_width, original_height = Image.fromarray(control_video[0]).size
|
||||
|
||||
closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size)
|
||||
height, width = [int(x / 16) * 16 for x in closest_size]
|
||||
|
||||
video_length = int((B - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if B != 1 else 1
|
||||
input_video, input_video_mask, clip_image = get_video_to_video_latent(control_video, video_length=video_length, sample_size=(height, width))
|
||||
|
||||
control_video = pipeline["pipe"].image_processor.preprocess(rearrange(input_video, "b c f h w -> (b f) c h w"), height=height, width=width)
|
||||
control_video = control_video.to(dtype=torch.float32)
|
||||
control_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length)
|
||||
|
||||
masked_image = control_video.to(device=device, dtype=vae.dtype)
|
||||
bs = 1
|
||||
new_mask_pixel_values = []
|
||||
for i in range(0, masked_image.shape[0], bs):
|
||||
mask_pixel_values_bs = masked_image[i : i + bs]
|
||||
mask_pixel_values_bs = vae.encode(mask_pixel_values_bs)[0]
|
||||
mask_pixel_values_bs = mask_pixel_values_bs.mode()
|
||||
new_mask_pixel_values.append(mask_pixel_values_bs)
|
||||
masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
|
||||
masked_image_latents = masked_image_latents * vae.config.scaling_factor
|
||||
|
||||
vae.to(offload_device)
|
||||
|
||||
control_latents = {
|
||||
"latents": masked_image_latents,
|
||||
"num_frames" : B,
|
||||
"height" : height,
|
||||
"width" : width,
|
||||
}
|
||||
|
||||
return (control_latents, )
|
||||
|
||||
class CogVideoXFunControlSampler:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -1169,10 +1240,7 @@ class CogVideoXFunControlSampler:
|
||||
"pipeline": ("COGVIDEOPIPE",),
|
||||
"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"video_length": ("INT", {"default": 49, "min": 5, "max": 49, "step": 4}),
|
||||
"base_resolution": (
|
||||
[256,320,384,448,512,768,960,1024,], {"default": 512}
|
||||
),
|
||||
"control_latents": ("COGCONTROL_LATENTS",),
|
||||
"seed": ("INT", {"default": 42, "min": 0, "max": 0xffffffffffffffff}),
|
||||
"steps": ("INT", {"default": 25, "min": 1, "max": 200, "step": 1}),
|
||||
"cfg": ("FLOAT", {"default": 6.0, "min": 1.0, "max": 20.0, "step": 0.01}),
|
||||
@ -1194,7 +1262,6 @@ class CogVideoXFunControlSampler:
|
||||
"default": 'DDIM'
|
||||
}
|
||||
),
|
||||
"control_video": ("IMAGE",),
|
||||
"control_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"control_start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
"control_end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
@ -1206,8 +1273,8 @@ class CogVideoXFunControlSampler:
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "CogVideoWrapper"
|
||||
|
||||
def process(self, pipeline, positive, negative, video_length, base_resolution, seed, steps, cfg, scheduler,
|
||||
control_video=None, control_strength=1.0, control_start_percent=0.0, control_end_percent=1.0):
|
||||
def process(self, pipeline, positive, negative, seed, steps, cfg, scheduler,
|
||||
control_latents, control_strength=1.0, control_start_percent=0.0, control_end_percent=1.0):
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
pipe = pipeline["pipe"]
|
||||
@ -1221,15 +1288,6 @@ class CogVideoXFunControlSampler:
|
||||
|
||||
mm.soft_empty_cache()
|
||||
|
||||
# Count most suitable height and width
|
||||
aspect_ratio_sample_size = {key : [x / 512 * base_resolution for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()}
|
||||
|
||||
control_video = np.array(control_video.cpu().numpy() * 255, np.uint8)
|
||||
original_width, original_height = Image.fromarray(control_video[0]).size
|
||||
|
||||
closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size)
|
||||
height, width = [int(x / 16) * 16 for x in closest_size]
|
||||
|
||||
# Load Sampler
|
||||
scheduler_config = pipeline["scheduler_config"]
|
||||
if scheduler in scheduler_mapping:
|
||||
@ -1243,8 +1301,6 @@ class CogVideoXFunControlSampler:
|
||||
autocastcondition = not pipeline["onediff"]
|
||||
autocast_context = torch.autocast(mm.get_autocast_device(device)) if autocastcondition else nullcontext()
|
||||
with autocast_context:
|
||||
video_length = int((video_length - 1) // pipe.vae.config.temporal_compression_ratio * pipe.vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1
|
||||
input_video, input_video_mask, clip_image = get_video_to_video_latent(control_video, video_length=video_length, sample_size=(height, width))
|
||||
|
||||
# for _lora_path, _lora_weight in zip(cogvideoxfun_model.get("loras", []), cogvideoxfun_model.get("strength_model", [])):
|
||||
# pipeline = merge_lora(pipeline, _lora_path, _lora_weight)
|
||||
@ -1252,9 +1308,9 @@ class CogVideoXFunControlSampler:
|
||||
common_params = {
|
||||
"prompt_embeds": positive.to(dtype).to(device),
|
||||
"negative_prompt_embeds": negative.to(dtype).to(device),
|
||||
"num_frames": video_length,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"num_frames": control_latents["num_frames"],
|
||||
"height": control_latents["height"],
|
||||
"width": control_latents["width"],
|
||||
"generator": generator,
|
||||
"guidance_scale": cfg,
|
||||
"num_inference_steps": steps,
|
||||
@ -1263,7 +1319,7 @@ class CogVideoXFunControlSampler:
|
||||
|
||||
latents = pipe(
|
||||
**common_params,
|
||||
control_video=input_video,
|
||||
control_video=control_latents["latents"],
|
||||
control_strength=control_strength,
|
||||
control_start_percent=control_start_percent,
|
||||
control_end_percent=control_end_percent
|
||||
@ -1286,7 +1342,8 @@ NODE_CLASS_MAPPINGS = {
|
||||
"CogVideoTextEncodeCombine": CogVideoTextEncodeCombine,
|
||||
"DownloadAndLoadCogVideoGGUFModel": DownloadAndLoadCogVideoGGUFModel,
|
||||
"CogVideoPABConfig": CogVideoPABConfig,
|
||||
"CogVideoTransformerEdit": CogVideoTransformerEdit
|
||||
"CogVideoTransformerEdit": CogVideoTransformerEdit,
|
||||
"CogVideoControlImageEncode": CogVideoControlImageEncode
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"DownloadAndLoadCogVideoModel": "(Down)load CogVideo Model",
|
||||
@ -1301,5 +1358,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"CogVideoTextEncodeCombine": "CogVideo TextEncode Combine",
|
||||
"DownloadAndLoadCogVideoGGUFModel": "(Down)load CogVideo GGUF Model",
|
||||
"CogVideoPABConfig": "CogVideo PABConfig",
|
||||
"CogVideoTransformerEdit": "CogVideo TransformerEdit"
|
||||
"CogVideoTransformerEdit": "CogVideo TransformerEdit",
|
||||
"CogVideoControlImageEncode": "CogVideo Control ImageEncode"
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user