diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index 1d1926068d28..c282bf002304 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -5,7 +5,7 @@ import json from argparse import ArgumentError from contextlib import nullcontext from dataclasses import dataclass, field -from typing import Annotated, Literal, Optional +from typing import Annotated, Literal, Optional, Union import pytest @@ -136,6 +136,8 @@ class DummyConfig: """List with variable length""" list_literal: list[Literal[1, 2]] = field(default_factory=list) """List with literal choices""" + list_union: list[Union[str, type[object]]] = field(default_factory=list) + """List with union type""" literal_literal: Literal[Literal[1], Literal[2]] = 1 """Literal of literals with default 1""" json_tip: dict = field(default_factory=dict) @@ -187,6 +189,9 @@ def test_get_kwargs(): assert kwargs["list_literal"]["type"] is int assert kwargs["list_literal"]["nargs"] == "+" assert kwargs["list_literal"]["choices"] == [1, 2] + # lists with unions should become str type. + # If not, we cannot know which type to use for parsing + assert kwargs["list_union"]["type"] is str # literals of literals should have merged choices assert kwargs["literal_literal"]["choices"] == [1, 2] # dict should have json tip in help diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index efa077a88270..f938f19b9046 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -217,10 +217,12 @@ Additionally, list elements can be passed individually using `+`: elif contains_type(type_hints, list): type_hint = get_type(type_hints, list) types = get_args(type_hint) - assert len(types) == 1, ( - "List type must have exactly one type. Got " - f"{type_hint} with types {types}") - kwargs[name]["type"] = types[0] + list_type = types[0] + if get_origin(list_type) is Union: + msg = "List type must contain str if it is a Union." + assert str in get_args(list_type), msg + list_type = str + kwargs[name]["type"] = list_type kwargs[name]["nargs"] = "+" elif contains_type(type_hints, int): kwargs[name]["type"] = int