From b18201fe060a3ddcc088f8aea3cf1d7c4b461288 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Fri, 16 May 2025 05:05:34 +0100 Subject: [PATCH] Allow users to pass arbitrary JSON keys from CLI (#18208) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- tests/engine/test_arg_utils.py | 4 +-- tests/test_utils.py | 25 ++++++++++++++++++ vllm/engine/arg_utils.py | 6 ++++- vllm/utils.py | 46 ++++++++++++++++++++++++++++++++++ 4 files changed, 78 insertions(+), 3 deletions(-) diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index ce8873d58d4d..05d9cfc7ab74 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -181,8 +181,8 @@ def test_get_kwargs(): # literals of literals should have merged choices assert kwargs["literal_literal"]["choices"] == [1, 2] # dict should have json tip in help - json_tip = "\n\nShould be a valid JSON string." - assert kwargs["json_tip"]["help"].endswith(json_tip) + json_tip = "Should either be a valid JSON string or JSON keys" + assert json_tip in kwargs["json_tip"]["help"] # nested config should should construct the nested config assert kwargs["nested_config"]["type"]('{"field": 2}') == NestedConfig(2) # from_cli configs should be constructed with the correct method diff --git a/tests/test_utils.py b/tests/test_utils.py index ea7db0a79c86..0b88d05efeaa 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,6 +3,7 @@ import asyncio import hashlib +import json import pickle import socket from collections.abc import AsyncIterator @@ -138,6 +139,7 @@ def parser(): parser.add_argument('--model-name') parser.add_argument('--batch-size', type=int) parser.add_argument('--enable-feature', action='store_true') + parser.add_argument('--hf-overrides', type=json.loads) return parser @@ -251,6 +253,29 @@ def test_no_model_tag(parser_with_config, cli_config_file): parser_with_config.parse_args(['serve', '--config', cli_config_file]) +def test_dict_args(parser): + args = [ + "--model-name=something.something", + "--hf-overrides.key1", + "val1", + "--hf-overrides.key2.key3", + "val2", + "--hf-overrides.key2.key4", + "val3", + "--hf-overrides.key5=val4", + ] + parsed_args = parser.parse_args(args) + assert parsed_args.model_name == "something.something" + assert parsed_args.hf_overrides == { + "key1": "val1", + "key2": { + "key3": "val2", + "key4": "val3", + }, + "key5": "val4", + } + + # yapf: enable @pytest.mark.parametrize( "callable,kw_name,requires_kw_only,allow_var_kwargs,is_supported", diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 3e942b0f0ff9..6fdb5e6c3772 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -183,7 +183,11 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]: kwargs[name] = {"default": default, "help": help} # Set other kwargs based on the type hints - json_tip = "\n\nShould be a valid JSON string." + json_tip = """\n\nShould either be a valid JSON string or JSON keys + passed individually. For example, the following sets of arguments are + equivalent:\n\n + - `--json-arg '{"key1": "value1", "key2": {"key3": "value2"}}'`\n + - `--json-arg.key1 value1 --json-arg.key2.key3 value2`\n\n""" if dataclass_cls is not None: dataclass_init = lambda x, f=dataclass_cls: f(**json.loads(x)) # Special case for configs with a from_cli method diff --git a/vllm/utils.py b/vllm/utils.py index edfbb8c9481e..0cd90c130d3e 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -15,6 +15,7 @@ import importlib.metadata import importlib.util import inspect import ipaddress +import json import multiprocessing import os import pickle @@ -1419,6 +1420,51 @@ class FlexibleArgumentParser(ArgumentParser): else: processed_args.append(arg) + def create_nested_dict(keys: list[str], value: str): + """Creates a nested dictionary from a list of keys and a value. + + For example, `keys = ["a", "b", "c"]` and `value = 1` will create: + `{"a": {"b": {"c": 1}}}` + """ + nested_dict: Any = value + for key in reversed(keys): + nested_dict = {key: nested_dict} + return nested_dict + + def recursive_dict_update(original: dict, update: dict): + """Recursively updates a dictionary with another dictionary.""" + for k, v in update.items(): + if isinstance(v, dict) and isinstance(original.get(k), dict): + recursive_dict_update(original[k], v) + else: + original[k] = v + + delete = set() + dict_args: dict[str, dict] = defaultdict(dict) + for i, processed_arg in enumerate(processed_args): + if processed_arg.startswith("--") and "." in processed_arg: + if "=" in processed_arg: + processed_arg, value = processed_arg.split("=", 1) + if "." not in processed_arg: + # False positive, . was only in the value + continue + else: + value = processed_args[i + 1] + delete.add(i + 1) + key, *keys = processed_arg.split(".") + # Merge all values with the same key into a single dict + arg_dict = create_nested_dict(keys, value) + recursive_dict_update(dict_args[key], arg_dict) + delete.add(i) + # Filter out the dict args we set to None + processed_args = [ + a for i, a in enumerate(processed_args) if i not in delete + ] + # Add the dict args back as if they were originally passed as JSON + for dict_arg, dict_value in dict_args.items(): + processed_args.append(dict_arg) + processed_args.append(json.dumps(dict_value)) + return super().parse_args(processed_args, namespace) def check_port(self, value):