ability to split the batch for the other decoder node
This commit is contained in:
parent
b932036af3
commit
3348a0fed7
20
nodes.py
20
nodes.py
@ -509,6 +509,7 @@ class MochiDecodeSpatialTiling:
|
||||
"num_tiles_h": ("INT", {"default": 4, "min": 2, "max": 64, "step": 2, "tooltip": "Number of vertical tiles"}),
|
||||
"overlap": ("INT", {"default": 16, "min": 0, "max": 256, "step": 1, "tooltip": "Number of pixel of overlap between adjacent tiles"}),
|
||||
"min_block_size": ("INT", {"default": 1, "min": 1, "max": 256, "step": 1, "tooltip": "Minimum number of pixels in each dimension when subdividing"}),
|
||||
"per_batch": ("INT", {"default": 6, "min": 1, "max": 64, "step": 1, "tooltip": "Number of samples per batch"}),
|
||||
},
|
||||
}
|
||||
|
||||
@ -518,7 +519,7 @@ class MochiDecodeSpatialTiling:
|
||||
CATEGORY = "MochiWrapper"
|
||||
|
||||
def decode(self, vae, samples, enable_vae_tiling, num_tiles_w, num_tiles_h, overlap,
|
||||
min_block_size):
|
||||
min_block_size, per_batch):
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
intermediate_device = mm.intermediate_device()
|
||||
@ -528,16 +529,27 @@ class MochiDecodeSpatialTiling:
|
||||
B, C, T, H, W = samples.shape
|
||||
|
||||
vae.to(device)
|
||||
|
||||
decoded_list = []
|
||||
with torch.autocast(mm.get_autocast_device(device), dtype=torch.bfloat16):
|
||||
if enable_vae_tiling:
|
||||
from .mochi_preview.vae.model import apply_tiled
|
||||
logging.warning("Decoding with tiling...")
|
||||
frames = apply_tiled(vae, samples, num_tiles_w = num_tiles_w, num_tiles_h = num_tiles_h, overlap=overlap, min_block_size=min_block_size)
|
||||
for i in range(0, T, per_batch):
|
||||
end_index = min(i + per_batch, T)
|
||||
chunk = samples[:, :, i:end_index, :, :]
|
||||
frames = apply_tiled(vae, chunk, num_tiles_w = num_tiles_w, num_tiles_h = num_tiles_h, overlap=overlap, min_block_size=min_block_size)
|
||||
print(frames.shape)
|
||||
# Blend the first and last frames of each pair
|
||||
if len(decoded_list) > 0:
|
||||
previous_frames = decoded_list[-1]
|
||||
blended_frames = (previous_frames[:, :, -1:, :, :] + frames[:, :, :1, :, :]) / 2
|
||||
decoded_list[-1][:, :, -1:, :, :] = blended_frames
|
||||
|
||||
decoded_list.append(frames)
|
||||
else:
|
||||
logging.info("Decoding without tiling...")
|
||||
frames = vae(samples)
|
||||
|
||||
frames = torch.cat(decoded_list, dim=2)
|
||||
vae.to(offload_device)
|
||||
|
||||
frames = frames.float()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user