[Core] Support configuration parsing plugin (#24277)

Signed-off-by: Xingyu Liu <charlotteliu12x@gmail.com>
Signed-off-by: Xingyu Liu <38244988+charlotte12l@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Xingyu Liu 2025-09-10 11:32:43 -07:00 committed by GitHub
parent 4032949630
commit 9fb74c27a7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 237 additions and 107 deletions

View File

View File

@ -0,0 +1,37 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from pathlib import Path
from typing import Optional, Union
import pytest
from transformers import PretrainedConfig
from vllm.transformers_utils.config import (get_config_parser,
register_config_parser)
from vllm.transformers_utils.config_parser_base import ConfigParserBase
@register_config_parser("custom_config_parser")
class CustomConfigParser(ConfigParserBase):
def parse(self,
model: Union[str, Path],
trust_remote_code: bool,
revision: Optional[str] = None,
code_revision: Optional[str] = None,
**kwargs) -> tuple[dict, PretrainedConfig]:
raise NotImplementedError
def test_register_config_parser():
assert isinstance(get_config_parser("custom_config_parser"),
CustomConfigParser)
def test_invalid_config_parser():
with pytest.raises(ValueError):
@register_config_parser("invalid_config_parser")
class InvalidConfigParser:
pass

View File

@ -419,7 +419,7 @@ class ModelConfig:
`--media-io-kwargs '{"video": {"num_frames": 40} }'` """
use_async_output_proc: bool = True
"""Whether to use async output processor."""
config_format: Union[str, ConfigFormat] = ConfigFormat.AUTO.value
config_format: Union[str, ConfigFormat] = "auto"
"""The format of the model config to load:\n
- "auto" will try to load the config in hf format if available else it
will try to load in mistral format.\n
@ -624,9 +624,6 @@ class ModelConfig:
raise ValueError(
"Sleep mode is not supported on current platform.")
if isinstance(self.config_format, str):
self.config_format = ConfigFormat(self.config_format)
hf_config = get_config(self.hf_config_path or self.model,
self.trust_remote_code,
self.revision,

View File

@ -22,9 +22,9 @@ from typing_extensions import TypeIs, deprecated
import vllm.envs as envs
from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
ConfigFormat, ConfigType, ConvertOption,
DecodingConfig, DetailedTraceModules, Device,
DeviceConfig, DistributedExecutorBackend, EPLBConfig,
ConfigType, ConvertOption, DecodingConfig,
DetailedTraceModules, Device, DeviceConfig,
DistributedExecutorBackend, EPLBConfig,
GuidedDecodingBackend, HfOverrides, KVEventsConfig,
KVTransferConfig, LoadConfig, LogprobsMode,
LoRAConfig, MambaDType, MMEncoderTPMode, ModelConfig,
@ -547,7 +547,6 @@ class EngineArgs:
help="Disable async output processing. This may result in "
"lower performance.")
model_group.add_argument("--config-format",
choices=[f.value for f in ConfigFormat],
**model_kwargs["config_format"])
# This one is a special case because it can bool
# or str. TODO: Handle this in get_kwargs

View File

@ -1,13 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
import json
import os
import time
from functools import cache, partial
from pathlib import Path
from typing import Any, Callable, Optional, TypeVar, Union
from typing import Any, Callable, Literal, Optional, TypeVar, Union
import huggingface_hub
from huggingface_hub import get_safetensors_metadata, hf_hub_download
@ -27,6 +26,7 @@ from transformers.utils import CONFIG_NAME as HF_CONFIG_NAME
from vllm import envs
from vllm.logger import init_logger
from vllm.transformers_utils.config_parser_base import ConfigParserBase
from vllm.transformers_utils.utils import check_gguf_file
if envs.VLLM_USE_MODELSCOPE:
@ -100,10 +100,163 @@ _AUTO_CONFIG_KWARGS_OVERRIDES: dict[str, dict[str, Any]] = {
}
class ConfigFormat(str, enum.Enum):
AUTO = "auto"
HF = "hf"
MISTRAL = "mistral"
class HFConfigParser(ConfigParserBase):
def parse(self,
model: Union[str, Path],
trust_remote_code: bool,
revision: Optional[str] = None,
code_revision: Optional[str] = None,
**kwargs) -> tuple[dict, PretrainedConfig]:
kwargs["local_files_only"] = huggingface_hub.constants.HF_HUB_OFFLINE
config_dict, _ = PretrainedConfig.get_config_dict(
model,
revision=revision,
code_revision=code_revision,
token=_get_hf_token(),
**kwargs,
)
# Use custom model class if it's in our registry
model_type = config_dict.get("model_type")
if model_type is None:
model_type = "speculators" if config_dict.get(
"speculators_config") is not None else model_type
if model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[model_type]
config = config_class.from_pretrained(
model,
revision=revision,
code_revision=code_revision,
token=_get_hf_token(),
**kwargs,
)
else:
try:
kwargs = _maybe_update_auto_config_kwargs(
kwargs, model_type=model_type)
config = AutoConfig.from_pretrained(
model,
trust_remote_code=trust_remote_code,
revision=revision,
code_revision=code_revision,
token=_get_hf_token(),
**kwargs,
)
except ValueError as e:
if (not trust_remote_code
and "requires you to execute the configuration file"
in str(e)):
err_msg = (
"Failed to load the model config. If the model "
"is a custom model not yet available in the "
"HuggingFace transformers library, consider setting "
"`trust_remote_code=True` in LLM or using the "
"`--trust-remote-code` flag in the CLI.")
raise RuntimeError(err_msg) from e
else:
raise e
config = _maybe_remap_hf_config_attrs(config)
return config_dict, config
class MistralConfigParser(ConfigParserBase):
def parse(self,
model: Union[str, Path],
trust_remote_code: bool,
revision: Optional[str] = None,
code_revision: Optional[str] = None,
**kwargs) -> tuple[dict, PretrainedConfig]:
# This function loads a params.json config which
# should be used when loading models in mistral format
config_dict = _download_mistral_config_file(model, revision)
if (max_position_embeddings :=
config_dict.get("max_position_embeddings")) is None:
max_position_embeddings = _maybe_retrieve_max_pos_from_hf(
model, revision, **kwargs)
config_dict["max_position_embeddings"] = max_position_embeddings
from vllm.transformers_utils.configs.mistral import adapt_config_dict
config = adapt_config_dict(config_dict)
# Mistral configs may define sliding_window as list[int]. Convert it
# to int and add the layer_types list[str] to make it HF compatible
if ((sliding_window := getattr(config, "sliding_window", None))
and isinstance(sliding_window, list)):
pattern_repeats = config.num_hidden_layers // len(sliding_window)
layer_types = sliding_window * pattern_repeats
config.layer_types = [
"full_attention" if layer_type is None else "sliding_attention"
for layer_type in layer_types
]
config.sliding_window = next(filter(None, sliding_window), None)
return config_dict, config
_CONFIG_FORMAT_TO_CONFIG_PARSER: dict[str, type[ConfigParserBase]] = {
"hf": HFConfigParser,
"mistral": MistralConfigParser,
}
ConfigFormat = Literal[
"auto",
"hf",
"mistral",
]
def get_config_parser(config_format: str) -> ConfigParserBase:
"""Get the config parser for a given config format."""
if config_format not in _CONFIG_FORMAT_TO_CONFIG_PARSER:
raise ValueError(f"Unknown config format `{config_format}`.")
return _CONFIG_FORMAT_TO_CONFIG_PARSER[config_format]()
def register_config_parser(config_format: str):
"""Register a customized vllm config parser.
When a config format is not supported by vllm, you can register a customized
config parser to support it.
Args:
config_format (str): The config parser format name.
Examples:
>>> from vllm.transformers_utils.config import (get_config_parser,
register_config_parser)
>>> from vllm.transformers_utils.config_parser_base import ConfigParserBase
>>>
>>> @register_config_parser("custom_config_parser")
... class CustomConfigParser(ConfigParserBase):
... def parse(self,
... model: Union[str, Path],
... trust_remote_code: bool,
... revision: Optional[str] = None,
... code_revision: Optional[str] = None,
... **kwargs) -> tuple[dict, PretrainedConfig]:
... raise NotImplementedError
>>>
>>> type(get_config_parser("custom_config_parser"))
<class 'CustomConfigParser'>
""" # noqa: E501
def _wrapper(config_parser_cls):
if config_format in _CONFIG_FORMAT_TO_CONFIG_PARSER:
logger.warning(
"Config format `%s` is already registered, and will be "
"overwritten by the new parser class `%s`.", config_format,
config_parser_cls)
if not issubclass(config_parser_cls, ConfigParserBase):
raise ValueError("The config parser must be a subclass of "
"`ConfigParserBase`.")
_CONFIG_FORMAT_TO_CONFIG_PARSER[config_format] = config_parser_cls
logger.info("Registered config parser `%s` with config format `%s`",
config_parser_cls, config_format)
return config_parser_cls
return _wrapper
_R = TypeVar("_R")
@ -350,7 +503,7 @@ def get_config(
trust_remote_code: bool,
revision: Optional[str] = None,
code_revision: Optional[str] = None,
config_format: ConfigFormat = ConfigFormat.AUTO,
config_format: Union[str, ConfigFormat] = "auto",
hf_overrides_kw: Optional[dict[str, Any]] = None,
hf_overrides_fn: Optional[Callable[[PretrainedConfig],
PretrainedConfig]] = None,
@ -363,20 +516,22 @@ def get_config(
kwargs["gguf_file"] = Path(model).name
model = Path(model).parent
if config_format == ConfigFormat.AUTO:
if config_format == "auto":
try:
if is_gguf or file_or_path_exists(
model, HF_CONFIG_NAME, revision=revision):
config_format = ConfigFormat.HF
config_format = "hf"
elif file_or_path_exists(model,
MISTRAL_CONFIG_NAME,
revision=revision):
config_format = ConfigFormat.MISTRAL
config_format = "mistral"
else:
raise ValueError(
"Could not detect config format for no config file found. "
"Ensure your model has either config.json (HF format) "
"or params.json (Mistral format).")
"With config_format 'auto', ensure your model has either"
"config.json (HF format) or params.json (Mistral format)."
"Otherwise please specify your_custom_config_format"
"in engine args for customized config parser")
except Exception as e:
error_message = (
@ -395,92 +550,14 @@ def get_config(
raise ValueError(error_message) from e
if config_format == ConfigFormat.HF:
kwargs["local_files_only"] = huggingface_hub.constants.HF_HUB_OFFLINE
config_dict, _ = PretrainedConfig.get_config_dict(
model,
revision=revision,
code_revision=code_revision,
token=_get_hf_token(),
**kwargs,
)
# Use custom model class if it's in our registry
model_type = config_dict.get("model_type")
if model_type is None:
model_type = "speculators" if config_dict.get(
"speculators_config") is not None else model_type
if model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[model_type]
config = config_class.from_pretrained(
model,
revision=revision,
code_revision=code_revision,
token=_get_hf_token(),
**kwargs,
)
else:
try:
kwargs = _maybe_update_auto_config_kwargs(
kwargs, model_type=model_type)
config = AutoConfig.from_pretrained(
model,
trust_remote_code=trust_remote_code,
revision=revision,
code_revision=code_revision,
token=_get_hf_token(),
**kwargs,
)
except ValueError as e:
if (not trust_remote_code
and "requires you to execute the configuration file"
in str(e)):
err_msg = (
"Failed to load the model config. If the model "
"is a custom model not yet available in the "
"HuggingFace transformers library, consider setting "
"`trust_remote_code=True` in LLM or using the "
"`--trust-remote-code` flag in the CLI.")
raise RuntimeError(err_msg) from e
else:
raise e
config = _maybe_remap_hf_config_attrs(config)
elif config_format == ConfigFormat.MISTRAL:
# This function loads a params.json config which
# should be used when loading models in mistral format
config_dict = _download_mistral_config_file(model, revision)
if (max_position_embeddings :=
config_dict.get("max_position_embeddings")) is None:
max_position_embeddings = _maybe_retrieve_max_pos_from_hf(
model, revision, **kwargs)
config_dict["max_position_embeddings"] = max_position_embeddings
from vllm.transformers_utils.configs.mistral import adapt_config_dict
config = adapt_config_dict(config_dict)
# Mistral configs may define sliding_window as list[int]. Convert it
# to int and add the layer_types list[str] to make it HF compatible
if ((sliding_window := getattr(config, "sliding_window", None))
and isinstance(sliding_window, list)):
pattern_repeats = config.num_hidden_layers // len(sliding_window)
layer_types = sliding_window * pattern_repeats
config.layer_types = [
"full_attention" if layer_type is None else "sliding_attention"
for layer_type in layer_types
]
config.sliding_window = next(filter(None, sliding_window), None)
else:
supported_formats = [
fmt.value for fmt in ConfigFormat if fmt != ConfigFormat.AUTO
]
raise ValueError(
f"Unsupported config format: {config_format}. "
f"Supported formats are: {', '.join(supported_formats)}. "
f"Ensure your model uses one of these configuration formats "
f"or specify the correct format explicitly.")
config_parser = get_config_parser(config_format)
config_dict, config = config_parser.parse(
model,
trust_remote_code=trust_remote_code,
revision=revision,
code_revision=code_revision,
**kwargs,
)
# Special architecture mapping check for GGUF models
if is_gguf:
if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
@ -914,7 +991,7 @@ def _maybe_retrieve_max_pos_from_hf(model, revision, **kwargs) -> int:
hf_config = get_config(model=model,
trust_remote_code=trust_remote_code_val,
revision=revision,
config_format=ConfigFormat.HF)
config_format="hf")
if hf_value := hf_config.get_text_config().max_position_embeddings:
max_position_embeddings = hf_value
except Exception as e:

View File

@ -0,0 +1,20 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Optional, Union
from transformers import PretrainedConfig
class ConfigParserBase(ABC):
@abstractmethod
def parse(self,
model: Union[str, Path],
trust_remote_code: bool,
revision: Optional[str] = None,
code_revision: Optional[str] = None,
**kwargs) -> tuple[dict, PretrainedConfig]:
raise NotImplementedError