[Core] feat: Add --safetensors-load-strategy flag for faster safetensors loading from Lustre (#24469)

Signed-off-by: Shiqi Sheng <shengshiqi@google.com>
Signed-off-by: shengshiqi-google <160179165+shengshiqi-google@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
shengshiqi-google 2025-09-11 06:10:01 +00:00 committed by GitHub
parent ee0bc5e1b4
commit 41329a0ff9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 31 additions and 6 deletions

View File

@ -51,6 +51,15 @@ class LoadConfig:
download_dir: Optional[str] = None
"""Directory to download and load the weights, default to the default
cache directory of Hugging Face."""
safetensors_load_strategy: Optional[str] = "lazy"
"""Specifies the loading strategy for safetensors weights.
- "lazy" (default): Weights are memory-mapped from the file. This enables
on-demand loading and is highly efficient for models on local storage.
- "eager": The entire file is read into CPU memory upfront before loading.
This is recommended for models on network filesystems (e.g., Lustre, NFS)
as it avoids inefficient random reads, significantly speeding up model
initialization. However, it uses more CPU RAM.
"""
model_loader_extra_config: Union[dict, TensorizerConfig] = field(
default_factory=dict)
"""Extra config for model loader. This will be passed to the model loader

View File

@ -289,6 +289,8 @@ class EngineArgs:
trust_remote_code: bool = ModelConfig.trust_remote_code
allowed_local_media_path: str = ModelConfig.allowed_local_media_path
download_dir: Optional[str] = LoadConfig.download_dir
safetensors_load_strategy: Optional[
str] = LoadConfig.safetensors_load_strategy
load_format: Union[str, LoadFormats] = LoadConfig.load_format
config_format: str = ModelConfig.config_format
dtype: ModelDType = ModelConfig.dtype
@ -587,6 +589,8 @@ class EngineArgs:
load_group.add_argument("--load-format", **load_kwargs["load_format"])
load_group.add_argument("--download-dir",
**load_kwargs["download_dir"])
load_group.add_argument("--safetensors-load-strategy",
**load_kwargs["safetensors_load_strategy"])
load_group.add_argument("--model-loader-extra-config",
**load_kwargs["model_loader_extra_config"])
load_group.add_argument("--ignore-patterns",
@ -1023,6 +1027,7 @@ class EngineArgs:
return LoadConfig(
load_format=self.load_format,
download_dir=self.download_dir,
safetensors_load_strategy=self.safetensors_load_strategy,
device="cpu"
if is_online_quantization(self.quantization) else None,
model_loader_extra_config=self.model_loader_extra_config,

View File

@ -189,6 +189,7 @@ class DefaultModelLoader(BaseModelLoader):
weights_iterator = safetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
self.load_config.safetensors_load_strategy,
)
else:
if extra_config.get("enable_multithread_load"):

View File

@ -19,7 +19,7 @@ import huggingface_hub.constants
import numpy as np
import torch
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
from safetensors.torch import load_file, safe_open, save_file
from safetensors.torch import load, load_file, safe_open, save_file
from tqdm.auto import tqdm
from vllm import envs
@ -519,18 +519,28 @@ def np_cache_weights_iterator(
def safetensors_weights_iterator(
hf_weights_files: list[str],
use_tqdm_on_load: bool,
safetensors_load_strategy: Optional[str] = "lazy",
) -> Generator[tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files."""
loading_desc = "Loading safetensors checkpoint shards"
if safetensors_load_strategy == "eager":
loading_desc += " (eager)"
for st_file in tqdm(
hf_weights_files,
desc="Loading safetensors checkpoint shards",
desc=loading_desc,
disable=not enable_tqdm(use_tqdm_on_load),
bar_format=_BAR_FORMAT,
):
with safe_open(st_file, framework="pt") as f:
for name in f.keys(): # noqa: SIM118
param = f.get_tensor(name)
yield name, param
if safetensors_load_strategy == "eager":
with open(st_file, "rb") as f:
state_dict = load(f.read())
yield from state_dict.items()
else:
with safe_open(st_file, framework="pt") as f:
for name in f.keys(): # noqa: SIM118
param = f.get_tensor(name)
yield name, param
def multi_thread_safetensors_weights_iterator(