diff --git a/mochi_preview/t2v_synth_mochi.py b/mochi_preview/t2v_synth_mochi.py index 798a7aa..f20808d 100644 --- a/mochi_preview/t2v_synth_mochi.py +++ b/mochi_preview/t2v_synth_mochi.py @@ -240,7 +240,10 @@ class T2VSynthMochiModel: if hasattr(self.dit, "cublas_half_matmul") and self.dit.cublas_half_matmul: autocast_dtype = torch.float16 else: - autocast_dtype = torch.bfloat16 + if self.device.type == "mps": + autocast_dtype = torch.float16 + else: + autocast_dtype = torch.bfloat16 def model_fn(*, z, sigma, cfg_scale): nonlocal sample, sample_null