Allow users to pass arbitrary JSON keys from CLI (#18208)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-05-16 05:05:34 +01:00 committed by GitHub
parent f4937a51c1
commit b18201fe06
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 78 additions and 3 deletions

View File

@ -181,8 +181,8 @@ def test_get_kwargs():
# literals of literals should have merged choices # literals of literals should have merged choices
assert kwargs["literal_literal"]["choices"] == [1, 2] assert kwargs["literal_literal"]["choices"] == [1, 2]
# dict should have json tip in help # dict should have json tip in help
json_tip = "\n\nShould be a valid JSON string." json_tip = "Should either be a valid JSON string or JSON keys"
assert kwargs["json_tip"]["help"].endswith(json_tip) assert json_tip in kwargs["json_tip"]["help"]
# nested config should should construct the nested config # nested config should should construct the nested config
assert kwargs["nested_config"]["type"]('{"field": 2}') == NestedConfig(2) assert kwargs["nested_config"]["type"]('{"field": 2}') == NestedConfig(2)
# from_cli configs should be constructed with the correct method # from_cli configs should be constructed with the correct method

View File

@ -3,6 +3,7 @@
import asyncio import asyncio
import hashlib import hashlib
import json
import pickle import pickle
import socket import socket
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
@ -138,6 +139,7 @@ def parser():
parser.add_argument('--model-name') parser.add_argument('--model-name')
parser.add_argument('--batch-size', type=int) parser.add_argument('--batch-size', type=int)
parser.add_argument('--enable-feature', action='store_true') parser.add_argument('--enable-feature', action='store_true')
parser.add_argument('--hf-overrides', type=json.loads)
return parser return parser
@ -251,6 +253,29 @@ def test_no_model_tag(parser_with_config, cli_config_file):
parser_with_config.parse_args(['serve', '--config', cli_config_file]) parser_with_config.parse_args(['serve', '--config', cli_config_file])
def test_dict_args(parser):
args = [
"--model-name=something.something",
"--hf-overrides.key1",
"val1",
"--hf-overrides.key2.key3",
"val2",
"--hf-overrides.key2.key4",
"val3",
"--hf-overrides.key5=val4",
]
parsed_args = parser.parse_args(args)
assert parsed_args.model_name == "something.something"
assert parsed_args.hf_overrides == {
"key1": "val1",
"key2": {
"key3": "val2",
"key4": "val3",
},
"key5": "val4",
}
# yapf: enable # yapf: enable
@pytest.mark.parametrize( @pytest.mark.parametrize(
"callable,kw_name,requires_kw_only,allow_var_kwargs,is_supported", "callable,kw_name,requires_kw_only,allow_var_kwargs,is_supported",

View File

@ -183,7 +183,11 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
kwargs[name] = {"default": default, "help": help} kwargs[name] = {"default": default, "help": help}
# Set other kwargs based on the type hints # Set other kwargs based on the type hints
json_tip = "\n\nShould be a valid JSON string." json_tip = """\n\nShould either be a valid JSON string or JSON keys
passed individually. For example, the following sets of arguments are
equivalent:\n\n
- `--json-arg '{"key1": "value1", "key2": {"key3": "value2"}}'`\n
- `--json-arg.key1 value1 --json-arg.key2.key3 value2`\n\n"""
if dataclass_cls is not None: if dataclass_cls is not None:
dataclass_init = lambda x, f=dataclass_cls: f(**json.loads(x)) dataclass_init = lambda x, f=dataclass_cls: f(**json.loads(x))
# Special case for configs with a from_cli method # Special case for configs with a from_cli method

View File

@ -15,6 +15,7 @@ import importlib.metadata
import importlib.util import importlib.util
import inspect import inspect
import ipaddress import ipaddress
import json
import multiprocessing import multiprocessing
import os import os
import pickle import pickle
@ -1419,6 +1420,51 @@ class FlexibleArgumentParser(ArgumentParser):
else: else:
processed_args.append(arg) processed_args.append(arg)
def create_nested_dict(keys: list[str], value: str):
"""Creates a nested dictionary from a list of keys and a value.
For example, `keys = ["a", "b", "c"]` and `value = 1` will create:
`{"a": {"b": {"c": 1}}}`
"""
nested_dict: Any = value
for key in reversed(keys):
nested_dict = {key: nested_dict}
return nested_dict
def recursive_dict_update(original: dict, update: dict):
"""Recursively updates a dictionary with another dictionary."""
for k, v in update.items():
if isinstance(v, dict) and isinstance(original.get(k), dict):
recursive_dict_update(original[k], v)
else:
original[k] = v
delete = set()
dict_args: dict[str, dict] = defaultdict(dict)
for i, processed_arg in enumerate(processed_args):
if processed_arg.startswith("--") and "." in processed_arg:
if "=" in processed_arg:
processed_arg, value = processed_arg.split("=", 1)
if "." not in processed_arg:
# False positive, . was only in the value
continue
else:
value = processed_args[i + 1]
delete.add(i + 1)
key, *keys = processed_arg.split(".")
# Merge all values with the same key into a single dict
arg_dict = create_nested_dict(keys, value)
recursive_dict_update(dict_args[key], arg_dict)
delete.add(i)
# Filter out the dict args we set to None
processed_args = [
a for i, a in enumerate(processed_args) if i not in delete
]
# Add the dict args back as if they were originally passed as JSON
for dict_arg, dict_value in dict_args.items():
processed_args.append(dict_arg)
processed_args.append(json.dumps(dict_value))
return super().parse_args(processed_args, namespace) return super().parse_args(processed_args, namespace)
def check_port(self, value): def check_port(self, value):