From 4ef7df00c9ebd020f68da1b65cbcdbe9b0fb4e67 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Wed, 6 Nov 2024 19:50:40 +0200 Subject: [PATCH] Update nodes.py --- nodes.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/nodes.py b/nodes.py index 0a5b819..50ac380 100644 --- a/nodes.py +++ b/nodes.py @@ -119,6 +119,10 @@ class DownloadAndLoadMochiModel: mm.soft_empty_cache() dtype = {"fp8_e4m3fn": torch.float8_e4m3fn, "fp8_e4m3fn_fast": torch.float8_e4m3fn, "bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] + if "fp8" in precision: + vae_dtype = torch.bfloat16 + else: + vae_dtype = dtype # Transformer model model_download_path = os.path.join(folder_paths.models_dir, 'diffusion_models', 'mochi') @@ -174,15 +178,15 @@ class DownloadAndLoadMochiModel: nonlinearity="silu", output_nonlinearity="silu", causal=True, - dtype=dtype, + dtype=vae_dtype, ) vae_sd = load_torch_file(vae_path) if is_accelerate_available: for key in vae_sd: - set_module_tensor_to_device(vae, key, dtype=dtype, device=offload_device, value=vae_sd[key]) + set_module_tensor_to_device(vae, key, dtype=vae_dtype, device=offload_device, value=vae_sd[key]) else: vae.load_state_dict(vae_sd, strict=True) - vae.eval().to(dtype).to("cpu") + vae.eval().to(vae_dtype).to("cpu") del vae_sd return (model, vae,)