diff --git a/mochi_preview/t2v_synth_mochi.py b/mochi_preview/t2v_synth_mochi.py index cf0789c..69ae024 100644 --- a/mochi_preview/t2v_synth_mochi.py +++ b/mochi_preview/t2v_synth_mochi.py @@ -308,13 +308,10 @@ class T2VSynthMochiModel: comfy_pbar = ProgressBar(sample_steps) - if hasattr(self.dit, "cublas_half_matmul") and self.dit.cublas_half_matmul: + if (hasattr(self.dit, "cublas_half_matmul") and self.dit.cublas_half_matmul) or self.device.type == "mps": autocast_dtype = torch.float16 else: - if self.device.type == "mps": - autocast_dtype = torch.float16 - else: - autocast_dtype = torch.bfloat16 + autocast_dtype = torch.bfloat16 self.dit.to(self.device)