mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:15:01 +08:00
Fix interaction between Optional and Annotated in CLI typing (#19093)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Yikun Jiang <yikun@apache.org>
This commit is contained in:
parent
e31446b6c8
commit
6865fe0074
@ -5,14 +5,14 @@ import json
|
||||
from argparse import ArgumentError, ArgumentTypeError
|
||||
from contextlib import nullcontext
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal, Optional
|
||||
from typing import Annotated, Literal, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config import CompilationConfig, config
|
||||
from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs,
|
||||
get_type, is_not_builtin, is_type,
|
||||
literal_to_kwargs, nullable_kvs,
|
||||
get_type, get_type_hints, is_not_builtin,
|
||||
is_type, literal_to_kwargs, nullable_kvs,
|
||||
optional_type, parse_type)
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
@ -160,6 +160,18 @@ def test_is_not_builtin(type_hint, expected):
|
||||
assert is_not_builtin(type_hint) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("type_hint", "expected"), [
|
||||
(Annotated[int, "annotation"], {int}),
|
||||
(Optional[int], {int, type(None)}),
|
||||
(Annotated[Optional[int], "annotation"], {int, type(None)}),
|
||||
(Optional[Annotated[int, "annotation"]], {int, type(None)}),
|
||||
],
|
||||
ids=["Annotated", "Optional", "Annotated_Optional", "Optional_Annotated"])
|
||||
def test_get_type_hints(type_hint, expected):
|
||||
assert get_type_hints(type_hint) == expected
|
||||
|
||||
|
||||
def test_get_kwargs():
|
||||
kwargs = get_kwargs(DummyConfig)
|
||||
print(kwargs)
|
||||
|
||||
@ -15,7 +15,7 @@ from typing import (Annotated, Any, Callable, Dict, List, Literal, Optional,
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
from pydantic import SkipValidation, TypeAdapter, ValidationError
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
from typing_extensions import TypeIs, deprecated
|
||||
|
||||
import vllm.envs as envs
|
||||
@ -151,17 +151,29 @@ def is_not_builtin(type_hint: TypeHint) -> bool:
|
||||
return type_hint.__module__ != "builtins"
|
||||
|
||||
|
||||
def get_type_hints(type_hint: TypeHint) -> set[TypeHint]:
|
||||
"""Extract type hints from Annotated or Union type hints."""
|
||||
type_hints: set[TypeHint] = set()
|
||||
origin = get_origin(type_hint)
|
||||
args = get_args(type_hint)
|
||||
|
||||
if origin is Annotated:
|
||||
type_hints.update(get_type_hints(args[0]))
|
||||
elif origin is Union:
|
||||
for arg in args:
|
||||
type_hints.update(get_type_hints(arg))
|
||||
else:
|
||||
type_hints.add(type_hint)
|
||||
|
||||
return type_hints
|
||||
|
||||
|
||||
def get_kwargs(cls: ConfigType) -> dict[str, Any]:
|
||||
cls_docs = get_attr_docs(cls)
|
||||
kwargs = {}
|
||||
for field in fields(cls):
|
||||
# Get the set of possible types for the field
|
||||
type_hints: set[TypeHint] = set()
|
||||
if get_origin(field.type) in {Union, Annotated}:
|
||||
predicate = lambda arg: not isinstance(arg, SkipValidation)
|
||||
type_hints.update(filter(predicate, get_args(field.type)))
|
||||
else:
|
||||
type_hints.add(field.type)
|
||||
type_hints: set[TypeHint] = get_type_hints(field.type)
|
||||
|
||||
# If the field is a dataclass, we can use the model_validate_json
|
||||
generator = (th for th in type_hints if is_dataclass(th))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user