[Model loader]: support multi-thread model weight loading (#23928)

Signed-off-by: Yang Kaiyong <yangkaiyong.yky@antgroup.com>
Signed-off-by: Simon Mo <simon.mo@hey.com>
Co-authored-by: Simon Mo <simon.mo@hey.com>
This commit is contained in:
Yang Kaiyong 2025-09-09 02:49:39 +08:00 committed by GitHub
parent 7be141b2c5
commit 43d9ad03ba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 105 additions and 12 deletions

View File

@ -18,8 +18,9 @@ from vllm.model_executor.model_loader.weight_utils import (
download_safetensors_index_file_from_hf, download_weights_from_hf, download_safetensors_index_file_from_hf, download_weights_from_hf,
fastsafetensors_weights_iterator, filter_duplicate_safetensors_files, fastsafetensors_weights_iterator, filter_duplicate_safetensors_files,
filter_files_not_needed_for_inference, maybe_download_from_modelscope, filter_files_not_needed_for_inference, maybe_download_from_modelscope,
np_cache_weights_iterator, pt_weights_iterator, multi_thread_pt_weights_iterator,
safetensors_weights_iterator) multi_thread_safetensors_weights_iterator, np_cache_weights_iterator,
pt_weights_iterator, safetensors_weights_iterator)
from vllm.platforms import current_platform from vllm.platforms import current_platform
logger = init_logger(__name__) logger = init_logger(__name__)
@ -28,6 +29,9 @@ logger = init_logger(__name__)
class DefaultModelLoader(BaseModelLoader): class DefaultModelLoader(BaseModelLoader):
"""Model loader that can load different file types from disk.""" """Model loader that can load different file types from disk."""
# default number of thread when enable multithread weight loading
DEFAULT_NUM_THREADS = 8
@dataclasses.dataclass @dataclasses.dataclass
class Source: class Source:
"""A source for weights.""" """A source for weights."""
@ -52,9 +56,15 @@ class DefaultModelLoader(BaseModelLoader):
def __init__(self, load_config: LoadConfig): def __init__(self, load_config: LoadConfig):
super().__init__(load_config) super().__init__(load_config)
if load_config.model_loader_extra_config:
raise ValueError(f"Model loader extra config is not supported for " extra_config = load_config.model_loader_extra_config
f"load format {load_config.load_format}") allowed_keys = {"enable_multithread_load", "num_threads"}
unexpected_keys = set(extra_config.keys()) - allowed_keys
if unexpected_keys:
raise ValueError(f"Unexpected extra config keys for load format "
f"{load_config.load_format}: "
f"{unexpected_keys}")
def _prepare_weights( def _prepare_weights(
self, self,
@ -145,6 +155,7 @@ class DefaultModelLoader(BaseModelLoader):
self, source: "Source" self, source: "Source"
) -> Generator[tuple[str, torch.Tensor], None, None]: ) -> Generator[tuple[str, torch.Tensor], None, None]:
"""Get an iterator for the model weights based on the load format.""" """Get an iterator for the model weights based on the load format."""
extra_config = self.load_config.model_loader_extra_config
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
source.model_or_path, source.revision, source.fall_back_to_pt, source.model_or_path, source.revision, source.fall_back_to_pt,
source.allow_patterns_overrides) source.allow_patterns_overrides)
@ -164,11 +175,29 @@ class DefaultModelLoader(BaseModelLoader):
hf_weights_files, hf_weights_files,
self.load_config.use_tqdm_on_load, self.load_config.use_tqdm_on_load,
) )
else:
if extra_config.get("enable_multithread_load"):
weights_iterator = (
multi_thread_safetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
max_workers=extra_config.get(
"num_threads", self.DEFAULT_NUM_THREADS),
))
else: else:
weights_iterator = safetensors_weights_iterator( weights_iterator = safetensors_weights_iterator(
hf_weights_files, hf_weights_files,
self.load_config.use_tqdm_on_load, self.load_config.use_tqdm_on_load,
) )
else:
if extra_config.get("enable_multithread_load"):
weights_iterator = multi_thread_pt_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
self.load_config.pt_load_map_location,
max_workers=extra_config.get("num_threads",
self.DEFAULT_NUM_THREADS),
)
else: else:
weights_iterator = pt_weights_iterator( weights_iterator = pt_weights_iterator(
hf_weights_files, hf_weights_files,

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Utilities for downloading and initializing model weights.""" """Utilities for downloading and initializing model weights."""
import concurrent.futures
import fnmatch import fnmatch
import glob import glob
import hashlib import hashlib
@ -531,6 +532,36 @@ def safetensors_weights_iterator(
yield name, param yield name, param
def multi_thread_safetensors_weights_iterator(
hf_weights_files: list[str],
use_tqdm_on_load: bool,
max_workers: int = 4,
) -> Generator[tuple[str, torch.Tensor], None, None]:
"""Multi-Thread iterate over the weights in the model safetensor files."""
def _load_file(st_file: str):
result = load_file(st_file, device="cpu")
return result
with concurrent.futures.ThreadPoolExecutor(
max_workers=max_workers) as executor:
futures = [
executor.submit(_load_file, st_file)
for st_file in hf_weights_files
]
futures_iter = tqdm(
concurrent.futures.as_completed(futures),
total=len(hf_weights_files),
desc="Multi-thread loading shards",
disable=not enable_tqdm(use_tqdm_on_load),
bar_format=_BAR_FORMAT,
)
for future in futures_iter:
state_dict = future.result()
yield from state_dict.items()
def runai_safetensors_weights_iterator( def runai_safetensors_weights_iterator(
hf_weights_files: list[str], hf_weights_files: list[str],
use_tqdm_on_load: bool, use_tqdm_on_load: bool,
@ -611,6 +642,39 @@ def pt_weights_iterator(
del state del state
def multi_thread_pt_weights_iterator(
hf_weights_files: list[str],
use_tqdm_on_load: bool,
pt_load_map_location: Union[str, dict[str, str]] = "cpu",
max_workers: int = 4,
) -> Generator[tuple[str, torch.Tensor], None, None]:
"""Multi-Thread iterate over the weights in the model bin/pt files."""
def _load_file(bin_file: str):
return torch.load(bin_file,
map_location=pt_load_map_location,
weights_only=True)
with concurrent.futures.ThreadPoolExecutor(
max_workers=max_workers) as executor:
futures = [
executor.submit(_load_file, bin_file)
for bin_file in hf_weights_files
]
futures_iter = tqdm(
concurrent.futures.as_completed(futures),
total=len(hf_weights_files),
desc="Multi-thread loading pt checkpoint shards",
disable=not enable_tqdm(use_tqdm_on_load),
bar_format=_BAR_FORMAT,
)
for future in futures_iter:
state = future.result()
yield from state.items()
del state
def get_gguf_extra_tensor_names( def get_gguf_extra_tensor_names(
gguf_file: str, gguf_to_hf_name_map: dict[str, str]) -> list[str]: gguf_file: str, gguf_to_hf_name_map: dict[str, str]) -> list[str]:
reader = gguf.GGUFReader(gguf_file) reader = gguf.GGUFReader(gguf_file)