[Core] Support model loader plugins (#21067)

Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
22quinn 2025-07-24 01:49:44 -07:00 committed by GitHub
parent f0f4de8f26
commit 610852a423
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 159 additions and 86 deletions

View File

@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm import SamplingParams from vllm import SamplingParams
from vllm.config import LoadFormat
test_model = "openai-community/gpt2" test_model = "openai-community/gpt2"
@ -17,7 +16,6 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0)
def test_model_loader_download_files(vllm_runner): def test_model_loader_download_files(vllm_runner):
with vllm_runner(test_model, with vllm_runner(test_model, load_format="fastsafetensors") as llm:
load_format=LoadFormat.FASTSAFETENSORS) as llm:
deserialized_outputs = llm.generate(prompts, sampling_params) deserialized_outputs = llm.generate(prompts, sampling_params)
assert deserialized_outputs assert deserialized_outputs

View File

@ -0,0 +1,37 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from torch import nn
from vllm.config import LoadConfig, ModelConfig
from vllm.model_executor.model_loader import (get_model_loader,
register_model_loader)
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
@register_model_loader("custom_load_format")
class CustomModelLoader(BaseModelLoader):
def __init__(self, load_config: LoadConfig) -> None:
super().__init__(load_config)
def download_model(self, model_config: ModelConfig) -> None:
pass
def load_weights(self, model: nn.Module,
model_config: ModelConfig) -> None:
pass
def test_register_model_loader():
load_config = LoadConfig(load_format="custom_load_format")
assert isinstance(get_model_loader(load_config), CustomModelLoader)
def test_invalid_model_loader():
with pytest.raises(ValueError):
@register_model_loader("invalid_load_format")
class InValidModelLoader:
pass

View File

