From 2ef5d106bbf269563889308039ab10b149b57008 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Tue, 29 Apr 2025 17:25:08 +0100 Subject: [PATCH] Improve literal dataclass field conversion to argparse argument (#17391) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- tests/engine/test_arg_utils.py | 34 +++++++++++++++++++++++++++++---- tests/test_config.py | 35 +++++++++++++++++++++++++++++++++- vllm/config.py | 19 ++++++++++++++---- vllm/engine/arg_utils.py | 27 +++++++++++++++++--------- 4 files changed, 97 insertions(+), 18 deletions(-) diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index 052d5793c1b3..2c86658022c0 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -11,7 +11,8 @@ import pytest from vllm.config import PoolerConfig, config from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs, get_type, is_not_builtin, is_type, - nullable_kvs, optional_type) + literal_to_kwargs, nullable_kvs, + optional_type) from vllm.utils import FlexibleArgumentParser @@ -71,6 +72,21 @@ def test_get_type(type_hints, type, expected): assert get_type(type_hints, type) == expected +@pytest.mark.parametrize(("type_hints", "expected"), [ + ({Literal[1, 2]}, { + "type": int, + "choices": [1, 2] + }), + ({Literal[1, "a"]}, Exception), +]) +def test_literal_to_kwargs(type_hints, expected): + context = nullcontext() + if expected is Exception: + context = pytest.raises(expected) + with context: + assert literal_to_kwargs(type_hints) == expected + + @config @dataclass class DummyConfigClass: @@ -81,11 +97,15 @@ class DummyConfigClass: optional_literal: Optional[Literal["x", "y"]] = None """Optional literal with default None""" tuple_n: tuple[int, ...] = field(default_factory=lambda: (1, 2, 3)) - """Tuple with default (1, 2, 3)""" + """Tuple with variable length""" tuple_2: tuple[int, int] = field(default_factory=lambda: (1, 2)) - """Tuple with default (1, 2)""" + """Tuple with fixed length""" list_n: list[int] = field(default_factory=lambda: [1, 2, 3]) - """List with default [1, 2, 3]""" + """List with variable length""" + list_literal: list[Literal[1, 2]] = field(default_factory=list) + """List with literal choices""" + literal_literal: Literal[Literal[1], Literal[2]] = 1 + """Literal of literals with default 1""" @pytest.mark.parametrize(("type_hint", "expected"), [ @@ -111,6 +131,12 @@ def test_get_kwargs(): # lists should work assert kwargs["list_n"]["type"] is int assert kwargs["list_n"]["nargs"] == "+" + # lists with literals should have the correct choices + assert kwargs["list_literal"]["type"] is int + assert kwargs["list_literal"]["nargs"] == "+" + assert kwargs["list_literal"]["choices"] == [1, 2] + # literals of literals should have merged choices + assert kwargs["literal_literal"]["choices"] == [1, 2] @pytest.mark.parametrize(("arg", "expected"), [ diff --git a/tests/test_config.py b/tests/test_config.py index 53db91e81c41..2e5da8128d99 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,14 +1,47 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import MISSING, Field, asdict, dataclass, field +from typing import Literal, Union import pytest -from vllm.config import ModelConfig, PoolerConfig, get_field +from vllm.config import ModelConfig, PoolerConfig, config, get_field from vllm.model_executor.layers.pooler import PoolingType from vllm.platforms import current_platform +class TestConfig1: + pass + + +@dataclass +class TestConfig2: + a: int + """docstring""" + + +@dataclass +class TestConfig3: + a: int = 1 + + +@dataclass +class TestConfig4: + a: Union[Literal[1], Literal[2]] = 1 + """docstring""" + + +@pytest.mark.parametrize(("test_config", "expected_error"), [ + (TestConfig1, "must be a dataclass"), + (TestConfig2, "must have a default"), + (TestConfig3, "must have a docstring"), + (TestConfig4, "must use a single Literal"), +]) +def test_config(test_config, expected_error): + with pytest.raises(Exception, match=expected_error): + config(test_config) + + def test_get_field(): @dataclass diff --git a/vllm/config.py b/vllm/config.py index c1c72846d93a..8f927835d2d4 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -17,7 +17,7 @@ from dataclasses import (MISSING, dataclass, field, fields, is_dataclass, from importlib.util import find_spec from pathlib import Path from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal, - Optional, Protocol, TypeVar, Union, get_args) + Optional, Protocol, TypeVar, Union, get_args, get_origin) import torch from pydantic import BaseModel, Field, PrivateAttr @@ -177,9 +177,19 @@ def config(cls: ConfigT) -> ConfigT: raise ValueError( f"Field '{f.name}' in {cls.__name__} must have a default value." ) + if f.name not in attr_docs: raise ValueError( f"Field '{f.name}' in {cls.__name__} must have a docstring.") + + if get_origin(f.type) is Union: + args = get_args(f.type) + literal_args = [arg for arg in args if get_origin(arg) is Literal] + if len(literal_args) > 1: + raise ValueError( + f"Field '{f.name}' in {cls.__name__} must use a single " + "Literal type. Please use 'Literal[Literal1, Literal2]' " + "instead of 'Union[Literal1, Literal2]'.") return cls @@ -3166,6 +3176,8 @@ def get_served_model_name(model: str, GuidedDecodingBackendV0 = Literal["auto", "outlines", "lm-format-enforcer", "xgrammar", "guidance"] GuidedDecodingBackendV1 = Literal["auto", "xgrammar", "guidance"] +GuidedDecodingBackend = Literal[GuidedDecodingBackendV0, + GuidedDecodingBackendV1] @config @@ -3173,9 +3185,8 @@ GuidedDecodingBackendV1 = Literal["auto", "xgrammar", "guidance"] class DecodingConfig: """Dataclass which contains the decoding strategy of the engine.""" - guided_decoding_backend: Union[ - GuidedDecodingBackendV0, - GuidedDecodingBackendV1] = "auto" if envs.VLLM_USE_V1 else "xgrammar" + guided_decoding_backend: GuidedDecodingBackend = \ + "auto" if envs.VLLM_USE_V1 else "xgrammar" """Which engine will be used for guided decoding (JSON schema / regex etc) by default. With "auto", we will make opinionated choices based on request contents and what the backend libraries currently support, so the behavior diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index ad26241235c0..fe688025f9b1 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -116,6 +116,18 @@ def get_type(type_hints: set[TypeHint], type: TypeHintT) -> TypeHintT: return next((th for th in type_hints if is_type(th, type)), None) +def literal_to_kwargs(type_hints: set[TypeHint]) -> dict[str, Any]: + """Convert Literal type hints to argparse kwargs.""" + type_hint = get_type(type_hints, Literal) + choices = get_args(type_hint) + choice_type = type(choices[0]) + if not all(isinstance(choice, choice_type) for choice in choices): + raise ValueError( + "All choices must be of the same type. " + f"Got {choices} with types {[type(c) for c in choices]}") + return {"type": choice_type, "choices": sorted(choices)} + + def is_not_builtin(type_hint: TypeHint) -> bool: """Check if the class is not a built-in type.""" return type_hint.__module__ != "builtins" @@ -151,15 +163,7 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]: # Creates --no- and -- flags kwargs[name]["action"] = argparse.BooleanOptionalAction elif contains_type(type_hints, Literal): - # Creates choices from Literal arguments - type_hint = get_type(type_hints, Literal) - choices = sorted(get_args(type_hint)) - kwargs[name]["choices"] = choices - choice_type = type(choices[0]) - assert all(type(c) is choice_type for c in choices), ( - "All choices must be of the same type. " - f"Got {choices} with types {[type(c) for c in choices]}") - kwargs[name]["type"] = choice_type + kwargs[name].update(literal_to_kwargs(type_hints)) elif contains_type(type_hints, tuple): type_hint = get_type(type_hints, tuple) types = get_args(type_hint) @@ -191,6 +195,11 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]: raise ValueError( f"Unsupported type {type_hints} for argument {name}.") + # If the type hint was a sequence of literals, use the helper function + # to update the type and choices + if get_origin(kwargs[name].get("type")) is Literal: + kwargs[name].update(literal_to_kwargs({kwargs[name]["type"]})) + # If None is in type_hints, make the argument optional. # But not if it's a bool, argparse will handle this better. if type(None) in type_hints and not contains_type(type_hints, bool):