From 610852a42327836eb804c1a893fcb9bb6d52c29c Mon Sep 17 00:00:00 2001 From: 22quinn <33176974+22quinn@users.noreply.github.com> Date: Thu, 24 Jul 2025 01:49:44 -0700 Subject: [PATCH] [Core] Support model loader plugins (#21067) Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com> --- .../test_fastsafetensors_loader.py | 4 +- tests/model_executor/model_loader/__init__.py | 0 .../model_loader/test_registry.py | 37 ++++++ .../test_runai_model_streamer_loader.py | 7 +- vllm/config.py | 30 +---- vllm/engine/arg_utils.py | 28 ++--- vllm/model_executor/model_loader/__init__.py | 114 +++++++++++++----- .../model_loader/default_loader.py | 18 +-- .../model_loader/sharded_state_loader.py | 7 +- 9 files changed, 159 insertions(+), 86 deletions(-) create mode 100644 tests/model_executor/model_loader/__init__.py create mode 100644 tests/model_executor/model_loader/test_registry.py diff --git a/tests/fastsafetensors_loader/test_fastsafetensors_loader.py b/tests/fastsafetensors_loader/test_fastsafetensors_loader.py index 1b95bf59f67c6..afd411ff4874e 100644 --- a/tests/fastsafetensors_loader/test_fastsafetensors_loader.py +++ b/tests/fastsafetensors_loader/test_fastsafetensors_loader.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from vllm import SamplingParams -from vllm.config import LoadFormat test_model = "openai-community/gpt2" @@ -17,7 +16,6 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0) def test_model_loader_download_files(vllm_runner): - with vllm_runner(test_model, - load_format=LoadFormat.FASTSAFETENSORS) as llm: + with vllm_runner(test_model, load_format="fastsafetensors") as llm: deserialized_outputs = llm.generate(prompts, sampling_params) assert deserialized_outputs diff --git a/tests/model_executor/model_loader/__init__.py b/tests/model_executor/model_loader/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/model_executor/model_loader/test_registry.py b/tests/model_executor/model_loader/test_registry.py new file mode 100644 index 0000000000000..93a3e34835b5a --- /dev/null +++ b/tests/model_executor/model_loader/test_registry.py @@ -0,0 +1,37 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +from torch import nn + +from vllm.config import LoadConfig, ModelConfig +from vllm.model_executor.model_loader import (get_model_loader, + register_model_loader) +from vllm.model_executor.model_loader.base_loader import BaseModelLoader + + +@register_model_loader("custom_load_format") +class CustomModelLoader(BaseModelLoader): + + def __init__(self, load_config: LoadConfig) -> None: + super().__init__(load_config) + + def download_model(self, model_config: ModelConfig) -> None: + pass + + def load_weights(self, model: nn.Module, + model_config: ModelConfig) -> None: + pass + + +def test_register_model_loader(): + load_config = LoadConfig(load_format="custom_load_format") + assert isinstance(get_model_loader(load_config), CustomModelLoader) + + +def test_invalid_model_loader(): + with pytest.raises(ValueError): + + @register_model_loader("invalid_load_format") + class InValidModelLoader: + pass diff --git a/tests/runai_model_streamer_test/test_runai_model_streamer_loader.py b/tests/runai_model_streamer_test/test_runai_model_streamer_loader.py index e27d9958f2917..84c615b6b8dbc 100644 --- a/tests/runai_model_streamer_test/test_runai_model_streamer_loader.py +++ b/tests/runai_model_streamer_test/test_runai_model_streamer_loader.py @@ -2,9 +2,10 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from vllm import SamplingParams -from vllm.config import LoadConfig, LoadFormat +from vllm.config import LoadConfig from vllm.model_executor.model_loader import get_model_loader +load_format = "runai_streamer" test_model = "openai-community/gpt2" prompts = [ @@ -18,7 +19,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0) def get_runai_model_loader(): - load_config = LoadConfig(load_format=LoadFormat.RUNAI_STREAMER) + load_config = LoadConfig(load_format=load_format) return get_model_loader(load_config) @@ -28,6 +29,6 @@ def test_get_model_loader_with_runai_flag(): def test_runai_model_loader_download_files(vllm_runner): - with vllm_runner(test_model, load_format=LoadFormat.RUNAI_STREAMER) as llm: + with vllm_runner(test_model, load_format=load_format) as llm: deserialized_outputs = llm.generate(prompts, sampling_params) assert deserialized_outputs diff --git a/vllm/config.py b/vllm/config.py index f038cdd64c67a..07df71ec51ef3 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -65,7 +65,7 @@ if TYPE_CHECKING: from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) - from vllm.model_executor.model_loader import BaseModelLoader + from vllm.model_executor.model_loader import LoadFormats from vllm.model_executor.model_loader.tensorizer import TensorizerConfig ConfigType = type[DataclassInstance] @@ -78,6 +78,7 @@ else: QuantizationConfig = Any QuantizationMethods = Any BaseModelLoader = Any + LoadFormats = Any TensorizerConfig = Any ConfigType = type HfOverrides = Union[dict[str, Any], Callable[[type], type]] @@ -1773,29 +1774,12 @@ class CacheConfig: logger.warning("Possibly too large swap space. %s", msg) -class LoadFormat(str, enum.Enum): - AUTO = "auto" - PT = "pt" - SAFETENSORS = "safetensors" - NPCACHE = "npcache" - DUMMY = "dummy" - TENSORIZER = "tensorizer" - SHARDED_STATE = "sharded_state" - GGUF = "gguf" - BITSANDBYTES = "bitsandbytes" - MISTRAL = "mistral" - RUNAI_STREAMER = "runai_streamer" - RUNAI_STREAMER_SHARDED = "runai_streamer_sharded" - FASTSAFETENSORS = "fastsafetensors" - - @config @dataclass class LoadConfig: """Configuration for loading the model weights.""" - load_format: Union[str, LoadFormat, - "BaseModelLoader"] = LoadFormat.AUTO.value + load_format: Union[str, LoadFormats] = "auto" """The format of the model weights to load:\n - "auto" will try to load the weights in the safetensors format and fall back to the pytorch bin format if safetensors format is not available.\n @@ -1816,7 +1800,8 @@ class LoadConfig: - "gguf" will load weights from GGUF format files (details specified in https://github.com/ggml-org/ggml/blob/master/docs/gguf.md).\n - "mistral" will load weights from consolidated safetensors files used by - Mistral models.""" + Mistral models. + - Other custom values can be supported via plugins.""" download_dir: Optional[str] = None """Directory to download and load the weights, default to the default cache directory of Hugging Face.""" @@ -1864,10 +1849,7 @@ class LoadConfig: return hash_str def __post_init__(self): - if isinstance(self.load_format, str): - load_format = self.load_format.lower() - self.load_format = LoadFormat(load_format) - + self.load_format = self.load_format.lower() if self.ignore_patterns is not None and len(self.ignore_patterns) > 0: logger.info( "Ignoring the following patterns when downloading weights: %s", diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index aec75f82631a2..7099680047182 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -26,13 +26,12 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig, DetailedTraceModules, Device, DeviceConfig, DistributedExecutorBackend, GuidedDecodingBackend, GuidedDecodingBackendV1, HfOverrides, KVEventsConfig, - KVTransferConfig, LoadConfig, LoadFormat, - LogprobsMode, LoRAConfig, ModelConfig, ModelDType, - ModelImpl, MultiModalConfig, ObservabilityConfig, - ParallelConfig, PoolerConfig, PrefixCachingHashAlgo, - SchedulerConfig, SchedulerPolicy, SpeculativeConfig, - TaskOption, TokenizerMode, VllmConfig, get_attr_docs, - get_field) + KVTransferConfig, LoadConfig, LogprobsMode, + LoRAConfig, ModelConfig, ModelDType, ModelImpl, + MultiModalConfig, ObservabilityConfig, ParallelConfig, + PoolerConfig, PrefixCachingHashAlgo, SchedulerConfig, + SchedulerPolicy, SpeculativeConfig, TaskOption, + TokenizerMode, VllmConfig, get_attr_docs, get_field) from vllm.logger import init_logger from vllm.platforms import CpuArchEnum, current_platform from vllm.plugins import load_general_plugins @@ -47,10 +46,12 @@ from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser, if TYPE_CHECKING: from vllm.executor.executor_base import ExecutorBase from vllm.model_executor.layers.quantization import QuantizationMethods + from vllm.model_executor.model_loader import LoadFormats from vllm.usage.usage_lib import UsageContext else: ExecutorBase = Any QuantizationMethods = Any + LoadFormats = Any UsageContext = Any logger = init_logger(__name__) @@ -276,7 +277,7 @@ class EngineArgs: trust_remote_code: bool = ModelConfig.trust_remote_code allowed_local_media_path: str = ModelConfig.allowed_local_media_path download_dir: Optional[str] = LoadConfig.download_dir - load_format: str = LoadConfig.load_format + load_format: Union[str, LoadFormats] = LoadConfig.load_format config_format: str = ModelConfig.config_format dtype: ModelDType = ModelConfig.dtype kv_cache_dtype: CacheDType = CacheConfig.cache_dtype @@ -547,9 +548,7 @@ class EngineArgs: title="LoadConfig", description=LoadConfig.__doc__, ) - load_group.add_argument("--load-format", - choices=[f.value for f in LoadFormat], - **load_kwargs["load_format"]) + load_group.add_argument("--load-format", **load_kwargs["load_format"]) load_group.add_argument("--download-dir", **load_kwargs["download_dir"]) load_group.add_argument("--model-loader-extra-config", @@ -864,10 +863,9 @@ class EngineArgs: # NOTE: This is to allow model loading from S3 in CI if (not isinstance(self, AsyncEngineArgs) and envs.VLLM_CI_USE_S3 - and self.model in MODELS_ON_S3 - and self.load_format == LoadFormat.AUTO): # noqa: E501 + and self.model in MODELS_ON_S3 and self.load_format == "auto"): self.model = f"{MODEL_WEIGHTS_S3_BUCKET}/{self.model}" - self.load_format = LoadFormat.RUNAI_STREAMER + self.load_format = "runai_streamer" return ModelConfig( model=self.model, @@ -1299,7 +1297,7 @@ class EngineArgs: ############################################################# # Unsupported Feature Flags on V1. - if self.load_format == LoadFormat.SHARDED_STATE.value: + if self.load_format == "sharded_state": _raise_or_fallback( feature_name=f"--load_format {self.load_format}", recommend_to_remove=False) diff --git a/vllm/model_executor/model_loader/__init__.py b/vllm/model_executor/model_loader/__init__.py index 78681a0463710..2dada794a8f3e 100644 --- a/vllm/model_executor/model_loader/__init__.py +++ b/vllm/model_executor/model_loader/__init__.py @@ -1,11 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional +from typing import Literal, Optional from torch import nn -from vllm.config import LoadConfig, LoadFormat, ModelConfig, VllmConfig +from vllm.config import LoadConfig, ModelConfig, VllmConfig +from vllm.logger import init_logger from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.bitsandbytes_loader import ( BitsAndBytesModelLoader) @@ -20,34 +21,92 @@ from vllm.model_executor.model_loader.tensorizer_loader import TensorizerLoader from vllm.model_executor.model_loader.utils import ( get_architecture_class_name, get_model_architecture, get_model_cls) +logger = init_logger(__name__) + +# Reminder: Please update docstring in `LoadConfig` +# if a new load format is added here +LoadFormats = Literal[ + "auto", + "bitsandbytes", + "dummy", + "fastsafetensors", + "gguf", + "mistral", + "npcache", + "pt", + "runai_streamer", + "runai_streamer_sharded", + "safetensors", + "sharded_state", + "tensorizer", +] +_LOAD_FORMAT_TO_MODEL_LOADER: dict[str, type[BaseModelLoader]] = { + "auto": DefaultModelLoader, + "bitsandbytes": BitsAndBytesModelLoader, + "dummy": DummyModelLoader, + "fastsafetensors": DefaultModelLoader, + "gguf": GGUFModelLoader, + "mistral": DefaultModelLoader, + "npcache": DefaultModelLoader, + "pt": DefaultModelLoader, + "runai_streamer": RunaiModelStreamerLoader, + "runai_streamer_sharded": ShardedStateLoader, + "safetensors": DefaultModelLoader, + "sharded_state": ShardedStateLoader, + "tensorizer": TensorizerLoader, +} + + +def register_model_loader(load_format: str): + """Register a customized vllm model loader. + + When a load format is not supported by vllm, you can register a customized + model loader to support it. + + Args: + load_format (str): The model loader format name. + + Examples: + >>> from vllm.config import LoadConfig + >>> from vllm.model_executor.model_loader import get_model_loader, register_model_loader + >>> from vllm.model_executor.model_loader.base_loader import BaseModelLoader + >>> + >>> @register_model_loader("my_loader") + ... class MyModelLoader(BaseModelLoader): + ... def download_model(self): + ... pass + ... + ... def load_weights(self): + ... pass + >>> + >>> load_config = LoadConfig(load_format="my_loader") + >>> type(get_model_loader(load_config)) + + """ # noqa: E501 + + def _wrapper(model_loader_cls): + if load_format in _LOAD_FORMAT_TO_MODEL_LOADER: + logger.warning( + "Load format `%s` is already registered, and will be " + "overwritten by the new loader class `%s`.", load_format, + model_loader_cls) + if not issubclass(model_loader_cls, BaseModelLoader): + raise ValueError("The model loader must be a subclass of " + "`BaseModelLoader`.") + _LOAD_FORMAT_TO_MODEL_LOADER[load_format] = model_loader_cls + logger.info("Registered model loader `%s` with load format `%s`", + model_loader_cls, load_format) + return model_loader_cls + + return _wrapper + def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: """Get a model loader based on the load format.""" - if isinstance(load_config.load_format, type): - return load_config.load_format(load_config) - - if load_config.load_format == LoadFormat.DUMMY: - return DummyModelLoader(load_config) - - if load_config.load_format == LoadFormat.TENSORIZER: - return TensorizerLoader(load_config) - - if load_config.load_format == LoadFormat.SHARDED_STATE: - return ShardedStateLoader(load_config) - - if load_config.load_format == LoadFormat.BITSANDBYTES: - return BitsAndBytesModelLoader(load_config) - - if load_config.load_format == LoadFormat.GGUF: - return GGUFModelLoader(load_config) - - if load_config.load_format == LoadFormat.RUNAI_STREAMER: - return RunaiModelStreamerLoader(load_config) - - if load_config.load_format == LoadFormat.RUNAI_STREAMER_SHARDED: - return ShardedStateLoader(load_config, runai_model_streamer=True) - - return DefaultModelLoader(load_config) + load_format = load_config.load_format + if load_format not in _LOAD_FORMAT_TO_MODEL_LOADER: + raise ValueError(f"Load format `{load_format}` is not supported") + return _LOAD_FORMAT_TO_MODEL_LOADER[load_format](load_config) def get_model(*, @@ -66,6 +125,7 @@ __all__ = [ "get_architecture_class_name", "get_model_architecture", "get_model_cls", + "register_model_loader", "BaseModelLoader", "BitsAndBytesModelLoader", "GGUFModelLoader", diff --git a/vllm/model_executor/model_loader/default_loader.py b/vllm/model_executor/model_loader/default_loader.py index 2fcae7eb6e6c5..36568e881ebb1 100644 --- a/vllm/model_executor/model_loader/default_loader.py +++ b/vllm/model_executor/model_loader/default_loader.py @@ -13,7 +13,7 @@ from torch import nn from transformers.utils import SAFE_WEIGHTS_INDEX_NAME from vllm import envs -from vllm.config import LoadConfig, LoadFormat, ModelConfig +from vllm.config import LoadConfig, ModelConfig from vllm.logger import init_logger from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.weight_utils import ( @@ -104,19 +104,19 @@ class DefaultModelLoader(BaseModelLoader): use_safetensors = False index_file = SAFE_WEIGHTS_INDEX_NAME # Some quantized models use .pt files for storing the weights. - if load_format == LoadFormat.AUTO: + if load_format == "auto": allow_patterns = ["*.safetensors", "*.bin"] - elif (load_format == LoadFormat.SAFETENSORS - or load_format == LoadFormat.FASTSAFETENSORS): + elif (load_format == "safetensors" + or load_format == "fastsafetensors"): use_safetensors = True allow_patterns = ["*.safetensors"] - elif load_format == LoadFormat.MISTRAL: + elif load_format == "mistral": use_safetensors = True allow_patterns = ["consolidated*.safetensors"] index_file = "consolidated.safetensors.index.json" - elif load_format == LoadFormat.PT: + elif load_format == "pt": allow_patterns = ["*.pt"] - elif load_format == LoadFormat.NPCACHE: + elif load_format == "npcache": allow_patterns = ["*.bin"] else: raise ValueError(f"Unknown load_format: {load_format}") @@ -178,7 +178,7 @@ class DefaultModelLoader(BaseModelLoader): hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( source.model_or_path, source.revision, source.fall_back_to_pt, source.allow_patterns_overrides) - if self.load_config.load_format == LoadFormat.NPCACHE: + if self.load_config.load_format == "npcache": # Currently np_cache only support *.bin checkpoints assert use_safetensors is False weights_iterator = np_cache_weights_iterator( @@ -189,7 +189,7 @@ class DefaultModelLoader(BaseModelLoader): self.load_config.use_tqdm_on_load, ) elif use_safetensors: - if self.load_config.load_format == LoadFormat.FASTSAFETENSORS: + if self.load_config.load_format == "fastsafetensors": weights_iterator = fastsafetensors_weights_iterator( hf_weights_files, self.load_config.use_tqdm_on_load, diff --git a/vllm/model_executor/model_loader/sharded_state_loader.py b/vllm/model_executor/model_loader/sharded_state_loader.py index 2fd9cfba3f61a..3edd4ec4007e8 100644 --- a/vllm/model_executor/model_loader/sharded_state_loader.py +++ b/vllm/model_executor/model_loader/sharded_state_loader.py @@ -32,12 +32,9 @@ class ShardedStateLoader(BaseModelLoader): DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors" - def __init__(self, - load_config: LoadConfig, - runai_model_streamer: bool = False): + def __init__(self, load_config: LoadConfig): super().__init__(load_config) - self.runai_model_streamer = runai_model_streamer extra_config = ({} if load_config.model_loader_extra_config is None else load_config.model_loader_extra_config.copy()) self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN) @@ -152,7 +149,7 @@ class ShardedStateLoader(BaseModelLoader): def iterate_over_files( self, paths) -> Generator[tuple[str, torch.Tensor], None, None]: - if self.runai_model_streamer: + if self.load_config.load_format == "runai_streamer_sharded": yield from runai_safetensors_weights_iterator(paths, True) else: from safetensors.torch import safe_open