Improve literal dataclass field conversion to argparse argument (#17391)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-04-29 17:25:08 +01:00 committed by GitHub
parent 0ed27ef66c
commit 2ef5d106bb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 97 additions and 18 deletions

View File

@ -11,7 +11,8 @@ import pytest
from vllm.config import PoolerConfig, config from vllm.config import PoolerConfig, config
from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs, from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs,
get_type, is_not_builtin, is_type, get_type, is_not_builtin, is_type,
nullable_kvs, optional_type) literal_to_kwargs, nullable_kvs,
optional_type)
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
@ -71,6 +72,21 @@ def test_get_type(type_hints, type, expected):
assert get_type(type_hints, type) == expected assert get_type(type_hints, type) == expected
@pytest.mark.parametrize(("type_hints", "expected"), [
({Literal[1, 2]}, {
"type": int,
"choices": [1, 2]
}),
({Literal[1, "a"]}, Exception),
])
def test_literal_to_kwargs(type_hints, expected):
context = nullcontext()
if expected is Exception:
context = pytest.raises(expected)
with context:
assert literal_to_kwargs(type_hints) == expected
@config @config
@dataclass @dataclass
class DummyConfigClass: class DummyConfigClass:
@ -81,11 +97,15 @@ class DummyConfigClass:
optional_literal: Optional[Literal["x", "y"]] = None optional_literal: Optional[Literal["x", "y"]] = None
"""Optional literal with default None""" """Optional literal with default None"""
tuple_n: tuple[int, ...] = field(default_factory=lambda: (1, 2, 3)) tuple_n: tuple[int, ...] = field(default_factory=lambda: (1, 2, 3))
"""Tuple with default (1, 2, 3)""" """Tuple with variable length"""
tuple_2: tuple[int, int] = field(default_factory=lambda: (1, 2)) tuple_2: tuple[int, int] = field(default_factory=lambda: (1, 2))
"""Tuple with default (1, 2)""" """Tuple with fixed length"""
list_n: list[int] = field(default_factory=lambda: [1, 2, 3]) list_n: list[int] = field(default_factory=lambda: [1, 2, 3])
"""List with default [1, 2, 3]""" """List with variable length"""
list_literal: list[Literal[1, 2]] = field(default_factory=list)
"""List with literal choices"""
literal_literal: Literal[Literal[1], Literal[2]] = 1
"""Literal of literals with default 1"""
@pytest.mark.parametrize(("type_hint", "expected"), [ @pytest.mark.parametrize(("type_hint", "expected"), [
@ -111,6 +131,12 @@ def test_get_kwargs():
# lists should work # lists should work
assert kwargs["list_n"]["type"] is int assert kwargs["list_n"]["type"] is int
assert kwargs["list_n"]["nargs"] == "+" assert kwargs["list_n"]["nargs"] == "+"
# lists with literals should have the correct choices
assert kwargs["list_literal"]["type"] is int
assert kwargs["list_literal"]["nargs"] == "+"
assert kwargs["list_literal"]["choices"] == [1, 2]
# literals of literals should have merged choices
assert kwargs["literal_literal"]["choices"] == [1, 2]
@pytest.mark.parametrize(("arg", "expected"), [ @pytest.mark.parametrize(("arg", "expected"), [

View File

@ -1,14 +1,47 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from dataclasses import MISSING, Field, asdict, dataclass, field from dataclasses import MISSING, Field, asdict, dataclass, field
from typing import Literal, Union
import pytest import pytest
from vllm.config import ModelConfig, PoolerConfig, get_field from vllm.config import ModelConfig, PoolerConfig, config, get_field
from vllm.model_executor.layers.pooler import PoolingType from vllm.model_executor.layers.pooler import PoolingType
from vllm.platforms import current_platform from vllm.platforms import current_platform
class TestConfig1:
pass
@dataclass
class TestConfig2:
a: int
"""docstring"""
@dataclass
class TestConfig3:
a: int = 1
@dataclass
class TestConfig4:
a: Union[Literal[1], Literal[2]] = 1
"""docstring"""
@pytest.mark.parametrize(("test_config", "expected_error"), [
(TestConfig1, "must be a dataclass"),
(TestConfig2, "must have a default"),
(TestConfig3, "must have a docstring"),
(TestConfig4, "must use a single Literal"),
])
def test_config(test_config, expected_error):
with pytest.raises(Exception, match=expected_error):
config(test_config)
def test_get_field(): def test_get_field():
@dataclass @dataclass

View File

@ -17,7 +17,7 @@ from dataclasses import (MISSING, dataclass, field, fields, is_dataclass,
from importlib.util import find_spec from importlib.util import find_spec
from pathlib import Path from pathlib import Path
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal, from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal,
Optional, Protocol, TypeVar, Union, get_args) Optional, Protocol, TypeVar, Union, get_args, get_origin)
import torch import torch
from pydantic import BaseModel, Field, PrivateAttr from pydantic import BaseModel, Field, PrivateAttr
@ -177,9 +177,19 @@ def config(cls: ConfigT) -> ConfigT:
raise ValueError( raise ValueError(
f"Field '{f.name}' in {cls.__name__} must have a default value." f"Field '{f.name}' in {cls.__name__} must have a default value."
) )
if f.name not in attr_docs: if f.name not in attr_docs:
raise ValueError( raise ValueError(
f"Field '{f.name}' in {cls.__name__} must have a docstring.") f"Field '{f.name}' in {cls.__name__} must have a docstring.")
if get_origin(f.type) is Union:
args = get_args(f.type)
literal_args = [arg for arg in args if get_origin(arg) is Literal]
if len(literal_args) > 1:
raise ValueError(
f"Field '{f.name}' in {cls.__name__} must use a single "
"Literal type. Please use 'Literal[Literal1, Literal2]' "
"instead of 'Union[Literal1, Literal2]'.")
return cls return cls
@ -3166,6 +3176,8 @@ def get_served_model_name(model: str,
GuidedDecodingBackendV0 = Literal["auto", "outlines", "lm-format-enforcer", GuidedDecodingBackendV0 = Literal["auto", "outlines", "lm-format-enforcer",
"xgrammar", "guidance"] "xgrammar", "guidance"]
GuidedDecodingBackendV1 = Literal["auto", "xgrammar", "guidance"] GuidedDecodingBackendV1 = Literal["auto", "xgrammar", "guidance"]
GuidedDecodingBackend = Literal[GuidedDecodingBackendV0,
GuidedDecodingBackendV1]
@config @config
@ -3173,9 +3185,8 @@ GuidedDecodingBackendV1 = Literal["auto", "xgrammar", "guidance"]
class DecodingConfig: class DecodingConfig:
"""Dataclass which contains the decoding strategy of the engine.""" """Dataclass which contains the decoding strategy of the engine."""
guided_decoding_backend: Union[ guided_decoding_backend: GuidedDecodingBackend = \
GuidedDecodingBackendV0, "auto" if envs.VLLM_USE_V1 else "xgrammar"
GuidedDecodingBackendV1] = "auto" if envs.VLLM_USE_V1 else "xgrammar"
"""Which engine will be used for guided decoding (JSON schema / regex etc) """Which engine will be used for guided decoding (JSON schema / regex etc)
by default. With "auto", we will make opinionated choices based on request by default. With "auto", we will make opinionated choices based on request
contents and what the backend libraries currently support, so the behavior contents and what the backend libraries currently support, so the behavior

View File

@ -116,6 +116,18 @@ def get_type(type_hints: set[TypeHint], type: TypeHintT) -> TypeHintT:
return next((th for th in type_hints if is_type(th, type)), None) return next((th for th in type_hints if is_type(th, type)), None)
def literal_to_kwargs(type_hints: set[TypeHint]) -> dict[str, Any]:
"""Convert Literal type hints to argparse kwargs."""
type_hint = get_type(type_hints, Literal)
choices = get_args(type_hint)
choice_type = type(choices[0])
if not all(isinstance(choice, choice_type) for choice in choices):
raise ValueError(
"All choices must be of the same type. "
f"Got {choices} with types {[type(c) for c in choices]}")
return {"type": choice_type, "choices": sorted(choices)}
def is_not_builtin(type_hint: TypeHint) -> bool: def is_not_builtin(type_hint: TypeHint) -> bool:
"""Check if the class is not a built-in type.""" """Check if the class is not a built-in type."""
return type_hint.__module__ != "builtins" return type_hint.__module__ != "builtins"
@ -151,15 +163,7 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
# Creates --no-<name> and --<name> flags # Creates --no-<name> and --<name> flags
kwargs[name]["action"] = argparse.BooleanOptionalAction kwargs[name]["action"] = argparse.BooleanOptionalAction
elif contains_type(type_hints, Literal): elif contains_type(type_hints, Literal):
# Creates choices from Literal arguments kwargs[name].update(literal_to_kwargs(type_hints))
type_hint = get_type(type_hints, Literal)
choices = sorted(get_args(type_hint))
kwargs[name]["choices"] = choices
choice_type = type(choices[0])
assert all(type(c) is choice_type for c in choices), (
"All choices must be of the same type. "
f"Got {choices} with types {[type(c) for c in choices]}")
kwargs[name]["type"] = choice_type
elif contains_type(type_hints, tuple): elif contains_type(type_hints, tuple):
type_hint = get_type(type_hints, tuple) type_hint = get_type(type_hints, tuple)
types = get_args(type_hint) types = get_args(type_hint)
@ -191,6 +195,11 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
raise ValueError( raise ValueError(
f"Unsupported type {type_hints} for argument {name}.") f"Unsupported type {type_hints} for argument {name}.")
# If the type hint was a sequence of literals, use the helper function
# to update the type and choices
if get_origin(kwargs[name].get("type")) is Literal:
kwargs[name].update(literal_to_kwargs({kwargs[name]["type"]}))
# If None is in type_hints, make the argument optional. # If None is in type_hints, make the argument optional.
# But not if it's a bool, argparse will handle this better. # But not if it's a bool, argparse will handle this better.
if type(None) in type_hints and not contains_type(type_hints, bool): if type(None) in type_hints and not contains_type(type_hints, bool):