diff --git a/nodes.py b/nodes.py index 3c1e947..46a9c8e 100644 --- a/nodes.py +++ b/nodes.py @@ -634,7 +634,7 @@ class MochiDecode: return torch.cat(result_rows, dim=3) vae.to(device) - with torch.autocast(mm.get_autocast_device(device), dtype=torch.bfloat16): + with torch.autocast(mm.get_autocast_device(device), dtype=vae.dtype): if enable_vae_tiling and frame_batch_size > T: logging.warning(f"Frame batch size is larger than the number of samples, setting to {T}") frame_batch_size = T