[8/N] enable cli flag without a space (#10529)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2024-11-21 12:30:42 -08:00 committed by GitHub
parent e7a8341c7c
commit 7560ae5caf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 43 additions and 7 deletions

View File

@ -103,7 +103,7 @@ def test_compile_correctness(test_setting: TestSetting):
CompilationLevel.NO_COMPILATION,
CompilationLevel.PIECEWISE,
]:
all_args.append(final_args + ["-O", str(level)])
all_args.append(final_args + [f"-O{level}"])
all_envs.append({})
# inductor will change the output, so we only compare if the output
@ -121,7 +121,7 @@ def test_compile_correctness(test_setting: TestSetting):
CompilationLevel.DYNAMO_AS_IS,
CompilationLevel.DYNAMO_ONCE,
]:
all_args.append(final_args + ["-O", str(level)])
all_args.append(final_args + [f"-O{level}"])
all_envs.append({})
if level != CompilationLevel.DYNAMO_ONCE and not fullgraph:
# "DYNAMO_ONCE" will always use fullgraph

View File

@ -31,6 +31,34 @@ def test_limit_mm_per_prompt_parser(arg, expected):
assert args.limit_mm_per_prompt == expected
def test_compilation_config():
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
# default value
args = parser.parse_args([])
assert args.compilation_config is None
# set to O3
args = parser.parse_args(["-O3"])
assert args.compilation_config.level == 3
# set to O 3 (space)
args = parser.parse_args(["-O", "3"])
assert args.compilation_config.level == 3
# set to O 3 (equals)
args = parser.parse_args(["-O=3"])
assert args.compilation_config.level == 3
# set to json
args = parser.parse_args(["--compilation-config", '{"level": 3}'])
assert args.compilation_config.level == 3
# set to json
args = parser.parse_args(['--compilation-config={"level": 3}'])
assert args.compilation_config.level == 3
def test_valid_pooling_config():
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
args = parser.parse_args([

View File

@ -13,9 +13,10 @@ os.environ["VLLM_RPC_TIMEOUT"] = "30000"
def test_custom_dispatcher():
compare_two_settings(
"google/gemma-2b",
arg1=["--enforce-eager", "-O",
str(CompilationLevel.DYNAMO_ONCE)],
arg2=["--enforce-eager", "-O",
str(CompilationLevel.DYNAMO_AS_IS)],
arg1=[
"--enforce-eager",
f"-O{CompilationLevel.DYNAMO_ONCE}",
],
arg2=["--enforce-eager", f"-O{CompilationLevel.DYNAMO_AS_IS}"],
env1={},
env2={})

View File

@ -882,7 +882,10 @@ class EngineArgs:
'testing only. level 3 is the recommended level '
'for production.\n'
'To specify the full compilation config, '
'use a JSON string.')
'use a JSON string.\n'
'Following the convention of traditional '
'compilers, using -O without space is also '
'supported. -O3 is equivalent to -O 3.')
return parser

View File

@ -1192,6 +1192,10 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
else:
processed_args.append('--' +
arg[len('--'):].replace('_', '-'))
elif arg.startswith('-O') and arg != '-O' and len(arg) == 2:
# allow -O flag to be used without space, e.g. -O3
processed_args.append('-O')
processed_args.append(arg[2:])
else:
processed_args.append(arg)