Fix interaction between Optional and Annotated in CLI typing (#19093)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Yikun Jiang <yikun@apache.org>
This commit is contained in:
Harry Mellor 2025-06-03 22:07:19 +01:00 committed by GitHub
parent e31446b6c8
commit 6865fe0074
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 34 additions and 10 deletions

View File

@ -5,14 +5,14 @@ import json
from argparse import ArgumentError, ArgumentTypeError
from contextlib import nullcontext
from dataclasses import dataclass, field
from typing import Literal, Optional
from typing import Annotated, Literal, Optional
import pytest
from vllm.config import CompilationConfig, config
from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs,
get_type, is_not_builtin, is_type,
literal_to_kwargs, nullable_kvs,
get_type, get_type_hints, is_not_builtin,
is_type, literal_to_kwargs, nullable_kvs,
optional_type, parse_type)
from vllm.utils import FlexibleArgumentParser
@ -160,6 +160,18 @@ def test_is_not_builtin(type_hint, expected):
assert is_not_builtin(type_hint) == expected
@pytest.mark.parametrize(
("type_hint", "expected"), [
(Annotated[int, "annotation"], {int}),
(Optional[int], {int, type(None)}),
(Annotated[Optional[int], "annotation"], {int, type(None)}),
(Optional[Annotated[int, "annotation"]], {int, type(None)}),
],
ids=["Annotated", "Optional", "Annotated_Optional", "Optional_Annotated"])
def test_get_type_hints(type_hint, expected):
assert get_type_hints(type_hint) == expected
def test_get_kwargs():
kwargs = get_kwargs(DummyConfig)
print(kwargs)

View File

@ -15,7 +15,7 @@ from typing import (Annotated, Any, Callable, Dict, List, Literal, Optional,
import regex as re
import torch
from pydantic import SkipValidation, TypeAdapter, ValidationError
from pydantic import TypeAdapter, ValidationError
from typing_extensions import TypeIs, deprecated
import vllm.envs as envs
@ -151,17 +151,29 @@ def is_not_builtin(type_hint: TypeHint) -> bool:
return type_hint.__module__ != "builtins"
def get_type_hints(type_hint: TypeHint) -> set[TypeHint]:
"""Extract type hints from Annotated or Union type hints."""
type_hints: set[TypeHint] = set()
origin = get_origin(type_hint)
args = get_args(type_hint)
if origin is Annotated:
type_hints.update(get_type_hints(args[0]))
elif origin is Union:
for arg in args:
type_hints.update(get_type_hints(arg))
else:
type_hints.add(type_hint)
return type_hints
def get_kwargs(cls: ConfigType) -> dict[str, Any]:
cls_docs = get_attr_docs(cls)
kwargs = {}
for field in fields(cls):
# Get the set of possible types for the field
type_hints: set[TypeHint] = set()
if get_origin(field.type) in {Union, Annotated}:
predicate = lambda arg: not isinstance(arg, SkipValidation)
type_hints.update(filter(predicate, get_args(field.type)))
else:
type_hints.add(field.type)
type_hints: set[TypeHint] = get_type_hints(field.type)
# If the field is a dataclass, we can use the model_validate_json
generator = (th for th in type_hints if is_dataclass(th))