[Feature][Benchmark] add --link-vars can filter when serve_param equal bench_param (#28909)

Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io>
This commit is contained in:
rongfu.leng 2025-11-24 18:02:28 +08:00 committed by GitHub
parent ed40d85929
commit 68dfe28eae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -211,6 +211,7 @@ def run_combs(
output_dir: Path,
num_runs: int,
dry_run: bool,
links: list[tuple[str, str]],
):
all_data = list[dict[str, object]]()
for serve_comb in serve_params:
@ -226,6 +227,14 @@ def run_combs(
else contextlib.nullcontext()
) as server:
for bench_comb in bench_params:
should_run = all(
serve_key in serve_comb
and bench_key in bench_comb
and serve_comb[serve_key] == bench_comb[bench_key]
for serve_key, bench_key in links
)
if not should_run:
continue
base_path = _get_comb_base_path(output_dir, serve_comb, bench_comb)
comb_data = run_comb(
@ -262,6 +271,7 @@ class SweepServeArgs:
num_runs: int
dry_run: bool
resume: str | None
link_vars: list[tuple[str, str]] | None
parser_name: ClassVar[str] = "serve"
parser_help: ClassVar[str] = "Run vLLM server benchmark under multiple settings."
@ -285,7 +295,7 @@ class SweepServeArgs:
else:
# i.e.: run bench_cmd without any modification
bench_params = ParameterSweep.from_records([{}])
link_vars = cls.parse_link_vars(args.link_vars)
num_runs = args.num_runs
if num_runs < 1:
raise ValueError("`num_runs` should be at least 1.")
@ -301,6 +311,7 @@ class SweepServeArgs:
num_runs=num_runs,
dry_run=args.dry_run,
resume=args.resume,
link_vars=link_vars,
)
@classmethod
@ -376,8 +387,28 @@ class SweepServeArgs:
"parameter combinations for which there are still no output files.",
)
parser.add_argument(
"--link-vars",
type=str,
default="",
help=(
"Comma-separated list of linked variables between serve and bench, "
"e.g. max_num_seqs=max_concurrency,max_model_len=random_input_len"
),
)
return parser
@staticmethod
def parse_link_vars(s: str) -> list[tuple[str, str]]:
if not s:
return []
pairs = []
for item in s.split(","):
a, b = item.split("=")
pairs.append((a.strip(), b.strip()))
return pairs
def run_main(args: SweepServeArgs):
timestamp = args.resume or datetime.now().strftime("%Y%m%d_%H%M%S")
@ -397,6 +428,7 @@ def run_main(args: SweepServeArgs):
output_dir=output_dir,
num_runs=args.num_runs,
dry_run=args.dry_run,
links=args.link_vars,
)
except BaseException as exc:
raise RuntimeError(