mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:15:01 +08:00
Allow users to pass arbitrary JSON keys from CLI (#18208)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
f4937a51c1
commit
b18201fe06
@ -181,8 +181,8 @@ def test_get_kwargs():
|
||||
# literals of literals should have merged choices
|
||||
assert kwargs["literal_literal"]["choices"] == [1, 2]
|
||||
# dict should have json tip in help
|
||||
json_tip = "\n\nShould be a valid JSON string."
|
||||
assert kwargs["json_tip"]["help"].endswith(json_tip)
|
||||
json_tip = "Should either be a valid JSON string or JSON keys"
|
||||
assert json_tip in kwargs["json_tip"]["help"]
|
||||
# nested config should should construct the nested config
|
||||
assert kwargs["nested_config"]["type"]('{"field": 2}') == NestedConfig(2)
|
||||
# from_cli configs should be constructed with the correct method
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import pickle
|
||||
import socket
|
||||
from collections.abc import AsyncIterator
|
||||
@ -138,6 +139,7 @@ def parser():
|
||||
parser.add_argument('--model-name')
|
||||
parser.add_argument('--batch-size', type=int)
|
||||
parser.add_argument('--enable-feature', action='store_true')
|
||||
parser.add_argument('--hf-overrides', type=json.loads)
|
||||
return parser
|
||||
|
||||
|
||||
@ -251,6 +253,29 @@ def test_no_model_tag(parser_with_config, cli_config_file):
|
||||
parser_with_config.parse_args(['serve', '--config', cli_config_file])
|
||||
|
||||
|
||||
def test_dict_args(parser):
|
||||
args = [
|
||||
"--model-name=something.something",
|
||||
"--hf-overrides.key1",
|
||||
"val1",
|
||||
"--hf-overrides.key2.key3",
|
||||
"val2",
|
||||
"--hf-overrides.key2.key4",
|
||||
"val3",
|
||||
"--hf-overrides.key5=val4",
|
||||
]
|
||||
parsed_args = parser.parse_args(args)
|
||||
assert parsed_args.model_name == "something.something"
|
||||
assert parsed_args.hf_overrides == {
|
||||
"key1": "val1",
|
||||
"key2": {
|
||||
"key3": "val2",
|
||||
"key4": "val3",
|
||||
},
|
||||
"key5": "val4",
|
||||
}
|
||||
|
||||
|
||||
# yapf: enable
|
||||
@pytest.mark.parametrize(
|
||||
"callable,kw_name,requires_kw_only,allow_var_kwargs,is_supported",
|
||||
|
||||
@ -183,7 +183,11 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
|
||||
kwargs[name] = {"default": default, "help": help}
|
||||
|
||||
# Set other kwargs based on the type hints
|
||||
json_tip = "\n\nShould be a valid JSON string."
|
||||
json_tip = """\n\nShould either be a valid JSON string or JSON keys
|
||||
passed individually. For example, the following sets of arguments are
|
||||
equivalent:\n\n
|
||||
- `--json-arg '{"key1": "value1", "key2": {"key3": "value2"}}'`\n
|
||||
- `--json-arg.key1 value1 --json-arg.key2.key3 value2`\n\n"""
|
||||
if dataclass_cls is not None:
|
||||
dataclass_init = lambda x, f=dataclass_cls: f(**json.loads(x))
|
||||
# Special case for configs with a from_cli method
|
||||
|
||||
@ -15,6 +15,7 @@ import importlib.metadata
|
||||
import importlib.util
|
||||
import inspect
|
||||
import ipaddress
|
||||
import json
|
||||
import multiprocessing
|
||||
import os
|
||||
import pickle
|
||||
@ -1419,6 +1420,51 @@ class FlexibleArgumentParser(ArgumentParser):
|
||||
else:
|
||||
processed_args.append(arg)
|
||||
|
||||
def create_nested_dict(keys: list[str], value: str):
|
||||
"""Creates a nested dictionary from a list of keys and a value.
|
||||
|
||||
For example, `keys = ["a", "b", "c"]` and `value = 1` will create:
|
||||
`{"a": {"b": {"c": 1}}}`
|
||||
"""
|
||||
nested_dict: Any = value
|
||||
for key in reversed(keys):
|
||||
nested_dict = {key: nested_dict}
|
||||
return nested_dict
|
||||
|
||||
def recursive_dict_update(original: dict, update: dict):
|
||||
"""Recursively updates a dictionary with another dictionary."""
|
||||
for k, v in update.items():
|
||||
if isinstance(v, dict) and isinstance(original.get(k), dict):
|
||||
recursive_dict_update(original[k], v)
|
||||
else:
|
||||
original[k] = v
|
||||
|
||||
delete = set()
|
||||
dict_args: dict[str, dict] = defaultdict(dict)
|
||||
for i, processed_arg in enumerate(processed_args):
|
||||
if processed_arg.startswith("--") and "." in processed_arg:
|
||||
if "=" in processed_arg:
|
||||
processed_arg, value = processed_arg.split("=", 1)
|
||||
if "." not in processed_arg:
|
||||
# False positive, . was only in the value
|
||||
continue
|
||||
else:
|
||||
value = processed_args[i + 1]
|
||||
delete.add(i + 1)
|
||||
key, *keys = processed_arg.split(".")
|
||||
# Merge all values with the same key into a single dict
|
||||
arg_dict = create_nested_dict(keys, value)
|
||||
recursive_dict_update(dict_args[key], arg_dict)
|
||||
delete.add(i)
|
||||
# Filter out the dict args we set to None
|
||||
processed_args = [
|
||||
a for i, a in enumerate(processed_args) if i not in delete
|
||||
]
|
||||
# Add the dict args back as if they were originally passed as JSON
|
||||
for dict_arg, dict_value in dict_args.items():
|
||||
processed_args.append(dict_arg)
|
||||
processed_args.append(json.dumps(dict_value))
|
||||
|
||||
return super().parse_args(processed_args, namespace)
|
||||
|
||||
def check_port(self, value):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user