mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-12 15:16:14 +08:00
Support non-string values in JSON keys from CLI (#19471)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
871d6b7c74
commit
a2142f0196
@ -13,32 +13,32 @@ from vllm.model_executor.layers.pooler import PoolingType
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
|
||||||
class TestConfig1:
|
class _TestConfig1:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TestConfig2:
|
class _TestConfig2:
|
||||||
a: int
|
a: int
|
||||||
"""docstring"""
|
"""docstring"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TestConfig3:
|
class _TestConfig3:
|
||||||
a: int = 1
|
a: int = 1
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TestConfig4:
|
class _TestConfig4:
|
||||||
a: Union[Literal[1], Literal[2]] = 1
|
a: Union[Literal[1], Literal[2]] = 1
|
||||||
"""docstring"""
|
"""docstring"""
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(("test_config", "expected_error"), [
|
@pytest.mark.parametrize(("test_config", "expected_error"), [
|
||||||
(TestConfig1, "must be a dataclass"),
|
(_TestConfig1, "must be a dataclass"),
|
||||||
(TestConfig2, "must have a default"),
|
(_TestConfig2, "must have a default"),
|
||||||
(TestConfig3, "must have a docstring"),
|
(_TestConfig3, "must have a docstring"),
|
||||||
(TestConfig4, "must use a single Literal"),
|
(_TestConfig4, "must use a single Literal"),
|
||||||
])
|
])
|
||||||
def test_config(test_config, expected_error):
|
def test_config(test_config, expected_error):
|
||||||
with pytest.raises(Exception, match=expected_error):
|
with pytest.raises(Exception, match=expected_error):
|
||||||
@ -57,23 +57,23 @@ def test_compile_config_repr_succeeds():
|
|||||||
assert 'inductor_passes' in val
|
assert 'inductor_passes' in val
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _TestConfigFields:
|
||||||
|
a: int
|
||||||
|
b: dict = field(default_factory=dict)
|
||||||
|
c: str = "default"
|
||||||
|
|
||||||
|
|
||||||
def test_get_field():
|
def test_get_field():
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TestConfig:
|
|
||||||
a: int
|
|
||||||
b: dict = field(default_factory=dict)
|
|
||||||
c: str = "default"
|
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
get_field(TestConfig, "a")
|
get_field(_TestConfigFields, "a")
|
||||||
|
|
||||||
b = get_field(TestConfig, "b")
|
b = get_field(_TestConfigFields, "b")
|
||||||
assert isinstance(b, Field)
|
assert isinstance(b, Field)
|
||||||
assert b.default is MISSING
|
assert b.default is MISSING
|
||||||
assert b.default_factory is dict
|
assert b.default_factory is dict
|
||||||
|
|
||||||
c = get_field(TestConfig, "c")
|
c = get_field(_TestConfigFields, "c")
|
||||||
assert isinstance(c, Field)
|
assert isinstance(c, Field)
|
||||||
assert c.default == "default"
|
assert c.default == "default"
|
||||||
assert c.default_factory is MISSING
|
assert c.default_factory is MISSING
|
||||||
|
|||||||
@ -272,6 +272,15 @@ def test_dict_args(parser):
|
|||||||
"val5",
|
"val5",
|
||||||
"--hf_overrides.key-7.key_8",
|
"--hf_overrides.key-7.key_8",
|
||||||
"val6",
|
"val6",
|
||||||
|
# Test data type detection
|
||||||
|
"--hf_overrides.key9",
|
||||||
|
"100",
|
||||||
|
"--hf_overrides.key10",
|
||||||
|
"100.0",
|
||||||
|
"--hf_overrides.key11",
|
||||||
|
"true",
|
||||||
|
"--hf_overrides.key12.key13",
|
||||||
|
"null",
|
||||||
]
|
]
|
||||||
parsed_args = parser.parse_args(args)
|
parsed_args = parser.parse_args(args)
|
||||||
assert parsed_args.model_name == "something.something"
|
assert parsed_args.model_name == "something.something"
|
||||||
@ -286,6 +295,12 @@ def test_dict_args(parser):
|
|||||||
"key-7": {
|
"key-7": {
|
||||||
"key_8": "val6",
|
"key_8": "val6",
|
||||||
},
|
},
|
||||||
|
"key9": 100,
|
||||||
|
"key10": 100.0,
|
||||||
|
"key11": True,
|
||||||
|
"key12": {
|
||||||
|
"key13": None,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1466,7 +1466,7 @@ class FlexibleArgumentParser(ArgumentParser):
|
|||||||
pattern = re.compile(r"(?<=--)[^\.]*")
|
pattern = re.compile(r"(?<=--)[^\.]*")
|
||||||
|
|
||||||
# Convert underscores to dashes and vice versa in argument names
|
# Convert underscores to dashes and vice versa in argument names
|
||||||
processed_args = []
|
processed_args = list[str]()
|
||||||
for arg in args:
|
for arg in args:
|
||||||
if arg.startswith('--'):
|
if arg.startswith('--'):
|
||||||
if '=' in arg:
|
if '=' in arg:
|
||||||
@ -1483,7 +1483,7 @@ class FlexibleArgumentParser(ArgumentParser):
|
|||||||
else:
|
else:
|
||||||
processed_args.append(arg)
|
processed_args.append(arg)
|
||||||
|
|
||||||
def create_nested_dict(keys: list[str], value: str):
|
def create_nested_dict(keys: list[str], value: str) -> dict[str, Any]:
|
||||||
"""Creates a nested dictionary from a list of keys and a value.
|
"""Creates a nested dictionary from a list of keys and a value.
|
||||||
|
|
||||||
For example, `keys = ["a", "b", "c"]` and `value = 1` will create:
|
For example, `keys = ["a", "b", "c"]` and `value = 1` will create:
|
||||||
@ -1494,7 +1494,10 @@ class FlexibleArgumentParser(ArgumentParser):
|
|||||||
nested_dict = {key: nested_dict}
|
nested_dict = {key: nested_dict}
|
||||||
return nested_dict
|
return nested_dict
|
||||||
|
|
||||||
def recursive_dict_update(original: dict, update: dict):
|
def recursive_dict_update(
|
||||||
|
original: dict[str, Any],
|
||||||
|
update: dict[str, Any],
|
||||||
|
):
|
||||||
"""Recursively updates a dictionary with another dictionary."""
|
"""Recursively updates a dictionary with another dictionary."""
|
||||||
for k, v in update.items():
|
for k, v in update.items():
|
||||||
if isinstance(v, dict) and isinstance(original.get(k), dict):
|
if isinstance(v, dict) and isinstance(original.get(k), dict):
|
||||||
@ -1502,19 +1505,25 @@ class FlexibleArgumentParser(ArgumentParser):
|
|||||||
else:
|
else:
|
||||||
original[k] = v
|
original[k] = v
|
||||||
|
|
||||||
delete = set()
|
delete = set[int]()
|
||||||
dict_args: dict[str, dict] = defaultdict(dict)
|
dict_args = defaultdict[str, dict[str, Any]](dict)
|
||||||
for i, processed_arg in enumerate(processed_args):
|
for i, processed_arg in enumerate(processed_args):
|
||||||
if processed_arg.startswith("--") and "." in processed_arg:
|
if processed_arg.startswith("--") and "." in processed_arg:
|
||||||
if "=" in processed_arg:
|
if "=" in processed_arg:
|
||||||
processed_arg, value = processed_arg.split("=", 1)
|
processed_arg, value_str = processed_arg.split("=", 1)
|
||||||
if "." not in processed_arg:
|
if "." not in processed_arg:
|
||||||
# False positive, . was only in the value
|
# False positive, . was only in the value
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
value = processed_args[i + 1]
|
value_str = processed_args[i + 1]
|
||||||
delete.add(i + 1)
|
delete.add(i + 1)
|
||||||
|
|
||||||
key, *keys = processed_arg.split(".")
|
key, *keys = processed_arg.split(".")
|
||||||
|
try:
|
||||||
|
value = json.loads(value_str)
|
||||||
|
except json.decoder.JSONDecodeError:
|
||||||
|
value = value_str
|
||||||
|
|
||||||
# Merge all values with the same key into a single dict
|
# Merge all values with the same key into a single dict
|
||||||
arg_dict = create_nested_dict(keys, value)
|
arg_dict = create_nested_dict(keys, value)
|
||||||
recursive_dict_update(dict_args[key], arg_dict)
|
recursive_dict_update(dict_args[key], arg_dict)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user