Improve conversion from dataclass configs to argparse arguments (#17303)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-04-28 17:22:12 +01:00 committed by GitHub
parent 72dfe4c74f
commit f94886946e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 247 additions and 156 deletions

View File

@ -1,14 +1,118 @@
# SPDX-License-Identifier: Apache-2.0
import json
from argparse import ArgumentError, ArgumentTypeError
from contextlib import nullcontext
from dataclasses import dataclass, field
from typing import Literal, Optional
import pytest
from vllm.config import PoolerConfig
from vllm.engine.arg_utils import EngineArgs, nullable_kvs
from vllm.config import PoolerConfig, config
from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs,
get_type, is_not_builtin, is_type,
nullable_kvs, optional_type)
from vllm.utils import FlexibleArgumentParser
@pytest.mark.parametrize(("type", "value", "expected"), [
(int, "42", 42),
(int, "None", None),
(float, "3.14", 3.14),
(float, "None", None),
(str, "Hello World!", "Hello World!"),
(str, "None", None),
(json.loads, '{"foo":1,"bar":2}', {
"foo": 1,
"bar": 2
}),
(json.loads, "foo=1,bar=2", {
"foo": 1,
"bar": 2
}),
(json.loads, "None", None),
])
def test_optional_type(type, value, expected):
optional_type_func = optional_type(type)
context = nullcontext()
if value == "foo=1,bar=2":
context = pytest.warns(DeprecationWarning)
with context:
assert optional_type_func(value) == expected
@pytest.mark.parametrize(("type_hint", "type", "expected"), [
(int, int, True),
(int, float, False),
(list[int], list, True),
(list[int], tuple, False),
(Literal[0, 1], Literal, True),
])
def test_is_type(type_hint, type, expected):
assert is_type(type_hint, type) == expected
@pytest.mark.parametrize(("type_hints", "type", "expected"), [
({float, int}, int, True),
({int, tuple[int]}, int, True),
({int, tuple[int]}, float, False),
({str, Literal["x", "y"]}, Literal, True),
])
def test_contains_type(type_hints, type, expected):
assert contains_type(type_hints, type) == expected
@pytest.mark.parametrize(("type_hints", "type", "expected"), [
({int, float}, int, int),
({int, float}, str, None),
({str, Literal["x", "y"]}, Literal, Literal["x", "y"]),
])
def test_get_type(type_hints, type, expected):
assert get_type(type_hints, type) == expected
@config
@dataclass
class DummyConfigClass:
regular_bool: bool = True
"""Regular bool with default True"""
optional_bool: Optional[bool] = None
"""Optional bool with default None"""
optional_literal: Optional[Literal["x", "y"]] = None
"""Optional literal with default None"""
tuple_n: tuple[int, ...] = field(default_factory=lambda: (1, 2, 3))
"""Tuple with default (1, 2, 3)"""
tuple_2: tuple[int, int] = field(default_factory=lambda: (1, 2))
"""Tuple with default (1, 2)"""
list_n: list[int] = field(default_factory=lambda: [1, 2, 3])
"""List with default [1, 2, 3]"""
@pytest.mark.parametrize(("type_hint", "expected"), [
(int, False),
(DummyConfigClass, True),
])
def test_is_not_builtin(type_hint, expected):
assert is_not_builtin(type_hint) == expected
def test_get_kwargs():
kwargs = get_kwargs(DummyConfigClass)
print(kwargs)
# bools should not have their type set
assert kwargs["regular_bool"].get("type") is None
assert kwargs["optional_bool"].get("type") is None
# optional literals should have None as a choice
assert kwargs["optional_literal"]["choices"] == ["x", "y", "None"]
# tuples should have the correct nargs
assert kwargs["tuple_n"]["nargs"] == "+"
assert kwargs["tuple_2"]["nargs"] == 2
# lists should work
assert kwargs["list_n"]["type"] is int
assert kwargs["list_n"]["nargs"] == "+"
@pytest.mark.parametrize(("arg", "expected"), [
(None, dict()),
("image=16", {

View File

@ -11,7 +11,7 @@ from typing import (Any, Callable, Dict, List, Literal, Optional, Type,
TypeVar, Union, cast, get_args, get_origin)
import torch
from typing_extensions import TypeIs
from typing_extensions import TypeIs, deprecated
import vllm.envs as envs
from vllm import version
@ -48,33 +48,29 @@ TypeHint = Union[type[Any], object]
TypeHintT = Union[type[T], object]
def optional_arg(val: str, return_type: Callable[[str], T]) -> Optional[T]:
if val == "" or val == "None":
return None
try:
return return_type(val)
except ValueError as e:
raise argparse.ArgumentTypeError(
f"Value {val} cannot be converted to {return_type}.") from e
def optional_type(
return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]:
def _optional_type(val: str) -> Optional[T]:
if val == "" or val == "None":
return None
try:
if return_type is json.loads and not re.match("^{.*}$", val):
return cast(T, nullable_kvs(val))
return return_type(val)
except ValueError as e:
raise argparse.ArgumentTypeError(
f"Value {val} cannot be converted to {return_type}.") from e
return _optional_type
def optional_str(val: str) -> Optional[str]:
return optional_arg(val, str)
def optional_int(val: str) -> Optional[int]:
return optional_arg(val, int)
def optional_float(val: str) -> Optional[float]:
return optional_arg(val, float)
def nullable_kvs(val: str) -> Optional[dict[str, int]]:
"""NOTE: This function is deprecated, args should be passed as JSON
strings instead.
Parses a string containing comma separate key [str] to value [int]
@deprecated(
"Passing a JSON argument as a string containing comma separated key=value "
"pairs is deprecated. This will be removed in v0.10.0. Please use a JSON "
"string instead.")
def nullable_kvs(val: str) -> dict[str, int]:
"""Parses a string containing comma separate key [str] to value [int]
pairs into a dictionary.
Args:
@ -83,10 +79,7 @@ def nullable_kvs(val: str) -> Optional[dict[str, int]]:
Returns:
Dictionary with parsed values.
"""
if len(val) == 0:
return None
out_dict: Dict[str, int] = {}
out_dict: dict[str, int] = {}
for item in val.split(","):
kv_parts = [part.lower().strip() for part in item.split("=")]
if len(kv_parts) != 2:
@ -108,15 +101,103 @@ def nullable_kvs(val: str) -> Optional[dict[str, int]]:
return out_dict
def optional_dict(val: str) -> Optional[dict[str, int]]:
if re.match("^{.*}$", val):
return optional_arg(val, json.loads)
def is_type(type_hint: TypeHint, type: TypeHintT) -> TypeIs[TypeHintT]:
"""Check if the type hint is a specific type."""
return type_hint is type or get_origin(type_hint) is type
logger.warning(
"Failed to parse JSON string. Attempting to parse as "
"comma-separated key=value pairs. This will be deprecated in a "
"future release.")
return nullable_kvs(val)
def contains_type(type_hints: set[TypeHint], type: TypeHintT) -> bool:
"""Check if the type hints contain a specific type."""
return any(is_type(type_hint, type) for type_hint in type_hints)
def get_type(type_hints: set[TypeHint], type: TypeHintT) -> TypeHintT:
"""Get the specific type from the type hints."""
return next((th for th in type_hints if is_type(th, type)), None)
def is_not_builtin(type_hint: TypeHint) -> bool:
"""Check if the class is not a built-in type."""
return type_hint.__module__ != "builtins"
def get_kwargs(cls: ConfigType) -> dict[str, Any]:
cls_docs = get_attr_docs(cls)
kwargs = {}
for field in fields(cls):
# Get the default value of the field
default = field.default
if field.default_factory is not MISSING:
default = field.default_factory()
# Get the help text for the field
name = field.name
help = cls_docs[name]
# Escape % for argparse
help = help.replace("%", "%%")
# Initialise the kwargs dictionary for the field
kwargs[name] = {"default": default, "help": help}
# Get the set of possible types for the field
type_hints: set[TypeHint] = set()
if get_origin(field.type) is Union:
type_hints.update(get_args(field.type))
else:
type_hints.add(field.type)
# Set other kwargs based on the type hints
if contains_type(type_hints, bool):
# Creates --no-<name> and --<name> flags
kwargs[name]["action"] = argparse.BooleanOptionalAction
elif contains_type(type_hints, Literal):
# Creates choices from Literal arguments
type_hint = get_type(type_hints, Literal)
choices = sorted(get_args(type_hint))
kwargs[name]["choices"] = choices
choice_type = type(choices[0])
assert all(type(c) is choice_type for c in choices), (
"All choices must be of the same type. "
f"Got {choices} with types {[type(c) for c in choices]}")
kwargs[name]["type"] = choice_type
elif contains_type(type_hints, tuple):
type_hint = get_type(type_hints, tuple)
types = get_args(type_hint)
tuple_type = types[0]
assert all(t is tuple_type for t in types if t is not Ellipsis), (
"All non-Ellipsis tuple elements must be of the same "
f"type. Got {types}.")
kwargs[name]["type"] = tuple_type
kwargs[name]["nargs"] = "+" if Ellipsis in types else len(types)
elif contains_type(type_hints, list):
type_hint = get_type(type_hints, list)
types = get_args(type_hint)
assert len(types) == 1, (
"List type must have exactly one type. Got "
f"{type_hint} with types {types}")
kwargs[name]["type"] = types[0]
kwargs[name]["nargs"] = "+"
elif contains_type(type_hints, int):
kwargs[name]["type"] = int
elif contains_type(type_hints, float):
kwargs[name]["type"] = float
elif contains_type(type_hints, dict):
# Dict arguments will always be optional
kwargs[name]["type"] = optional_type(json.loads)
elif (contains_type(type_hints, str)
or any(is_not_builtin(th) for th in type_hints)):
kwargs[name]["type"] = str
else:
raise ValueError(
f"Unsupported type {type_hints} for argument {name}.")
# If None is in type_hints, make the argument optional.
# But not if it's a bool, argparse will handle this better.
if type(None) in type_hints and not contains_type(type_hints, bool):
kwargs[name]["type"] = optional_type(kwargs[name]["type"])
if kwargs[name].get("choices"):
kwargs[name]["choices"].append("None")
return kwargs
@dataclass
@ -279,100 +360,6 @@ class EngineArgs:
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"""Shared CLI arguments for vLLM engine."""
def is_type_in_union(cls: TypeHint, type: TypeHint) -> bool:
"""Check if the class is a type in a union type."""
is_union = get_origin(cls) is Union
type_in_union = type in [get_origin(a) or a for a in get_args(cls)]
return is_union and type_in_union
def get_type_from_union(cls: TypeHint, type: TypeHintT) -> TypeHintT:
"""Get the type in a union type."""
for arg in get_args(cls):
if (get_origin(arg) or arg) is type:
return arg
raise ValueError(f"Type {type} not found in union type {cls}.")
def is_optional(cls: TypeHint) -> TypeIs[Union[Any, None]]:
"""Check if the class is an optional type."""
return is_type_in_union(cls, type(None))
def can_be_type(cls: TypeHint, type: TypeHintT) -> TypeIs[TypeHintT]:
"""Check if the class can be of type."""
return cls is type or get_origin(cls) is type or is_type_in_union(
cls, type)
def is_custom_type(cls: TypeHint) -> bool:
"""Check if the class is a custom type."""
return cls.__module__ != "builtins"
def get_kwargs(cls: ConfigType) -> dict[str, Any]:
cls_docs = get_attr_docs(cls)
kwargs = {}
for field in fields(cls):
# Get the default value of the field
default = field.default
if field.default_factory is not MISSING:
default = field.default_factory()
# Get the help text for the field
name = field.name
help = cls_docs[name]
# Escape % for argparse
help = help.replace("%", "%%")
# Initialise the kwargs dictionary for the field
kwargs[name] = {"default": default, "help": help}
# Make note of if the field is optional and get the actual
# type of the field if it is
optional = is_optional(field.type)
field_type = get_args(
field.type)[0] if optional else field.type
# Set type, action and choices for the field depending on the
# type of the field
if can_be_type(field_type, bool):
# Creates --no-<name> and --<name> flags
kwargs[name]["action"] = argparse.BooleanOptionalAction
kwargs[name]["type"] = bool
elif can_be_type(field_type, Literal):
# Creates choices from Literal arguments
if is_type_in_union(field_type, Literal):
field_type = get_type_from_union(field_type, Literal)
choices = get_args(field_type)
kwargs[name]["choices"] = choices
choice_type = type(choices[0])
assert all(type(c) is choice_type for c in choices), (
"All choices must be of the same type. "
f"Got {choices} with types {[type(c) for c in choices]}"
)
kwargs[name]["type"] = choice_type
elif can_be_type(field_type, tuple):
if is_type_in_union(field_type, tuple):
field_type = get_type_from_union(field_type, tuple)
dtypes = get_args(field_type)
dtype = dtypes[0]
assert all(
d is dtype for d in dtypes if d is not Ellipsis
), ("All non-Ellipsis tuple elements must be of the same "
f"type. Got {dtypes}.")
kwargs[name]["type"] = dtype
kwargs[name]["nargs"] = "+"
elif can_be_type(field_type, int):
kwargs[name]["type"] = optional_int if optional else int
elif can_be_type(field_type, float):
kwargs[name][
"type"] = optional_float if optional else float
elif can_be_type(field_type, dict):
kwargs[name]["type"] = optional_dict
elif (can_be_type(field_type, str)
or is_custom_type(field_type)):
kwargs[name]["type"] = optional_str if optional else str
else:
raise ValueError(
f"Unsupported type {field.type} for argument {name}. ")
return kwargs
# Model arguments
parser.add_argument(
'--model',
@ -390,13 +377,13 @@ class EngineArgs:
'which task to use.')
parser.add_argument(
'--tokenizer',
type=optional_str,
type=optional_type(str),
default=EngineArgs.tokenizer,
help='Name or path of the huggingface tokenizer to use. '
'If unspecified, model name or path will be used.')
parser.add_argument(
"--hf-config-path",
type=optional_str,
type=optional_type(str),
default=EngineArgs.hf_config_path,
help='Name or path of the huggingface config to use. '
'If unspecified, model name or path will be used.')
@ -408,21 +395,21 @@ class EngineArgs:
'the input. The generated output will contain token ids.')
parser.add_argument(
'--revision',
type=optional_str,
type=optional_type(str),
default=None,
help='The specific model version to use. It can be a branch '
'name, a tag name, or a commit id. If unspecified, will use '
'the default version.')
parser.add_argument(
'--code-revision',
type=optional_str,
type=optional_type(str),
default=None,
help='The specific revision to use for the model code on '
'Hugging Face Hub. It can be a branch name, a tag name, or a '
'commit id. If unspecified, will use the default version.')
parser.add_argument(
'--tokenizer-revision',
type=optional_str,
type=optional_type(str),
default=None,
help='Revision of the huggingface tokenizer to use. '
'It can be a branch name, a tag name, or a commit id. '
@ -513,7 +500,7 @@ class EngineArgs:
parser.add_argument(
'--logits-processor-pattern',
type=optional_str,
type=optional_type(str),
default=None,
help='Optional regex pattern specifying valid logits processor '
'qualified names that can be passed with the `logits_processors` '
@ -612,7 +599,7 @@ class EngineArgs:
# Quantization settings.
parser.add_argument('--quantization',
'-q',
type=optional_str,
type=optional_type(str),
choices=[*QUANTIZATION_METHODS, None],
default=EngineArgs.quantization,
help='Method used to quantize the weights. If '
@ -921,7 +908,7 @@ class EngineArgs:
'class without changing the existing functions.')
parser.add_argument(
"--generation-config",
type=optional_str,
type=optional_type(str),
default="auto",
help="The folder path to the generation config. "
"Defaults to 'auto', the generation config will be loaded from "

View File

@ -11,7 +11,7 @@ import ssl
from collections.abc import Sequence
from typing import Optional, Union, get_args
from vllm.engine.arg_utils import AsyncEngineArgs, optional_str
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
validate_chat_template)
from vllm.entrypoints.openai.serving_models import (LoRAModulePath,
@ -79,7 +79,7 @@ class PromptAdapterParserAction(argparse.Action):
def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument("--host",
type=optional_str,
type=optional_type(str),
default=None,
help="Host name.")
parser.add_argument("--port", type=int, default=8000, help="Port number.")
@ -108,13 +108,13 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
default=["*"],
help="Allowed headers.")
parser.add_argument("--api-key",
type=optional_str,
type=optional_type(str),
default=None,
help="If provided, the server will require this key "
"to be presented in the header.")
parser.add_argument(
"--lora-modules",
type=optional_str,
type=optional_type(str),
default=None,
nargs='+',
action=LoRAParserAction,
@ -126,14 +126,14 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"\"base_model_name\": \"id\"}``")
parser.add_argument(
"--prompt-adapters",
type=optional_str,
type=optional_type(str),
default=None,
nargs='+',
action=PromptAdapterParserAction,
help="Prompt adapter configurations in the format name=path. "
"Multiple adapters can be specified.")
parser.add_argument("--chat-template",
type=optional_str,
type=optional_type(str),
default=None,
help="The file path to the chat template, "
"or the template in single-line form "
@ -151,20 +151,20 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
'similar to OpenAI schema. '
'Example: ``[{"type": "text", "text": "Hello world!"}]``')
parser.add_argument("--response-role",
type=optional_str,
type=optional_type(str),
default="assistant",
help="The role name to return if "
"``request.add_generation_prompt=true``.")
parser.add_argument("--ssl-keyfile",
type=optional_str,
type=optional_type(str),
default=None,
help="The file path to the SSL key file.")
parser.add_argument("--ssl-certfile",
type=optional_str,
type=optional_type(str),
default=None,
help="The file path to the SSL cert file.")
parser.add_argument("--ssl-ca-certs",
type=optional_str,
type=optional_type(str),
default=None,
help="The CA certificates file.")
parser.add_argument(
@ -180,13 +180,13 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
)
parser.add_argument(
"--root-path",
type=optional_str,
type=optional_type(str),
default=None,
help="FastAPI root_path when app is behind a path based routing proxy."
)
parser.add_argument(
"--middleware",
type=optional_str,
type=optional_type(str),
action="append",
default=[],
help="Additional ASGI middleware to apply to the app. "

View File

@ -12,7 +12,7 @@ import torch
from prometheus_client import start_http_server
from tqdm import tqdm
from vllm.engine.arg_utils import AsyncEngineArgs, optional_str
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.logger import RequestLogger, logger
# yapf: disable
@ -61,7 +61,7 @@ def parse_args():
"to the output URL.",
)
parser.add_argument("--response-role",
type=optional_str,
type=optional_type(str),
default="assistant",
help="The role name to return if "
"`request.add_generation_prompt=True`.")