From 99285ca1e7700613990cbd35119f45625bd8a889 Mon Sep 17 00:00:00 2001 From: Yoshimasa Niwa Date: Tue, 5 Nov 2024 13:14:17 +0900 Subject: [PATCH 1/2] Use float16 for autocast on mps mps only supports float16 for autocast for now. --- mochi_preview/t2v_synth_mochi.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mochi_preview/t2v_synth_mochi.py b/mochi_preview/t2v_synth_mochi.py index 798a7aa..f20808d 100644 --- a/mochi_preview/t2v_synth_mochi.py +++ b/mochi_preview/t2v_synth_mochi.py @@ -240,7 +240,10 @@ class T2VSynthMochiModel: if hasattr(self.dit, "cublas_half_matmul") and self.dit.cublas_half_matmul: autocast_dtype = torch.float16 else: - autocast_dtype = torch.bfloat16 + if self.device.type == "mps": + autocast_dtype = torch.float16 + else: + autocast_dtype = torch.bfloat16 def model_fn(*, z, sigma, cfg_scale): nonlocal sample, sample_null From 24a834be796f5aa16a20e2ce4fc3768acaa0b4cc Mon Sep 17 00:00:00 2001 From: Yoshimasa Niwa Date: Tue, 5 Nov 2024 13:46:45 +0900 Subject: [PATCH 2/2] Use dtype for vea. Somehow, download node is not using dtype for vae. --- nodes.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/nodes.py b/nodes.py index 1546747..2d8eb48 100644 --- a/nodes.py +++ b/nodes.py @@ -177,14 +177,15 @@ class DownloadAndLoadMochiModel: nonlinearity="silu", output_nonlinearity="silu", causal=True, + dtype=dtype, ) vae_sd = load_torch_file(vae_path) if is_accelerate_available: for key in vae_sd: - set_module_tensor_to_device(vae, key, dtype=torch.bfloat16, device=offload_device, value=vae_sd[key]) + set_module_tensor_to_device(vae, key, dtype=dtype, device=offload_device, value=vae_sd[key]) else: vae.load_state_dict(vae_sd, strict=True) - vae.eval().to(torch.bfloat16).to("cpu") + vae.eval().to(dtype).to("cpu") del vae_sd return (model, vae,)