Improve conversion from dataclass configs to argparse arguments (#17303)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-04-28 17:22:12 +01:00 committed by GitHub
parent 72dfe4c74f
commit f94886946e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 247 additions and 156 deletions

View File

@ -1,14 +1,118 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import json
from argparse import ArgumentError, ArgumentTypeError from argparse import ArgumentError, ArgumentTypeError
from contextlib import nullcontext
from dataclasses import dataclass, field
from typing import Literal, Optional
import pytest import pytest
from vllm.config import PoolerConfig from vllm.config import PoolerConfig, config
from vllm.engine.arg_utils import EngineArgs, nullable_kvs from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs,
get_type, is_not_builtin, is_type,
nullable_kvs, optional_type)
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
@pytest.mark.parametrize(("type", "value", "expected"), [
(int, "42", 42),
(int, "None", None),
(float, "3.14", 3.14),
(float, "None", None),
(str, "Hello World!", "Hello World!"),
(str, "None", None),
(json.loads, '{"foo":1,"bar":2}', {
"foo": 1,
"bar": 2
}),
(json.loads, "foo=1,bar=2", {
"foo": 1,
"bar": 2
}),
(json.loads, "None", None),
])
def test_optional_type(type, value, expected):
optional_type_func = optional_type(type)
context = nullcontext()
if value == "foo=1,bar=2":
context = pytest.warns(DeprecationWarning)
with context:
assert optional_type_func(value) == expected
@pytest.mark.parametrize(("type_hint", "type", "expected"), [
(int, int, True),
(int, float, False),
(list[int], list, True),
(list[int], tuple, False),
(Literal[0, 1], Literal, True),
])
def test_is_type(type_hint, type, expected):
assert is_type(type_hint, type) == expected
@pytest.mark.parametrize(("type_hints", "type", "expected"), [
({float, int}, int, True),
({int, tuple[int]}, int, True),
({int, tuple[int]}, float, False),
({str, Literal["x", "y"]}, Literal, True),
])
def test_contains_type(type_hints, type, expected):
assert contains_type(type_hints, type) == expected
@pytest.mark.parametrize(("type_hints", "type", "expected"), [
({int, float}, int, int),
({int, float}, str, None),
({str, Literal["x", "y"]}, Literal, Literal["x", "y"]),
])
def test_get_type(type_hints, type, expected):
assert get_type(type_hints, type) == expected
@config
@dataclass
class DummyConfigClass:
regular_bool: bool = True
"""Regular bool with default True"""
optional_bool: Optional[bool] = None
"""Optional bool with default None"""
optional_literal: Optional[Literal["x", "y"]] = None
"""Optional literal with default None"""
tuple_n: tuple[int, ...] = field(default_factory=lambda: (1, 2, 3))
"""Tuple with default (1, 2, 3)"""
tuple_2: tuple[int, int] = field(default_factory=lambda: (1, 2))
"""Tuple with default (1, 2)"""
list_n: list[int] = field(default_factory=lambda: [1, 2, 3])
"""List with default [1, 2, 3]"""
@pytest.mark.parametrize(("type_hint", "expected"), [
(int, False),
(DummyConfigClass, True),
])
def test_is_not_builtin(type_hint, expected):
assert is_not_builtin(type_hint) == expected
def test_get_kwargs():
kwargs = get_kwargs(DummyConfigClass)
print(kwargs)
# bools should not have their type set
assert kwargs["regular_bool"].get("type") is None
assert kwargs["optional_bool"].get("type") is None
# optional literals should have None as a choice
assert kwargs["optional_literal"]["choices"] == ["x", "y", "None"]
# tuples should have the correct nargs
assert kwargs["tuple_n"]["nargs"] == "+"
assert kwargs["tuple_2"]["nargs"] == 2
# lists should work
assert kwargs["list_n"]["type"] is int
assert kwargs["list_n"]["nargs"] == "+"
@pytest.mark.parametrize(("arg", "expected"), [ @pytest.mark.parametrize(("arg", "expected"), [
(None, dict()), (None, dict()),
("image=16", { ("image=16", {

View File

@ -11,7 +11,7 @@ from typing import (Any, Callable, Dict, List, Literal, Optional, Type,
TypeVar, Union, cast, get_args, get_origin) TypeVar, Union, cast, get_args, get_origin)
import torch import torch
from typing_extensions import TypeIs from typing_extensions import TypeIs, deprecated
import vllm.envs as envs import vllm.envs as envs
from vllm import version from vllm import version
@ -48,33 +48,29 @@ TypeHint = Union[type[Any], object]
TypeHintT = Union[type[T], object] TypeHintT = Union[type[T], object]
def optional_arg(val: str, return_type: Callable[[str], T]) -> Optional[T]: def optional_type(
if val == "" or val == "None": return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]:
return None
try: def _optional_type(val: str) -> Optional[T]:
return return_type(val) if val == "" or val == "None":
except ValueError as e: return None
raise argparse.ArgumentTypeError( try:
f"Value {val} cannot be converted to {return_type}.") from e if return_type is json.loads and not re.match("^{.*}$", val):
return cast(T, nullable_kvs(val))
return return_type(val)
except ValueError as e:
raise argparse.ArgumentTypeError(
f"Value {val} cannot be converted to {return_type}.") from e
return _optional_type
def optional_str(val: str) -> Optional[str]: @deprecated(
return optional_arg(val, str) "Passing a JSON argument as a string containing comma separated key=value "
"pairs is deprecated. This will be removed in v0.10.0. Please use a JSON "
"string instead.")
def optional_int(val: str) -> Optional[int]: def nullable_kvs(val: str) -> dict[str, int]:
return optional_arg(val, int) """Parses a string containing comma separate key [str] to value [int]
def optional_float(val: str) -> Optional[float]:
return optional_arg(val, float)
def nullable_kvs(val: str) -> Optional[dict[str, int]]:
"""NOTE: This function is deprecated, args should be passed as JSON
strings instead.
Parses a string containing comma separate key [str] to value [int]
pairs into a dictionary. pairs into a dictionary.
Args: Args:
@ -83,10 +79,7 @@ def nullable_kvs(val: str) -> Optional[dict[str, int]]:
Returns: Returns:
Dictionary with parsed values. Dictionary with parsed values.
""" """
if len(val) == 0: out_dict: dict[str, int] = {}
return None
out_dict: Dict[str, int] = {}
for item in val.split(","): for item in val.split(","):
kv_parts = [part.lower().strip() for part in item.split("=")] kv_parts = [part.lower().strip() for part in item.split("=")]
if len(kv_parts) != 2: if len(kv_parts) != 2:
@ -108,15 +101,103 @@ def nullable_kvs(val: str) -> Optional[dict[str, int]]:
return out_dict return out_dict
def optional_dict(val: str) -> Optional[dict[str, int]]: def is_type(type_hint: TypeHint, type: TypeHintT) -> TypeIs[TypeHintT]:
if re.match("^{.*}$", val): """Check if the type hint is a specific type."""
return optional_arg(val, json.loads) return type_hint is type or get_origin(type_hint) is type
logger.warning(
"Failed to parse JSON string. Attempting to parse as " def contains_type(type_hints: set[TypeHint], type: TypeHintT) -> bool:
"comma-separated key=value pairs. This will be deprecated in a " """Check if the type hints contain a specific type."""
"future release.") return any(is_type(type_hint, type) for type_hint in type_hints)
return nullable_kvs(val)
def get_type(type_hints: set[TypeHint], type: TypeHintT) -> TypeHintT:
"""Get the specific type from the type hints."""
return next((th for th in type_hints if is_type(th, type)), None)
def is_not_builtin(type_hint: TypeHint) -> bool:
"""Check if the class is not a built-in type."""
return type_hint.__module__ != "builtins"
def get_kwargs(cls: ConfigType) -> dict[str, Any]:
cls_docs = get_attr_docs(cls)
kwargs = {}
for field in fields(cls):
# Get the default value of the field
default = field.default
if field.default_factory is not MISSING:
default = field.default_factory()
# Get the help text for the field
name = field.name
help = cls_docs[name]
# Escape % for argparse
help = help.replace("%", "%%")
# Initialise the kwargs dictionary for the field
kwargs[name] = {"default": default, "help": help}
# Get the set of possible types for the field
type_hints: set[TypeHint] = set()
if get_origin(field.type) is Union:
type_hints.update(get_args(field.type))
else:
type_hints.add(field.type)
# Set other kwargs based on the type hints
if contains_type(type_hints, bool):
# Creates --no-<name> and --<name> flags
kwargs[name]["action"] = argparse.BooleanOptionalAction
elif contains_type(type_hints, Literal):
# Creates choices from Literal arguments
type_hint = get_type(type_hints, Literal)
choices = sorted(get_args(type_hint))
kwargs[name]["choices"] = choices
choice_type = type(choices[0])
assert all(type(c) is choice_type for c in choices), (
"All choices must be of the same type. "
f"Got {choices} with types {[type(c) for c in choices]}")
kwargs[name]["type"] = choice_type
elif contains_type(type_hints, tuple):
type_hint = get_type(type_hints, tuple)
types = get_args(type_hint)
tuple_type = types[0]
assert all(t is tuple_type for t in types if t is not Ellipsis), (
"All non-Ellipsis tuple elements must be of the same "
f"type. Got {types}.")
kwargs[name]["type"] = tuple_type
kwargs[name]["nargs"] = "+" if Ellipsis in types else len(types)
elif contains_type(type_hints, list):
type_hint = get_type(type_hints, list)
types = get_args(type_hint)
assert len(types) == 1, (
"List type must have exactly one type. Got "
f"{type_hint} with types {types}")
kwargs[name]["type"] = types[0]
kwargs[name]["nargs"] = "+"
elif contains_type(type_hints, int):
kwargs[name]["type"] = int
elif contains_type(type_hints, float):
kwargs[name]["type"] = float
elif contains_type(type_hints, dict):
# Dict arguments will always be optional
kwargs[name]["type"] = optional_type(json.loads)
elif (contains_type(type_hints, str)
or any(is_not_builtin(th) for th in type_hints)):
kwargs[name]["type"] = str
else:
raise ValueError(
f"Unsupported type {type_hints} for argument {name}.")
# If None is in type_hints, make the argument optional.
# But not if it's a bool, argparse will handle this better.
if type(None) in type_hints and not contains_type(type_hints, bool):
kwargs[name]["type"] = optional_type(kwargs[name]["type"])
if kwargs[name].get("choices"):
kwargs[name]["choices"].append("None")
return kwargs
@dataclass @dataclass
@ -279,100 +360,6 @@ class EngineArgs:
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"""Shared CLI arguments for vLLM engine.""" """Shared CLI arguments for vLLM engine."""
def is_type_in_union(cls: TypeHint, type: TypeHint) -> bool:
"""Check if the class is a type in a union type."""
is_union = get_origin(cls) is Union
type_in_union = type in [get_origin(a) or a for a in get_args(cls)]
return is_union and type_in_union
def get_type_from_union(cls: TypeHint, type: TypeHintT) -> TypeHintT:
"""Get the type in a union type."""
for arg in get_args(cls):
if (get_origin(arg) or arg) is type:
return arg
raise ValueError(f"Type {type} not found in union type {cls}.")
def is_optional(cls: TypeHint) -> TypeIs[Union[Any, None]]:
"""Check if the class is an optional type."""
return is_type_in_union(cls, type(None))
def can_be_type(cls: TypeHint, type: TypeHintT) -> TypeIs[TypeHintT]:
"""Check if the class can be of type."""
return cls is type or get_origin(cls) is type or is_type_in_union(
cls, type)
def is_custom_type(cls: TypeHint) -> bool:
"""Check if the class is a custom type."""
return cls.__module__ != "builtins"
def get_kwargs(cls: ConfigType) -> dict[str, Any]:
cls_docs = get_attr_docs(cls)
kwargs = {}
for field in fields(cls):
# Get the default value of the field
default = field.default
if field.default_factory is not MISSING:
default = field.default_factory()
# Get the help text for the field
name = field.name
help = cls_docs[name]
# Escape % for argparse
help = help.replace("%", "%%")
# Initialise the kwargs dictionary for the field
kwargs[name] = {"default": default, "help": help}
# Make note of if the field is optional and get the actual
# type of the field if it is
optional = is_optional(field.type)
field_type = get_args(
field.type)[0] if optional else field.type
# Set type, action and choices for the field depending on the
# type of the field
if can_be_type(field_type, bool):
# Creates --no-<name> and --<name> flags
kwargs[name]["action"] = argparse.BooleanOptionalAction
kwargs[name]["type"] = bool
elif can_be_type(field_type, Literal):
# Creates choices from Literal arguments
if is_type_in_union(field_type, Literal):
field_type = get_type_from_union(field_type, Literal)
choices = get_args(field_type)
kwargs[name]["choices"] = choices
choice_type = type(choices[0])
assert all(type(c) is choice_type for c in choices), (
"All choices must be of the same type. "
f"Got {choices} with types {[type(c) for c in choices]}"
)
kwargs[name]["type"] = choice_type
elif can_be_type(field_type, tuple):
if is_type_in_union(field_type, tuple):
field_type = get_type_from_union(field_type, tuple)
dtypes = get_args(field_type)
dtype = dtypes[0]
assert all(
d is dtype for d in dtypes if d is not Ellipsis
), ("All non-Ellipsis tuple elements must be of the same "
f"type. Got {dtypes}.")
kwargs[name]["type"] = dtype
kwargs[name]["nargs"] = "+"
elif can_be_type(field_type, int):
kwargs[name]["type"] = optional_int if optional else int
elif can_be_type(field_type, float):
kwargs[name][
"type"] = optional_float if optional else float
elif can_be_type(field_type, dict):
kwargs[name]["type"] = optional_dict
elif (can_be_type(field_type, str)
or is_custom_type(field_type)):
kwargs[name]["type"] = optional_str if optional else str
else:
raise ValueError(
f"Unsupported type {field.type} for argument {name}. ")
return kwargs
# Model arguments # Model arguments
parser.add_argument( parser.add_argument(
'--model', '--model',
@ -390,13 +377,13 @@ class EngineArgs:
'which task to use.') 'which task to use.')
parser.add_argument( parser.add_argument(
'--tokenizer', '--tokenizer',
type=optional_str, type=optional_type(str),
default=EngineArgs.tokenizer, default=EngineArgs.tokenizer,
help='Name or path of the huggingface tokenizer to use. ' help='Name or path of the huggingface tokenizer to use. '
'If unspecified, model name or path will be used.') 'If unspecified, model name or path will be used.')
parser.add_argument( parser.add_argument(
"--hf-config-path", "--hf-config-path",
type=optional_str, type=optional_type(str),
default=EngineArgs.hf_config_path, default=EngineArgs.hf_config_path,
help='Name or path of the huggingface config to use. ' help='Name or path of the huggingface config to use. '
'If unspecified, model name or path will be used.') 'If unspecified, model name or path will be used.')
@ -408,21 +395,21 @@ class EngineArgs:
'the input. The generated output will contain token ids.') 'the input. The generated output will contain token ids.')
parser.add_argument( parser.add_argument(
'--revision', '--revision',
type=optional_str, type=optional_type(str),
default=None, default=None,
help='The specific model version to use. It can be a branch ' help='The specific model version to use. It can be a branch '
'name, a tag name, or a commit id. If unspecified, will use ' 'name, a tag name, or a commit id. If unspecified, will use '
'the default version.') 'the default version.')
parser.add_argument( parser.add_argument(
'--code-revision', '--code-revision',
type=optional_str, type=optional_type(str),
default=None, default=None,
help='The specific revision to use for the model code on ' help='The specific revision to use for the model code on '
'Hugging Face Hub. It can be a branch name, a tag name, or a ' 'Hugging Face Hub. It can be a branch name, a tag name, or a '
'commit id. If unspecified, will use the default version.') 'commit id. If unspecified, will use the default version.')
parser.add_argument( parser.add_argument(
'--tokenizer-revision', '--tokenizer-revision',
type=optional_str, type=optional_type(str),
default=None, default=None,
help='Revision of the huggingface tokenizer to use. ' help='Revision of the huggingface tokenizer to use. '
'It can be a branch name, a tag name, or a commit id. ' 'It can be a branch name, a tag name, or a commit id. '
@ -513,7 +500,7 @@ class EngineArgs:
parser.add_argument( parser.add_argument(
'--logits-processor-pattern', '--logits-processor-pattern',
type=optional_str, type=optional_type(str),
default=None, default=None,
help='Optional regex pattern specifying valid logits processor ' help='Optional regex pattern specifying valid logits processor '
'qualified names that can be passed with the `logits_processors` ' 'qualified names that can be passed with the `logits_processors` '
@ -612,7 +599,7 @@ class EngineArgs:
# Quantization settings. # Quantization settings.
parser.add_argument('--quantization', parser.add_argument('--quantization',
'-q', '-q',
type=optional_str, type=optional_type(str),
choices=[*QUANTIZATION_METHODS, None], choices=[*QUANTIZATION_METHODS, None],
default=EngineArgs.quantization, default=EngineArgs.quantization,
help='Method used to quantize the weights. If ' help='Method used to quantize the weights. If '
@ -921,7 +908,7 @@ class EngineArgs:
'class without changing the existing functions.') 'class without changing the existing functions.')
parser.add_argument( parser.add_argument(
"--generation-config", "--generation-config",
type=optional_str, type=optional_type(str),
default="auto", default="auto",
help="The folder path to the generation config. " help="The folder path to the generation config. "
"Defaults to 'auto', the generation config will be loaded from " "Defaults to 'auto', the generation config will be loaded from "

View File

@ -11,7 +11,7 @@ import ssl
from collections.abc import Sequence from collections.abc import Sequence
from typing import Optional, Union, get_args from typing import Optional, Union, get_args
from vllm.engine.arg_utils import AsyncEngineArgs, optional_str from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption, from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
validate_chat_template) validate_chat_template)
from vllm.entrypoints.openai.serving_models import (LoRAModulePath, from vllm.entrypoints.openai.serving_models import (LoRAModulePath,
@ -79,7 +79,7 @@ class PromptAdapterParserAction(argparse.Action):
def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument("--host", parser.add_argument("--host",
type=optional_str, type=optional_type(str),
default=None, default=None,
help="Host name.") help="Host name.")
parser.add_argument("--port", type=int, default=8000, help="Port number.") parser.add_argument("--port", type=int, default=8000, help="Port number.")
@ -108,13 +108,13 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
default=["*"], default=["*"],
help="Allowed headers.") help="Allowed headers.")
parser.add_argument("--api-key", parser.add_argument("--api-key",
type=optional_str, type=optional_type(str),
default=None, default=None,
help="If provided, the server will require this key " help="If provided, the server will require this key "
"to be presented in the header.") "to be presented in the header.")
parser.add_argument( parser.add_argument(
"--lora-modules", "--lora-modules",
type=optional_str, type=optional_type(str),
default=None, default=None,
nargs='+', nargs='+',
action=LoRAParserAction, action=LoRAParserAction,
@ -126,14 +126,14 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"\"base_model_name\": \"id\"}``") "\"base_model_name\": \"id\"}``")
parser.add_argument( parser.add_argument(
"--prompt-adapters", "--prompt-adapters",
type=optional_str, type=optional_type(str),
default=None, default=None,
nargs='+', nargs='+',
action=PromptAdapterParserAction, action=PromptAdapterParserAction,
help="Prompt adapter configurations in the format name=path. " help="Prompt adapter configurations in the format name=path. "
"Multiple adapters can be specified.") "Multiple adapters can be specified.")
parser.add_argument("--chat-template", parser.add_argument("--chat-template",
type=optional_str, type=optional_type(str),
default=None, default=None,
help="The file path to the chat template, " help="The file path to the chat template, "
"or the template in single-line form " "or the template in single-line form "
@ -151,20 +151,20 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
'similar to OpenAI schema. ' 'similar to OpenAI schema. '
'Example: ``[{"type": "text", "text": "Hello world!"}]``') 'Example: ``[{"type": "text", "text": "Hello world!"}]``')
parser.add_argument("--response-role", parser.add_argument("--response-role",
type=optional_str, type=optional_type(str),
default="assistant", default="assistant",
help="The role name to return if " help="The role name to return if "
"``request.add_generation_prompt=true``.") "``request.add_generation_prompt=true``.")
parser.add_argument("--ssl-keyfile", parser.add_argument("--ssl-keyfile",
type=optional_str, type=optional_type(str),
default=None, default=None,
help="The file path to the SSL key file.") help="The file path to the SSL key file.")
parser.add_argument("--ssl-certfile", parser.add_argument("--ssl-certfile",
type=optional_str, type=optional_type(str),
default=None, default=None,
help="The file path to the SSL cert file.") help="The file path to the SSL cert file.")
parser.add_argument("--ssl-ca-certs", parser.add_argument("--ssl-ca-certs",
type=optional_str, type=optional_type(str),
default=None, default=None,
help="The CA certificates file.") help="The CA certificates file.")
parser.add_argument( parser.add_argument(
@ -180,13 +180,13 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
) )
parser.add_argument( parser.add_argument(
"--root-path", "--root-path",
type=optional_str, type=optional_type(str),
default=None, default=None,
help="FastAPI root_path when app is behind a path based routing proxy." help="FastAPI root_path when app is behind a path based routing proxy."
) )
parser.add_argument( parser.add_argument(
"--middleware", "--middleware",
type=optional_str, type=optional_type(str),
action="append", action="append",
default=[], default=[],
help="Additional ASGI middleware to apply to the app. " help="Additional ASGI middleware to apply to the app. "

View File

@ -12,7 +12,7 @@ import torch
from prometheus_client import start_http_server from prometheus_client import start_http_server
from tqdm import tqdm from tqdm import tqdm
from vllm.engine.arg_utils import AsyncEngineArgs, optional_str from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.logger import RequestLogger, logger from vllm.entrypoints.logger import RequestLogger, logger
# yapf: disable # yapf: disable
@ -61,7 +61,7 @@ def parse_args():
"to the output URL.", "to the output URL.",
) )
parser.add_argument("--response-role", parser.add_argument("--response-role",
type=optional_str, type=optional_type(str),
default="assistant", default="assistant",
help="The role name to return if " help="The role name to return if "
"`request.add_generation_prompt=True`.") "`request.add_generation_prompt=True`.")