Update nodes.py

This commit is contained in:
kijai 2024-10-30 22:08:22 +02:00
parent d971a19410
commit ebd0f62d53

View File

@ -578,10 +578,9 @@ class MochiDecodeSpatialTiling:
B, C, T, H, W = samples.shape B, C, T, H, W = samples.shape
vae.to(device) vae.to(device)
decoded_list = [] decoded_list = []
with torch.autocast(mm.get_autocast_device(device), dtype=torch.bfloat16): with torch.autocast(mm.get_autocast_device(device), dtype=vae.dtype):
if enable_vae_tiling: if enable_vae_tiling:
from .mochi_preview.vae.model import apply_tiled from .mochi_preview.vae.model import apply_tiled
@ -595,7 +594,7 @@ class MochiDecodeSpatialTiling:
frames = apply_tiled(vae, chunk, num_tiles_w = num_tiles_w, num_tiles_h = num_tiles_h, overlap=overlap, min_block_size=min_block_size) frames = apply_tiled(vae, chunk, num_tiles_w = num_tiles_w, num_tiles_h = num_tiles_h, overlap=overlap, min_block_size=min_block_size)
logging.info(f"Decoded {frames.shape[2]} frames from {end_index - i} samples") logging.info(f"Decoded {frames.shape[2]} frames from {end_index - i} samples")
pbar.update(1) pbar.update(1)
# Blend the first and last frames of each pair # Blend the first and last frames of each pair
if len(decoded_list) > 0: if len(decoded_list) > 0:
previous_frames = decoded_list[-1] previous_frames = decoded_list[-1]
blended_frames = (previous_frames[:, :, -1:, :, :] + frames[:, :, :1, :, :]) / 2 blended_frames = (previous_frames[:, :, -1:, :, :] + frames[:, :, :1, :, :]) / 2