Merge pull request #67 from niw/use_float16_on_mps
Use float16 for autocast on mps
This commit is contained in:
commit
76956cda50
@ -311,7 +311,10 @@ class T2VSynthMochiModel:
|
||||
if hasattr(self.dit, "cublas_half_matmul") and self.dit.cublas_half_matmul:
|
||||
autocast_dtype = torch.float16
|
||||
else:
|
||||
autocast_dtype = torch.bfloat16
|
||||
if self.device.type == "mps":
|
||||
autocast_dtype = torch.float16
|
||||
else:
|
||||
autocast_dtype = torch.bfloat16
|
||||
|
||||
self.dit.to(self.device)
|
||||
|
||||
|
||||
5
nodes.py
5
nodes.py
@ -174,14 +174,15 @@ class DownloadAndLoadMochiModel:
|
||||
nonlinearity="silu",
|
||||
output_nonlinearity="silu",
|
||||
causal=True,
|
||||
dtype=dtype,
|
||||
)
|
||||
vae_sd = load_torch_file(vae_path)
|
||||
if is_accelerate_available:
|
||||
for key in vae_sd:
|
||||
set_module_tensor_to_device(vae, key, dtype=torch.bfloat16, device=offload_device, value=vae_sd[key])
|
||||
set_module_tensor_to_device(vae, key, dtype=dtype, device=offload_device, value=vae_sd[key])
|
||||
else:
|
||||
vae.load_state_dict(vae_sd, strict=True)
|
||||
vae.eval().to(torch.bfloat16).to("cpu")
|
||||
vae.eval().to(dtype).to("cpu")
|
||||
del vae_sd
|
||||
|
||||
return (model, vae,)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user