From 3348a0fed7432a47bf6e2e452db34f7c171d0448 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sat, 26 Oct 2024 00:42:48 +0300 Subject: [PATCH] ability to split the batch for the other decoder node --- nodes.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/nodes.py b/nodes.py index c3854c7..7e2084b 100644 --- a/nodes.py +++ b/nodes.py @@ -509,6 +509,7 @@ class MochiDecodeSpatialTiling: "num_tiles_h": ("INT", {"default": 4, "min": 2, "max": 64, "step": 2, "tooltip": "Number of vertical tiles"}), "overlap": ("INT", {"default": 16, "min": 0, "max": 256, "step": 1, "tooltip": "Number of pixel of overlap between adjacent tiles"}), "min_block_size": ("INT", {"default": 1, "min": 1, "max": 256, "step": 1, "tooltip": "Minimum number of pixels in each dimension when subdividing"}), + "per_batch": ("INT", {"default": 6, "min": 1, "max": 64, "step": 1, "tooltip": "Number of samples per batch"}), }, } @@ -518,7 +519,7 @@ class MochiDecodeSpatialTiling: CATEGORY = "MochiWrapper" def decode(self, vae, samples, enable_vae_tiling, num_tiles_w, num_tiles_h, overlap, - min_block_size): + min_block_size, per_batch): device = mm.get_torch_device() offload_device = mm.unet_offload_device() intermediate_device = mm.intermediate_device() @@ -528,16 +529,27 @@ class MochiDecodeSpatialTiling: B, C, T, H, W = samples.shape vae.to(device) - + decoded_list = [] with torch.autocast(mm.get_autocast_device(device), dtype=torch.bfloat16): if enable_vae_tiling: from .mochi_preview.vae.model import apply_tiled logging.warning("Decoding with tiling...") - frames = apply_tiled(vae, samples, num_tiles_w = num_tiles_w, num_tiles_h = num_tiles_h, overlap=overlap, min_block_size=min_block_size) + for i in range(0, T, per_batch): + end_index = min(i + per_batch, T) + chunk = samples[:, :, i:end_index, :, :] + frames = apply_tiled(vae, chunk, num_tiles_w = num_tiles_w, num_tiles_h = num_tiles_h, overlap=overlap, min_block_size=min_block_size) + print(frames.shape) + # Blend the first and last frames of each pair + if len(decoded_list) > 0: + previous_frames = decoded_list[-1] + blended_frames = (previous_frames[:, :, -1:, :, :] + frames[:, :, :1, :, :]) / 2 + decoded_list[-1][:, :, -1:, :, :] = blended_frames + + decoded_list.append(frames) else: logging.info("Decoding without tiling...") frames = vae(samples) - + frames = torch.cat(decoded_list, dim=2) vae.to(offload_device) frames = frames.float()