mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 06:35:00 +08:00
Allow users to pass arbitrary JSON keys from CLI (#18208)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
f4937a51c1
commit
b18201fe06
@ -181,8 +181,8 @@ def test_get_kwargs():
|
|||||||
# literals of literals should have merged choices
|
# literals of literals should have merged choices
|
||||||
assert kwargs["literal_literal"]["choices"] == [1, 2]
|
assert kwargs["literal_literal"]["choices"] == [1, 2]
|
||||||
# dict should have json tip in help
|
# dict should have json tip in help
|
||||||
json_tip = "\n\nShould be a valid JSON string."
|
json_tip = "Should either be a valid JSON string or JSON keys"
|
||||||
assert kwargs["json_tip"]["help"].endswith(json_tip)
|
assert json_tip in kwargs["json_tip"]["help"]
|
||||||
# nested config should should construct the nested config
|
# nested config should should construct the nested config
|
||||||
assert kwargs["nested_config"]["type"]('{"field": 2}') == NestedConfig(2)
|
assert kwargs["nested_config"]["type"]('{"field": 2}') == NestedConfig(2)
|
||||||
# from_cli configs should be constructed with the correct method
|
# from_cli configs should be constructed with the correct method
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import json
|
||||||
import pickle
|
import pickle
|
||||||
import socket
|
import socket
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
@ -138,6 +139,7 @@ def parser():
|
|||||||
parser.add_argument('--model-name')
|
parser.add_argument('--model-name')
|
||||||
parser.add_argument('--batch-size', type=int)
|
parser.add_argument('--batch-size', type=int)
|
||||||
parser.add_argument('--enable-feature', action='store_true')
|
parser.add_argument('--enable-feature', action='store_true')
|
||||||
|
parser.add_argument('--hf-overrides', type=json.loads)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -251,6 +253,29 @@ def test_no_model_tag(parser_with_config, cli_config_file):
|
|||||||
parser_with_config.parse_args(['serve', '--config', cli_config_file])
|
parser_with_config.parse_args(['serve', '--config', cli_config_file])
|
||||||
|
|
||||||
|
|
||||||
|
def test_dict_args(parser):
|
||||||
|
args = [
|
||||||
|
"--model-name=something.something",
|
||||||
|
"--hf-overrides.key1",
|
||||||
|
"val1",
|
||||||
|
"--hf-overrides.key2.key3",
|
||||||
|
"val2",
|
||||||
|
"--hf-overrides.key2.key4",
|
||||||
|
"val3",
|
||||||
|
"--hf-overrides.key5=val4",
|
||||||
|
]
|
||||||
|
parsed_args = parser.parse_args(args)
|
||||||
|
assert parsed_args.model_name == "something.something"
|
||||||
|
assert parsed_args.hf_overrides == {
|
||||||
|
"key1": "val1",
|
||||||
|
"key2": {
|
||||||
|
"key3": "val2",
|
||||||
|
"key4": "val3",
|
||||||
|
},
|
||||||
|
"key5": "val4",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"callable,kw_name,requires_kw_only,allow_var_kwargs,is_supported",
|
"callable,kw_name,requires_kw_only,allow_var_kwargs,is_supported",
|
||||||
|
|||||||
@ -183,7 +183,11 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
|
|||||||
kwargs[name] = {"default": default, "help": help}
|
kwargs[name] = {"default": default, "help": help}
|
||||||
|
|
||||||
# Set other kwargs based on the type hints
|
# Set other kwargs based on the type hints
|
||||||
json_tip = "\n\nShould be a valid JSON string."
|
json_tip = """\n\nShould either be a valid JSON string or JSON keys
|
||||||
|
passed individually. For example, the following sets of arguments are
|
||||||
|
equivalent:\n\n
|
||||||
|
- `--json-arg '{"key1": "value1", "key2": {"key3": "value2"}}'`\n
|
||||||
|
- `--json-arg.key1 value1 --json-arg.key2.key3 value2`\n\n"""
|
||||||
if dataclass_cls is not None:
|
if dataclass_cls is not None:
|
||||||
dataclass_init = lambda x, f=dataclass_cls: f(**json.loads(x))
|
dataclass_init = lambda x, f=dataclass_cls: f(**json.loads(x))
|
||||||
# Special case for configs with a from_cli method
|
# Special case for configs with a from_cli method
|
||||||
|
|||||||
@ -15,6 +15,7 @@ import importlib.metadata
|
|||||||
import importlib.util
|
import importlib.util
|
||||||
import inspect
|
import inspect
|
||||||
import ipaddress
|
import ipaddress
|
||||||
|
import json
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
@ -1419,6 +1420,51 @@ class FlexibleArgumentParser(ArgumentParser):
|
|||||||
else:
|
else:
|
||||||
processed_args.append(arg)
|
processed_args.append(arg)
|
||||||
|
|
||||||
|
def create_nested_dict(keys: list[str], value: str):
|
||||||
|
"""Creates a nested dictionary from a list of keys and a value.
|
||||||
|
|
||||||
|
For example, `keys = ["a", "b", "c"]` and `value = 1` will create:
|
||||||
|
`{"a": {"b": {"c": 1}}}`
|
||||||
|
"""
|
||||||
|
nested_dict: Any = value
|
||||||
|
for key in reversed(keys):
|
||||||
|
nested_dict = {key: nested_dict}
|
||||||
|
return nested_dict
|
||||||
|
|
||||||
|
def recursive_dict_update(original: dict, update: dict):
|
||||||
|
"""Recursively updates a dictionary with another dictionary."""
|
||||||
|
for k, v in update.items():
|
||||||
|
if isinstance(v, dict) and isinstance(original.get(k), dict):
|
||||||
|
recursive_dict_update(original[k], v)
|
||||||
|
else:
|
||||||
|
original[k] = v
|
||||||
|
|
||||||
|
delete = set()
|
||||||
|
dict_args: dict[str, dict] = defaultdict(dict)
|
||||||
|
for i, processed_arg in enumerate(processed_args):
|
||||||
|
if processed_arg.startswith("--") and "." in processed_arg:
|
||||||
|
if "=" in processed_arg:
|
||||||
|
processed_arg, value = processed_arg.split("=", 1)
|
||||||
|
if "." not in processed_arg:
|
||||||
|
# False positive, . was only in the value
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
value = processed_args[i + 1]
|
||||||
|
delete.add(i + 1)
|
||||||
|
key, *keys = processed_arg.split(".")
|
||||||
|
# Merge all values with the same key into a single dict
|
||||||
|
arg_dict = create_nested_dict(keys, value)
|
||||||
|
recursive_dict_update(dict_args[key], arg_dict)
|
||||||
|
delete.add(i)
|
||||||
|
# Filter out the dict args we set to None
|
||||||
|
processed_args = [
|
||||||
|
a for i, a in enumerate(processed_args) if i not in delete
|
||||||
|
]
|
||||||
|
# Add the dict args back as if they were originally passed as JSON
|
||||||
|
for dict_arg, dict_value in dict_args.items():
|
||||||
|
processed_args.append(dict_arg)
|
||||||
|
processed_args.append(json.dumps(dict_value))
|
||||||
|
|
||||||
return super().parse_args(processed_args, namespace)
|
return super().parse_args(processed_args, namespace)
|
||||||
|
|
||||||
def check_port(self, value):
|
def check_port(self, value):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user