From a2142f01969a7c327da69be1e82a6c01a530116c Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 11 Jun 2025 17:34:04 +0800 Subject: [PATCH] Support non-string values in JSON keys from CLI (#19471) Signed-off-by: DarkLight1337 --- tests/test_config.py | 36 ++++++++++++++++++------------------ tests/test_utils.py | 15 +++++++++++++++ vllm/utils.py | 23 ++++++++++++++++------- 3 files changed, 49 insertions(+), 25 deletions(-) diff --git a/tests/test_config.py b/tests/test_config.py index ce383e1b420a..715ef09dd307 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -13,32 +13,32 @@ from vllm.model_executor.layers.pooler import PoolingType from vllm.platforms import current_platform -class TestConfig1: +class _TestConfig1: pass @dataclass -class TestConfig2: +class _TestConfig2: a: int """docstring""" @dataclass -class TestConfig3: +class _TestConfig3: a: int = 1 @dataclass -class TestConfig4: +class _TestConfig4: a: Union[Literal[1], Literal[2]] = 1 """docstring""" @pytest.mark.parametrize(("test_config", "expected_error"), [ - (TestConfig1, "must be a dataclass"), - (TestConfig2, "must have a default"), - (TestConfig3, "must have a docstring"), - (TestConfig4, "must use a single Literal"), + (_TestConfig1, "must be a dataclass"), + (_TestConfig2, "must have a default"), + (_TestConfig3, "must have a docstring"), + (_TestConfig4, "must use a single Literal"), ]) def test_config(test_config, expected_error): with pytest.raises(Exception, match=expected_error): @@ -57,23 +57,23 @@ def test_compile_config_repr_succeeds(): assert 'inductor_passes' in val +@dataclass +class _TestConfigFields: + a: int + b: dict = field(default_factory=dict) + c: str = "default" + + def test_get_field(): - - @dataclass - class TestConfig: - a: int - b: dict = field(default_factory=dict) - c: str = "default" - with pytest.raises(ValueError): - get_field(TestConfig, "a") + get_field(_TestConfigFields, "a") - b = get_field(TestConfig, "b") + b = get_field(_TestConfigFields, "b") assert isinstance(b, Field) assert b.default is MISSING assert b.default_factory is dict - c = get_field(TestConfig, "c") + c = get_field(_TestConfigFields, "c") assert isinstance(c, Field) assert c.default == "default" assert c.default_factory is MISSING diff --git a/tests/test_utils.py b/tests/test_utils.py index a2fd845ea54b..913188455d8e 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -272,6 +272,15 @@ def test_dict_args(parser): "val5", "--hf_overrides.key-7.key_8", "val6", + # Test data type detection + "--hf_overrides.key9", + "100", + "--hf_overrides.key10", + "100.0", + "--hf_overrides.key11", + "true", + "--hf_overrides.key12.key13", + "null", ] parsed_args = parser.parse_args(args) assert parsed_args.model_name == "something.something" @@ -286,6 +295,12 @@ def test_dict_args(parser): "key-7": { "key_8": "val6", }, + "key9": 100, + "key10": 100.0, + "key11": True, + "key12": { + "key13": None, + }, } diff --git a/vllm/utils.py b/vllm/utils.py index d8dd5f284ab3..342241d0dd8a 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1466,7 +1466,7 @@ class FlexibleArgumentParser(ArgumentParser): pattern = re.compile(r"(?<=--)[^\.]*") # Convert underscores to dashes and vice versa in argument names - processed_args = [] + processed_args = list[str]() for arg in args: if arg.startswith('--'): if '=' in arg: @@ -1483,7 +1483,7 @@ class FlexibleArgumentParser(ArgumentParser): else: processed_args.append(arg) - def create_nested_dict(keys: list[str], value: str): + def create_nested_dict(keys: list[str], value: str) -> dict[str, Any]: """Creates a nested dictionary from a list of keys and a value. For example, `keys = ["a", "b", "c"]` and `value = 1` will create: @@ -1494,7 +1494,10 @@ class FlexibleArgumentParser(ArgumentParser): nested_dict = {key: nested_dict} return nested_dict - def recursive_dict_update(original: dict, update: dict): + def recursive_dict_update( + original: dict[str, Any], + update: dict[str, Any], + ): """Recursively updates a dictionary with another dictionary.""" for k, v in update.items(): if isinstance(v, dict) and isinstance(original.get(k), dict): @@ -1502,19 +1505,25 @@ class FlexibleArgumentParser(ArgumentParser): else: original[k] = v - delete = set() - dict_args: dict[str, dict] = defaultdict(dict) + delete = set[int]() + dict_args = defaultdict[str, dict[str, Any]](dict) for i, processed_arg in enumerate(processed_args): if processed_arg.startswith("--") and "." in processed_arg: if "=" in processed_arg: - processed_arg, value = processed_arg.split("=", 1) + processed_arg, value_str = processed_arg.split("=", 1) if "." not in processed_arg: # False positive, . was only in the value continue else: - value = processed_args[i + 1] + value_str = processed_args[i + 1] delete.add(i + 1) + key, *keys = processed_arg.split(".") + try: + value = json.loads(value_str) + except json.decoder.JSONDecodeError: + value = value_str + # Merge all values with the same key into a single dict arg_dict = create_nested_dict(keys, value) recursive_dict_update(dict_args[key], arg_dict)