Use dtype for vea.

Somehow, download node is not using dtype for vae.
This commit is contained in:
Yoshimasa Niwa 2024-11-05 13:46:45 +09:00
parent 99285ca1e7
commit 24a834be79

View File

@ -177,14 +177,15 @@ class DownloadAndLoadMochiModel:
nonlinearity="silu", nonlinearity="silu",
output_nonlinearity="silu", output_nonlinearity="silu",
causal=True, causal=True,
dtype=dtype,
) )
vae_sd = load_torch_file(vae_path) vae_sd = load_torch_file(vae_path)
if is_accelerate_available: if is_accelerate_available:
for key in vae_sd: for key in vae_sd:
set_module_tensor_to_device(vae, key, dtype=torch.bfloat16, device=offload_device, value=vae_sd[key]) set_module_tensor_to_device(vae, key, dtype=dtype, device=offload_device, value=vae_sd[key])
else: else:
vae.load_state_dict(vae_sd, strict=True) vae.load_state_dict(vae_sd, strict=True)
vae.eval().to(torch.bfloat16).to("cpu") vae.eval().to(dtype).to("cpu")
del vae_sd del vae_sd
return (model, vae,) return (model, vae,)