diff --git a/nodes.py b/nodes.py index fe9f3ad..ca971a8 100644 --- a/nodes.py +++ b/nodes.py @@ -578,10 +578,9 @@ class MochiDecodeSpatialTiling: B, C, T, H, W = samples.shape - vae.to(device) decoded_list = [] - with torch.autocast(mm.get_autocast_device(device), dtype=torch.bfloat16): + with torch.autocast(mm.get_autocast_device(device), dtype=vae.dtype): if enable_vae_tiling: from .mochi_preview.vae.model import apply_tiled @@ -595,7 +594,7 @@ class MochiDecodeSpatialTiling: frames = apply_tiled(vae, chunk, num_tiles_w = num_tiles_w, num_tiles_h = num_tiles_h, overlap=overlap, min_block_size=min_block_size) logging.info(f"Decoded {frames.shape[2]} frames from {end_index - i} samples") pbar.update(1) - # Blend the first and last frames of each pair + # Blend the first and last frames of each pair if len(decoded_list) > 0: previous_frames = decoded_list[-1] blended_frames = (previous_frames[:, :, -1:, :, :] + frames[:, :, :1, :, :]) / 2