mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 10:24:28 +08:00
[Core] feat: Add --safetensors-load-strategy flag for faster safetensors loading from Lustre (#24469)
Signed-off-by: Shiqi Sheng <shengshiqi@google.com> Signed-off-by: shengshiqi-google <160179165+shengshiqi-google@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
parent
ee0bc5e1b4
commit
41329a0ff9
@ -51,6 +51,15 @@ class LoadConfig:
|
||||
download_dir: Optional[str] = None
|
||||
"""Directory to download and load the weights, default to the default
|
||||
cache directory of Hugging Face."""
|
||||
safetensors_load_strategy: Optional[str] = "lazy"
|
||||
"""Specifies the loading strategy for safetensors weights.
|
||||
- "lazy" (default): Weights are memory-mapped from the file. This enables
|
||||
on-demand loading and is highly efficient for models on local storage.
|
||||
- "eager": The entire file is read into CPU memory upfront before loading.
|
||||
This is recommended for models on network filesystems (e.g., Lustre, NFS)
|
||||
as it avoids inefficient random reads, significantly speeding up model
|
||||
initialization. However, it uses more CPU RAM.
|
||||
"""
|
||||
model_loader_extra_config: Union[dict, TensorizerConfig] = field(
|
||||
default_factory=dict)
|
||||
"""Extra config for model loader. This will be passed to the model loader
|
||||
|
||||
@ -289,6 +289,8 @@ class EngineArgs:
|
||||
trust_remote_code: bool = ModelConfig.trust_remote_code
|
||||
allowed_local_media_path: str = ModelConfig.allowed_local_media_path
|
||||
download_dir: Optional[str] = LoadConfig.download_dir
|
||||
safetensors_load_strategy: Optional[
|
||||
str] = LoadConfig.safetensors_load_strategy
|
||||
load_format: Union[str, LoadFormats] = LoadConfig.load_format
|
||||
config_format: str = ModelConfig.config_format
|
||||
dtype: ModelDType = ModelConfig.dtype
|
||||
@ -587,6 +589,8 @@ class EngineArgs:
|
||||
load_group.add_argument("--load-format", **load_kwargs["load_format"])
|
||||
load_group.add_argument("--download-dir",
|
||||
**load_kwargs["download_dir"])
|
||||
load_group.add_argument("--safetensors-load-strategy",
|
||||
**load_kwargs["safetensors_load_strategy"])
|
||||
load_group.add_argument("--model-loader-extra-config",
|
||||
**load_kwargs["model_loader_extra_config"])
|
||||
load_group.add_argument("--ignore-patterns",
|
||||
@ -1023,6 +1027,7 @@ class EngineArgs:
|
||||
return LoadConfig(
|
||||
load_format=self.load_format,
|
||||
download_dir=self.download_dir,
|
||||
safetensors_load_strategy=self.safetensors_load_strategy,
|
||||
device="cpu"
|
||||
if is_online_quantization(self.quantization) else None,
|
||||
model_loader_extra_config=self.model_loader_extra_config,
|
||||
|
||||
@ -189,6 +189,7 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
weights_iterator = safetensors_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
self.load_config.safetensors_load_strategy,
|
||||
)
|
||||
else:
|
||||
if extra_config.get("enable_multithread_load"):
|
||||
|
||||
@ -19,7 +19,7 @@ import huggingface_hub.constants
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
|
||||
from safetensors.torch import load_file, safe_open, save_file
|
||||
from safetensors.torch import load, load_file, safe_open, save_file
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from vllm import envs
|
||||
@ -519,18 +519,28 @@ def np_cache_weights_iterator(
|
||||
def safetensors_weights_iterator(
|
||||
hf_weights_files: list[str],
|
||||
use_tqdm_on_load: bool,
|
||||
safetensors_load_strategy: Optional[str] = "lazy",
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
"""Iterate over the weights in the model safetensor files."""
|
||||
loading_desc = "Loading safetensors checkpoint shards"
|
||||
if safetensors_load_strategy == "eager":
|
||||
loading_desc += " (eager)"
|
||||
|
||||
for st_file in tqdm(
|
||||
hf_weights_files,
|
||||
desc="Loading safetensors checkpoint shards",
|
||||
desc=loading_desc,
|
||||
disable=not enable_tqdm(use_tqdm_on_load),
|
||||
bar_format=_BAR_FORMAT,
|
||||
):
|
||||
with safe_open(st_file, framework="pt") as f:
|
||||
for name in f.keys(): # noqa: SIM118
|
||||
param = f.get_tensor(name)
|
||||
yield name, param
|
||||
if safetensors_load_strategy == "eager":
|
||||
with open(st_file, "rb") as f:
|
||||
state_dict = load(f.read())
|
||||
yield from state_dict.items()
|
||||
else:
|
||||
with safe_open(st_file, framework="pt") as f:
|
||||
for name in f.keys(): # noqa: SIM118
|
||||
param = f.get_tensor(name)
|
||||
yield name, param
|
||||
|
||||
|
||||
def multi_thread_safetensors_weights_iterator(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user