Update t2v_synth_mochi.py

This commit is contained in:
kijai 2024-11-06 17:27:13 +02:00
parent 76956cda50
commit a46643ae5a

View File

@ -308,13 +308,10 @@ class T2VSynthMochiModel:
comfy_pbar = ProgressBar(sample_steps)
if hasattr(self.dit, "cublas_half_matmul") and self.dit.cublas_half_matmul:
if (hasattr(self.dit, "cublas_half_matmul") and self.dit.cublas_half_matmul) or self.device.type == "mps":
autocast_dtype = torch.float16
else:
if self.device.type == "mps":
autocast_dtype = torch.float16
else:
autocast_dtype = torch.bfloat16
autocast_dtype = torch.bfloat16
self.dit.to(self.device)