ability to split the batch for the other decoder node

This commit is contained in:
kijai 2024-10-26 00:42:48 +03:00
parent b932036af3
commit 3348a0fed7

View File

@ -509,6 +509,7 @@ class MochiDecodeSpatialTiling:
"num_tiles_h": ("INT", {"default": 4, "min": 2, "max": 64, "step": 2, "tooltip": "Number of vertical tiles"}),
"overlap": ("INT", {"default": 16, "min": 0, "max": 256, "step": 1, "tooltip": "Number of pixel of overlap between adjacent tiles"}),
"min_block_size": ("INT", {"default": 1, "min": 1, "max": 256, "step": 1, "tooltip": "Minimum number of pixels in each dimension when subdividing"}),
"per_batch": ("INT", {"default": 6, "min": 1, "max": 64, "step": 1, "tooltip": "Number of samples per batch"}),
},
}
@ -518,7 +519,7 @@ class MochiDecodeSpatialTiling:
CATEGORY = "MochiWrapper"
def decode(self, vae, samples, enable_vae_tiling, num_tiles_w, num_tiles_h, overlap,
min_block_size):
min_block_size, per_batch):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
intermediate_device = mm.intermediate_device()
@ -528,16 +529,27 @@ class MochiDecodeSpatialTiling:
B, C, T, H, W = samples.shape
vae.to(device)
decoded_list = []
with torch.autocast(mm.get_autocast_device(device), dtype=torch.bfloat16):
if enable_vae_tiling:
from .mochi_preview.vae.model import apply_tiled
logging.warning("Decoding with tiling...")
frames = apply_tiled(vae, samples, num_tiles_w = num_tiles_w, num_tiles_h = num_tiles_h, overlap=overlap, min_block_size=min_block_size)
for i in range(0, T, per_batch):
end_index = min(i + per_batch, T)
chunk = samples[:, :, i:end_index, :, :]
frames = apply_tiled(vae, chunk, num_tiles_w = num_tiles_w, num_tiles_h = num_tiles_h, overlap=overlap, min_block_size=min_block_size)
print(frames.shape)
# Blend the first and last frames of each pair
if len(decoded_list) > 0:
previous_frames = decoded_list[-1]
blended_frames = (previous_frames[:, :, -1:, :, :] + frames[:, :, :1, :, :]) / 2
decoded_list[-1][:, :, -1:, :, :] = blended_frames
decoded_list.append(frames)
else:
logging.info("Decoding without tiling...")
frames = vae(samples)
frames = torch.cat(decoded_list, dim=2)
vae.to(offload_device)
frames = frames.float()