Fix get_kwargs for case where type hint is list[Union[str, type]] (#22016)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-08-01 13:26:42 +01:00 committed by GitHub
parent 26b5f7bd2a
commit fb0e0d46fc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 12 additions and 5 deletions

View File

@ -5,7 +5,7 @@ import json
from argparse import ArgumentError from argparse import ArgumentError
from contextlib import nullcontext from contextlib import nullcontext
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Annotated, Literal, Optional from typing import Annotated, Literal, Optional, Union
import pytest import pytest
@ -136,6 +136,8 @@ class DummyConfig:
"""List with variable length""" """List with variable length"""
list_literal: list[Literal[1, 2]] = field(default_factory=list) list_literal: list[Literal[1, 2]] = field(default_factory=list)
"""List with literal choices""" """List with literal choices"""
list_union: list[Union[str, type[object]]] = field(default_factory=list)
"""List with union type"""
literal_literal: Literal[Literal[1], Literal[2]] = 1 literal_literal: Literal[Literal[1], Literal[2]] = 1
"""Literal of literals with default 1""" """Literal of literals with default 1"""
json_tip: dict = field(default_factory=dict) json_tip: dict = field(default_factory=dict)
@ -187,6 +189,9 @@ def test_get_kwargs():
assert kwargs["list_literal"]["type"] is int assert kwargs["list_literal"]["type"] is int
assert kwargs["list_literal"]["nargs"] == "+" assert kwargs["list_literal"]["nargs"] == "+"
assert kwargs["list_literal"]["choices"] == [1, 2] assert kwargs["list_literal"]["choices"] == [1, 2]
# lists with unions should become str type.
# If not, we cannot know which type to use for parsing
assert kwargs["list_union"]["type"] is str
# literals of literals should have merged choices # literals of literals should have merged choices
assert kwargs["literal_literal"]["choices"] == [1, 2] assert kwargs["literal_literal"]["choices"] == [1, 2]
# dict should have json tip in help # dict should have json tip in help

View File

@ -217,10 +217,12 @@ Additionally, list elements can be passed individually using `+`:
elif contains_type(type_hints, list): elif contains_type(type_hints, list):
type_hint = get_type(type_hints, list) type_hint = get_type(type_hints, list)
types = get_args(type_hint) types = get_args(type_hint)
assert len(types) == 1, ( list_type = types[0]
"List type must have exactly one type. Got " if get_origin(list_type) is Union:
f"{type_hint} with types {types}") msg = "List type must contain str if it is a Union."
kwargs[name]["type"] = types[0] assert str in get_args(list_type), msg
list_type = str
kwargs[name]["type"] = list_type
kwargs[name]["nargs"] = "+" kwargs[name]["nargs"] = "+"
elif contains_type(type_hints, int): elif contains_type(type_hints, int):
kwargs[name]["type"] = int kwargs[name]["type"] = int