@ -2,9 +2,10 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm import SamplingParams from vllm import SamplingParams
from vllm.config import LoadConfig, LoadFormat from vllm.config import LoadConfig
from vllm.model_executor.model_loader import get_model_loader from vllm.model_executor.model_loader import get_model_loader
load_format = "runai_streamer"
test_model = "openai-community/gpt2" test_model = "openai-community/gpt2"
prompts = [ prompts = [
@ -18,7 +19,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0)
def get_runai_model_loader(): def get_runai_model_loader():
load_config = LoadConfig(load_format=LoadFormat.RUNAI_STREAMER) load_config = LoadConfig(load_format=load_format)
return get_model_loader(load_config) return get_model_loader(load_config)
@ -28,6 +29,6 @@ def test_get_model_loader_with_runai_flag():
def test_runai_model_loader_download_files(vllm_runner): def test_runai_model_loader_download_files(vllm_runner):
with vllm_runner(test_model, load_format=LoadFormat.RUNAI_STREAMER) as llm: with vllm_runner(test_model, load_format=load_format) as llm:
deserialized_outputs = llm.generate(prompts, sampling_params) deserialized_outputs = llm.generate(prompts, sampling_params)
assert deserialized_outputs assert deserialized_outputs

View File

@ -65,7 +65,7 @@ if TYPE_CHECKING:
from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.model_loader import BaseModelLoader from vllm.model_executor.model_loader import LoadFormats
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
ConfigType = type[DataclassInstance] ConfigType = type[DataclassInstance]
@ -78,6 +78,7 @@ else:
QuantizationConfig = Any QuantizationConfig = Any
QuantizationMethods = Any QuantizationMethods = Any
BaseModelLoader = Any BaseModelLoader = Any
LoadFormats = Any
TensorizerConfig = Any TensorizerConfig = Any
ConfigType = type ConfigType = type
HfOverrides = Union[dict[str, Any], Callable[[type], type]] HfOverrides = Union[dict[str, Any], Callable[[type], type]]
@ -1773,29 +1774,12 @@ class CacheConfig:
logger.warning("Possibly too large swap space. %s", msg) logger.warning("Possibly too large swap space. %s", msg)
class LoadFormat(str, enum.Enum):
AUTO = "auto"
PT = "pt"
SAFETENSORS = "safetensors"
NPCACHE = "npcache"
DUMMY = "dummy"
TENSORIZER = "tensorizer"
SHARDED_STATE = "sharded_state"
GGUF = "gguf"
BITSANDBYTES = "bitsandbytes"
MISTRAL = "mistral"
RUNAI_STREAMER = "runai_streamer"
RUNAI_STREAMER_SHARDED = "runai_streamer_sharded"
FASTSAFETENSORS = "fastsafetensors"
@config @config
@dataclass @dataclass
class LoadConfig: class LoadConfig:
"""Configuration for loading the model weights.""" """Configuration for loading the model weights."""
load_format: Union[str, LoadFormat, load_format: Union[str, LoadFormats] = "auto"
"BaseModelLoader"] = LoadFormat.AUTO.value
"""The format of the model weights to load:\n """The format of the model weights to load:\n
- "auto" will try to load the weights in the safetensors format and fall - "auto" will try to load the weights in the safetensors format and fall
back to the pytorch bin format if safetensors format is not available.\n back to the pytorch bin format if safetensors format is not available.\n
@ -1816,7 +1800,8 @@ class LoadConfig:
- "gguf" will load weights from GGUF format files (details specified in - "gguf" will load weights from GGUF format files (details specified in
https://github.com/ggml-org/ggml/blob/master/docs/gguf.md).\n https://github.com/ggml-org/ggml/blob/master/docs/gguf.md).\n
- "mistral" will load weights from consolidated safetensors files used by - "mistral" will load weights from consolidated safetensors files used by
Mistral models.""" Mistral models.
- Other custom values can be supported via plugins."""
download_dir: Optional[str] = None download_dir: Optional[str] = None
"""Directory to download and load the weights, default to the default """Directory to download and load the weights, default to the default
cache directory of Hugging Face.""" cache directory of Hugging Face."""
@ -1864,10 +1849,7 @@ class LoadConfig:
return hash_str return hash_str
def __post_init__(self): def __post_init__(self):
if isinstance(self.load_format, str): self.load_format = self.load_format.lower()
load_format = self.load_format.lower()
self.load_format = LoadFormat(load_format)
if self.ignore_patterns is not None and len(self.ignore_patterns) > 0: if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
logger.info( logger.info(
"Ignoring the following patterns when downloading weights: %s", "Ignoring the following patterns when downloading weights: %s",

View File

@ -26,13 +26,12 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
DetailedTraceModules, Device, DeviceConfig, DetailedTraceModules, Device, DeviceConfig,
DistributedExecutorBackend, GuidedDecodingBackend, DistributedExecutorBackend, GuidedDecodingBackend,
GuidedDecodingBackendV1, HfOverrides, KVEventsConfig, GuidedDecodingBackendV1, HfOverrides, KVEventsConfig,
KVTransferConfig, LoadConfig, LoadFormat, KVTransferConfig, LoadConfig, LogprobsMode,
LogprobsMode, LoRAConfig, ModelConfig, ModelDType, LoRAConfig, ModelConfig, ModelDType, ModelImpl,
ModelImpl, MultiModalConfig, ObservabilityConfig, MultiModalConfig, ObservabilityConfig, ParallelConfig,
ParallelConfig, PoolerConfig, PrefixCachingHashAlgo, PoolerConfig, PrefixCachingHashAlgo, SchedulerConfig,
SchedulerConfig, SchedulerPolicy, SpeculativeConfig, SchedulerPolicy, SpeculativeConfig, TaskOption,
TaskOption, TokenizerMode, VllmConfig, get_attr_docs, TokenizerMode, VllmConfig, get_attr_docs, get_field)
get_field)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform from vllm.platforms import CpuArchEnum, current_platform
from vllm.plugins import load_general_plugins from vllm.plugins import load_general_plugins
@ -47,10 +46,12 @@ from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser,
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.executor.executor_base import ExecutorBase from vllm.executor.executor_base import ExecutorBase
from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.model_loader import LoadFormats
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
else: else:
ExecutorBase = Any ExecutorBase = Any
QuantizationMethods = Any QuantizationMethods = Any
LoadFormats = Any
UsageContext = Any UsageContext = Any
logger = init_logger(__name__) logger = init_logger(__name__)
@ -276,7 +277,7 @@ class EngineArgs:
trust_remote_code: bool = ModelConfig.trust_remote_code trust_remote_code: bool = ModelConfig.trust_remote_code
allowed_local_media_path: str = ModelConfig.allowed_local_media_path allowed_local_media_path: str = ModelConfig.allowed_local_media_path
download_dir: Optional[str] = LoadConfig.download_dir download_dir: Optional[str] = LoadConfig.download_dir
load_format: str = LoadConfig.load_format load_format: Union[str, LoadFormats] = LoadConfig.load_format
config_format: str = ModelConfig.config_format config_format: str = ModelConfig.config_format
dtype: ModelDType = ModelConfig.dtype dtype: ModelDType = ModelConfig.dtype
kv_cache_dtype: CacheDType = CacheConfig.cache_dtype kv_cache_dtype: CacheDType = CacheConfig.cache_dtype
@ -547,9 +548,7 @@ class EngineArgs:
title="LoadConfig", title="LoadConfig",
description=LoadConfig.__doc__, description=LoadConfig.__doc__,
) )
load_group.add_argument("--load-format", load_group.add_argument("--load-format", **load_kwargs["load_format"])
choices=[f.value for f in LoadFormat],
**load_kwargs["load_format"])
load_group.add_argument("--download-dir", load_group.add_argument("--download-dir",
**load_kwargs["download_dir"]) **load_kwargs["download_dir"])
load_group.add_argument("--model-loader-extra-config", load_group.add_argument("--model-loader-extra-config",
@ -864,10 +863,9 @@ class EngineArgs:
# NOTE: This is to allow model loading from S3 in CI # NOTE: This is to allow model loading from S3 in CI
if (not isinstance(self, AsyncEngineArgs) and envs.VLLM_CI_USE_S3 if (not isinstance(self, AsyncEngineArgs) and envs.VLLM_CI_USE_S3
and self.model in MODELS_ON_S3 and self.model in MODELS_ON_S3 and self.load_format == "auto"):
and self.load_format == LoadFormat.AUTO): # noqa: E501
self.model = f"{MODEL_WEIGHTS_S3_BUCKET}/{self.model}" self.model = f"{MODEL_WEIGHTS_S3_BUCKET}/{self.model}"
self.load_format = LoadFormat.RUNAI_STREAMER self.load_format = "runai_streamer"
return ModelConfig( return ModelConfig(
model=self.model, model=self.model,
@ -1299,7 +1297,7 @@ class EngineArgs:
############################################################# #############################################################
# Unsupported Feature Flags on V1. # Unsupported Feature Flags on V1.
if self.load_format == LoadFormat.SHARDED_STATE.value: if self.load_format == "sharded_state":
_raise_or_fallback( _raise_or_fallback(
feature_name=f"--load_format {self.load_format}", feature_name=f"--load_format {self.load_format}",
recommend_to_remove=False) recommend_to_remove=False)

View File

@ -1,11 +1,12 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional from typing import Literal, Optional
from torch import nn from torch import nn
from vllm.config import LoadConfig, LoadFormat, ModelConfig, VllmConfig from vllm.config import LoadConfig, ModelConfig, VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.bitsandbytes_loader import ( from vllm.model_executor.model_loader.bitsandbytes_loader import (
BitsAndBytesModelLoader) BitsAndBytesModelLoader)
@ -20,34 +21,92 @@ from vllm.model_executor.model_loader.tensorizer_loader import TensorizerLoader
from vllm.model_executor.model_loader.utils import ( from vllm.model_executor.model_loader.utils import (
get_architecture_class_name, get_model_architecture, get_model_cls) get_architecture_class_name, get_model_architecture, get_model_cls)
logger = init_logger(__name__)
# Reminder: Please update docstring in `LoadConfig`
# if a new load format is added here
LoadFormats = Literal[
"auto",
"bitsandbytes",
"dummy",
"fastsafetensors",
"gguf",
"mistral",
"npcache",
"pt",
"runai_streamer",
"runai_streamer_sharded",
"safetensors",
"sharded_state",
"tensorizer",
]
_LOAD_FORMAT_TO_MODEL_LOADER: dict[str, type[BaseModelLoader]] = {
"auto": DefaultModelLoader,
"bitsandbytes": BitsAndBytesModelLoader,
"dummy": DummyModelLoader,
"fastsafetensors": DefaultModelLoader,
"gguf": GGUFModelLoader,
"mistral": DefaultModelLoader,
"npcache": DefaultModelLoader,
"pt": DefaultModelLoader,
"runai_streamer": RunaiModelStreamerLoader,
"runai_streamer_sharded": ShardedStateLoader,
"safetensors": DefaultModelLoader,
"sharded_state": ShardedStateLoader,
"tensorizer": TensorizerLoader,
}
def register_model_loader(load_format: str):
"""Register a customized vllm model loader.
When a load format is not supported by vllm, you can register a customized
model loader to support it.
Args:
load_format (str): The model loader format name.
Examples:
>>> from vllm.config import LoadConfig
>>> from vllm.model_executor.model_loader import get_model_loader, register_model_loader
>>> from vllm.model_executor.model_loader.base_loader import BaseModelLoader
>>>
>>> @register_model_loader("my_loader")
... class MyModelLoader(BaseModelLoader):
... def download_model(self):
... pass
...
... def load_weights(self):
... pass
>>>
>>> load_config = LoadConfig(load_format="my_loader")
>>> type(get_model_loader(load_config))
<class 'MyModelLoader'>
""" # noqa: E501
def _wrapper(model_loader_cls):
if load_format in _LOAD_FORMAT_TO_MODEL_LOADER:
logger.warning(
"Load format `%s` is already registered, and will be "
"overwritten by the new loader class `%s`.", load_format,
model_loader_cls)
if not issubclass(model_loader_cls, BaseModelLoader):
raise ValueError("The model loader must be a subclass of "
"`BaseModelLoader`.")
_LOAD_FORMAT_TO_MODEL_LOADER[load_format] = model_loader_cls
logger.info("Registered model loader `%s` with load format `%s`",
model_loader_cls, load_format)
return model_loader_cls
return _wrapper
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
"""Get a model loader based on the load format.""" """Get a model loader based on the load format."""
if isinstance(load_config.load_format, type): load_format = load_config.load_format
return load_config.load_format(load_config) if load_format not in _LOAD_FORMAT_TO_MODEL_LOADER:
raise ValueError(f"Load format `{load_format}` is not supported")
if load_config.load_format == LoadFormat.DUMMY: return _LOAD_FORMAT_TO_MODEL_LOADER[load_format](load_config)
return DummyModelLoader(load_config)
if load_config.load_format == LoadFormat.TENSORIZER:
return TensorizerLoader(load_config)
if load_config.load_format == LoadFormat.SHARDED_STATE:
return ShardedStateLoader(load_config)
if load_config.load_format == LoadFormat.BITSANDBYTES:
return BitsAndBytesModelLoader(load_config)
if load_config.load_format == LoadFormat.GGUF:
return GGUFModelLoader(load_config)
if load_config.load_format == LoadFormat.RUNAI_STREAMER:
return RunaiModelStreamerLoader(load_config)
if load_config.load_format == LoadFormat.RUNAI_STREAMER_SHARDED:
return ShardedStateLoader(load_config, runai_model_streamer=True)
return DefaultModelLoader(load_config)
def get_model(*, def get_model(*,
@ -66,6 +125,7 @@ __all__ = [
"get_architecture_class_name", "get_architecture_class_name",
"get_model_architecture", "get_model_architecture",
"get_model_cls", "get_model_cls",
"register_model_loader",
"BaseModelLoader", "BaseModelLoader",
"BitsAndBytesModelLoader", "BitsAndBytesModelLoader",
"GGUFModelLoader", "GGUFModelLoader",

View File

@ -13,7 +13,7 @@ from torch import nn
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from vllm import envs from vllm import envs
from vllm.config import LoadConfig, LoadFormat, ModelConfig from vllm.config import LoadConfig, ModelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
@ -104,19 +104,19 @@ class DefaultModelLoader(BaseModelLoader):
use_safetensors = False use_safetensors = False
index_file = SAFE_WEIGHTS_INDEX_NAME index_file = SAFE_WEIGHTS_INDEX_NAME
# Some quantized models use .pt files for storing the weights. # Some quantized models use .pt files for storing the weights.
if load_format == LoadFormat.AUTO: if load_format == "auto":
allow_patterns = ["*.safetensors", "*.bin"] allow_patterns = ["*.safetensors", "*.bin"]
elif (load_format == LoadFormat.SAFETENSORS elif (load_format == "safetensors"
or load_format == LoadFormat.FASTSAFETENSORS): or load_format == "fastsafetensors"):
use_safetensors = True use_safetensors = True
allow_patterns = ["*.safetensors"] allow_patterns = ["*.safetensors"]
elif load_format == LoadFormat.MISTRAL: elif load_format == "mistral":
use_safetensors = True use_safetensors = True
allow_patterns = ["consolidated*.safetensors"] allow_patterns = ["consolidated*.safetensors"]
index_file = "consolidated.safetensors.index.json" index_file = "consolidated.safetensors.index.json"
elif load_format == LoadFormat.PT: elif load_format == "pt":
allow_patterns = ["*.pt"] allow_patterns = ["*.pt"]
elif load_format == LoadFormat.NPCACHE: elif load_format == "npcache":
allow_patterns = ["*.bin"] allow_patterns = ["*.bin"]
else: else:
raise ValueError(f"Unknown load_format: {load_format}") raise ValueError(f"Unknown load_format: {load_format}")
@ -178,7 +178,7 @@ class DefaultModelLoader(BaseModelLoader):
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
source.model_or_path, source.revision, source.fall_back_to_pt, source.model_or_path, source.revision, source.fall_back_to_pt,
source.allow_patterns_overrides) source.allow_patterns_overrides)
if self.load_config.load_format == LoadFormat.NPCACHE: if self.load_config.load_format == "npcache":
# Currently np_cache only support *.bin checkpoints # Currently np_cache only support *.bin checkpoints
assert use_safetensors is False assert use_safetensors is False
weights_iterator = np_cache_weights_iterator( weights_iterator = np_cache_weights_iterator(
@ -189,7 +189,7 @@ class DefaultModelLoader(BaseModelLoader):
self.load_config.use_tqdm_on_load, self.load_config.use_tqdm_on_load,
) )
elif use_safetensors: elif use_safetensors:
if self.load_config.load_format == LoadFormat.FASTSAFETENSORS: if self.load_config.load_format == "fastsafetensors":
weights_iterator = fastsafetensors_weights_iterator( weights_iterator = fastsafetensors_weights_iterator(
hf_weights_files, hf_weights_files,
self.load_config.use_tqdm_on_load, self.load_config.use_tqdm_on_load,

View File

@ -32,12 +32,9 @@ class ShardedStateLoader(BaseModelLoader):
DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors" DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"
def __init__(self, def __init__(self, load_config: LoadConfig):
load_config: LoadConfig,
runai_model_streamer: bool = False):
super().__init__(load_config) super().__init__(load_config)
self.runai_model_streamer = runai_model_streamer
extra_config = ({} if load_config.model_loader_extra_config is None extra_config = ({} if load_config.model_loader_extra_config is None
else load_config.model_loader_extra_config.copy()) else load_config.model_loader_extra_config.copy())
self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN) self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN)
@ -152,7 +149,7 @@ class ShardedStateLoader(BaseModelLoader):
def iterate_over_files( def iterate_over_files(
self, paths) -> Generator[tuple[str, torch.Tensor], None, None]: self, paths) -> Generator[tuple[str, torch.Tensor], None, None]:
if self.runai_model_streamer: if self.load_config.load_format == "runai_streamer_sharded":
yield from runai_safetensors_weights_iterator(paths, True) yield from runai_safetensors_weights_iterator(paths, True)
else: else:
from safetensors.torch import safe_open from safetensors.torch import safe_open