mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-17 16:15:36 +08:00
[Model loader]: support multi-thread model weight loading (#23928)
Signed-off-by: Yang Kaiyong <yangkaiyong.yky@antgroup.com> Signed-off-by: Simon Mo <simon.mo@hey.com> Co-authored-by: Simon Mo <simon.mo@hey.com>
This commit is contained in:
parent
7be141b2c5
commit
43d9ad03ba
@ -18,8 +18,9 @@ from vllm.model_executor.model_loader.weight_utils import (
|
|||||||
download_safetensors_index_file_from_hf, download_weights_from_hf,
|
download_safetensors_index_file_from_hf, download_weights_from_hf,
|
||||||
fastsafetensors_weights_iterator, filter_duplicate_safetensors_files,
|
fastsafetensors_weights_iterator, filter_duplicate_safetensors_files,
|
||||||
filter_files_not_needed_for_inference, maybe_download_from_modelscope,
|
filter_files_not_needed_for_inference, maybe_download_from_modelscope,
|
||||||
np_cache_weights_iterator, pt_weights_iterator,
|
multi_thread_pt_weights_iterator,
|
||||||
safetensors_weights_iterator)
|
multi_thread_safetensors_weights_iterator, np_cache_weights_iterator,
|
||||||
|
pt_weights_iterator, safetensors_weights_iterator)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -28,6 +29,9 @@ logger = init_logger(__name__)
|
|||||||
class DefaultModelLoader(BaseModelLoader):
|
class DefaultModelLoader(BaseModelLoader):
|
||||||
"""Model loader that can load different file types from disk."""
|
"""Model loader that can load different file types from disk."""
|
||||||
|
|
||||||
|
# default number of thread when enable multithread weight loading
|
||||||
|
DEFAULT_NUM_THREADS = 8
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class Source:
|
class Source:
|
||||||
"""A source for weights."""
|
"""A source for weights."""
|
||||||
@ -52,9 +56,15 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
|
|
||||||
def __init__(self, load_config: LoadConfig):
|
def __init__(self, load_config: LoadConfig):
|
||||||
super().__init__(load_config)
|
super().__init__(load_config)
|
||||||
if load_config.model_loader_extra_config:
|
|
||||||
raise ValueError(f"Model loader extra config is not supported for "
|
extra_config = load_config.model_loader_extra_config
|
||||||
f"load format {load_config.load_format}")
|
allowed_keys = {"enable_multithread_load", "num_threads"}
|
||||||
|
unexpected_keys = set(extra_config.keys()) - allowed_keys
|
||||||
|
|
||||||
|
if unexpected_keys:
|
||||||
|
raise ValueError(f"Unexpected extra config keys for load format "
|
||||||
|
f"{load_config.load_format}: "
|
||||||
|
f"{unexpected_keys}")
|
||||||
|
|
||||||
def _prepare_weights(
|
def _prepare_weights(
|
||||||
self,
|
self,
|
||||||
@ -145,6 +155,7 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
self, source: "Source"
|
self, source: "Source"
|
||||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||||
"""Get an iterator for the model weights based on the load format."""
|
"""Get an iterator for the model weights based on the load format."""
|
||||||
|
extra_config = self.load_config.model_loader_extra_config
|
||||||
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
|
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
|
||||||
source.model_or_path, source.revision, source.fall_back_to_pt,
|
source.model_or_path, source.revision, source.fall_back_to_pt,
|
||||||
source.allow_patterns_overrides)
|
source.allow_patterns_overrides)
|
||||||
@ -164,11 +175,29 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
hf_weights_files,
|
hf_weights_files,
|
||||||
self.load_config.use_tqdm_on_load,
|
self.load_config.use_tqdm_on_load,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
if extra_config.get("enable_multithread_load"):
|
||||||
|
weights_iterator = (
|
||||||
|
multi_thread_safetensors_weights_iterator(
|
||||||
|
hf_weights_files,
|
||||||
|
self.load_config.use_tqdm_on_load,
|
||||||
|
max_workers=extra_config.get(
|
||||||
|
"num_threads", self.DEFAULT_NUM_THREADS),
|
||||||
|
))
|
||||||
else:
|
else:
|
||||||
weights_iterator = safetensors_weights_iterator(
|
weights_iterator = safetensors_weights_iterator(
|
||||||
hf_weights_files,
|
hf_weights_files,
|
||||||
self.load_config.use_tqdm_on_load,
|
self.load_config.use_tqdm_on_load,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
if extra_config.get("enable_multithread_load"):
|
||||||
|
weights_iterator = multi_thread_pt_weights_iterator(
|
||||||
|
hf_weights_files,
|
||||||
|
self.load_config.use_tqdm_on_load,
|
||||||
|
self.load_config.pt_load_map_location,
|
||||||
|
max_workers=extra_config.get("num_threads",
|
||||||
|
self.DEFAULT_NUM_THREADS),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
weights_iterator = pt_weights_iterator(
|
weights_iterator = pt_weights_iterator(
|
||||||
hf_weights_files,
|
hf_weights_files,
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
"""Utilities for downloading and initializing model weights."""
|
"""Utilities for downloading and initializing model weights."""
|
||||||
|
import concurrent.futures
|
||||||
import fnmatch
|
import fnmatch
|
||||||
import glob
|
import glob
|
||||||
import hashlib
|
import hashlib
|
||||||
@ -531,6 +532,36 @@ def safetensors_weights_iterator(
|
|||||||
yield name, param
|
yield name, param
|
||||||
|
|
||||||
|
|
||||||
|
def multi_thread_safetensors_weights_iterator(
|
||||||
|
hf_weights_files: list[str],
|
||||||
|
use_tqdm_on_load: bool,
|
||||||
|
max_workers: int = 4,
|
||||||
|
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||||
|
"""Multi-Thread iterate over the weights in the model safetensor files."""
|
||||||
|
|
||||||
|
def _load_file(st_file: str):
|
||||||
|
result = load_file(st_file, device="cpu")
|
||||||
|
return result
|
||||||
|
|
||||||
|
with concurrent.futures.ThreadPoolExecutor(
|
||||||
|
max_workers=max_workers) as executor:
|
||||||
|
futures = [
|
||||||
|
executor.submit(_load_file, st_file)
|
||||||
|
for st_file in hf_weights_files
|
||||||
|
]
|
||||||
|
futures_iter = tqdm(
|
||||||
|
concurrent.futures.as_completed(futures),
|
||||||
|
total=len(hf_weights_files),
|
||||||
|
desc="Multi-thread loading shards",
|
||||||
|
disable=not enable_tqdm(use_tqdm_on_load),
|
||||||
|
bar_format=_BAR_FORMAT,
|
||||||
|
)
|
||||||
|
|
||||||
|
for future in futures_iter:
|
||||||
|
state_dict = future.result()
|
||||||
|
yield from state_dict.items()
|
||||||
|
|
||||||
|
|
||||||
def runai_safetensors_weights_iterator(
|
def runai_safetensors_weights_iterator(
|
||||||
hf_weights_files: list[str],
|
hf_weights_files: list[str],
|
||||||
use_tqdm_on_load: bool,
|
use_tqdm_on_load: bool,
|
||||||
@ -611,6 +642,39 @@ def pt_weights_iterator(
|
|||||||
del state
|
del state
|
||||||
|
|
||||||
|
|
||||||
|
def multi_thread_pt_weights_iterator(
|
||||||
|
hf_weights_files: list[str],
|
||||||
|
use_tqdm_on_load: bool,
|
||||||
|
pt_load_map_location: Union[str, dict[str, str]] = "cpu",
|
||||||
|
max_workers: int = 4,
|
||||||
|
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||||
|
"""Multi-Thread iterate over the weights in the model bin/pt files."""
|
||||||
|
|
||||||
|
def _load_file(bin_file: str):
|
||||||
|
return torch.load(bin_file,
|
||||||
|
map_location=pt_load_map_location,
|
||||||
|
weights_only=True)
|
||||||
|
|
||||||
|
with concurrent.futures.ThreadPoolExecutor(
|
||||||
|
max_workers=max_workers) as executor:
|
||||||
|
futures = [
|
||||||
|
executor.submit(_load_file, bin_file)
|
||||||
|
for bin_file in hf_weights_files
|
||||||
|
]
|
||||||
|
futures_iter = tqdm(
|
||||||
|
concurrent.futures.as_completed(futures),
|
||||||
|
total=len(hf_weights_files),
|
||||||
|
desc="Multi-thread loading pt checkpoint shards",
|
||||||
|
disable=not enable_tqdm(use_tqdm_on_load),
|
||||||
|
bar_format=_BAR_FORMAT,
|
||||||
|
)
|
||||||
|
|
||||||
|
for future in futures_iter:
|
||||||
|
state = future.result()
|
||||||
|
yield from state.items()
|
||||||
|
del state
|
||||||
|
|
||||||
|
|
||||||
def get_gguf_extra_tensor_names(
|
def get_gguf_extra_tensor_names(
|
||||||
gguf_file: str, gguf_to_hf_name_map: dict[str, str]) -> list[str]:
|
gguf_file: str, gguf_to_hf_name_map: dict[str, str]) -> list[str]:
|
||||||
reader = gguf.GGUFReader(gguf_file)
|
reader = gguf.GGUFReader(gguf_file)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user