mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 09:25:44 +08:00
Support non-string values in JSON keys from CLI (#19471)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
871d6b7c74
commit
a2142f0196
@ -13,32 +13,32 @@ from vllm.model_executor.layers.pooler import PoolingType
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
class TestConfig1:
|
||||
class _TestConfig1:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestConfig2:
|
||||
class _TestConfig2:
|
||||
a: int
|
||||
"""docstring"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestConfig3:
|
||||
class _TestConfig3:
|
||||
a: int = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestConfig4:
|
||||
class _TestConfig4:
|
||||
a: Union[Literal[1], Literal[2]] = 1
|
||||
"""docstring"""
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("test_config", "expected_error"), [
|
||||
(TestConfig1, "must be a dataclass"),
|
||||
(TestConfig2, "must have a default"),
|
||||
(TestConfig3, "must have a docstring"),
|
||||
(TestConfig4, "must use a single Literal"),
|
||||
(_TestConfig1, "must be a dataclass"),
|
||||
(_TestConfig2, "must have a default"),
|
||||
(_TestConfig3, "must have a docstring"),
|
||||
(_TestConfig4, "must use a single Literal"),
|
||||
])
|
||||
def test_config(test_config, expected_error):
|
||||
with pytest.raises(Exception, match=expected_error):
|
||||
@ -57,23 +57,23 @@ def test_compile_config_repr_succeeds():
|
||||
assert 'inductor_passes' in val
|
||||
|
||||
|
||||
@dataclass
|
||||
class _TestConfigFields:
|
||||
a: int
|
||||
b: dict = field(default_factory=dict)
|
||||
c: str = "default"
|
||||
|
||||
|
||||
def test_get_field():
|
||||
|
||||
@dataclass
|
||||
class TestConfig:
|
||||
a: int
|
||||
b: dict = field(default_factory=dict)
|
||||
c: str = "default"
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
get_field(TestConfig, "a")
|
||||
get_field(_TestConfigFields, "a")
|
||||
|
||||
b = get_field(TestConfig, "b")
|
||||
b = get_field(_TestConfigFields, "b")
|
||||
assert isinstance(b, Field)
|
||||
assert b.default is MISSING
|
||||
assert b.default_factory is dict
|
||||
|
||||
c = get_field(TestConfig, "c")
|
||||
c = get_field(_TestConfigFields, "c")
|
||||
assert isinstance(c, Field)
|
||||
assert c.default == "default"
|
||||
assert c.default_factory is MISSING
|
||||
|
||||
@ -272,6 +272,15 @@ def test_dict_args(parser):
|
||||
"val5",
|
||||
"--hf_overrides.key-7.key_8",
|
||||
"val6",
|
||||
# Test data type detection
|
||||
"--hf_overrides.key9",
|
||||
"100",
|
||||
"--hf_overrides.key10",
|
||||
"100.0",
|
||||
"--hf_overrides.key11",
|
||||
"true",
|
||||
"--hf_overrides.key12.key13",
|
||||
"null",
|
||||
]
|
||||
parsed_args = parser.parse_args(args)
|
||||
assert parsed_args.model_name == "something.something"
|
||||
@ -286,6 +295,12 @@ def test_dict_args(parser):
|
||||
"key-7": {
|
||||
"key_8": "val6",
|
||||
},
|
||||
"key9": 100,
|
||||
"key10": 100.0,
|
||||
"key11": True,
|
||||
"key12": {
|
||||
"key13": None,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -1466,7 +1466,7 @@ class FlexibleArgumentParser(ArgumentParser):
|
||||
pattern = re.compile(r"(?<=--)[^\.]*")
|
||||
|
||||
# Convert underscores to dashes and vice versa in argument names
|
||||
processed_args = []
|
||||
processed_args = list[str]()
|
||||
for arg in args:
|
||||
if arg.startswith('--'):
|
||||
if '=' in arg:
|
||||
@ -1483,7 +1483,7 @@ class FlexibleArgumentParser(ArgumentParser):
|
||||
else:
|
||||
processed_args.append(arg)
|
||||
|
||||
def create_nested_dict(keys: list[str], value: str):
|
||||
def create_nested_dict(keys: list[str], value: str) -> dict[str, Any]:
|
||||
"""Creates a nested dictionary from a list of keys and a value.
|
||||
|
||||
For example, `keys = ["a", "b", "c"]` and `value = 1` will create:
|
||||
@ -1494,7 +1494,10 @@ class FlexibleArgumentParser(ArgumentParser):
|
||||
nested_dict = {key: nested_dict}
|
||||
return nested_dict
|
||||
|
||||
def recursive_dict_update(original: dict, update: dict):
|
||||
def recursive_dict_update(
|
||||
original: dict[str, Any],
|
||||
update: dict[str, Any],
|
||||
):
|
||||
"""Recursively updates a dictionary with another dictionary."""
|
||||
for k, v in update.items():
|
||||
if isinstance(v, dict) and isinstance(original.get(k), dict):
|
||||
@ -1502,19 +1505,25 @@ class FlexibleArgumentParser(ArgumentParser):
|
||||
else:
|
||||
original[k] = v
|
||||
|
||||
delete = set()
|
||||
dict_args: dict[str, dict] = defaultdict(dict)
|
||||
delete = set[int]()
|
||||
dict_args = defaultdict[str, dict[str, Any]](dict)
|
||||
for i, processed_arg in enumerate(processed_args):
|
||||
if processed_arg.startswith("--") and "." in processed_arg:
|
||||
if "=" in processed_arg:
|
||||
processed_arg, value = processed_arg.split("=", 1)
|
||||
processed_arg, value_str = processed_arg.split("=", 1)
|
||||
if "." not in processed_arg:
|
||||
# False positive, . was only in the value
|
||||
continue
|
||||
else:
|
||||
value = processed_args[i + 1]
|
||||
value_str = processed_args[i + 1]
|
||||
delete.add(i + 1)
|
||||
|
||||
key, *keys = processed_arg.split(".")
|
||||
try:
|
||||
value = json.loads(value_str)
|
||||
except json.decoder.JSONDecodeError:
|
||||
value = value_str
|
||||
|
||||
# Merge all values with the same key into a single dict
|
||||
arg_dict = create_nested_dict(keys, value)
|
||||
recursive_dict_update(dict_args[key], arg_dict)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user