Update nodes.py
This commit is contained in:
parent
a46643ae5a
commit
4ef7df00c9
10
nodes.py
10
nodes.py
@ -119,6 +119,10 @@ class DownloadAndLoadMochiModel:
|
|||||||
mm.soft_empty_cache()
|
mm.soft_empty_cache()
|
||||||
|
|
||||||
dtype = {"fp8_e4m3fn": torch.float8_e4m3fn, "fp8_e4m3fn_fast": torch.float8_e4m3fn, "bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
|
dtype = {"fp8_e4m3fn": torch.float8_e4m3fn, "fp8_e4m3fn_fast": torch.float8_e4m3fn, "bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
|
||||||
|
if "fp8" in precision:
|
||||||
|
vae_dtype = torch.bfloat16
|
||||||
|
else:
|
||||||
|
vae_dtype = dtype
|
||||||
|
|
||||||
# Transformer model
|
# Transformer model
|
||||||
model_download_path = os.path.join(folder_paths.models_dir, 'diffusion_models', 'mochi')
|
model_download_path = os.path.join(folder_paths.models_dir, 'diffusion_models', 'mochi')
|
||||||
@ -174,15 +178,15 @@ class DownloadAndLoadMochiModel:
|
|||||||
nonlinearity="silu",
|
nonlinearity="silu",
|
||||||
output_nonlinearity="silu",
|
output_nonlinearity="silu",
|
||||||
causal=True,
|
causal=True,
|
||||||
dtype=dtype,
|
dtype=vae_dtype,
|
||||||
)
|
)
|
||||||
vae_sd = load_torch_file(vae_path)
|
vae_sd = load_torch_file(vae_path)
|
||||||
if is_accelerate_available:
|
if is_accelerate_available:
|
||||||
for key in vae_sd:
|
for key in vae_sd:
|
||||||
set_module_tensor_to_device(vae, key, dtype=dtype, device=offload_device, value=vae_sd[key])
|
set_module_tensor_to_device(vae, key, dtype=vae_dtype, device=offload_device, value=vae_sd[key])
|
||||||
else:
|
else:
|
||||||
vae.load_state_dict(vae_sd, strict=True)
|
vae.load_state_dict(vae_sd, strict=True)
|
||||||
vae.eval().to(dtype).to("cpu")
|
vae.eval().to(vae_dtype).to("cpu")
|
||||||
del vae_sd
|
del vae_sd
|
||||||
|
|
||||||
return (model, vae,)
|
return (model, vae,)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user