diff --git a/vllm/config/load.py b/vllm/config/load.py index e4999e36b49bf..68253359fc567 100644 --- a/vllm/config/load.py +++ b/vllm/config/load.py @@ -51,6 +51,15 @@ class LoadConfig: download_dir: Optional[str] = None """Directory to download and load the weights, default to the default cache directory of Hugging Face.""" + safetensors_load_strategy: Optional[str] = "lazy" + """Specifies the loading strategy for safetensors weights. + - "lazy" (default): Weights are memory-mapped from the file. This enables + on-demand loading and is highly efficient for models on local storage. + - "eager": The entire file is read into CPU memory upfront before loading. + This is recommended for models on network filesystems (e.g., Lustre, NFS) + as it avoids inefficient random reads, significantly speeding up model + initialization. However, it uses more CPU RAM. + """ model_loader_extra_config: Union[dict, TensorizerConfig] = field( default_factory=dict) """Extra config for model loader. This will be passed to the model loader diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 87d90f5147cd8..d9a29511eb529 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -289,6 +289,8 @@ class EngineArgs: trust_remote_code: bool = ModelConfig.trust_remote_code allowed_local_media_path: str = ModelConfig.allowed_local_media_path download_dir: Optional[str] = LoadConfig.download_dir + safetensors_load_strategy: Optional[ + str] = LoadConfig.safetensors_load_strategy load_format: Union[str, LoadFormats] = LoadConfig.load_format config_format: str = ModelConfig.config_format dtype: ModelDType = ModelConfig.dtype @@ -587,6 +589,8 @@ class EngineArgs: load_group.add_argument("--load-format", **load_kwargs["load_format"]) load_group.add_argument("--download-dir", **load_kwargs["download_dir"]) + load_group.add_argument("--safetensors-load-strategy", + **load_kwargs["safetensors_load_strategy"]) load_group.add_argument("--model-loader-extra-config", **load_kwargs["model_loader_extra_config"]) load_group.add_argument("--ignore-patterns", @@ -1023,6 +1027,7 @@ class EngineArgs: return LoadConfig( load_format=self.load_format, download_dir=self.download_dir, + safetensors_load_strategy=self.safetensors_load_strategy, device="cpu" if is_online_quantization(self.quantization) else None, model_loader_extra_config=self.model_loader_extra_config, diff --git a/vllm/model_executor/model_loader/default_loader.py b/vllm/model_executor/model_loader/default_loader.py index f883e1e739102..d1bdec21fd974 100644 --- a/vllm/model_executor/model_loader/default_loader.py +++ b/vllm/model_executor/model_loader/default_loader.py @@ -189,6 +189,7 @@ class DefaultModelLoader(BaseModelLoader): weights_iterator = safetensors_weights_iterator( hf_weights_files, self.load_config.use_tqdm_on_load, + self.load_config.safetensors_load_strategy, ) else: if extra_config.get("enable_multithread_load"): diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 0de8dbbca9c7f..c6ca9cd48d009 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -19,7 +19,7 @@ import huggingface_hub.constants import numpy as np import torch from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download -from safetensors.torch import load_file, safe_open, save_file +from safetensors.torch import load, load_file, safe_open, save_file from tqdm.auto import tqdm from vllm import envs @@ -519,18 +519,28 @@ def np_cache_weights_iterator( def safetensors_weights_iterator( hf_weights_files: list[str], use_tqdm_on_load: bool, + safetensors_load_strategy: Optional[str] = "lazy", ) -> Generator[tuple[str, torch.Tensor], None, None]: """Iterate over the weights in the model safetensor files.""" + loading_desc = "Loading safetensors checkpoint shards" + if safetensors_load_strategy == "eager": + loading_desc += " (eager)" + for st_file in tqdm( hf_weights_files, - desc="Loading safetensors checkpoint shards", + desc=loading_desc, disable=not enable_tqdm(use_tqdm_on_load), bar_format=_BAR_FORMAT, ): - with safe_open(st_file, framework="pt") as f: - for name in f.keys(): # noqa: SIM118 - param = f.get_tensor(name) - yield name, param + if safetensors_load_strategy == "eager": + with open(st_file, "rb") as f: + state_dict = load(f.read()) + yield from state_dict.items() + else: + with safe_open(st_file, framework="pt") as f: + for name in f.keys(): # noqa: SIM118 + param = f.get_tensor(name) + yield name, param def multi_thread_safetensors_weights_iterator(