Support non-string values in JSON keys from CLI (#19471)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-06-11 17:34:04 +08:00 committed by GitHub
parent 871d6b7c74
commit a2142f0196
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 49 additions and 25 deletions

View File

@ -13,32 +13,32 @@ from vllm.model_executor.layers.pooler import PoolingType
from vllm.platforms import current_platform from vllm.platforms import current_platform
class TestConfig1: class _TestConfig1:
pass pass
@dataclass @dataclass
class TestConfig2: class _TestConfig2:
a: int a: int
"""docstring""" """docstring"""
@dataclass @dataclass
class TestConfig3: class _TestConfig3:
a: int = 1 a: int = 1
@dataclass @dataclass
class TestConfig4: class _TestConfig4:
a: Union[Literal[1], Literal[2]] = 1 a: Union[Literal[1], Literal[2]] = 1
"""docstring""" """docstring"""
@pytest.mark.parametrize(("test_config", "expected_error"), [ @pytest.mark.parametrize(("test_config", "expected_error"), [
(TestConfig1, "must be a dataclass"), (_TestConfig1, "must be a dataclass"),
(TestConfig2, "must have a default"), (_TestConfig2, "must have a default"),
(TestConfig3, "must have a docstring"), (_TestConfig3, "must have a docstring"),
(TestConfig4, "must use a single Literal"), (_TestConfig4, "must use a single Literal"),
]) ])
def test_config(test_config, expected_error): def test_config(test_config, expected_error):
with pytest.raises(Exception, match=expected_error): with pytest.raises(Exception, match=expected_error):
@ -57,23 +57,23 @@ def test_compile_config_repr_succeeds():
assert 'inductor_passes' in val assert 'inductor_passes' in val
def test_get_field():
@dataclass @dataclass
class TestConfig: class _TestConfigFields:
a: int a: int
b: dict = field(default_factory=dict) b: dict = field(default_factory=dict)
c: str = "default" c: str = "default"
with pytest.raises(ValueError):
get_field(TestConfig, "a")
b = get_field(TestConfig, "b") def test_get_field():
with pytest.raises(ValueError):
get_field(_TestConfigFields, "a")
b = get_field(_TestConfigFields, "b")
assert isinstance(b, Field) assert isinstance(b, Field)
assert b.default is MISSING assert b.default is MISSING
assert b.default_factory is dict assert b.default_factory is dict
c = get_field(TestConfig, "c") c = get_field(_TestConfigFields, "c")
assert isinstance(c, Field) assert isinstance(c, Field)
assert c.default == "default" assert c.default == "default"
assert c.default_factory is MISSING assert c.default_factory is MISSING

View File

@ -272,6 +272,15 @@ def test_dict_args(parser):
"val5", "val5",
"--hf_overrides.key-7.key_8", "--hf_overrides.key-7.key_8",
"val6", "val6",
# Test data type detection
"--hf_overrides.key9",
"100",
"--hf_overrides.key10",
"100.0",
"--hf_overrides.key11",
"true",
"--hf_overrides.key12.key13",
"null",
] ]
parsed_args = parser.parse_args(args) parsed_args = parser.parse_args(args)
assert parsed_args.model_name == "something.something" assert parsed_args.model_name == "something.something"
@ -286,6 +295,12 @@ def test_dict_args(parser):
"key-7": { "key-7": {
"key_8": "val6", "key_8": "val6",
}, },
"key9": 100,
"key10": 100.0,
"key11": True,
"key12": {
"key13": None,
},
} }

View File

@ -1466,7 +1466,7 @@ class FlexibleArgumentParser(ArgumentParser):
pattern = re.compile(r"(?<=--)[^\.]*") pattern = re.compile(r"(?<=--)[^\.]*")
# Convert underscores to dashes and vice versa in argument names # Convert underscores to dashes and vice versa in argument names
processed_args = [] processed_args = list[str]()
for arg in args: for arg in args:
if arg.startswith('--'): if arg.startswith('--'):
if '=' in arg: if '=' in arg:
@ -1483,7 +1483,7 @@ class FlexibleArgumentParser(ArgumentParser):
else: else:
processed_args.append(arg) processed_args.append(arg)
def create_nested_dict(keys: list[str], value: str): def create_nested_dict(keys: list[str], value: str) -> dict[str, Any]:
"""Creates a nested dictionary from a list of keys and a value. """Creates a nested dictionary from a list of keys and a value.
For example, `keys = ["a", "b", "c"]` and `value = 1` will create: For example, `keys = ["a", "b", "c"]` and `value = 1` will create:
@ -1494,7 +1494,10 @@ class FlexibleArgumentParser(ArgumentParser):
nested_dict = {key: nested_dict} nested_dict = {key: nested_dict}
return nested_dict return nested_dict
def recursive_dict_update(original: dict, update: dict): def recursive_dict_update(
original: dict[str, Any],
update: dict[str, Any],
):
"""Recursively updates a dictionary with another dictionary.""" """Recursively updates a dictionary with another dictionary."""
for k, v in update.items(): for k, v in update.items():
if isinstance(v, dict) and isinstance(original.get(k), dict): if isinstance(v, dict) and isinstance(original.get(k), dict):
@ -1502,19 +1505,25 @@ class FlexibleArgumentParser(ArgumentParser):
else: else:
original[k] = v original[k] = v
delete = set() delete = set[int]()
dict_args: dict[str, dict] = defaultdict(dict) dict_args = defaultdict[str, dict[str, Any]](dict)
for i, processed_arg in enumerate(processed_args): for i, processed_arg in enumerate(processed_args):
if processed_arg.startswith("--") and "." in processed_arg: if processed_arg.startswith("--") and "." in processed_arg:
if "=" in processed_arg: if "=" in processed_arg:
processed_arg, value = processed_arg.split("=", 1) processed_arg, value_str = processed_arg.split("=", 1)
if "." not in processed_arg: if "." not in processed_arg:
# False positive, . was only in the value # False positive, . was only in the value
continue continue
else: else:
value = processed_args[i + 1] value_str = processed_args[i + 1]
delete.add(i + 1) delete.add(i + 1)
key, *keys = processed_arg.split(".") key, *keys = processed_arg.split(".")
try:
value = json.loads(value_str)
except json.decoder.JSONDecodeError:
value = value_str
# Merge all values with the same key into a single dict # Merge all values with the same key into a single dict
arg_dict = create_nested_dict(keys, value) arg_dict = create_nested_dict(keys, value)
recursive_dict_update(dict_args[key], arg_dict) recursive_dict_update(dict_args[key], arg_dict)