mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 12:25:01 +08:00
[Core] Support configuration parsing plugin (#24277)
Signed-off-by: Xingyu Liu <charlotteliu12x@gmail.com> Signed-off-by: Xingyu Liu <38244988+charlotte12l@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
parent
4032949630
commit
9fb74c27a7
0
tests/transformers_utils/__init__.py
Normal file
0
tests/transformers_utils/__init__.py
Normal file
37
tests/transformers_utils/test_config_parser_registry.py
Normal file
37
tests/transformers_utils/test_config_parser_registry.py
Normal file
@ -0,0 +1,37 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import pytest
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.transformers_utils.config import (get_config_parser,
|
||||
register_config_parser)
|
||||
from vllm.transformers_utils.config_parser_base import ConfigParserBase
|
||||
|
||||
|
||||
@register_config_parser("custom_config_parser")
|
||||
class CustomConfigParser(ConfigParserBase):
|
||||
|
||||
def parse(self,
|
||||
model: Union[str, Path],
|
||||
trust_remote_code: bool,
|
||||
revision: Optional[str] = None,
|
||||
code_revision: Optional[str] = None,
|
||||
**kwargs) -> tuple[dict, PretrainedConfig]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def test_register_config_parser():
|
||||
assert isinstance(get_config_parser("custom_config_parser"),
|
||||
CustomConfigParser)
|
||||
|
||||
|
||||
def test_invalid_config_parser():
|
||||
with pytest.raises(ValueError):
|
||||
|
||||
@register_config_parser("invalid_config_parser")
|
||||
class InvalidConfigParser:
|
||||
pass
|
||||
@ -419,7 +419,7 @@ class ModelConfig:
|
||||
`--media-io-kwargs '{"video": {"num_frames": 40} }'` """
|
||||
use_async_output_proc: bool = True
|
||||
"""Whether to use async output processor."""
|
||||
config_format: Union[str, ConfigFormat] = ConfigFormat.AUTO.value
|
||||
config_format: Union[str, ConfigFormat] = "auto"
|
||||
"""The format of the model config to load:\n
|
||||
- "auto" will try to load the config in hf format if available else it
|
||||
will try to load in mistral format.\n
|
||||
@ -624,9 +624,6 @@ class ModelConfig:
|
||||
raise ValueError(
|
||||
"Sleep mode is not supported on current platform.")
|
||||
|
||||
if isinstance(self.config_format, str):
|
||||
self.config_format = ConfigFormat(self.config_format)
|
||||
|
||||
hf_config = get_config(self.hf_config_path or self.model,
|
||||
self.trust_remote_code,
|
||||
self.revision,
|
||||
|
||||
@ -22,9 +22,9 @@ from typing_extensions import TypeIs, deprecated
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
|
||||
ConfigFormat, ConfigType, ConvertOption,
|
||||
DecodingConfig, DetailedTraceModules, Device,
|
||||
DeviceConfig, DistributedExecutorBackend, EPLBConfig,
|
||||
ConfigType, ConvertOption, DecodingConfig,
|
||||
DetailedTraceModules, Device, DeviceConfig,
|
||||
DistributedExecutorBackend, EPLBConfig,
|
||||
GuidedDecodingBackend, HfOverrides, KVEventsConfig,
|
||||
KVTransferConfig, LoadConfig, LogprobsMode,
|
||||
LoRAConfig, MambaDType, MMEncoderTPMode, ModelConfig,
|
||||
@ -547,7 +547,6 @@ class EngineArgs:
|
||||
help="Disable async output processing. This may result in "
|
||||
"lower performance.")
|
||||
model_group.add_argument("--config-format",
|
||||
choices=[f.value for f in ConfigFormat],
|
||||
**model_kwargs["config_format"])
|
||||
# This one is a special case because it can bool
|
||||
# or str. TODO: Handle this in get_kwargs
|
||||
|
||||
@ -1,13 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import enum
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from functools import cache, partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Optional, TypeVar, Union
|
||||
from typing import Any, Callable, Literal, Optional, TypeVar, Union
|
||||
|
||||
import huggingface_hub
|
||||
from huggingface_hub import get_safetensors_metadata, hf_hub_download
|
||||
@ -27,6 +26,7 @@ from transformers.utils import CONFIG_NAME as HF_CONFIG_NAME
|
||||
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.config_parser_base import ConfigParserBase
|
||||
from vllm.transformers_utils.utils import check_gguf_file
|
||||
|
||||
if envs.VLLM_USE_MODELSCOPE:
|
||||
@ -100,10 +100,163 @@ _AUTO_CONFIG_KWARGS_OVERRIDES: dict[str, dict[str, Any]] = {
|
||||
}
|
||||
|
||||
|
||||
class ConfigFormat(str, enum.Enum):
|
||||
AUTO = "auto"
|
||||
HF = "hf"
|
||||
MISTRAL = "mistral"
|
||||
class HFConfigParser(ConfigParserBase):
|
||||
|
||||
def parse(self,
|
||||
model: Union[str, Path],
|
||||
trust_remote_code: bool,
|
||||
revision: Optional[str] = None,
|
||||
code_revision: Optional[str] = None,
|
||||
**kwargs) -> tuple[dict, PretrainedConfig]:
|
||||
kwargs["local_files_only"] = huggingface_hub.constants.HF_HUB_OFFLINE
|
||||
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||
model,
|
||||
revision=revision,
|
||||
code_revision=code_revision,
|
||||
token=_get_hf_token(),
|
||||
**kwargs,
|
||||
)
|
||||
# Use custom model class if it's in our registry
|
||||
model_type = config_dict.get("model_type")
|
||||
if model_type is None:
|
||||
model_type = "speculators" if config_dict.get(
|
||||
"speculators_config") is not None else model_type
|
||||
|
||||
if model_type in _CONFIG_REGISTRY:
|
||||
config_class = _CONFIG_REGISTRY[model_type]
|
||||
config = config_class.from_pretrained(
|
||||
model,
|
||||
revision=revision,
|
||||
code_revision=code_revision,
|
||||
token=_get_hf_token(),
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
kwargs = _maybe_update_auto_config_kwargs(
|
||||
kwargs, model_type=model_type)
|
||||
config = AutoConfig.from_pretrained(
|
||||
model,
|
||||
trust_remote_code=trust_remote_code,
|
||||
revision=revision,
|
||||
code_revision=code_revision,
|
||||
token=_get_hf_token(),
|
||||
**kwargs,
|
||||
)
|
||||
except ValueError as e:
|
||||
if (not trust_remote_code
|
||||
and "requires you to execute the configuration file"
|
||||
in str(e)):
|
||||
err_msg = (
|
||||
"Failed to load the model config. If the model "
|
||||
"is a custom model not yet available in the "
|
||||
"HuggingFace transformers library, consider setting "
|
||||
"`trust_remote_code=True` in LLM or using the "
|
||||
"`--trust-remote-code` flag in the CLI.")
|
||||
raise RuntimeError(err_msg) from e
|
||||
else:
|
||||
raise e
|
||||
config = _maybe_remap_hf_config_attrs(config)
|
||||
return config_dict, config
|
||||
|
||||
|
||||
class MistralConfigParser(ConfigParserBase):
|
||||
|
||||
def parse(self,
|
||||
model: Union[str, Path],
|
||||
trust_remote_code: bool,
|
||||
revision: Optional[str] = None,
|
||||
code_revision: Optional[str] = None,
|
||||
**kwargs) -> tuple[dict, PretrainedConfig]:
|
||||
# This function loads a params.json config which
|
||||
# should be used when loading models in mistral format
|
||||
config_dict = _download_mistral_config_file(model, revision)
|
||||
if (max_position_embeddings :=
|
||||
config_dict.get("max_position_embeddings")) is None:
|
||||
max_position_embeddings = _maybe_retrieve_max_pos_from_hf(
|
||||
model, revision, **kwargs)
|
||||
config_dict["max_position_embeddings"] = max_position_embeddings
|
||||
|
||||
from vllm.transformers_utils.configs.mistral import adapt_config_dict
|
||||
|
||||
config = adapt_config_dict(config_dict)
|
||||
|
||||
# Mistral configs may define sliding_window as list[int]. Convert it
|
||||
# to int and add the layer_types list[str] to make it HF compatible
|
||||
if ((sliding_window := getattr(config, "sliding_window", None))
|
||||
and isinstance(sliding_window, list)):
|
||||
pattern_repeats = config.num_hidden_layers // len(sliding_window)
|
||||
layer_types = sliding_window * pattern_repeats
|
||||
config.layer_types = [
|
||||
"full_attention" if layer_type is None else "sliding_attention"
|
||||
for layer_type in layer_types
|
||||
]
|
||||
config.sliding_window = next(filter(None, sliding_window), None)
|
||||
|
||||
return config_dict, config
|
||||
|
||||
|
||||
_CONFIG_FORMAT_TO_CONFIG_PARSER: dict[str, type[ConfigParserBase]] = {
|
||||
"hf": HFConfigParser,
|
||||
"mistral": MistralConfigParser,
|
||||
}
|
||||
|
||||
ConfigFormat = Literal[
|
||||
"auto",
|
||||
"hf",
|
||||
"mistral",
|
||||
]
|
||||
|
||||
|
||||
def get_config_parser(config_format: str) -> ConfigParserBase:
|
||||
"""Get the config parser for a given config format."""
|
||||
if config_format not in _CONFIG_FORMAT_TO_CONFIG_PARSER:
|
||||
raise ValueError(f"Unknown config format `{config_format}`.")
|
||||
return _CONFIG_FORMAT_TO_CONFIG_PARSER[config_format]()
|
||||
|
||||
|
||||
def register_config_parser(config_format: str):
|
||||
|
||||
"""Register a customized vllm config parser.
|
||||
When a config format is not supported by vllm, you can register a customized
|
||||
config parser to support it.
|
||||
Args:
|
||||
config_format (str): The config parser format name.
|
||||
Examples:
|
||||
|
||||
>>> from vllm.transformers_utils.config import (get_config_parser,
|
||||
register_config_parser)
|
||||
>>> from vllm.transformers_utils.config_parser_base import ConfigParserBase
|
||||
>>>
|
||||
>>> @register_config_parser("custom_config_parser")
|
||||
... class CustomConfigParser(ConfigParserBase):
|
||||
... def parse(self,
|
||||
... model: Union[str, Path],
|
||||
... trust_remote_code: bool,
|
||||
... revision: Optional[str] = None,
|
||||
... code_revision: Optional[str] = None,
|
||||
... **kwargs) -> tuple[dict, PretrainedConfig]:
|
||||
... raise NotImplementedError
|
||||
>>>
|
||||
>>> type(get_config_parser("custom_config_parser"))
|
||||
<class 'CustomConfigParser'>
|
||||
""" # noqa: E501
|
||||
|
||||
def _wrapper(config_parser_cls):
|
||||
if config_format in _CONFIG_FORMAT_TO_CONFIG_PARSER:
|
||||
logger.warning(
|
||||
"Config format `%s` is already registered, and will be "
|
||||
"overwritten by the new parser class `%s`.", config_format,
|
||||
config_parser_cls)
|
||||
if not issubclass(config_parser_cls, ConfigParserBase):
|
||||
raise ValueError("The config parser must be a subclass of "
|
||||
"`ConfigParserBase`.")
|
||||
_CONFIG_FORMAT_TO_CONFIG_PARSER[config_format] = config_parser_cls
|
||||
logger.info("Registered config parser `%s` with config format `%s`",
|
||||
config_parser_cls, config_format)
|
||||
return config_parser_cls
|
||||
|
||||
return _wrapper
|
||||
|
||||
|
||||
_R = TypeVar("_R")
|
||||
@ -350,7 +503,7 @@ def get_config(
|
||||
trust_remote_code: bool,
|
||||
revision: Optional[str] = None,
|
||||
code_revision: Optional[str] = None,
|
||||
config_format: ConfigFormat = ConfigFormat.AUTO,
|
||||
config_format: Union[str, ConfigFormat] = "auto",
|
||||
hf_overrides_kw: Optional[dict[str, Any]] = None,
|
||||
hf_overrides_fn: Optional[Callable[[PretrainedConfig],
|
||||
PretrainedConfig]] = None,
|
||||
@ -363,20 +516,22 @@ def get_config(
|
||||
kwargs["gguf_file"] = Path(model).name
|
||||
model = Path(model).parent
|
||||
|
||||
if config_format == ConfigFormat.AUTO:
|
||||
if config_format == "auto":
|
||||
try:
|
||||
if is_gguf or file_or_path_exists(
|
||||
model, HF_CONFIG_NAME, revision=revision):
|
||||
config_format = ConfigFormat.HF
|
||||
config_format = "hf"
|
||||
elif file_or_path_exists(model,
|
||||
MISTRAL_CONFIG_NAME,
|
||||
revision=revision):
|
||||
config_format = ConfigFormat.MISTRAL
|
||||
config_format = "mistral"
|
||||
else:
|
||||
raise ValueError(
|
||||
"Could not detect config format for no config file found. "
|
||||
"Ensure your model has either config.json (HF format) "
|
||||
"or params.json (Mistral format).")
|
||||
"With config_format 'auto', ensure your model has either"
|
||||
"config.json (HF format) or params.json (Mistral format)."
|
||||
"Otherwise please specify your_custom_config_format"
|
||||
"in engine args for customized config parser")
|
||||
|
||||
except Exception as e:
|
||||
error_message = (
|
||||
@ -395,92 +550,14 @@ def get_config(
|
||||
|
||||
raise ValueError(error_message) from e
|
||||
|
||||
if config_format == ConfigFormat.HF:
|
||||
kwargs["local_files_only"] = huggingface_hub.constants.HF_HUB_OFFLINE
|
||||
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||
model,
|
||||
revision=revision,
|
||||
code_revision=code_revision,
|
||||
token=_get_hf_token(),
|
||||
**kwargs,
|
||||
)
|
||||
# Use custom model class if it's in our registry
|
||||
model_type = config_dict.get("model_type")
|
||||
if model_type is None:
|
||||
model_type = "speculators" if config_dict.get(
|
||||
"speculators_config") is not None else model_type
|
||||
|
||||
if model_type in _CONFIG_REGISTRY:
|
||||
config_class = _CONFIG_REGISTRY[model_type]
|
||||
config = config_class.from_pretrained(
|
||||
model,
|
||||
revision=revision,
|
||||
code_revision=code_revision,
|
||||
token=_get_hf_token(),
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
kwargs = _maybe_update_auto_config_kwargs(
|
||||
kwargs, model_type=model_type)
|
||||
config = AutoConfig.from_pretrained(
|
||||
model,
|
||||
trust_remote_code=trust_remote_code,
|
||||
revision=revision,
|
||||
code_revision=code_revision,
|
||||
token=_get_hf_token(),
|
||||
**kwargs,
|
||||
)
|
||||
except ValueError as e:
|
||||
if (not trust_remote_code
|
||||
and "requires you to execute the configuration file"
|
||||
in str(e)):
|
||||
err_msg = (
|
||||
"Failed to load the model config. If the model "
|
||||
"is a custom model not yet available in the "
|
||||
"HuggingFace transformers library, consider setting "
|
||||
"`trust_remote_code=True` in LLM or using the "
|
||||
"`--trust-remote-code` flag in the CLI.")
|
||||
raise RuntimeError(err_msg) from e
|
||||
else:
|
||||
raise e
|
||||
config = _maybe_remap_hf_config_attrs(config)
|
||||
|
||||
elif config_format == ConfigFormat.MISTRAL:
|
||||
# This function loads a params.json config which
|
||||
# should be used when loading models in mistral format
|
||||
config_dict = _download_mistral_config_file(model, revision)
|
||||
if (max_position_embeddings :=
|
||||
config_dict.get("max_position_embeddings")) is None:
|
||||
max_position_embeddings = _maybe_retrieve_max_pos_from_hf(
|
||||
model, revision, **kwargs)
|
||||
config_dict["max_position_embeddings"] = max_position_embeddings
|
||||
|
||||
from vllm.transformers_utils.configs.mistral import adapt_config_dict
|
||||
|
||||
config = adapt_config_dict(config_dict)
|
||||
|
||||
# Mistral configs may define sliding_window as list[int]. Convert it
|
||||
# to int and add the layer_types list[str] to make it HF compatible
|
||||
if ((sliding_window := getattr(config, "sliding_window", None))
|
||||
and isinstance(sliding_window, list)):
|
||||
pattern_repeats = config.num_hidden_layers // len(sliding_window)
|
||||
layer_types = sliding_window * pattern_repeats
|
||||
config.layer_types = [
|
||||
"full_attention" if layer_type is None else "sliding_attention"
|
||||
for layer_type in layer_types
|
||||
]
|
||||
config.sliding_window = next(filter(None, sliding_window), None)
|
||||
else:
|
||||
supported_formats = [
|
||||
fmt.value for fmt in ConfigFormat if fmt != ConfigFormat.AUTO
|
||||
]
|
||||
raise ValueError(
|
||||
f"Unsupported config format: {config_format}. "
|
||||
f"Supported formats are: {', '.join(supported_formats)}. "
|
||||
f"Ensure your model uses one of these configuration formats "
|
||||
f"or specify the correct format explicitly.")
|
||||
|
||||
config_parser = get_config_parser(config_format)
|
||||
config_dict, config = config_parser.parse(
|
||||
model,
|
||||
trust_remote_code=trust_remote_code,
|
||||
revision=revision,
|
||||
code_revision=code_revision,
|
||||
**kwargs,
|
||||
)
|
||||
# Special architecture mapping check for GGUF models
|
||||
if is_gguf:
|
||||
if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||
@ -914,7 +991,7 @@ def _maybe_retrieve_max_pos_from_hf(model, revision, **kwargs) -> int:
|
||||
hf_config = get_config(model=model,
|
||||
trust_remote_code=trust_remote_code_val,
|
||||
revision=revision,
|
||||
config_format=ConfigFormat.HF)
|
||||
config_format="hf")
|
||||
if hf_value := hf_config.get_text_config().max_position_embeddings:
|
||||
max_position_embeddings = hf_value
|
||||
except Exception as e:
|
||||
|
||||
20
vllm/transformers_utils/config_parser_base.py
Normal file
20
vllm/transformers_utils/config_parser_base.py
Normal file
@ -0,0 +1,20 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
|
||||
class ConfigParserBase(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def parse(self,
|
||||
model: Union[str, Path],
|
||||
trust_remote_code: bool,
|
||||
revision: Optional[str] = None,
|
||||
code_revision: Optional[str] = None,
|
||||
**kwargs) -> tuple[dict, PretrainedConfig]:
|
||||
raise NotImplementedError
|
||||
Loading…
x
Reference in New Issue
Block a user