mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-20 18:15:01 +08:00
149 lines
4.7 KiB
Python
149 lines
4.7 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from typing import Literal
|
|
|
|
from torch import nn
|
|
|
|
from vllm.config import ModelConfig, VllmConfig
|
|
from vllm.config.load import LoadConfig
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
|
from vllm.model_executor.model_loader.bitsandbytes_loader import BitsAndBytesModelLoader
|
|
from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
|
|
from vllm.model_executor.model_loader.dummy_loader import DummyModelLoader
|
|
from vllm.model_executor.model_loader.gguf_loader import GGUFModelLoader
|
|
from vllm.model_executor.model_loader.runai_streamer_loader import (
|
|
RunaiModelStreamerLoader,
|
|
)
|
|
from vllm.model_executor.model_loader.sharded_state_loader import ShardedStateLoader
|
|
from vllm.model_executor.model_loader.tensorizer_loader import TensorizerLoader
|
|
from vllm.model_executor.model_loader.utils import (
|
|
get_architecture_class_name,
|
|
get_model_architecture,
|
|
get_model_cls,
|
|
)
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
# Reminder: Please update docstring in `LoadConfig`
|
|
# if a new load format is added here
|
|
LoadFormats = Literal[
|
|
"auto",
|
|
"bitsandbytes",
|
|
"dummy",
|
|
"fastsafetensors",
|
|
"gguf",
|
|
"mistral",
|
|
"npcache",
|
|
"pt",
|
|
"runai_streamer",
|
|
"runai_streamer_sharded",
|
|
"safetensors",
|
|
"sharded_state",
|
|
"tensorizer",
|
|
]
|
|
_LOAD_FORMAT_TO_MODEL_LOADER: dict[str, type[BaseModelLoader]] = {
|
|
"auto": DefaultModelLoader,
|
|
"bitsandbytes": BitsAndBytesModelLoader,
|
|
"dummy": DummyModelLoader,
|
|
"fastsafetensors": DefaultModelLoader,
|
|
"gguf": GGUFModelLoader,
|
|
"mistral": DefaultModelLoader,
|
|
"npcache": DefaultModelLoader,
|
|
"pt": DefaultModelLoader,
|
|
"runai_streamer": RunaiModelStreamerLoader,
|
|
"runai_streamer_sharded": ShardedStateLoader,
|
|
"safetensors": DefaultModelLoader,
|
|
"sharded_state": ShardedStateLoader,
|
|
"tensorizer": TensorizerLoader,
|
|
}
|
|
|
|
|
|
def register_model_loader(load_format: str):
|
|
"""Register a customized vllm model loader.
|
|
|
|
When a load format is not supported by vllm, you can register a customized
|
|
model loader to support it.
|
|
|
|
Args:
|
|
load_format (str): The model loader format name.
|
|
|
|
Examples:
|
|
>>> from vllm.config.load import LoadConfig
|
|
>>> from vllm.model_executor.model_loader import (
|
|
... get_model_loader,
|
|
... register_model_loader,
|
|
... )
|
|
>>> from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
|
>>>
|
|
>>> @register_model_loader("my_loader")
|
|
... class MyModelLoader(BaseModelLoader):
|
|
... def download_model(self):
|
|
... pass
|
|
...
|
|
... def load_weights(self):
|
|
... pass
|
|
>>>
|
|
>>> load_config = LoadConfig(load_format="my_loader")
|
|
>>> type(get_model_loader(load_config))
|
|
<class 'MyModelLoader'>
|
|
""" # noqa: E501
|
|
|
|
def _wrapper(model_loader_cls):
|
|
if load_format in _LOAD_FORMAT_TO_MODEL_LOADER:
|
|
logger.warning(
|
|
"Load format `%s` is already registered, and will be "
|
|
"overwritten by the new loader class `%s`.",
|
|
load_format,
|
|
model_loader_cls,
|
|
)
|
|
if not issubclass(model_loader_cls, BaseModelLoader):
|
|
raise ValueError(
|
|
"The model loader must be a subclass of `BaseModelLoader`."
|
|
)
|
|
_LOAD_FORMAT_TO_MODEL_LOADER[load_format] = model_loader_cls
|
|
logger.info(
|
|
"Registered model loader `%s` with load format `%s`",
|
|
model_loader_cls,
|
|
load_format,
|
|
)
|
|
return model_loader_cls
|
|
|
|
return _wrapper
|
|
|
|
|
|
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
|
"""Get a model loader based on the load format."""
|
|
load_format = load_config.load_format
|
|
if load_format not in _LOAD_FORMAT_TO_MODEL_LOADER:
|
|
raise ValueError(f"Load format `{load_format}` is not supported")
|
|
return _LOAD_FORMAT_TO_MODEL_LOADER[load_format](load_config)
|
|
|
|
|
|
def get_model(
|
|
*, vllm_config: VllmConfig, model_config: ModelConfig | None = None
|
|
) -> nn.Module:
|
|
loader = get_model_loader(vllm_config.load_config)
|
|
if model_config is None:
|
|
model_config = vllm_config.model_config
|
|
return loader.load_model(vllm_config=vllm_config, model_config=model_config)
|
|
|
|
|
|
__all__ = [
|
|
"get_model",
|
|
"get_model_loader",
|
|
"get_architecture_class_name",
|
|
"get_model_architecture",
|
|
"get_model_cls",
|
|
"register_model_loader",
|
|
"BaseModelLoader",
|
|
"BitsAndBytesModelLoader",
|
|
"GGUFModelLoader",
|
|
"DefaultModelLoader",
|
|
"DummyModelLoader",
|
|
"RunaiModelStreamerLoader",
|
|
"ShardedStateLoader",
|
|
"TensorizerLoader",
|
|
]
|