mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 03:35:01 +08:00
322 lines
12 KiB
Python
322 lines
12 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import dataclasses
|
|
import glob
|
|
import os
|
|
import time
|
|
from collections.abc import Generator, Iterable
|
|
from typing import cast
|
|
|
|
import torch
|
|
from torch import nn
|
|
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
|
|
|
from vllm.config import ModelConfig
|
|
from vllm.config.load import LoadConfig
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.layers.quantization.torchao import torchao_version_at_least
|
|
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
|
from vllm.model_executor.model_loader.weight_utils import (
|
|
download_safetensors_index_file_from_hf,
|
|
download_weights_from_hf,
|
|
fastsafetensors_weights_iterator,
|
|
filter_duplicate_safetensors_files,
|
|
filter_files_not_needed_for_inference,
|
|
get_quant_config,
|
|
maybe_download_from_modelscope,
|
|
multi_thread_pt_weights_iterator,
|
|
multi_thread_safetensors_weights_iterator,
|
|
np_cache_weights_iterator,
|
|
pt_weights_iterator,
|
|
safetensors_weights_iterator,
|
|
)
|
|
from vllm.platforms import current_platform
|
|
from vllm.transformers_utils.repo_utils import list_filtered_repo_files
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class DefaultModelLoader(BaseModelLoader):
|
|
"""Model loader that can load different file types from disk."""
|
|
|
|
# default number of thread when enable multithread weight loading
|
|
DEFAULT_NUM_THREADS = 8
|
|
|
|
@dataclasses.dataclass
|
|
class Source:
|
|
"""A source for weights."""
|
|
|
|
model_or_path: str
|
|
"""The model ID or path."""
|
|
|
|
revision: str | None
|
|
"""The optional model revision."""
|
|
|
|
prefix: str = ""
|
|
"""A prefix to prepend to all weights."""
|
|
|
|
fall_back_to_pt: bool = True
|
|
"""Whether .pt weights can be used."""
|
|
|
|
allow_patterns_overrides: list[str] | None = None
|
|
"""If defined, weights will load exclusively using these patterns."""
|
|
|
|
counter_before_loading_weights: float = 0.0
|
|
counter_after_loading_weights: float = 0.0
|
|
|
|
def __init__(self, load_config: LoadConfig):
|
|
super().__init__(load_config)
|
|
|
|
extra_config = load_config.model_loader_extra_config
|
|
allowed_keys = {"enable_multithread_load", "num_threads"}
|
|
unexpected_keys = set(extra_config.keys()) - allowed_keys
|
|
|
|
if unexpected_keys:
|
|
raise ValueError(
|
|
f"Unexpected extra config keys for load format "
|
|
f"{load_config.load_format}: "
|
|
f"{unexpected_keys}"
|
|
)
|
|
|
|
def _prepare_weights(
|
|
self,
|
|
model_name_or_path: str,
|
|
revision: str | None,
|
|
fall_back_to_pt: bool,
|
|
allow_patterns_overrides: list[str] | None,
|
|
) -> tuple[str, list[str], bool]:
|
|
"""Prepare weights for the model.
|
|
|
|
If the model is not local, it will be downloaded."""
|
|
model_name_or_path = (
|
|
maybe_download_from_modelscope(model_name_or_path, revision)
|
|
or model_name_or_path
|
|
)
|
|
|
|
is_local = os.path.isdir(model_name_or_path)
|
|
load_format = self.load_config.load_format
|
|
use_safetensors = False
|
|
index_file = SAFE_WEIGHTS_INDEX_NAME
|
|
|
|
# First check for 'auto' format that mistral files format are present.
|
|
# This is to load mistral models with official format by default.
|
|
if load_format == "auto":
|
|
load_format = (
|
|
"mistral"
|
|
if len(
|
|
list_filtered_repo_files(
|
|
model_name_or_path=model_name_or_path,
|
|
allow_patterns=["consolidated*.safetensors"],
|
|
revision=revision,
|
|
)
|
|
)
|
|
> 0
|
|
else "hf"
|
|
)
|
|
|
|
# Some quantized models use .pt files for storing the weights.
|
|
if load_format == "hf":
|
|
allow_patterns = ["*.safetensors", "*.bin"]
|
|
elif load_format == "safetensors" or load_format == "fastsafetensors":
|
|
use_safetensors = True
|
|
allow_patterns = ["*.safetensors"]
|
|
elif load_format == "mistral":
|
|
use_safetensors = True
|
|
allow_patterns = ["consolidated*.safetensors"]
|
|
index_file = "consolidated.safetensors.index.json"
|
|
elif load_format == "pt":
|
|
allow_patterns = ["*.pt"]
|
|
elif load_format == "npcache":
|
|
allow_patterns = ["*.bin"]
|
|
else:
|
|
raise ValueError(f"Unknown load_format: {load_format}")
|
|
|
|
if fall_back_to_pt:
|
|
allow_patterns += ["*.pt"]
|
|
|
|
if allow_patterns_overrides is not None:
|
|
allow_patterns = allow_patterns_overrides
|
|
|
|
if not is_local:
|
|
hf_folder = download_weights_from_hf(
|
|
model_name_or_path,
|
|
self.load_config.download_dir,
|
|
allow_patterns,
|
|
revision,
|
|
ignore_patterns=self.load_config.ignore_patterns,
|
|
)
|
|
else:
|
|
hf_folder = model_name_or_path
|
|
|
|
hf_weights_files: list[str] = []
|
|
for pattern in allow_patterns:
|
|
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
|
|
if len(hf_weights_files) > 0:
|
|
if pattern == "*.safetensors":
|
|
use_safetensors = True
|
|
break
|
|
|
|
if use_safetensors:
|
|
# For models like Mistral-7B-Instruct-v0.3
|
|
# there are both sharded safetensors files and a consolidated
|
|
# safetensors file. Using both breaks.
|
|
# Here, we download the `model.safetensors.index.json` and filter
|
|
# any files not found in the index.
|
|
if not is_local:
|
|
download_safetensors_index_file_from_hf(
|
|
model_name_or_path,
|
|
index_file,
|
|
self.load_config.download_dir,
|
|
revision,
|
|
)
|
|
hf_weights_files = filter_duplicate_safetensors_files(
|
|
hf_weights_files, hf_folder, index_file
|
|
)
|
|
else:
|
|
hf_weights_files = filter_files_not_needed_for_inference(hf_weights_files)
|
|
|
|
if len(hf_weights_files) == 0:
|
|
raise RuntimeError(
|
|
f"Cannot find any model weights with `{model_name_or_path}`"
|
|
)
|
|
|
|
return hf_folder, hf_weights_files, use_safetensors
|
|
|
|
def _get_weights_iterator(
|
|
self, source: "Source"
|
|
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
|
"""Get an iterator for the model weights based on the load format."""
|
|
extra_config = self.load_config.model_loader_extra_config
|
|
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
|
|
source.model_or_path,
|
|
source.revision,
|
|
source.fall_back_to_pt,
|
|
source.allow_patterns_overrides,
|
|
)
|
|
if self.load_config.load_format == "npcache":
|
|
# Currently np_cache only support *.bin checkpoints
|
|
assert use_safetensors is False
|
|
weights_iterator = np_cache_weights_iterator(
|
|
source.model_or_path,
|
|
self.load_config.download_dir,
|
|
hf_folder,
|
|
hf_weights_files,
|
|
self.load_config.use_tqdm_on_load,
|
|
)
|
|
elif use_safetensors:
|
|
if self.load_config.load_format == "fastsafetensors":
|
|
weights_iterator = fastsafetensors_weights_iterator(
|
|
hf_weights_files,
|
|
self.load_config.use_tqdm_on_load,
|
|
)
|
|
else:
|
|
if extra_config.get("enable_multithread_load"):
|
|
weights_iterator = multi_thread_safetensors_weights_iterator(
|
|
hf_weights_files,
|
|
self.load_config.use_tqdm_on_load,
|
|
max_workers=extra_config.get(
|
|
"num_threads", self.DEFAULT_NUM_THREADS
|
|
),
|
|
)
|
|
else:
|
|
weights_iterator = safetensors_weights_iterator(
|
|
hf_weights_files,
|
|
self.load_config.use_tqdm_on_load,
|
|
self.load_config.safetensors_load_strategy,
|
|
)
|
|
else:
|
|
if extra_config.get("enable_multithread_load"):
|
|
weights_iterator = multi_thread_pt_weights_iterator(
|
|
hf_weights_files,
|
|
self.load_config.use_tqdm_on_load,
|
|
self.load_config.pt_load_map_location,
|
|
max_workers=extra_config.get(
|
|
"num_threads", self.DEFAULT_NUM_THREADS
|
|
),
|
|
)
|
|
else:
|
|
weights_iterator = pt_weights_iterator(
|
|
hf_weights_files,
|
|
self.load_config.use_tqdm_on_load,
|
|
self.load_config.pt_load_map_location,
|
|
)
|
|
|
|
if current_platform.is_tpu():
|
|
from vllm.platforms.tpu import USE_TPU_INFERENCE
|
|
|
|
if not USE_TPU_INFERENCE:
|
|
# In PyTorch XLA, we should call `torch_xla.sync`
|
|
# frequently so that not too many ops are accumulated
|
|
# in the XLA program.
|
|
import torch_xla
|
|
|
|
def _xla_weights_iterator(iterator: Generator):
|
|
for weights in iterator:
|
|
yield weights
|
|
torch_xla.sync(wait=False)
|
|
|
|
weights_iterator = _xla_weights_iterator(weights_iterator)
|
|
|
|
if self.counter_before_loading_weights == 0.0:
|
|
self.counter_before_loading_weights = time.perf_counter()
|
|
# Apply the prefix.
|
|
return ((source.prefix + name, tensor) for (name, tensor) in weights_iterator)
|
|
|
|
def get_all_weights(
|
|
self,
|
|
model_config: ModelConfig,
|
|
model: nn.Module,
|
|
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
|
primary_weights = DefaultModelLoader.Source(
|
|
model_config.model,
|
|
model_config.revision,
|
|
prefix="",
|
|
fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True),
|
|
allow_patterns_overrides=getattr(model, "allow_patterns_overrides", None),
|
|
)
|
|
yield from self._get_weights_iterator(primary_weights)
|
|
|
|
secondary_weights = cast(
|
|
Iterable[DefaultModelLoader.Source],
|
|
getattr(model, "secondary_weights", ()),
|
|
)
|
|
for source in secondary_weights:
|
|
yield from self._get_weights_iterator(source)
|
|
|
|
def download_model(self, model_config: ModelConfig) -> None:
|
|
self._prepare_weights(
|
|
model_config.model,
|
|
model_config.revision,
|
|
fall_back_to_pt=True,
|
|
allow_patterns_overrides=None,
|
|
)
|
|
|
|
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
|
|
if model_config.quantization == "torchao":
|
|
quant_config = get_quant_config(model_config, self.load_config)
|
|
if (
|
|
hasattr(quant_config, "is_checkpoint_torchao_serialized")
|
|
and quant_config.is_checkpoint_torchao_serialized
|
|
and torchao_version_at_least("0.15.0")
|
|
):
|
|
self.load_config.safetensors_load_strategy = "torchao"
|
|
|
|
weights_to_load = {name for name, _ in model.named_parameters()}
|
|
loaded_weights = model.load_weights(self.get_all_weights(model_config, model))
|
|
|
|
self.counter_after_loading_weights = time.perf_counter()
|
|
logger.info_once(
|
|
"Loading weights took %.2f seconds",
|
|
self.counter_after_loading_weights - self.counter_before_loading_weights,
|
|
scope="local",
|
|
)
|
|
# We only enable strict check for non-quantized models
|
|
# that have loaded weights tracking currently.
|
|
if model_config.quantization is None and loaded_weights is not None:
|
|
weights_not_loaded = weights_to_load - loaded_weights
|
|
if weights_not_loaded:
|
|
raise ValueError(
|
|
"Following weights were not initialized from "
|
|
f"checkpoint: {weights_not_loaded}"
|
|
)
|