mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-24 13:51:18 +08:00
[CLI] Improve CLI arg parsing for -O/--compilation-config (#20156)
Signed-off-by: luka <luka@neuralmagic.com>
This commit is contained in:
parent
ded1fb635b
commit
6d42ce8315
@ -239,32 +239,40 @@ def test_compilation_config():
|
|||||||
assert args.compilation_config == CompilationConfig()
|
assert args.compilation_config == CompilationConfig()
|
||||||
|
|
||||||
# set to O3
|
# set to O3
|
||||||
args = parser.parse_args(["-O3"])
|
args = parser.parse_args(["-O0"])
|
||||||
assert args.compilation_config.level == 3
|
assert args.compilation_config.level == 0
|
||||||
|
|
||||||
# set to O 3 (space)
|
# set to O 3 (space)
|
||||||
args = parser.parse_args(["-O", "3"])
|
args = parser.parse_args(["-O", "1"])
|
||||||
assert args.compilation_config.level == 3
|
assert args.compilation_config.level == 1
|
||||||
|
|
||||||
# set to O 3 (equals)
|
# set to O 3 (equals)
|
||||||
args = parser.parse_args(["-O=3"])
|
args = parser.parse_args(["-O=2"])
|
||||||
|
assert args.compilation_config.level == 2
|
||||||
|
|
||||||
|
# set to O.level 3
|
||||||
|
args = parser.parse_args(["-O.level", "3"])
|
||||||
assert args.compilation_config.level == 3
|
assert args.compilation_config.level == 3
|
||||||
|
|
||||||
# set to string form of a dict
|
# set to string form of a dict
|
||||||
args = parser.parse_args([
|
args = parser.parse_args([
|
||||||
"--compilation-config",
|
"-O",
|
||||||
'{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}',
|
'{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], '
|
||||||
|
'"use_inductor": false}',
|
||||||
])
|
])
|
||||||
assert (args.compilation_config.level == 3 and
|
assert (args.compilation_config.level == 3 and
|
||||||
args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8])
|
args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]
|
||||||
|
and not args.compilation_config.use_inductor)
|
||||||
|
|
||||||
# set to string form of a dict
|
# set to string form of a dict
|
||||||
args = parser.parse_args([
|
args = parser.parse_args([
|
||||||
"--compilation-config="
|
"--compilation-config="
|
||||||
'{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}',
|
'{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], '
|
||||||
|
'"use_inductor": true}',
|
||||||
])
|
])
|
||||||
assert (args.compilation_config.level == 3 and
|
assert (args.compilation_config.level == 3 and
|
||||||
args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8])
|
args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]
|
||||||
|
and args.compilation_config.use_inductor)
|
||||||
|
|
||||||
|
|
||||||
def test_prefix_cache_default():
|
def test_prefix_cache_default():
|
||||||
|
|||||||
@ -5,6 +5,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import pickle
|
import pickle
|
||||||
import socket
|
import socket
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
@ -142,6 +143,7 @@ def parser():
|
|||||||
parser.add_argument('--batch-size', type=int)
|
parser.add_argument('--batch-size', type=int)
|
||||||
parser.add_argument('--enable-feature', action='store_true')
|
parser.add_argument('--enable-feature', action='store_true')
|
||||||
parser.add_argument('--hf-overrides', type=json.loads)
|
parser.add_argument('--hf-overrides', type=json.loads)
|
||||||
|
parser.add_argument('-O', '--compilation-config', type=json.loads)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -265,6 +267,11 @@ def test_dict_args(parser):
|
|||||||
"val2",
|
"val2",
|
||||||
"--hf-overrides.key2.key4",
|
"--hf-overrides.key2.key4",
|
||||||
"val3",
|
"val3",
|
||||||
|
# Test compile config and compilation level
|
||||||
|
"-O.use_inductor=true",
|
||||||
|
"-O.backend",
|
||||||
|
"custom",
|
||||||
|
"-O1",
|
||||||
# Test = sign
|
# Test = sign
|
||||||
"--hf-overrides.key5=val4",
|
"--hf-overrides.key5=val4",
|
||||||
# Test underscore to dash conversion
|
# Test underscore to dash conversion
|
||||||
@ -281,6 +288,13 @@ def test_dict_args(parser):
|
|||||||
"true",
|
"true",
|
||||||
"--hf_overrides.key12.key13",
|
"--hf_overrides.key12.key13",
|
||||||
"null",
|
"null",
|
||||||
|
# Test '-' and '.' in value
|
||||||
|
"--hf_overrides.key14.key15",
|
||||||
|
"-minus.and.dot",
|
||||||
|
# Test array values
|
||||||
|
"-O.custom_ops+",
|
||||||
|
"-quant_fp8",
|
||||||
|
"-O.custom_ops+=+silu_mul,-rms_norm",
|
||||||
]
|
]
|
||||||
parsed_args = parser.parse_args(args)
|
parsed_args = parser.parse_args(args)
|
||||||
assert parsed_args.model_name == "something.something"
|
assert parsed_args.model_name == "something.something"
|
||||||
@ -301,7 +315,40 @@ def test_dict_args(parser):
|
|||||||
"key12": {
|
"key12": {
|
||||||
"key13": None,
|
"key13": None,
|
||||||
},
|
},
|
||||||
|
"key14": {
|
||||||
|
"key15": "-minus.and.dot",
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
assert parsed_args.compilation_config == {
|
||||||
|
"level": 1,
|
||||||
|
"use_inductor": True,
|
||||||
|
"backend": "custom",
|
||||||
|
"custom_ops": ["-quant_fp8", "+silu_mul", "-rms_norm"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_duplicate_dict_args(caplog_vllm, parser):
|
||||||
|
args = [
|
||||||
|
"--model-name=something.something",
|
||||||
|
"--hf-overrides.key1",
|
||||||
|
"val1",
|
||||||
|
"--hf-overrides.key1",
|
||||||
|
"val2",
|
||||||
|
"-O1",
|
||||||
|
"-O.level",
|
||||||
|
"2",
|
||||||
|
"-O3",
|
||||||
|
]
|
||||||
|
|
||||||
|
parsed_args = parser.parse_args(args)
|
||||||
|
# Should be the last value
|
||||||
|
assert parsed_args.hf_overrides == {"key1": "val2"}
|
||||||
|
assert parsed_args.compilation_config == {"level": 3}
|
||||||
|
|
||||||
|
assert len(caplog_vllm.records) == 1
|
||||||
|
assert "duplicate" in caplog_vllm.text
|
||||||
|
assert "--hf-overrides.key1" in caplog_vllm.text
|
||||||
|
assert "-O.level" in caplog_vllm.text
|
||||||
|
|
||||||
|
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
|
|||||||
@ -4140,9 +4140,9 @@ class CompilationConfig:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_cli(cls, cli_value: str) -> "CompilationConfig":
|
def from_cli(cls, cli_value: str) -> "CompilationConfig":
|
||||||
"""Parse the CLI value for the compilation config."""
|
"""Parse the CLI value for the compilation config.
|
||||||
if cli_value in ["0", "1", "2", "3"]:
|
-O1, -O2, -O3, etc. is handled in FlexibleArgumentParser.
|
||||||
return cls(level=int(cli_value))
|
"""
|
||||||
return TypeAdapter(CompilationConfig).validate_json(cli_value)
|
return TypeAdapter(CompilationConfig).validate_json(cli_value)
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
@ -4303,17 +4303,16 @@ class VllmConfig:
|
|||||||
"""Quantization configuration."""
|
"""Quantization configuration."""
|
||||||
compilation_config: CompilationConfig = field(
|
compilation_config: CompilationConfig = field(
|
||||||
default_factory=CompilationConfig)
|
default_factory=CompilationConfig)
|
||||||
"""`torch.compile` configuration for the model.
|
"""`torch.compile` and cudagraph capture configuration for the model.
|
||||||
|
|
||||||
When it is a number (0, 1, 2, 3), it will be interpreted as the
|
As a shorthand, `-O<n>` can be used to directly specify the compilation
|
||||||
optimization level.
|
level `n`: `-O3` is equivalent to `-O.level=3` (same as `-O='{"level":3}'`).
|
||||||
|
Currently, -O <n> and -O=<n> are supported as well but this will likely be
|
||||||
|
removed in favor of clearer -O<n> syntax in the future.
|
||||||
|
|
||||||
NOTE: level 0 is the default level without any optimization. level 1 and 2
|
NOTE: level 0 is the default level without any optimization. level 1 and 2
|
||||||
are for internal testing only. level 3 is the recommended level for
|
are for internal testing only. level 3 is the recommended level for
|
||||||
production.
|
production, also default in V1.
|
||||||
|
|
||||||
Following the convention of traditional compilers, using `-O` without space
|
|
||||||
is also supported. `-O3` is equivalent to `-O 3`.
|
|
||||||
|
|
||||||
You can specify the full compilation config like so:
|
You can specify the full compilation config like so:
|
||||||
`{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}`
|
`{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}`
|
||||||
|
|||||||
@ -202,7 +202,10 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, Any]:
|
|||||||
passed individually. For example, the following sets of arguments are
|
passed individually. For example, the following sets of arguments are
|
||||||
equivalent:\n\n
|
equivalent:\n\n
|
||||||
- `--json-arg '{"key1": "value1", "key2": {"key3": "value2"}}'`\n
|
- `--json-arg '{"key1": "value1", "key2": {"key3": "value2"}}'`\n
|
||||||
- `--json-arg.key1 value1 --json-arg.key2.key3 value2`\n\n"""
|
- `--json-arg.key1 value1 --json-arg.key2.key3 value2`\n
|
||||||
|
Additionally, list elements can be passed individually using '+':
|
||||||
|
- `--json-arg '{"key4": ["value3", "value4", "value5"]}'`\n
|
||||||
|
- `--json-arg.key4+ value3 --json-arg.key4+='value4,value5'`\n\n"""
|
||||||
if dataclass_cls is not None:
|
if dataclass_cls is not None:
|
||||||
|
|
||||||
def parse_dataclass(val: str, cls=dataclass_cls) -> Any:
|
def parse_dataclass(val: str, cls=dataclass_cls) -> Any:
|
||||||
|
|||||||
@ -89,15 +89,15 @@ MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120
|
|||||||
|
|
||||||
STR_NOT_IMPL_ENC_DEC_SWA = \
|
STR_NOT_IMPL_ENC_DEC_SWA = \
|
||||||
"Sliding window attention for encoder/decoder models " + \
|
"Sliding window attention for encoder/decoder models " + \
|
||||||
"is not currently supported."
|
"is not currently supported."
|
||||||
|
|
||||||
STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \
|
STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \
|
||||||
"Prefix caching for encoder/decoder models " + \
|
"Prefix caching for encoder/decoder models " + \
|
||||||
"is not currently supported."
|
"is not currently supported."
|
||||||
|
|
||||||
STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL = \
|
STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL = \
|
||||||
"Chunked prefill for encoder/decoder models " + \
|
"Chunked prefill for encoder/decoder models " + \
|
||||||
"is not currently supported."
|
"is not currently supported."
|
||||||
|
|
||||||
STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP = (
|
STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP = (
|
||||||
"Models with logits_soft_cap "
|
"Models with logits_soft_cap "
|
||||||
@ -752,7 +752,7 @@ def _generate_random_fp8(
|
|||||||
# to generate random data for fp8 data.
|
# to generate random data for fp8 data.
|
||||||
# For example, s.11111.00 in fp8e5m2 format represents Inf.
|
# For example, s.11111.00 in fp8e5m2 format represents Inf.
|
||||||
# | E4M3 | E5M2
|
# | E4M3 | E5M2
|
||||||
#-----|-------------|-------------------
|
# -----|-------------|-------------------
|
||||||
# Inf | N/A | s.11111.00
|
# Inf | N/A | s.11111.00
|
||||||
# NaN | s.1111.111 | s.11111.{01,10,11}
|
# NaN | s.1111.111 | s.11111.{01,10,11}
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
@ -840,7 +840,6 @@ def create_kv_caches_with_random(
|
|||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
device: Optional[str] = "cuda",
|
device: Optional[str] = "cuda",
|
||||||
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
|
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
|
||||||
|
|
||||||
if cache_dtype == "fp8" and head_size % 16:
|
if cache_dtype == "fp8" and head_size % 16:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Does not support key cache of type fp8 with head_size {head_size}"
|
f"Does not support key cache of type fp8 with head_size {head_size}"
|
||||||
@ -1205,7 +1204,6 @@ def deprecate_args(
|
|||||||
is_deprecated: Union[bool, Callable[[], bool]] = True,
|
is_deprecated: Union[bool, Callable[[], bool]] = True,
|
||||||
additional_message: Optional[str] = None,
|
additional_message: Optional[str] = None,
|
||||||
) -> Callable[[F], F]:
|
) -> Callable[[F], F]:
|
||||||
|
|
||||||
if not callable(is_deprecated):
|
if not callable(is_deprecated):
|
||||||
is_deprecated = partial(identity, is_deprecated)
|
is_deprecated = partial(identity, is_deprecated)
|
||||||
|
|
||||||
@ -1355,7 +1353,7 @@ def weak_bind(bound_method: Callable[..., Any], ) -> Callable[..., None]:
|
|||||||
return weak_bound
|
return weak_bound
|
||||||
|
|
||||||
|
|
||||||
#From: https://stackoverflow.com/a/4104188/2749989
|
# From: https://stackoverflow.com/a/4104188/2749989
|
||||||
def run_once(f: Callable[P, None]) -> Callable[P, None]:
|
def run_once(f: Callable[P, None]) -> Callable[P, None]:
|
||||||
|
|
||||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> None:
|
def wrapper(*args: P.args, **kwargs: P.kwargs) -> None:
|
||||||
@ -1474,7 +1472,7 @@ class FlexibleArgumentParser(ArgumentParser):
|
|||||||
|
|
||||||
# Convert underscores to dashes and vice versa in argument names
|
# Convert underscores to dashes and vice versa in argument names
|
||||||
processed_args = list[str]()
|
processed_args = list[str]()
|
||||||
for arg in args:
|
for i, arg in enumerate(args):
|
||||||
if arg.startswith('--'):
|
if arg.startswith('--'):
|
||||||
if '=' in arg:
|
if '=' in arg:
|
||||||
key, value = arg.split('=', 1)
|
key, value = arg.split('=', 1)
|
||||||
@ -1483,10 +1481,17 @@ class FlexibleArgumentParser(ArgumentParser):
|
|||||||
else:
|
else:
|
||||||
key = pattern.sub(repl, arg, count=1)
|
key = pattern.sub(repl, arg, count=1)
|
||||||
processed_args.append(key)
|
processed_args.append(key)
|
||||||
elif arg.startswith('-O') and arg != '-O' and len(arg) == 2:
|
elif arg.startswith('-O') and arg != '-O' and arg[2] != '.':
|
||||||
# allow -O flag to be used without space, e.g. -O3
|
# allow -O flag to be used without space, e.g. -O3 or -Odecode
|
||||||
processed_args.append('-O')
|
# -O.<...> handled later
|
||||||
processed_args.append(arg[2:])
|
# also handle -O=<level> here
|
||||||
|
level = arg[3:] if arg[2] == '=' else arg[2:]
|
||||||
|
processed_args.append(f'-O.level={level}')
|
||||||
|
elif arg == '-O' and i + 1 < len(args) and args[i + 1] in {
|
||||||
|
"0", "1", "2", "3"
|
||||||
|
}:
|
||||||
|
# Convert -O <n> to -O.level <n>
|
||||||
|
processed_args.append('-O.level')
|
||||||
else:
|
else:
|
||||||
processed_args.append(arg)
|
processed_args.append(arg)
|
||||||
|
|
||||||
@ -1504,27 +1509,44 @@ class FlexibleArgumentParser(ArgumentParser):
|
|||||||
def recursive_dict_update(
|
def recursive_dict_update(
|
||||||
original: dict[str, Any],
|
original: dict[str, Any],
|
||||||
update: dict[str, Any],
|
update: dict[str, Any],
|
||||||
):
|
) -> set[str]:
|
||||||
"""Recursively updates a dictionary with another dictionary."""
|
"""Recursively updates a dictionary with another dictionary.
|
||||||
|
Returns a set of duplicate keys that were overwritten.
|
||||||
|
"""
|
||||||
|
duplicates = set[str]()
|
||||||
for k, v in update.items():
|
for k, v in update.items():
|
||||||
if isinstance(v, dict) and isinstance(original.get(k), dict):
|
if isinstance(v, dict) and isinstance(original.get(k), dict):
|
||||||
recursive_dict_update(original[k], v)
|
nested_duplicates = recursive_dict_update(original[k], v)
|
||||||
|
duplicates |= {f"{k}.{d}" for d in nested_duplicates}
|
||||||
|
elif isinstance(v, list) and isinstance(original.get(k), list):
|
||||||
|
original[k] += v
|
||||||
else:
|
else:
|
||||||
|
if k in original:
|
||||||
|
duplicates.add(k)
|
||||||
original[k] = v
|
original[k] = v
|
||||||
|
return duplicates
|
||||||
|
|
||||||
delete = set[int]()
|
delete = set[int]()
|
||||||
dict_args = defaultdict[str, dict[str, Any]](dict)
|
dict_args = defaultdict[str, dict[str, Any]](dict)
|
||||||
|
duplicates = set[str]()
|
||||||
for i, processed_arg in enumerate(processed_args):
|
for i, processed_arg in enumerate(processed_args):
|
||||||
if processed_arg.startswith("--") and "." in processed_arg:
|
if i in delete: # skip if value from previous arg
|
||||||
|
continue
|
||||||
|
|
||||||
|
if processed_arg.startswith("-") and "." in processed_arg:
|
||||||
if "=" in processed_arg:
|
if "=" in processed_arg:
|
||||||
processed_arg, value_str = processed_arg.split("=", 1)
|
processed_arg, value_str = processed_arg.split("=", 1)
|
||||||
if "." not in processed_arg:
|
if "." not in processed_arg:
|
||||||
# False positive, . was only in the value
|
# False positive, '.' was only in the value
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
value_str = processed_args[i + 1]
|
value_str = processed_args[i + 1]
|
||||||
delete.add(i + 1)
|
delete.add(i + 1)
|
||||||
|
|
||||||
|
if processed_arg.endswith("+"):
|
||||||
|
processed_arg = processed_arg[:-1]
|
||||||
|
value_str = json.dumps(list(value_str.split(",")))
|
||||||
|
|
||||||
key, *keys = processed_arg.split(".")
|
key, *keys = processed_arg.split(".")
|
||||||
try:
|
try:
|
||||||
value = json.loads(value_str)
|
value = json.loads(value_str)
|
||||||
@ -1533,12 +1555,17 @@ class FlexibleArgumentParser(ArgumentParser):
|
|||||||
|
|
||||||
# Merge all values with the same key into a single dict
|
# Merge all values with the same key into a single dict
|
||||||
arg_dict = create_nested_dict(keys, value)
|
arg_dict = create_nested_dict(keys, value)
|
||||||
recursive_dict_update(dict_args[key], arg_dict)
|
arg_duplicates = recursive_dict_update(dict_args[key],
|
||||||
|
arg_dict)
|
||||||
|
duplicates |= {f'{key}.{d}' for d in arg_duplicates}
|
||||||
delete.add(i)
|
delete.add(i)
|
||||||
# Filter out the dict args we set to None
|
# Filter out the dict args we set to None
|
||||||
processed_args = [
|
processed_args = [
|
||||||
a for i, a in enumerate(processed_args) if i not in delete
|
a for i, a in enumerate(processed_args) if i not in delete
|
||||||
]
|
]
|
||||||
|
if duplicates:
|
||||||
|
logger.warning("Found duplicate keys %s", ", ".join(duplicates))
|
||||||
|
|
||||||
# Add the dict args back as if they were originally passed as JSON
|
# Add the dict args back as if they were originally passed as JSON
|
||||||
for dict_arg, dict_value in dict_args.items():
|
for dict_arg, dict_value in dict_args.items():
|
||||||
processed_args.append(dict_arg)
|
processed_args.append(dict_arg)
|
||||||
@ -2405,7 +2432,7 @@ def memory_profiling(
|
|||||||
The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]` during profiling gives (b.).
|
The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]` during profiling gives (b.).
|
||||||
|
|
||||||
The increase of `non_torch_memory` from creating the current vLLM instance until after profiling to get (c.).
|
The increase of `non_torch_memory` from creating the current vLLM instance until after profiling to get (c.).
|
||||||
""" # noqa
|
""" # noqa
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
torch.cuda.reset_peak_memory_stats()
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user