From a46643ae5ab02130e6267b3e0b7b5868274c5471 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Wed, 6 Nov 2024 17:27:13 +0200 Subject: [PATCH] Update t2v_synth_mochi.py --- mochi_preview/t2v_synth_mochi.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/mochi_preview/t2v_synth_mochi.py b/mochi_preview/t2v_synth_mochi.py index cf0789c..69ae024 100644 --- a/mochi_preview/t2v_synth_mochi.py +++ b/mochi_preview/t2v_synth_mochi.py @@ -308,13 +308,10 @@ class T2VSynthMochiModel: comfy_pbar = ProgressBar(sample_steps) - if hasattr(self.dit, "cublas_half_matmul") and self.dit.cublas_half_matmul: + if (hasattr(self.dit, "cublas_half_matmul") and self.dit.cublas_half_matmul) or self.device.type == "mps": autocast_dtype = torch.float16 else: - if self.device.type == "mps": - autocast_dtype = torch.float16 - else: - autocast_dtype = torch.bfloat16 + autocast_dtype = torch.bfloat16 self.dit.to(self.device)