mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-22 21:05:46 +08:00
Improve literal dataclass field conversion to argparse argument (#17391)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
0ed27ef66c
commit
2ef5d106bb
@ -11,7 +11,8 @@ import pytest
|
|||||||
from vllm.config import PoolerConfig, config
|
from vllm.config import PoolerConfig, config
|
||||||
from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs,
|
from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs,
|
||||||
get_type, is_not_builtin, is_type,
|
get_type, is_not_builtin, is_type,
|
||||||
nullable_kvs, optional_type)
|
literal_to_kwargs, nullable_kvs,
|
||||||
|
optional_type)
|
||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
|
||||||
@ -71,6 +72,21 @@ def test_get_type(type_hints, type, expected):
|
|||||||
assert get_type(type_hints, type) == expected
|
assert get_type(type_hints, type) == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(("type_hints", "expected"), [
|
||||||
|
({Literal[1, 2]}, {
|
||||||
|
"type": int,
|
||||||
|
"choices": [1, 2]
|
||||||
|
}),
|
||||||
|
({Literal[1, "a"]}, Exception),
|
||||||
|
])
|
||||||
|
def test_literal_to_kwargs(type_hints, expected):
|
||||||
|
context = nullcontext()
|
||||||
|
if expected is Exception:
|
||||||
|
context = pytest.raises(expected)
|
||||||
|
with context:
|
||||||
|
assert literal_to_kwargs(type_hints) == expected
|
||||||
|
|
||||||
|
|
||||||
@config
|
@config
|
||||||
@dataclass
|
@dataclass
|
||||||
class DummyConfigClass:
|
class DummyConfigClass:
|
||||||
@ -81,11 +97,15 @@ class DummyConfigClass:
|
|||||||
optional_literal: Optional[Literal["x", "y"]] = None
|
optional_literal: Optional[Literal["x", "y"]] = None
|
||||||
"""Optional literal with default None"""
|
"""Optional literal with default None"""
|
||||||
tuple_n: tuple[int, ...] = field(default_factory=lambda: (1, 2, 3))
|
tuple_n: tuple[int, ...] = field(default_factory=lambda: (1, 2, 3))
|
||||||
"""Tuple with default (1, 2, 3)"""
|
"""Tuple with variable length"""
|
||||||
tuple_2: tuple[int, int] = field(default_factory=lambda: (1, 2))
|
tuple_2: tuple[int, int] = field(default_factory=lambda: (1, 2))
|
||||||
"""Tuple with default (1, 2)"""
|
"""Tuple with fixed length"""
|
||||||
list_n: list[int] = field(default_factory=lambda: [1, 2, 3])
|
list_n: list[int] = field(default_factory=lambda: [1, 2, 3])
|
||||||
"""List with default [1, 2, 3]"""
|
"""List with variable length"""
|
||||||
|
list_literal: list[Literal[1, 2]] = field(default_factory=list)
|
||||||
|
"""List with literal choices"""
|
||||||
|
literal_literal: Literal[Literal[1], Literal[2]] = 1
|
||||||
|
"""Literal of literals with default 1"""
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(("type_hint", "expected"), [
|
@pytest.mark.parametrize(("type_hint", "expected"), [
|
||||||
@ -111,6 +131,12 @@ def test_get_kwargs():
|
|||||||
# lists should work
|
# lists should work
|
||||||
assert kwargs["list_n"]["type"] is int
|
assert kwargs["list_n"]["type"] is int
|
||||||
assert kwargs["list_n"]["nargs"] == "+"
|
assert kwargs["list_n"]["nargs"] == "+"
|
||||||
|
# lists with literals should have the correct choices
|
||||||
|
assert kwargs["list_literal"]["type"] is int
|
||||||
|
assert kwargs["list_literal"]["nargs"] == "+"
|
||||||
|
assert kwargs["list_literal"]["choices"] == [1, 2]
|
||||||
|
# literals of literals should have merged choices
|
||||||
|
assert kwargs["literal_literal"]["choices"] == [1, 2]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(("arg", "expected"), [
|
@pytest.mark.parametrize(("arg", "expected"), [
|
||||||
|
|||||||
@ -1,14 +1,47 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from dataclasses import MISSING, Field, asdict, dataclass, field
|
from dataclasses import MISSING, Field, asdict, dataclass, field
|
||||||
|
from typing import Literal, Union
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm.config import ModelConfig, PoolerConfig, get_field
|
from vllm.config import ModelConfig, PoolerConfig, config, get_field
|
||||||
from vllm.model_executor.layers.pooler import PoolingType
|
from vllm.model_executor.layers.pooler import PoolingType
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
|
||||||
|
class TestConfig1:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TestConfig2:
|
||||||
|
a: int
|
||||||
|
"""docstring"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TestConfig3:
|
||||||
|
a: int = 1
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TestConfig4:
|
||||||
|
a: Union[Literal[1], Literal[2]] = 1
|
||||||
|
"""docstring"""
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(("test_config", "expected_error"), [
|
||||||
|
(TestConfig1, "must be a dataclass"),
|
||||||
|
(TestConfig2, "must have a default"),
|
||||||
|
(TestConfig3, "must have a docstring"),
|
||||||
|
(TestConfig4, "must use a single Literal"),
|
||||||
|
])
|
||||||
|
def test_config(test_config, expected_error):
|
||||||
|
with pytest.raises(Exception, match=expected_error):
|
||||||
|
config(test_config)
|
||||||
|
|
||||||
|
|
||||||
def test_get_field():
|
def test_get_field():
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@ -17,7 +17,7 @@ from dataclasses import (MISSING, dataclass, field, fields, is_dataclass,
|
|||||||
from importlib.util import find_spec
|
from importlib.util import find_spec
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal,
|
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal,
|
||||||
Optional, Protocol, TypeVar, Union, get_args)
|
Optional, Protocol, TypeVar, Union, get_args, get_origin)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic import BaseModel, Field, PrivateAttr
|
from pydantic import BaseModel, Field, PrivateAttr
|
||||||
@ -177,9 +177,19 @@ def config(cls: ConfigT) -> ConfigT:
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Field '{f.name}' in {cls.__name__} must have a default value."
|
f"Field '{f.name}' in {cls.__name__} must have a default value."
|
||||||
)
|
)
|
||||||
|
|
||||||
if f.name not in attr_docs:
|
if f.name not in attr_docs:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Field '{f.name}' in {cls.__name__} must have a docstring.")
|
f"Field '{f.name}' in {cls.__name__} must have a docstring.")
|
||||||
|
|
||||||
|
if get_origin(f.type) is Union:
|
||||||
|
args = get_args(f.type)
|
||||||
|
literal_args = [arg for arg in args if get_origin(arg) is Literal]
|
||||||
|
if len(literal_args) > 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Field '{f.name}' in {cls.__name__} must use a single "
|
||||||
|
"Literal type. Please use 'Literal[Literal1, Literal2]' "
|
||||||
|
"instead of 'Union[Literal1, Literal2]'.")
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
|
|
||||||
@ -3166,6 +3176,8 @@ def get_served_model_name(model: str,
|
|||||||
GuidedDecodingBackendV0 = Literal["auto", "outlines", "lm-format-enforcer",
|
GuidedDecodingBackendV0 = Literal["auto", "outlines", "lm-format-enforcer",
|
||||||
"xgrammar", "guidance"]
|
"xgrammar", "guidance"]
|
||||||
GuidedDecodingBackendV1 = Literal["auto", "xgrammar", "guidance"]
|
GuidedDecodingBackendV1 = Literal["auto", "xgrammar", "guidance"]
|
||||||
|
GuidedDecodingBackend = Literal[GuidedDecodingBackendV0,
|
||||||
|
GuidedDecodingBackendV1]
|
||||||
|
|
||||||
|
|
||||||
@config
|
@config
|
||||||
@ -3173,9 +3185,8 @@ GuidedDecodingBackendV1 = Literal["auto", "xgrammar", "guidance"]
|
|||||||
class DecodingConfig:
|
class DecodingConfig:
|
||||||
"""Dataclass which contains the decoding strategy of the engine."""
|
"""Dataclass which contains the decoding strategy of the engine."""
|
||||||
|
|
||||||
guided_decoding_backend: Union[
|
guided_decoding_backend: GuidedDecodingBackend = \
|
||||||
GuidedDecodingBackendV0,
|
"auto" if envs.VLLM_USE_V1 else "xgrammar"
|
||||||
GuidedDecodingBackendV1] = "auto" if envs.VLLM_USE_V1 else "xgrammar"
|
|
||||||
"""Which engine will be used for guided decoding (JSON schema / regex etc)
|
"""Which engine will be used for guided decoding (JSON schema / regex etc)
|
||||||
by default. With "auto", we will make opinionated choices based on request
|
by default. With "auto", we will make opinionated choices based on request
|
||||||
contents and what the backend libraries currently support, so the behavior
|
contents and what the backend libraries currently support, so the behavior
|
||||||
|
|||||||
@ -116,6 +116,18 @@ def get_type(type_hints: set[TypeHint], type: TypeHintT) -> TypeHintT:
|
|||||||
return next((th for th in type_hints if is_type(th, type)), None)
|
return next((th for th in type_hints if is_type(th, type)), None)
|
||||||
|
|
||||||
|
|
||||||
|
def literal_to_kwargs(type_hints: set[TypeHint]) -> dict[str, Any]:
|
||||||
|
"""Convert Literal type hints to argparse kwargs."""
|
||||||
|
type_hint = get_type(type_hints, Literal)
|
||||||
|
choices = get_args(type_hint)
|
||||||
|
choice_type = type(choices[0])
|
||||||
|
if not all(isinstance(choice, choice_type) for choice in choices):
|
||||||
|
raise ValueError(
|
||||||
|
"All choices must be of the same type. "
|
||||||
|
f"Got {choices} with types {[type(c) for c in choices]}")
|
||||||
|
return {"type": choice_type, "choices": sorted(choices)}
|
||||||
|
|
||||||
|
|
||||||
def is_not_builtin(type_hint: TypeHint) -> bool:
|
def is_not_builtin(type_hint: TypeHint) -> bool:
|
||||||
"""Check if the class is not a built-in type."""
|
"""Check if the class is not a built-in type."""
|
||||||
return type_hint.__module__ != "builtins"
|
return type_hint.__module__ != "builtins"
|
||||||
@ -151,15 +163,7 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
|
|||||||
# Creates --no-<name> and --<name> flags
|
# Creates --no-<name> and --<name> flags
|
||||||
kwargs[name]["action"] = argparse.BooleanOptionalAction
|
kwargs[name]["action"] = argparse.BooleanOptionalAction
|
||||||
elif contains_type(type_hints, Literal):
|
elif contains_type(type_hints, Literal):
|
||||||
# Creates choices from Literal arguments
|
kwargs[name].update(literal_to_kwargs(type_hints))
|
||||||
type_hint = get_type(type_hints, Literal)
|
|
||||||
choices = sorted(get_args(type_hint))
|
|
||||||
kwargs[name]["choices"] = choices
|
|
||||||
choice_type = type(choices[0])
|
|
||||||
assert all(type(c) is choice_type for c in choices), (
|
|
||||||
"All choices must be of the same type. "
|
|
||||||
f"Got {choices} with types {[type(c) for c in choices]}")
|
|
||||||
kwargs[name]["type"] = choice_type
|
|
||||||
elif contains_type(type_hints, tuple):
|
elif contains_type(type_hints, tuple):
|
||||||
type_hint = get_type(type_hints, tuple)
|
type_hint = get_type(type_hints, tuple)
|
||||||
types = get_args(type_hint)
|
types = get_args(type_hint)
|
||||||
@ -191,6 +195,11 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported type {type_hints} for argument {name}.")
|
f"Unsupported type {type_hints} for argument {name}.")
|
||||||
|
|
||||||
|
# If the type hint was a sequence of literals, use the helper function
|
||||||
|
# to update the type and choices
|
||||||
|
if get_origin(kwargs[name].get("type")) is Literal:
|
||||||
|
kwargs[name].update(literal_to_kwargs({kwargs[name]["type"]}))
|
||||||
|
|
||||||
# If None is in type_hints, make the argument optional.
|
# If None is in type_hints, make the argument optional.
|
||||||
# But not if it's a bool, argparse will handle this better.
|
# But not if it's a bool, argparse will handle this better.
|
||||||
if type(None) in type_hints and not contains_type(type_hints, bool):
|
if type(None) in type_hints and not contains_type(type_hints, bool):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user