mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 20:15:42 +08:00
Fix interaction between Optional and Annotated in CLI typing (#19093)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Yikun Jiang <yikun@apache.org>
This commit is contained in:
parent
e31446b6c8
commit
6865fe0074
@ -5,14 +5,14 @@ import json
|
|||||||
from argparse import ArgumentError, ArgumentTypeError
|
from argparse import ArgumentError, ArgumentTypeError
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Literal, Optional
|
from typing import Annotated, Literal, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm.config import CompilationConfig, config
|
from vllm.config import CompilationConfig, config
|
||||||
from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs,
|
from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs,
|
||||||
get_type, is_not_builtin, is_type,
|
get_type, get_type_hints, is_not_builtin,
|
||||||
literal_to_kwargs, nullable_kvs,
|
is_type, literal_to_kwargs, nullable_kvs,
|
||||||
optional_type, parse_type)
|
optional_type, parse_type)
|
||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
@ -160,6 +160,18 @@ def test_is_not_builtin(type_hint, expected):
|
|||||||
assert is_not_builtin(type_hint) == expected
|
assert is_not_builtin(type_hint) == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("type_hint", "expected"), [
|
||||||
|
(Annotated[int, "annotation"], {int}),
|
||||||
|
(Optional[int], {int, type(None)}),
|
||||||
|
(Annotated[Optional[int], "annotation"], {int, type(None)}),
|
||||||
|
(Optional[Annotated[int, "annotation"]], {int, type(None)}),
|
||||||
|
],
|
||||||
|
ids=["Annotated", "Optional", "Annotated_Optional", "Optional_Annotated"])
|
||||||
|
def test_get_type_hints(type_hint, expected):
|
||||||
|
assert get_type_hints(type_hint) == expected
|
||||||
|
|
||||||
|
|
||||||
def test_get_kwargs():
|
def test_get_kwargs():
|
||||||
kwargs = get_kwargs(DummyConfig)
|
kwargs = get_kwargs(DummyConfig)
|
||||||
print(kwargs)
|
print(kwargs)
|
||||||
|
|||||||
@ -15,7 +15,7 @@ from typing import (Annotated, Any, Callable, Dict, List, Literal, Optional,
|
|||||||
|
|
||||||
import regex as re
|
import regex as re
|
||||||
import torch
|
import torch
|
||||||
from pydantic import SkipValidation, TypeAdapter, ValidationError
|
from pydantic import TypeAdapter, ValidationError
|
||||||
from typing_extensions import TypeIs, deprecated
|
from typing_extensions import TypeIs, deprecated
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
@ -151,17 +151,29 @@ def is_not_builtin(type_hint: TypeHint) -> bool:
|
|||||||
return type_hint.__module__ != "builtins"
|
return type_hint.__module__ != "builtins"
|
||||||
|
|
||||||
|
|
||||||
|
def get_type_hints(type_hint: TypeHint) -> set[TypeHint]:
|
||||||
|
"""Extract type hints from Annotated or Union type hints."""
|
||||||
|
type_hints: set[TypeHint] = set()
|
||||||
|
origin = get_origin(type_hint)
|
||||||
|
args = get_args(type_hint)
|
||||||
|
|
||||||
|
if origin is Annotated:
|
||||||
|
type_hints.update(get_type_hints(args[0]))
|
||||||
|
elif origin is Union:
|
||||||
|
for arg in args:
|
||||||
|
type_hints.update(get_type_hints(arg))
|
||||||
|
else:
|
||||||
|
type_hints.add(type_hint)
|
||||||
|
|
||||||
|
return type_hints
|
||||||
|
|
||||||
|
|
||||||
def get_kwargs(cls: ConfigType) -> dict[str, Any]:
|
def get_kwargs(cls: ConfigType) -> dict[str, Any]:
|
||||||
cls_docs = get_attr_docs(cls)
|
cls_docs = get_attr_docs(cls)
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
for field in fields(cls):
|
for field in fields(cls):
|
||||||
# Get the set of possible types for the field
|
# Get the set of possible types for the field
|
||||||
type_hints: set[TypeHint] = set()
|
type_hints: set[TypeHint] = get_type_hints(field.type)
|
||||||
if get_origin(field.type) in {Union, Annotated}:
|
|
||||||
predicate = lambda arg: not isinstance(arg, SkipValidation)
|
|
||||||
type_hints.update(filter(predicate, get_args(field.type)))
|
|
||||||
else:
|
|
||||||
type_hints.add(field.type)
|
|
||||||
|
|
||||||
# If the field is a dataclass, we can use the model_validate_json
|
# If the field is a dataclass, we can use the model_validate_json
|
||||||
generator = (th for th in type_hints if is_dataclass(th))
|
generator = (th for th in type_hints if is_dataclass(th))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user