[Core] feat: Add --safetensors-load-strategy flag for faster safetensors loading from Lustre (#24469)

Signed-off-by: Shiqi Sheng <shengshiqi@google.com>
Signed-off-by: shengshiqi-google <160179165+shengshiqi-google@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
shengshiqi-google 2025-09-11 06:10:01 +00:00 committed by GitHub
parent ee0bc5e1b4
commit 41329a0ff9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 31 additions and 6 deletions

View File

@ -51,6 +51,15 @@ class LoadConfig:
download_dir: Optional[str] = None download_dir: Optional[str] = None
"""Directory to download and load the weights, default to the default """Directory to download and load the weights, default to the default
cache directory of Hugging Face.""" cache directory of Hugging Face."""
safetensors_load_strategy: Optional[str] = "lazy"
"""Specifies the loading strategy for safetensors weights.
- "lazy" (default): Weights are memory-mapped from the file. This enables
on-demand loading and is highly efficient for models on local storage.
- "eager": The entire file is read into CPU memory upfront before loading.
This is recommended for models on network filesystems (e.g., Lustre, NFS)
as it avoids inefficient random reads, significantly speeding up model
initialization. However, it uses more CPU RAM.
"""
model_loader_extra_config: Union[dict, TensorizerConfig] = field( model_loader_extra_config: Union[dict, TensorizerConfig] = field(
default_factory=dict) default_factory=dict)
"""Extra config for model loader. This will be passed to the model loader """Extra config for model loader. This will be passed to the model loader

View File

@ -289,6 +289,8 @@ class EngineArgs:
trust_remote_code: bool = ModelConfig.trust_remote_code trust_remote_code: bool = ModelConfig.trust_remote_code
allowed_local_media_path: str = ModelConfig.allowed_local_media_path allowed_local_media_path: str = ModelConfig.allowed_local_media_path
download_dir: Optional[str] = LoadConfig.download_dir download_dir: Optional[str] = LoadConfig.download_dir
safetensors_load_strategy: Optional[
str] = LoadConfig.safetensors_load_strategy
load_format: Union[str, LoadFormats] = LoadConfig.load_format load_format: Union[str, LoadFormats] = LoadConfig.load_format
config_format: str = ModelConfig.config_format config_format: str = ModelConfig.config_format
dtype: ModelDType = ModelConfig.dtype dtype: ModelDType = ModelConfig.dtype
@ -587,6 +589,8 @@ class EngineArgs:
load_group.add_argument("--load-format", **load_kwargs["load_format"]) load_group.add_argument("--load-format", **load_kwargs["load_format"])
load_group.add_argument("--download-dir", load_group.add_argument("--download-dir",
**load_kwargs["download_dir"]) **load_kwargs["download_dir"])
load_group.add_argument("--safetensors-load-strategy",
**load_kwargs["safetensors_load_strategy"])
load_group.add_argument("--model-loader-extra-config", load_group.add_argument("--model-loader-extra-config",
**load_kwargs["model_loader_extra_config"]) **load_kwargs["model_loader_extra_config"])
load_group.add_argument("--ignore-patterns", load_group.add_argument("--ignore-patterns",
@ -1023,6 +1027,7 @@ class EngineArgs:
return LoadConfig( return LoadConfig(
load_format=self.load_format, load_format=self.load_format,
download_dir=self.download_dir, download_dir=self.download_dir,
safetensors_load_strategy=self.safetensors_load_strategy,
device="cpu" device="cpu"
if is_online_quantization(self.quantization) else None, if is_online_quantization(self.quantization) else None,
model_loader_extra_config=self.model_loader_extra_config, model_loader_extra_config=self.model_loader_extra_config,

View File

@ -189,6 +189,7 @@ class DefaultModelLoader(BaseModelLoader):
weights_iterator = safetensors_weights_iterator( weights_iterator = safetensors_weights_iterator(
hf_weights_files, hf_weights_files,
self.load_config.use_tqdm_on_load, self.load_config.use_tqdm_on_load,
self.load_config.safetensors_load_strategy,
) )
else: else:
if extra_config.get("enable_multithread_load"): if extra_config.get("enable_multithread_load"):

View File

@ -19,7 +19,7 @@ import huggingface_hub.constants
import numpy as np import numpy as np
import torch import torch
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
from safetensors.torch import load_file, safe_open, save_file from safetensors.torch import load, load_file, safe_open, save_file
from tqdm.auto import tqdm from tqdm.auto import tqdm
from vllm import envs from vllm import envs
@ -519,18 +519,28 @@ def np_cache_weights_iterator(
def safetensors_weights_iterator( def safetensors_weights_iterator(
hf_weights_files: list[str], hf_weights_files: list[str],
use_tqdm_on_load: bool, use_tqdm_on_load: bool,
safetensors_load_strategy: Optional[str] = "lazy",
) -> Generator[tuple[str, torch.Tensor], None, None]: ) -> Generator[tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files.""" """Iterate over the weights in the model safetensor files."""
loading_desc = "Loading safetensors checkpoint shards"
if safetensors_load_strategy == "eager":
loading_desc += " (eager)"
for st_file in tqdm( for st_file in tqdm(
hf_weights_files, hf_weights_files,
desc="Loading safetensors checkpoint shards", desc=loading_desc,
disable=not enable_tqdm(use_tqdm_on_load), disable=not enable_tqdm(use_tqdm_on_load),
bar_format=_BAR_FORMAT, bar_format=_BAR_FORMAT,
): ):
with safe_open(st_file, framework="pt") as f: if safetensors_load_strategy == "eager":
for name in f.keys(): # noqa: SIM118 with open(st_file, "rb") as f:
param = f.get_tensor(name) state_dict = load(f.read())
yield name, param yield from state_dict.items()
else:
with safe_open(st_file, framework="pt") as f:
for name in f.keys(): # noqa: SIM118
param = f.get_tensor(name)
yield name, param
def multi_thread_safetensors_weights_iterator( def multi_thread_safetensors_weights_iterator(