diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index ab78aa7da21b..cfbc7c245ffd 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -5,14 +5,14 @@ import json from argparse import ArgumentError, ArgumentTypeError from contextlib import nullcontext from dataclasses import dataclass, field -from typing import Literal, Optional +from typing import Annotated, Literal, Optional import pytest from vllm.config import CompilationConfig, config from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs, - get_type, is_not_builtin, is_type, - literal_to_kwargs, nullable_kvs, + get_type, get_type_hints, is_not_builtin, + is_type, literal_to_kwargs, nullable_kvs, optional_type, parse_type) from vllm.utils import FlexibleArgumentParser @@ -160,6 +160,18 @@ def test_is_not_builtin(type_hint, expected): assert is_not_builtin(type_hint) == expected +@pytest.mark.parametrize( + ("type_hint", "expected"), [ + (Annotated[int, "annotation"], {int}), + (Optional[int], {int, type(None)}), + (Annotated[Optional[int], "annotation"], {int, type(None)}), + (Optional[Annotated[int, "annotation"]], {int, type(None)}), + ], + ids=["Annotated", "Optional", "Annotated_Optional", "Optional_Annotated"]) +def test_get_type_hints(type_hint, expected): + assert get_type_hints(type_hint) == expected + + def test_get_kwargs(): kwargs = get_kwargs(DummyConfig) print(kwargs) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 587a23134fe9..2197d44ca825 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -15,7 +15,7 @@ from typing import (Annotated, Any, Callable, Dict, List, Literal, Optional, import regex as re import torch -from pydantic import SkipValidation, TypeAdapter, ValidationError +from pydantic import TypeAdapter, ValidationError from typing_extensions import TypeIs, deprecated import vllm.envs as envs @@ -151,17 +151,29 @@ def is_not_builtin(type_hint: TypeHint) -> bool: return type_hint.__module__ != "builtins" +def get_type_hints(type_hint: TypeHint) -> set[TypeHint]: + """Extract type hints from Annotated or Union type hints.""" + type_hints: set[TypeHint] = set() + origin = get_origin(type_hint) + args = get_args(type_hint) + + if origin is Annotated: + type_hints.update(get_type_hints(args[0])) + elif origin is Union: + for arg in args: + type_hints.update(get_type_hints(arg)) + else: + type_hints.add(type_hint) + + return type_hints + + def get_kwargs(cls: ConfigType) -> dict[str, Any]: cls_docs = get_attr_docs(cls) kwargs = {} for field in fields(cls): # Get the set of possible types for the field - type_hints: set[TypeHint] = set() - if get_origin(field.type) in {Union, Annotated}: - predicate = lambda arg: not isinstance(arg, SkipValidation) - type_hints.update(filter(predicate, get_args(field.type))) - else: - type_hints.add(field.type) + type_hints: set[TypeHint] = get_type_hints(field.type) # If the field is a dataclass, we can use the model_validate_json generator = (th for th in type_hints if is_dataclass(th))