Support non-string values in JSON keys from CLI (#19471)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-06-11 17:34:04 +08:00 committed by GitHub
parent 871d6b7c74
commit a2142f0196
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 49 additions and 25 deletions

View File

@ -13,32 +13,32 @@ from vllm.model_executor.layers.pooler import PoolingType
from vllm.platforms import current_platform
class TestConfig1:
class _TestConfig1:
pass
@dataclass
class TestConfig2:
class _TestConfig2:
a: int
"""docstring"""
@dataclass
class TestConfig3:
class _TestConfig3:
a: int = 1
@dataclass
class TestConfig4:
class _TestConfig4:
a: Union[Literal[1], Literal[2]] = 1
"""docstring"""
@pytest.mark.parametrize(("test_config", "expected_error"), [
(TestConfig1, "must be a dataclass"),
(TestConfig2, "must have a default"),
(TestConfig3, "must have a docstring"),
(TestConfig4, "must use a single Literal"),
(_TestConfig1, "must be a dataclass"),
(_TestConfig2, "must have a default"),
(_TestConfig3, "must have a docstring"),
(_TestConfig4, "must use a single Literal"),
])
def test_config(test_config, expected_error):
with pytest.raises(Exception, match=expected_error):
@ -57,23 +57,23 @@ def test_compile_config_repr_succeeds():
assert 'inductor_passes' in val
@dataclass
class _TestConfigFields:
a: int
b: dict = field(default_factory=dict)
c: str = "default"
def test_get_field():
@dataclass
class TestConfig:
a: int
b: dict = field(default_factory=dict)
c: str = "default"
with pytest.raises(ValueError):
get_field(TestConfig, "a")
get_field(_TestConfigFields, "a")
b = get_field(TestConfig, "b")
b = get_field(_TestConfigFields, "b")
assert isinstance(b, Field)
assert b.default is MISSING
assert b.default_factory is dict
c = get_field(TestConfig, "c")
c = get_field(_TestConfigFields, "c")
assert isinstance(c, Field)
assert c.default == "default"
assert c.default_factory is MISSING

View File

@ -272,6 +272,15 @@ def test_dict_args(parser):
"val5",
"--hf_overrides.key-7.key_8",
"val6",
# Test data type detection
"--hf_overrides.key9",
"100",
"--hf_overrides.key10",
"100.0",
"--hf_overrides.key11",
"true",
"--hf_overrides.key12.key13",
"null",
]
parsed_args = parser.parse_args(args)
assert parsed_args.model_name == "something.something"
@ -286,6 +295,12 @@ def test_dict_args(parser):
"key-7": {
"key_8": "val6",
},
"key9": 100,
"key10": 100.0,
"key11": True,
"key12": {
"key13": None,
},
}

View File

@ -1466,7 +1466,7 @@ class FlexibleArgumentParser(ArgumentParser):
pattern = re.compile(r"(?<=--)[^\.]*")
# Convert underscores to dashes and vice versa in argument names
processed_args = []
processed_args = list[str]()
for arg in args:
if arg.startswith('--'):
if '=' in arg:
@ -1483,7 +1483,7 @@ class FlexibleArgumentParser(ArgumentParser):
else:
processed_args.append(arg)
def create_nested_dict(keys: list[str], value: str):
def create_nested_dict(keys: list[str], value: str) -> dict[str, Any]:
"""Creates a nested dictionary from a list of keys and a value.
For example, `keys = ["a", "b", "c"]` and `value = 1` will create:
@ -1494,7 +1494,10 @@ class FlexibleArgumentParser(ArgumentParser):
nested_dict = {key: nested_dict}
return nested_dict
def recursive_dict_update(original: dict, update: dict):
def recursive_dict_update(
original: dict[str, Any],
update: dict[str, Any],
):
"""Recursively updates a dictionary with another dictionary."""
for k, v in update.items():
if isinstance(v, dict) and isinstance(original.get(k), dict):
@ -1502,19 +1505,25 @@ class FlexibleArgumentParser(ArgumentParser):
else:
original[k] = v
delete = set()
dict_args: dict[str, dict] = defaultdict(dict)
delete = set[int]()
dict_args = defaultdict[str, dict[str, Any]](dict)
for i, processed_arg in enumerate(processed_args):
if processed_arg.startswith("--") and "." in processed_arg:
if "=" in processed_arg:
processed_arg, value = processed_arg.split("=", 1)
processed_arg, value_str = processed_arg.split("=", 1)
if "." not in processed_arg:
# False positive, . was only in the value
continue
else:
value = processed_args[i + 1]
value_str = processed_args[i + 1]
delete.add(i + 1)
key, *keys = processed_arg.split(".")
try:
value = json.loads(value_str)
except json.decoder.JSONDecodeError:
value = value_str
# Merge all values with the same key into a single dict
arg_dict = create_nested_dict(keys, value)
recursive_dict_update(dict_args[key], arg_dict)