diff --git a/nodes.py b/nodes.py index 1546747..2d8eb48 100644 --- a/nodes.py +++ b/nodes.py @@ -177,14 +177,15 @@ class DownloadAndLoadMochiModel: nonlinearity="silu", output_nonlinearity="silu", causal=True, + dtype=dtype, ) vae_sd = load_torch_file(vae_path) if is_accelerate_available: for key in vae_sd: - set_module_tensor_to_device(vae, key, dtype=torch.bfloat16, device=offload_device, value=vae_sd[key]) + set_module_tensor_to_device(vae, key, dtype=dtype, device=offload_device, value=vae_sd[key]) else: vae.load_state_dict(vae_sd, strict=True) - vae.eval().to(torch.bfloat16).to("cpu") + vae.eval().to(dtype).to("cpu") del vae_sd return (model, vae,)