mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 11:06:15 +08:00
[Core] Support configuration parsing plugin (#24277)
Signed-off-by: Xingyu Liu <charlotteliu12x@gmail.com> Signed-off-by: Xingyu Liu <38244988+charlotte12l@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
parent
4032949630
commit
9fb74c27a7
0
tests/transformers_utils/__init__.py
Normal file
0
tests/transformers_utils/__init__.py
Normal file
37
tests/transformers_utils/test_config_parser_registry.py
Normal file
37
tests/transformers_utils/test_config_parser_registry.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
from vllm.transformers_utils.config import (get_config_parser,
|
||||||
|
register_config_parser)
|
||||||
|
from vllm.transformers_utils.config_parser_base import ConfigParserBase
|
||||||
|
|
||||||
|
|
||||||
|
@register_config_parser("custom_config_parser")
|
||||||
|
class CustomConfigParser(ConfigParserBase):
|
||||||
|
|
||||||
|
def parse(self,
|
||||||
|
model: Union[str, Path],
|
||||||
|
trust_remote_code: bool,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
code_revision: Optional[str] = None,
|
||||||
|
**kwargs) -> tuple[dict, PretrainedConfig]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_config_parser():
|
||||||
|
assert isinstance(get_config_parser("custom_config_parser"),
|
||||||
|
CustomConfigParser)
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_config_parser():
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
|
||||||
|
@register_config_parser("invalid_config_parser")
|
||||||
|
class InvalidConfigParser:
|
||||||
|
pass
|
||||||
@ -419,7 +419,7 @@ class ModelConfig:
|
|||||||
`--media-io-kwargs '{"video": {"num_frames": 40} }'` """
|
`--media-io-kwargs '{"video": {"num_frames": 40} }'` """
|
||||||
use_async_output_proc: bool = True
|
use_async_output_proc: bool = True
|
||||||
"""Whether to use async output processor."""
|
"""Whether to use async output processor."""
|
||||||
config_format: Union[str, ConfigFormat] = ConfigFormat.AUTO.value
|
config_format: Union[str, ConfigFormat] = "auto"
|
||||||
"""The format of the model config to load:\n
|
"""The format of the model config to load:\n
|
||||||
- "auto" will try to load the config in hf format if available else it
|
- "auto" will try to load the config in hf format if available else it
|
||||||
will try to load in mistral format.\n
|
will try to load in mistral format.\n
|
||||||
@ -624,9 +624,6 @@ class ModelConfig:
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Sleep mode is not supported on current platform.")
|
"Sleep mode is not supported on current platform.")
|
||||||
|
|
||||||
if isinstance(self.config_format, str):
|
|
||||||
self.config_format = ConfigFormat(self.config_format)
|
|
||||||
|
|
||||||
hf_config = get_config(self.hf_config_path or self.model,
|
hf_config = get_config(self.hf_config_path or self.model,
|
||||||
self.trust_remote_code,
|
self.trust_remote_code,
|
||||||
self.revision,
|
self.revision,
|
||||||
|
|||||||
@ -22,9 +22,9 @@ from typing_extensions import TypeIs, deprecated
|
|||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
|
from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
|
||||||
ConfigFormat, ConfigType, ConvertOption,
|
ConfigType, ConvertOption, DecodingConfig,
|
||||||
DecodingConfig, DetailedTraceModules, Device,
|
DetailedTraceModules, Device, DeviceConfig,
|
||||||
DeviceConfig, DistributedExecutorBackend, EPLBConfig,
|
DistributedExecutorBackend, EPLBConfig,
|
||||||
GuidedDecodingBackend, HfOverrides, KVEventsConfig,
|
GuidedDecodingBackend, HfOverrides, KVEventsConfig,
|
||||||
KVTransferConfig, LoadConfig, LogprobsMode,
|
KVTransferConfig, LoadConfig, LogprobsMode,
|
||||||
LoRAConfig, MambaDType, MMEncoderTPMode, ModelConfig,
|
LoRAConfig, MambaDType, MMEncoderTPMode, ModelConfig,
|
||||||
@ -547,7 +547,6 @@ class EngineArgs:
|
|||||||
help="Disable async output processing. This may result in "
|
help="Disable async output processing. This may result in "
|
||||||
"lower performance.")
|
"lower performance.")
|
||||||
model_group.add_argument("--config-format",
|
model_group.add_argument("--config-format",
|
||||||
choices=[f.value for f in ConfigFormat],
|
|
||||||
**model_kwargs["config_format"])
|
**model_kwargs["config_format"])
|
||||||
# This one is a special case because it can bool
|
# This one is a special case because it can bool
|
||||||
# or str. TODO: Handle this in get_kwargs
|
# or str. TODO: Handle this in get_kwargs
|
||||||
|
|||||||
@ -1,13 +1,12 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import enum
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from functools import cache, partial
|
from functools import cache, partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Optional, TypeVar, Union
|
from typing import Any, Callable, Literal, Optional, TypeVar, Union
|
||||||
|
|
||||||
import huggingface_hub
|
import huggingface_hub
|
||||||
from huggingface_hub import get_safetensors_metadata, hf_hub_download
|
from huggingface_hub import get_safetensors_metadata, hf_hub_download
|
||||||
@ -27,6 +26,7 @@ from transformers.utils import CONFIG_NAME as HF_CONFIG_NAME
|
|||||||
|
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.transformers_utils.config_parser_base import ConfigParserBase
|
||||||
from vllm.transformers_utils.utils import check_gguf_file
|
from vllm.transformers_utils.utils import check_gguf_file
|
||||||
|
|
||||||
if envs.VLLM_USE_MODELSCOPE:
|
if envs.VLLM_USE_MODELSCOPE:
|
||||||
@ -100,10 +100,163 @@ _AUTO_CONFIG_KWARGS_OVERRIDES: dict[str, dict[str, Any]] = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class ConfigFormat(str, enum.Enum):
|
class HFConfigParser(ConfigParserBase):
|
||||||
AUTO = "auto"
|
|
||||||
HF = "hf"
|
def parse(self,
|
||||||
MISTRAL = "mistral"
|
model: Union[str, Path],
|
||||||
|
trust_remote_code: bool,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
code_revision: Optional[str] = None,
|
||||||
|
**kwargs) -> tuple[dict, PretrainedConfig]:
|
||||||
|
kwargs["local_files_only"] = huggingface_hub.constants.HF_HUB_OFFLINE
|
||||||
|
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||||
|
model,
|
||||||
|
revision=revision,
|
||||||
|
code_revision=code_revision,
|
||||||
|
token=_get_hf_token(),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
# Use custom model class if it's in our registry
|
||||||
|
model_type = config_dict.get("model_type")
|
||||||
|
if model_type is None:
|
||||||
|
model_type = "speculators" if config_dict.get(
|
||||||
|
"speculators_config") is not None else model_type
|
||||||
|
|
||||||
|
if model_type in _CONFIG_REGISTRY:
|
||||||
|
config_class = _CONFIG_REGISTRY[model_type]
|
||||||
|
config = config_class.from_pretrained(
|
||||||
|
model,
|
||||||
|
revision=revision,
|
||||||
|
code_revision=code_revision,
|
||||||
|
token=_get_hf_token(),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
kwargs = _maybe_update_auto_config_kwargs(
|
||||||
|
kwargs, model_type=model_type)
|
||||||
|
config = AutoConfig.from_pretrained(
|
||||||
|
model,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
revision=revision,
|
||||||
|
code_revision=code_revision,
|
||||||
|
token=_get_hf_token(),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
if (not trust_remote_code
|
||||||
|
and "requires you to execute the configuration file"
|
||||||
|
in str(e)):
|
||||||
|
err_msg = (
|
||||||
|
"Failed to load the model config. If the model "
|
||||||
|
"is a custom model not yet available in the "
|
||||||
|
"HuggingFace transformers library, consider setting "
|
||||||
|
"`trust_remote_code=True` in LLM or using the "
|
||||||
|
"`--trust-remote-code` flag in the CLI.")
|
||||||
|
raise RuntimeError(err_msg) from e
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
config = _maybe_remap_hf_config_attrs(config)
|
||||||
|
return config_dict, config
|
||||||
|
|
||||||
|
|
||||||
|
class MistralConfigParser(ConfigParserBase):
|
||||||
|
|
||||||
|
def parse(self,
|
||||||
|
model: Union[str, Path],
|
||||||
|
trust_remote_code: bool,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
code_revision: Optional[str] = None,
|
||||||
|
**kwargs) -> tuple[dict, PretrainedConfig]:
|
||||||
|
# This function loads a params.json config which
|
||||||
|
# should be used when loading models in mistral format
|
||||||
|
config_dict = _download_mistral_config_file(model, revision)
|
||||||
|
if (max_position_embeddings :=
|
||||||
|
config_dict.get("max_position_embeddings")) is None:
|
||||||
|
max_position_embeddings = _maybe_retrieve_max_pos_from_hf(
|
||||||
|
model, revision, **kwargs)
|
||||||
|
config_dict["max_position_embeddings"] = max_position_embeddings
|
||||||
|
|
||||||
|
from vllm.transformers_utils.configs.mistral import adapt_config_dict
|
||||||
|
|
||||||
|
config = adapt_config_dict(config_dict)
|
||||||
|
|
||||||
|
# Mistral configs may define sliding_window as list[int]. Convert it
|
||||||
|
# to int and add the layer_types list[str] to make it HF compatible
|
||||||
|
if ((sliding_window := getattr(config, "sliding_window", None))
|
||||||
|
and isinstance(sliding_window, list)):
|
||||||
|
pattern_repeats = config.num_hidden_layers // len(sliding_window)
|
||||||
|
layer_types = sliding_window * pattern_repeats
|
||||||
|
config.layer_types = [
|
||||||
|
"full_attention" if layer_type is None else "sliding_attention"
|
||||||
|
for layer_type in layer_types
|
||||||
|
]
|
||||||
|
config.sliding_window = next(filter(None, sliding_window), None)
|
||||||
|
|
||||||
|
return config_dict, config
|
||||||
|
|
||||||
|
|
||||||
|
_CONFIG_FORMAT_TO_CONFIG_PARSER: dict[str, type[ConfigParserBase]] = {
|
||||||
|
"hf": HFConfigParser,
|
||||||
|
"mistral": MistralConfigParser,
|
||||||
|
}
|
||||||
|
|
||||||
|
ConfigFormat = Literal[
|
||||||
|
"auto",
|
||||||
|
"hf",
|
||||||
|
"mistral",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_config_parser(config_format: str) -> ConfigParserBase:
|
||||||
|
"""Get the config parser for a given config format."""
|
||||||
|
if config_format not in _CONFIG_FORMAT_TO_CONFIG_PARSER:
|
||||||
|
raise ValueError(f"Unknown config format `{config_format}`.")
|
||||||
|
return _CONFIG_FORMAT_TO_CONFIG_PARSER[config_format]()
|
||||||
|
|
||||||
|
|
||||||
|
def register_config_parser(config_format: str):
|
||||||
|
|
||||||
|
"""Register a customized vllm config parser.
|
||||||
|
When a config format is not supported by vllm, you can register a customized
|
||||||
|
config parser to support it.
|
||||||
|
Args:
|
||||||
|
config_format (str): The config parser format name.
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
>>> from vllm.transformers_utils.config import (get_config_parser,
|
||||||
|
register_config_parser)
|
||||||
|
>>> from vllm.transformers_utils.config_parser_base import ConfigParserBase
|
||||||
|
>>>
|
||||||
|
>>> @register_config_parser("custom_config_parser")
|
||||||
|
... class CustomConfigParser(ConfigParserBase):
|
||||||
|
... def parse(self,
|
||||||
|
... model: Union[str, Path],
|
||||||
|
... trust_remote_code: bool,
|
||||||
|
... revision: Optional[str] = None,
|
||||||
|
... code_revision: Optional[str] = None,
|
||||||
|
... **kwargs) -> tuple[dict, PretrainedConfig]:
|
||||||
|
... raise NotImplementedError
|
||||||
|
>>>
|
||||||
|
>>> type(get_config_parser("custom_config_parser"))
|
||||||
|
<class 'CustomConfigParser'>
|
||||||
|
""" # noqa: E501
|
||||||
|
|
||||||
|
def _wrapper(config_parser_cls):
|
||||||
|
if config_format in _CONFIG_FORMAT_TO_CONFIG_PARSER:
|
||||||
|
logger.warning(
|
||||||
|
"Config format `%s` is already registered, and will be "
|
||||||
|
"overwritten by the new parser class `%s`.", config_format,
|
||||||
|
config_parser_cls)
|
||||||
|
if not issubclass(config_parser_cls, ConfigParserBase):
|
||||||
|
raise ValueError("The config parser must be a subclass of "
|
||||||
|
"`ConfigParserBase`.")
|
||||||
|
_CONFIG_FORMAT_TO_CONFIG_PARSER[config_format] = config_parser_cls
|
||||||
|
logger.info("Registered config parser `%s` with config format `%s`",
|
||||||
|
config_parser_cls, config_format)
|
||||||
|
return config_parser_cls
|
||||||
|
|
||||||
|
return _wrapper
|
||||||
|
|
||||||
|
|
||||||
_R = TypeVar("_R")
|
_R = TypeVar("_R")
|
||||||
@ -350,7 +503,7 @@ def get_config(
|
|||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
code_revision: Optional[str] = None,
|
code_revision: Optional[str] = None,
|
||||||
config_format: ConfigFormat = ConfigFormat.AUTO,
|
config_format: Union[str, ConfigFormat] = "auto",
|
||||||
hf_overrides_kw: Optional[dict[str, Any]] = None,
|
hf_overrides_kw: Optional[dict[str, Any]] = None,
|
||||||
hf_overrides_fn: Optional[Callable[[PretrainedConfig],
|
hf_overrides_fn: Optional[Callable[[PretrainedConfig],
|
||||||
PretrainedConfig]] = None,
|
PretrainedConfig]] = None,
|
||||||
@ -363,20 +516,22 @@ def get_config(
|
|||||||
kwargs["gguf_file"] = Path(model).name
|
kwargs["gguf_file"] = Path(model).name
|
||||||
model = Path(model).parent
|
model = Path(model).parent
|
||||||
|
|
||||||
if config_format == ConfigFormat.AUTO:
|
if config_format == "auto":
|
||||||
try:
|
try:
|
||||||
if is_gguf or file_or_path_exists(
|
if is_gguf or file_or_path_exists(
|
||||||
model, HF_CONFIG_NAME, revision=revision):
|
model, HF_CONFIG_NAME, revision=revision):
|
||||||
config_format = ConfigFormat.HF
|
config_format = "hf"
|
||||||
elif file_or_path_exists(model,
|
elif file_or_path_exists(model,
|
||||||
MISTRAL_CONFIG_NAME,
|
MISTRAL_CONFIG_NAME,
|
||||||
revision=revision):
|
revision=revision):
|
||||||
config_format = ConfigFormat.MISTRAL
|
config_format = "mistral"
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Could not detect config format for no config file found. "
|
"Could not detect config format for no config file found. "
|
||||||
"Ensure your model has either config.json (HF format) "
|
"With config_format 'auto', ensure your model has either"
|
||||||
"or params.json (Mistral format).")
|
"config.json (HF format) or params.json (Mistral format)."
|
||||||
|
"Otherwise please specify your_custom_config_format"
|
||||||
|
"in engine args for customized config parser")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_message = (
|
error_message = (
|
||||||
@ -395,92 +550,14 @@ def get_config(
|
|||||||
|
|
||||||
raise ValueError(error_message) from e
|
raise ValueError(error_message) from e
|
||||||
|
|
||||||
if config_format == ConfigFormat.HF:
|
config_parser = get_config_parser(config_format)
|
||||||
kwargs["local_files_only"] = huggingface_hub.constants.HF_HUB_OFFLINE
|
config_dict, config = config_parser.parse(
|
||||||
config_dict, _ = PretrainedConfig.get_config_dict(
|
|
||||||
model,
|
|
||||||
revision=revision,
|
|
||||||
code_revision=code_revision,
|
|
||||||
token=_get_hf_token(),
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
# Use custom model class if it's in our registry
|
|
||||||
model_type = config_dict.get("model_type")
|
|
||||||
if model_type is None:
|
|
||||||
model_type = "speculators" if config_dict.get(
|
|
||||||
"speculators_config") is not None else model_type
|
|
||||||
|
|
||||||
if model_type in _CONFIG_REGISTRY:
|
|
||||||
config_class = _CONFIG_REGISTRY[model_type]
|
|
||||||
config = config_class.from_pretrained(
|
|
||||||
model,
|
|
||||||
revision=revision,
|
|
||||||
code_revision=code_revision,
|
|
||||||
token=_get_hf_token(),
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
kwargs = _maybe_update_auto_config_kwargs(
|
|
||||||
kwargs, model_type=model_type)
|
|
||||||
config = AutoConfig.from_pretrained(
|
|
||||||
model,
|
model,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
code_revision=code_revision,
|
code_revision=code_revision,
|
||||||
token=_get_hf_token(),
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
except ValueError as e:
|
|
||||||
if (not trust_remote_code
|
|
||||||
and "requires you to execute the configuration file"
|
|
||||||
in str(e)):
|
|
||||||
err_msg = (
|
|
||||||
"Failed to load the model config. If the model "
|
|
||||||
"is a custom model not yet available in the "
|
|
||||||
"HuggingFace transformers library, consider setting "
|
|
||||||
"`trust_remote_code=True` in LLM or using the "
|
|
||||||
"`--trust-remote-code` flag in the CLI.")
|
|
||||||
raise RuntimeError(err_msg) from e
|
|
||||||
else:
|
|
||||||
raise e
|
|
||||||
config = _maybe_remap_hf_config_attrs(config)
|
|
||||||
|
|
||||||
elif config_format == ConfigFormat.MISTRAL:
|
|
||||||
# This function loads a params.json config which
|
|
||||||
# should be used when loading models in mistral format
|
|
||||||
config_dict = _download_mistral_config_file(model, revision)
|
|
||||||
if (max_position_embeddings :=
|
|
||||||
config_dict.get("max_position_embeddings")) is None:
|
|
||||||
max_position_embeddings = _maybe_retrieve_max_pos_from_hf(
|
|
||||||
model, revision, **kwargs)
|
|
||||||
config_dict["max_position_embeddings"] = max_position_embeddings
|
|
||||||
|
|
||||||
from vllm.transformers_utils.configs.mistral import adapt_config_dict
|
|
||||||
|
|
||||||
config = adapt_config_dict(config_dict)
|
|
||||||
|
|
||||||
# Mistral configs may define sliding_window as list[int]. Convert it
|
|
||||||
# to int and add the layer_types list[str] to make it HF compatible
|
|
||||||
if ((sliding_window := getattr(config, "sliding_window", None))
|
|
||||||
and isinstance(sliding_window, list)):
|
|
||||||
pattern_repeats = config.num_hidden_layers // len(sliding_window)
|
|
||||||
layer_types = sliding_window * pattern_repeats
|
|
||||||
config.layer_types = [
|
|
||||||
"full_attention" if layer_type is None else "sliding_attention"
|
|
||||||
for layer_type in layer_types
|
|
||||||
]
|
|
||||||
config.sliding_window = next(filter(None, sliding_window), None)
|
|
||||||
else:
|
|
||||||
supported_formats = [
|
|
||||||
fmt.value for fmt in ConfigFormat if fmt != ConfigFormat.AUTO
|
|
||||||
]
|
|
||||||
raise ValueError(
|
|
||||||
f"Unsupported config format: {config_format}. "
|
|
||||||
f"Supported formats are: {', '.join(supported_formats)}. "
|
|
||||||
f"Ensure your model uses one of these configuration formats "
|
|
||||||
f"or specify the correct format explicitly.")
|
|
||||||
|
|
||||||
# Special architecture mapping check for GGUF models
|
# Special architecture mapping check for GGUF models
|
||||||
if is_gguf:
|
if is_gguf:
|
||||||
if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||||
@ -914,7 +991,7 @@ def _maybe_retrieve_max_pos_from_hf(model, revision, **kwargs) -> int:
|
|||||||
hf_config = get_config(model=model,
|
hf_config = get_config(model=model,
|
||||||
trust_remote_code=trust_remote_code_val,
|
trust_remote_code=trust_remote_code_val,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
config_format=ConfigFormat.HF)
|
config_format="hf")
|
||||||
if hf_value := hf_config.get_text_config().max_position_embeddings:
|
if hf_value := hf_config.get_text_config().max_position_embeddings:
|
||||||
max_position_embeddings = hf_value
|
max_position_embeddings = hf_value
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
20
vllm/transformers_utils/config_parser_base.py
Normal file
20
vllm/transformers_utils/config_parser_base.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigParserBase(ABC):
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def parse(self,
|
||||||
|
model: Union[str, Path],
|
||||||
|
trust_remote_code: bool,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
code_revision: Optional[str] = None,
|
||||||
|
**kwargs) -> tuple[dict, PretrainedConfig]:
|
||||||
|
raise NotImplementedError
|
||||||
Loading…
x
Reference in New Issue
Block a user