diff --git a/docs/contributing/benchmarks.md b/docs/contributing/benchmarks.md
index fa56c6e18284a..52a16d7bdbff1 100644
--- a/docs/contributing/benchmarks.md
+++ b/docs/contributing/benchmarks.md
@@ -6,7 +6,8 @@ toc_depth: 4
vLLM provides comprehensive benchmarking tools for performance testing and evaluation:
-- **[Benchmark CLI]**: `vllm bench` CLI tools and specialized benchmark scripts for interactive performance testing
+- **[Benchmark CLI](#benchmark-cli)**: `vllm bench` CLI tools and specialized benchmark scripts for interactive performance testing
+- **[Batch Scripts](#batch-scripts)**: Run `vllm bench` against multiple configurations conveniently
- **[Performance benchmarks](#performance-benchmarks)**: Automated CI benchmarks for development
- **[Nightly benchmarks](#nightly-benchmarks)**: Comparative benchmarks against alternatives
@@ -29,7 +30,7 @@ th {
| Dataset | Online | Offline | Data Path |
|---------|--------|---------|-----------|
| ShareGPT | ✅ | ✅ | `wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json` |
-| ShareGPT4V (Image) | ✅ | ✅ | `wget https://huggingface.co/datasets/Lin-Chen/ShareGPT4V/blob/main/sharegpt4v_instruct_gpt4-vision_cap100k.json`
Note that the images need to be downloaded separately. For example, to download COCO's 2017 Train images:
`wget http://images.cocodataset.org/zips/train2017.zip` |
+| ShareGPT4V (Image) | ✅ | ✅ | `wget https://huggingface.co/datasets/Lin-Chen/ShareGPT4V/resolve/main/sharegpt4v_instruct_gpt4-vision_cap100k.json`
Note that the images need to be downloaded separately. For example, to download COCO's 2017 Train images:
`wget http://images.cocodataset.org/zips/train2017.zip` |
| ShareGPT4Video (Video) | ✅ | ✅ | `git clone https://huggingface.co/datasets/ShareGPT4Video/ShareGPT4Video` |
| BurstGPT | ✅ | ✅ | `wget https://github.com/HPMLL/BurstGPT/releases/download/v1.1/BurstGPT_without_fails_2.csv` |
| Sonnet (deprecated) | ✅ | ✅ | Local file: `benchmarks/sonnet.txt` |
@@ -714,7 +715,7 @@ Generate synthetic image inputs alongside random text prompts to stress-test vis
Notes:
-- Works only with online benchmark via the OpenAI backend (`--backend openai-chat`) and endpoint `/v1/chat/completions`.
+- Works only with online benchmark via the OpenAI backend (`--backend openai-chat`) and endpoint `/v1/chat/completions`.
- Video sampling is not yet implemented.
Start the server (example):
@@ -924,6 +925,147 @@ throughput numbers correctly is also adjusted.
+## Batch Scripts
+
+### Batch Serving Script
+
+[`vllm/benchmarks/serve_multi.py`](../../vllm/benchmarks/serve_multi.py) automatically starts `vllm serve` and runs `vllm bench serve` over multiple configurations.
+
+#### Batch Mode
+
+The basic purpose of this script is to evaluate vLLM under different settings. Follows these steps to run the script:
+
+1. Construct the base command to `vllm serve`, and pass it to the `--serve-cmd` option.
+2. Construct the base command to `vllm bench serve`, and pass it to the `--bench-cmd` option.
+3. (Optional) If you would like to vary the settings of `vllm serve`, create a new JSON file and populate it with the parameter combinations you want to test. Pass the file path to `--serve-params`.
+
+ - Example: Tuning `--max-num-seqs` and `--max-num-batched-tokens`:
+
+ ```json
+ [
+ {
+ "max_num_seqs": 32,
+ "max_num_batched_tokens": 1024
+ },
+ {
+ "max_num_seqs": 64,
+ "max_num_batched_tokens": 1024
+ },
+ {
+ "max_num_seqs": 64,
+ "max_num_batched_tokens": 2048
+ },
+ {
+ "max_num_seqs": 128,
+ "max_num_batched_tokens": 2048
+ },
+ {
+ "max_num_seqs": 128,
+ "max_num_batched_tokens": 4096
+ },
+ {
+ "max_num_seqs": 256,
+ "max_num_batched_tokens": 4096
+ }
+ ]
+ ```
+
+4. (Optional) If you would like to vary the settings of `vllm bench serve`, create a new JSON file and populate it with the parameter combinations you want to test. Pass the file path to `--bench-params`.
+
+ - Example: Using different input/output lengths for random dataset:
+
+ ```json
+ [
+ {
+ "random_input_len": 128,
+ "random_output_len": 32
+ },
+ {
+ "random_input_len": 256,
+ "random_output_len": 64
+ },
+ {
+ "random_input_len": 512,
+ "random_output_len": 128
+ }
+ ]
+ ```
+
+5. Determine where you want to save the results, and pass that to `--output-dir`.
+
+Example command:
+
+```bash
+python vllm/benchmarks/serve_multi.py \
+ --serve-cmd 'vllm serve meta-llama/Llama-2-7b-chat-hf' \
+ --bench-cmd 'vllm bench serve --model meta-llama/Llama-2-7b-chat-hf --backend vllm --endpoint /v1/completions --dataset-name sharegpt --dataset-path benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json' \
+ --serve-params benchmarks/serve_hparams.json \
+ --bench-params benchmarks/bench_hparams.json \
+ -o benchmarks/results
+```
+
+!!! important
+ If both `--serve-params` and `--bench-params` are passed, the script will iterate over the Cartesian product between them.
+ You can use `--dry-run` to preview the commands to be run.
+
+ We only start the server once for each `--serve-params`, and keep it running for multiple `--bench-params`.
+ Between each benchmark run, we call the `/reset_prefix_cache` and `/reset_mm_cache` endpoints to get a clean slate for the next run.
+ In case you are using a custom `--serve-cmd`, you can override the commands used for resetting the state by setting `--after-bench-cmd`.
+
+!!! note
+ By default, each parameter combination is run 3 times to make the results more reliable. You can adjust the number of runs by setting `--num-runs`.
+
+!!! tip
+ You can use the `--resume` option to continue the parameter sweep if one of the runs failed.
+
+#### SLA Mode
+
+By passing SLA constraints via `--sla-params`, you can run this script in SLA mode, causing it to adjust either the request rate or concurrency (choose using `--sla-variable`) in order to satisfy the SLA constraints.
+
+For example, to ensure E2E latency within different target values for 99% of requests:
+
+```json
+[
+ {
+ "p99_e2el_ms": "<=200"
+ },
+ {
+ "p99_e2el_ms": "<=500"
+ },
+ {
+ "p99_e2el_ms": "<=1000"
+ },
+ {
+ "p99_e2el_ms": "<=2000"
+ }
+]
+```
+
+Example command:
+
+```bash
+python vllm/benchmarks/serve_multi.py \
+ --serve-cmd 'vllm serve meta-llama/Llama-2-7b-chat-hf' \
+ --bench-cmd 'vllm bench serve --model meta-llama/Llama-2-7b-chat-hf --backend vllm --endpoint /v1/completions --dataset-name sharegpt --dataset-path benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json' \
+ --serve-params benchmarks/serve_hparams.json \
+ --bench-params benchmarks/bench_hparams.json \
+ --sla-params benchmarks/sla_hparams.json \
+ --sla-variable max_concurrency \
+ -o benchmarks/results
+```
+
+The algorithm for adjusting the SLA variable is as follows:
+
+1. Run the benchmark with infinite QPS, and use the corresponding metrics to determine the initial value of the variable.
+ - For example, the initial request rate is set to the concurrency under infinite QPS.
+2. If the SLA is still satisfied, keep doubling the value until the SLA is no longer satisfied. This gives a relatively narrow window that contains the point where the SLA is barely satisfied.
+3. Apply binary search over the window to find the maximum value that still satisfies the SLA.
+
+!!! important
+ SLA tuning is applied over each combination of `--serve-params`, `--bench-params`, and `--sla-params`.
+
+ For a given combination of `--serve-params` and `--bench-params`, we share the benchmark results across `--sla-params` to avoid rerunning benchmarks with the same SLA variable value.
+
## Performance Benchmarks
The performance benchmarks are used for development to confirm whether new changes improve performance under various workloads. They are triggered on every commit with both the `perf-benchmarks` and `ready` labels, and when a PR is merged into vLLM.
diff --git a/vllm/benchmarks/serve_multi.py b/vllm/benchmarks/serve_multi.py
new file mode 100644
index 0000000000000..e8524473aedd5
--- /dev/null
+++ b/vllm/benchmarks/serve_multi.py
@@ -0,0 +1,1157 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import argparse
+import contextlib
+import json
+import math
+import os
+import shlex
+import signal
+import subprocess
+from abc import ABC, abstractmethod
+from datetime import datetime
+from pathlib import Path
+from typing import Literal, get_args
+
+import pandas as pd
+import requests
+import seaborn as sns
+from typing_extensions import assert_never, override
+
+_BAD_PARAMS_TYPE_MSG = (
+ "The parameters to vary should be expressed as a JSON list of dictionaries."
+)
+
+
+def _parse_params(params: list[dict[str, object]]):
+ if not isinstance(params, list):
+ raise TypeError(f"{_BAD_PARAMS_TYPE_MSG} Found JSON type {type(params)}")
+
+ for comb in params:
+ if not isinstance(comb, dict):
+ raise TypeError(f"{_BAD_PARAMS_TYPE_MSG} Found item type {type(comb)}")
+
+ return params
+
+
+class SLACriterionBase(ABC):
+ def __init__(self, target: float) -> None:
+ super().__init__()
+
+ self.target = target
+
+ @abstractmethod
+ def validate(self, actual: float) -> bool:
+ """Return `True` if this criterion is met; otherwise `False`."""
+ raise NotImplementedError
+
+ @abstractmethod
+ def format_cond(self, lhs: str) -> str:
+ raise NotImplementedError
+
+ def print_and_validate(
+ self,
+ metrics: dict[str, float],
+ metrics_key: str,
+ ) -> bool:
+ metric = metrics[metrics_key]
+ result = self.validate(metric)
+
+ cond = self.format_cond(f"{metrics_key} = {metric:.2f}")
+ print(f"Validating SLA: {cond} | " + ("PASSED" if result else "FAILED"))
+
+ return result
+
+
+class SLALessThan(SLACriterionBase):
+ @override
+ def validate(self, actual: float) -> bool:
+ return actual < self.target
+
+ @override
+ def format_cond(self, lhs: str) -> str:
+ return f"{lhs}<{self.target:.2f}"
+
+
+class SLALessThanOrEqual(SLACriterionBase):
+ @override
+ def validate(self, actual: float) -> bool:
+ return actual <= self.target
+
+ @override
+ def format_cond(self, lhs: str) -> str:
+ return f"{lhs}<={self.target:.2f}"
+
+
+class SLAGreaterThan(SLACriterionBase):
+ @override
+ def validate(self, actual: float) -> bool:
+ return actual > self.target
+
+ @override
+ def format_cond(self, lhs: str) -> str:
+ return f"{lhs}>{self.target:.2f}"
+
+
+class SLAGreaterThanOrEqual(SLACriterionBase):
+ @override
+ def validate(self, actual: float) -> bool:
+ return actual >= self.target
+
+ @override
+ def format_cond(self, lhs: str) -> str:
+ return f"{lhs}>={self.target:.2f}"
+
+
+# NOTE: The ordering is important! Match longer op_keys first
+SLA_CRITERIA: dict[str, type[SLACriterionBase]] = {
+ "<=": SLALessThanOrEqual,
+ ">=": SLAGreaterThanOrEqual,
+ "<": SLALessThan,
+ ">": SLAGreaterThan,
+}
+
+
+def _parse_sla_item(sla_item: dict[str, str]):
+ sla_criteria: dict[str, SLACriterionBase] = {}
+
+ for metric_key, metric_value in sla_item.items():
+ for op_key in SLA_CRITERIA:
+ if metric_value.startswith(op_key):
+ sla_criteria[metric_key] = SLA_CRITERIA[op_key](
+ float(metric_value.removeprefix(op_key))
+ )
+ break
+ else:
+ raise ValueError(
+ f"Invalid operator for SLA constraint '{metric_key}={metric_value}'. "
+ f"Valid operators are: {set(SLA_CRITERIA)}",
+ )
+
+ return sla_criteria
+
+
+def _parse_sla(sla: list[dict[str, str]]):
+ return [_parse_sla_item(item) for item in sla]
+
+
+# In JSON, we prefer "_"
+def _iter_param_key_candidates(param_key: str):
+ yield param_key
+ yield param_key.replace("-", "_")
+ yield param_key.replace("_", "-")
+
+
+# In CLI, we prefer "-"
+def _iter_cmd_key_candidates(param_key: str):
+ for k in reversed(tuple(_iter_param_key_candidates(param_key))):
+ yield "--" + k
+
+
+def _normalize_cmd_key(param_key: str):
+ return next(_iter_cmd_key_candidates(param_key))
+
+
+def _override_args(cmd: list[str], params: dict[str, object]):
+ cmd = list(cmd)
+
+ for k, v in params.items():
+ for k_candidate in _iter_cmd_key_candidates(k):
+ try:
+ k_idx = cmd.index(k_candidate)
+
+ if isinstance(v, bool):
+ cmd[k_idx] = _normalize_cmd_key(k if v else "no-" + k)
+ else:
+ cmd[k_idx + 1] = str(v)
+
+ break
+ except ValueError:
+ continue
+ else:
+ if isinstance(v, bool):
+ cmd.append(_normalize_cmd_key(k if v else "no-" + k))
+ else:
+ cmd.extend([_normalize_cmd_key(k), str(v)])
+
+ return cmd
+
+
+class ServerWrapper:
+ def __init__(
+ self,
+ server_cmd: list[str],
+ after_bench_cmd: list[str],
+ *,
+ show_stdout: bool,
+ ) -> None:
+ super().__init__()
+
+ self.server_cmd = server_cmd
+ self.after_bench_cmd = after_bench_cmd
+ self.show_stdout = show_stdout
+
+ def run_subcommand(self, cmd: list[str]):
+ return subprocess.run(
+ cmd,
+ stdout=None if self.show_stdout else subprocess.DEVNULL,
+ check=True,
+ )
+
+ def after_bench(self) -> None:
+ if not self.after_bench_cmd:
+ self.reset_caches()
+ return
+
+ self.run_subcommand(self.after_bench_cmd)
+
+ def _get_vllm_server_address(self) -> str:
+ server_cmd = self.server_cmd
+
+ for host_key in ("--host",):
+ if host_key in server_cmd:
+ host = server_cmd[server_cmd.index(host_key) + 1]
+ break
+ else:
+ host = "localhost"
+
+ for port_key in ("-p", "--port"):
+ if port_key in server_cmd:
+ port = int(server_cmd[server_cmd.index(port_key) + 1])
+ break
+ else:
+ port = 8000 # The default value in vllm serve
+
+ return f"http://{host}:{port}"
+
+ def reset_caches(self) -> None:
+ server_cmd = self.server_cmd
+
+ # Use `.endswith()` to match `/bin/...`
+ if server_cmd[0].endswith("vllm"):
+ server_address = self._get_vllm_server_address()
+ print(f"Resetting caches at {server_address}")
+
+ res = requests.post(f"{server_address}/reset_prefix_cache")
+ res.raise_for_status()
+
+ res = requests.post(f"{server_address}/reset_mm_cache")
+ res.raise_for_status()
+ elif server_cmd[0].endswith("infinity_emb"):
+ if "--vector-disk-cache" in server_cmd:
+ raise NotImplementedError(
+ "Infinity server uses caching but does not expose a method "
+ "to reset the cache"
+ )
+ else:
+ raise NotImplementedError(
+ f"No implementation of `reset_caches` for `{server_cmd[0]}` server. "
+ "Please specify a custom command via `--after-bench-cmd`."
+ )
+
+
+@contextlib.contextmanager
+def _run_server(
+ serve_cmd: list[str],
+ after_bench_cmd: list[str],
+ *,
+ show_stdout: bool,
+ serve_overrides: dict[str, object],
+ dry_run: bool,
+):
+ server_cmd = _override_args(serve_cmd, serve_overrides)
+
+ print("[BEGIN SERVER]")
+ print(f"Server overrides: {serve_overrides}")
+ print(f"Server command: {server_cmd}")
+
+ if dry_run:
+ yield None
+ print("[END SERVER]")
+ return
+
+ # Create new process group for clean termination
+ server_process = subprocess.Popen(
+ server_cmd,
+ start_new_session=True,
+ stdout=None if show_stdout else subprocess.DEVNULL,
+ # Need VLLM_SERVER_DEV_MODE=1 for `_reset_caches`
+ env={**os.environ, "VLLM_SERVER_DEV_MODE": "1"},
+ )
+
+ try:
+ yield ServerWrapper(
+ server_cmd,
+ after_bench_cmd,
+ show_stdout=show_stdout,
+ )
+ finally:
+ if server_process.poll() is None:
+ # In case only some processes have been terminated
+ with contextlib.suppress(ProcessLookupError):
+ # We need to kill both API Server and Engine processes
+ os.killpg(os.getpgid(server_process.pid), signal.SIGKILL)
+
+ print("[END SERVER]")
+
+
+def _run_benchmark(
+ server: ServerWrapper | None,
+ bench_cmd: list[str],
+ *,
+ serve_overrides: dict[str, object],
+ bench_overrides: dict[str, object],
+ run_number: int,
+ output_path: Path,
+ dry_run: bool,
+):
+ benchmark_cmd = [
+ *_override_args(bench_cmd, bench_overrides),
+ "--save-result",
+ "--result-dir",
+ str(output_path.parent),
+ "--result-filename",
+ output_path.name,
+ ]
+
+ print("[BEGIN BENCHMARK]")
+ print(f"Benchmark overrides: {bench_overrides}")
+ print(f"Run Number: {run_number}")
+ print(f"Benchmark command: {benchmark_cmd}")
+ print(f"Output file: {output_path}")
+
+ run_data: dict[str, object]
+
+ if output_path.exists():
+ print("Found existing results. Skipping.")
+
+ with output_path.open("rb") as f:
+ run_data = json.load(f)
+ return run_data
+
+ if server is None:
+ assert dry_run
+ print("[END BENCHMARK]")
+ return None
+
+ output_path.parent.mkdir(parents=True, exist_ok=True)
+
+ server.run_subcommand(benchmark_cmd)
+ server.after_bench()
+
+ with output_path.open("rb") as f:
+ run_data = json.load(f)
+
+ run_data["run_number"] = run_number
+ run_data.update(serve_overrides)
+
+ with output_path.open("w") as f:
+ json.dump(run_data, f, indent=4)
+
+ print("[END BENCHMARK]")
+
+ return run_data
+
+
+def _get_comb_base_path(
+ output_dir: Path,
+ serve_comb: dict[str, object],
+ bench_comb: dict[str, object],
+):
+ return output_dir / "-".join(
+ (
+ "SERVE",
+ *(f"{k}={v}" for k, v in serve_comb.items()),
+ "BENCH",
+ *(f"{k}={v}" for k, v in bench_comb.items()),
+ )
+ ).replace("/", "_").replace("..", "__") # Sanitize
+
+
+def _get_comb_run_path(base_path: Path, run_number: int | None):
+ if run_number is None:
+ return base_path / "summary.json"
+
+ return base_path / f"run={run_number}.json"
+
+
+def _comb_needs_server(
+ serve_comb: dict[str, object],
+ bench_combs: list[dict[str, object]],
+ output_dir: Path,
+):
+ for bench_comb in bench_combs:
+ base_path = _get_comb_base_path(output_dir, serve_comb, bench_comb)
+ if not _get_comb_run_path(base_path, run_number=None).exists():
+ return True
+
+ return False
+
+
+def _run_comb(
+ server: ServerWrapper | None,
+ bench_cmd: list[str],
+ *,
+ serve_comb: dict[str, object],
+ bench_comb: dict[str, object],
+ base_path: Path,
+ num_runs: int,
+ dry_run: bool,
+):
+ comb_data = list[dict[str, object]]()
+
+ for run_number in range(num_runs):
+ run_data = _run_benchmark(
+ server,
+ bench_cmd,
+ serve_overrides=serve_comb,
+ bench_overrides=bench_comb,
+ run_number=run_number,
+ output_path=_get_comb_run_path(base_path, run_number),
+ dry_run=dry_run,
+ )
+
+ if run_data is not None:
+ comb_data.append(run_data)
+
+ if dry_run:
+ return None
+
+ with _get_comb_run_path(base_path, run_number=None).open("w") as f:
+ json.dump(comb_data, f, indent=4)
+
+ return comb_data
+
+
+def run_combs(
+ serve_cmd: list[str],
+ bench_cmd: list[str],
+ after_bench_cmd: list[str],
+ *,
+ show_stdout: bool,
+ serve_params: list[dict[str, object]],
+ bench_params: list[dict[str, object]],
+ output_dir: Path,
+ num_runs: int,
+ dry_run: bool,
+):
+ all_data = list[dict[str, object]]()
+ for serve_comb in serve_params:
+ with (
+ _run_server(
+ serve_cmd,
+ after_bench_cmd,
+ show_stdout=show_stdout,
+ serve_overrides=serve_comb,
+ dry_run=dry_run,
+ )
+ if _comb_needs_server(serve_comb, bench_params, output_dir)
+ else contextlib.nullcontext()
+ ) as server:
+ for bench_comb in bench_params:
+ base_path = _get_comb_base_path(output_dir, serve_comb, bench_comb)
+
+ comb_data = _run_comb(
+ server,
+ bench_cmd,
+ serve_comb=serve_comb,
+ bench_comb=bench_comb,
+ base_path=base_path,
+ num_runs=num_runs,
+ dry_run=dry_run,
+ )
+
+ if comb_data is not None:
+ all_data.extend(comb_data)
+
+ if dry_run:
+ return None
+
+ combined_df = pd.DataFrame.from_records(all_data)
+ combined_df.to_csv(output_dir / "summary.csv")
+
+ return combined_df
+
+
+def _get_sla_base_path(
+ output_dir: Path,
+ serve_comb: dict[str, object],
+ bench_comb: dict[str, object],
+):
+ return output_dir / "-".join(
+ (
+ "SERVE",
+ *(f"{k}={v}" for k, v in serve_comb.items()),
+ "BENCH",
+ *(f"{k}={v}" for k, v in bench_comb.items()),
+ )
+ ).replace("/", "_").replace("..", "__") # Sanitize
+
+
+def _get_sla_iter_path(
+ base_path: Path,
+ sla_comb: dict[str, SLACriterionBase],
+ sla_variable: str,
+ sla_value: int | None,
+):
+ if sla_value is None:
+ prefix = "-".join(v.format_cond(k) for k, v in sla_comb.items())
+ return base_path / f"SLA-{prefix}.json"
+
+ return base_path / f"{sla_variable}={sla_value}"
+
+
+def _get_sla_run_path(iter_path: Path, run_number: int | None):
+ if run_number is None:
+ return iter_path / "summary.json"
+
+ return iter_path / f"run={run_number}.json"
+
+
+def _sla_needs_server(
+ serve_comb: dict[str, object],
+ bench_combs: list[dict[str, object]],
+ sla_combs: list[dict[str, SLACriterionBase]],
+ sla_variable: str,
+ output_dir: Path,
+):
+ for bench_comb in bench_combs:
+ base_path = _get_sla_base_path(output_dir, serve_comb, bench_comb)
+ for sla_comb in sla_combs:
+ if not _get_sla_iter_path(
+ base_path,
+ sla_comb,
+ sla_variable,
+ sla_value=None,
+ ).exists():
+ return True
+
+ return False
+
+
+def _run_sla(
+ server: ServerWrapper | None,
+ bench_cmd: list[str],
+ *,
+ serve_comb: dict[str, object],
+ bench_comb: dict[str, object],
+ iter_path: Path,
+ num_runs: int,
+ dry_run: bool,
+):
+ iter_data = list[dict[str, object]]()
+
+ for run_number in range(num_runs):
+ run_data = _run_benchmark(
+ server,
+ bench_cmd,
+ serve_overrides=serve_comb,
+ bench_overrides=bench_comb,
+ run_number=run_number,
+ output_path=_get_sla_run_path(iter_path, run_number),
+ dry_run=dry_run,
+ )
+
+ if run_data is not None:
+ iter_data.append(run_data)
+
+ if dry_run:
+ return None
+
+ with _get_sla_run_path(iter_path, run_number=None).open("w") as f:
+ json.dump(iter_data, f, indent=4)
+
+ return iter_data
+
+
+SLAVariable = Literal["request_rate", "max_concurrency"]
+
+
+def _estimate_sla_value(run_data: dict[str, object], sla_variable: SLAVariable):
+ request_throughput = float(run_data["request_throughput"]) # type: ignore
+ if sla_variable == "request_rate":
+ return request_throughput
+ if sla_variable == "max_concurrency":
+ mean_latency_ms = float(run_data["mean_e2el_ms"]) # type: ignore
+ return request_throughput * mean_latency_ms / 1000
+
+ assert_never(sla_variable)
+
+
+def _estimate_sla_bounds(
+ server: ServerWrapper | None,
+ bench_cmd: list[str],
+ *,
+ serve_comb: dict[str, object],
+ bench_comb: dict[str, object],
+ sla_comb: dict[str, SLACriterionBase],
+ base_path: Path,
+ num_runs: int,
+ dry_run: bool,
+ sla_variable: SLAVariable,
+ init_value: int,
+ max_value: int,
+):
+ sla_data = list[dict[str, object]]()
+
+ max_passing: int = 0
+ min_failing: int = 0
+
+ val: int = init_value
+ assert val > 0
+
+ while True:
+ print(f"Testing {sla_variable}: {val} req/s")
+
+ iter_data = _run_sla(
+ server,
+ bench_cmd,
+ serve_comb=serve_comb,
+ bench_comb={**bench_comb, sla_variable: val},
+ iter_path=_get_sla_iter_path(base_path, sla_comb, sla_variable, val),
+ num_runs=num_runs,
+ dry_run=dry_run,
+ )
+
+ assert iter_data is not None
+ sla_data.extend(iter_data)
+
+ iter_data_mean = {
+ k: sum(float(run_data[k]) for run_data in iter_data) / len(iter_data) # type: ignore
+ for k in sla_comb
+ }
+
+ sla_results = [
+ criterion.print_and_validate(iter_data_mean, k)
+ for k, criterion in sla_comb.items()
+ ]
+
+ if all(sla_results):
+ print("SLA criteria are met.")
+ max_passing = val
+ val *= 2
+ else:
+ print("SLA criteria are not met.")
+ min_failing = val
+ break
+
+ if val >= max_value:
+ break
+
+ return sla_data, (max_passing, min_failing)
+
+
+def _find_sla_value(
+ server: ServerWrapper | None,
+ bench_cmd: list[str],
+ *,
+ serve_comb: dict[str, object],
+ bench_comb: dict[str, object],
+ sla_comb: dict[str, SLACriterionBase],
+ base_path: Path,
+ num_runs: int,
+ dry_run: bool,
+ sla_variable: SLAVariable,
+ min_value: int,
+ max_value: int,
+):
+ sla_data = list[dict[str, object]]()
+
+ left: int = min_value
+ right: int = max_value
+
+ while True:
+ val = (left + right) // 2
+ print(f"Testing {sla_variable}: {val} req/s")
+
+ iter_data = _run_sla(
+ server,
+ bench_cmd,
+ serve_comb=serve_comb,
+ bench_comb={**bench_comb, sla_variable: val},
+ iter_path=_get_sla_iter_path(base_path, sla_comb, sla_variable, val),
+ num_runs=num_runs,
+ dry_run=dry_run,
+ )
+
+ assert iter_data is not None
+ sla_data.extend(iter_data)
+
+ iter_data_mean = {
+ k: sum(float(run_data[k]) for run_data in iter_data) / len(iter_data) # type: ignore
+ for k in sla_comb
+ }
+
+ sla_results = [
+ criterion.print_and_validate(iter_data_mean, k)
+ for k, criterion in sla_comb.items()
+ ]
+
+ if all(sla_results):
+ print("SLA criteria are met.")
+ left = val
+ else:
+ print("SLA criteria are not met.")
+ right = val
+
+ if right - left <= 1:
+ break
+
+ return sla_data, left
+
+
+def _search_sla(
+ server: ServerWrapper | None,
+ bench_cmd: list[str],
+ *,
+ serve_comb: dict[str, object],
+ bench_comb: dict[str, object],
+ sla_comb: dict[str, SLACriterionBase],
+ sla_variable: SLAVariable,
+ sla_inf_value: int = 65536, # The value that represents infinite QPS
+ base_path: Path,
+ num_runs: int,
+ dry_run: bool,
+):
+ print("[SLA START]")
+ print(f"SLA criteria: {', '.join(v.format_cond(k) for k, v in sla_comb.items())}")
+
+ sla_data_0 = _run_sla(
+ server,
+ bench_cmd,
+ serve_comb=serve_comb,
+ bench_comb={**bench_comb, sla_variable: sla_inf_value},
+ iter_path=_get_sla_iter_path(base_path, sla_comb, sla_variable, sla_inf_value),
+ num_runs=num_runs,
+ dry_run=dry_run,
+ )
+ if sla_data_0 is None:
+ assert dry_run
+ print("Omitting SLA search.")
+ print("[SLA END]")
+ return None
+
+ sla_init_value = math.ceil(
+ sum(_estimate_sla_value(item, sla_variable) for item in sla_data_0)
+ / len(sla_data_0)
+ )
+ print(f"Initial {sla_variable} to search: {sla_init_value} req/s.")
+
+ sla_data_1, (sla_min, sla_max) = _estimate_sla_bounds(
+ server,
+ bench_cmd,
+ serve_comb=serve_comb,
+ bench_comb=bench_comb,
+ sla_comb=sla_comb,
+ base_path=base_path,
+ num_runs=num_runs,
+ dry_run=dry_run,
+ sla_variable=sla_variable,
+ init_value=sla_init_value,
+ max_value=sla_inf_value,
+ )
+ print(f"Range of {sla_variable} to search: [{sla_min}, {sla_max}] req/s.")
+
+ sla_data_2, sla_value = _find_sla_value(
+ server,
+ bench_cmd,
+ serve_comb=serve_comb,
+ bench_comb=bench_comb,
+ sla_comb=sla_comb,
+ base_path=base_path,
+ num_runs=num_runs,
+ dry_run=dry_run,
+ sla_variable=sla_variable,
+ min_value=sla_min,
+ max_value=sla_max,
+ )
+
+ sla_data = sla_data_0 + sla_data_1 + sla_data_2
+ print(f"Maximum {sla_variable} for SLA: {sla_value} req/s.")
+
+ with _get_sla_iter_path(
+ base_path,
+ sla_comb,
+ sla_variable,
+ sla_value=None,
+ ).open("w") as f:
+ json.dump(sla_data, f, indent=4)
+
+ print("[SLA END]")
+
+ return sla_data
+
+
+def _plot_throughput_latency_curve(
+ all_data: list[dict[str, object]],
+ serve_combs: list[dict[str, object]],
+ bench_comb: dict[str, object],
+ output_dir: Path,
+):
+ fig_path = output_dir / "-".join(
+ (
+ "BENCH",
+ *(f"{k}={v}" for k, v in bench_comb.items()),
+ )
+ ).replace("/", "_").replace("..", "__") # Sanitize
+
+ df = pd.DataFrame.from_records(
+ [item for item in all_data if all(item[k] == bench_comb[k] for k in bench_comb)]
+ )
+
+ # Group together points with similar throughput
+ df["request_throughput"] = df["request_throughput"].round()
+
+ # Preserve the key order using dictionary
+ all_comb_keys = {k: None for comb in serve_combs for k in comb}
+ for k in all_comb_keys:
+ df[k] = df[k].astype(str)
+
+ keys_per_comb = [comb.keys() for comb in serve_combs]
+ if (
+ all(ks == keys_per_comb[0] for ks in keys_per_comb)
+ and len(keys_per_comb[0]) <= 3
+ ):
+ hue, style, size, *_ = (*keys_per_comb[0], None, None)
+ ax = sns.lineplot(
+ df,
+ x="request_throughput",
+ y="p99_e2el_ms",
+ hue=hue,
+ style=style,
+ size=size,
+ markers=True,
+ )
+ else:
+ df["category"] = df[list(all_comb_keys)].agg("-".join, axis=1)
+ ax = sns.lineplot(
+ df,
+ x="request_throughput",
+ y="p99_e2el_ms",
+ hue="category",
+ markers=True,
+ )
+
+ sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1))
+
+ fig = ax.get_figure()
+ assert fig is not None
+
+ fig.tight_layout()
+ fig.savefig(fig_path)
+
+
+def _plot_throughput_latency_curves(
+ all_data: list[dict[str, object]],
+ serve_combs: list[dict[str, object]],
+ bench_combs: list[dict[str, object]],
+ output_dir: Path,
+):
+ for bench_comb in bench_combs:
+ _plot_throughput_latency_curve(all_data, serve_combs, bench_comb, output_dir)
+
+
+def run_slas(
+ serve_cmd: list[str],
+ bench_cmd: list[str],
+ after_bench_cmd: list[str],
+ *,
+ show_stdout: bool,
+ serve_params: list[dict[str, object]],
+ bench_params: list[dict[str, object]],
+ sla_params: list[dict[str, SLACriterionBase]],
+ sla_variable: SLAVariable,
+ output_dir: Path,
+ num_runs: int,
+ dry_run: bool,
+):
+ if any(
+ k in bench_comb
+ for bench_comb in bench_params
+ for k in _iter_param_key_candidates(sla_variable)
+ ):
+ raise ValueError(
+ f"You should not override `{sla_variable}` in `bench_params` in SLA mode, "
+ "since it is supposed to be determined automatically."
+ )
+
+ all_data = list[dict[str, object]]()
+ for serve_comb in serve_params:
+ with (
+ _run_server(
+ serve_cmd,
+ after_bench_cmd,
+ show_stdout=show_stdout,
+ serve_overrides=serve_comb,
+ dry_run=dry_run,
+ )
+ if _sla_needs_server(
+ serve_comb,
+ bench_params,
+ sla_params,
+ sla_variable,
+ output_dir,
+ )
+ else contextlib.nullcontext()
+ ) as server:
+ for bench_comb in bench_params:
+ for sla_comb in sla_params:
+ base_path = _get_sla_base_path(output_dir, serve_comb, bench_comb)
+
+ comb_data = _search_sla(
+ server,
+ bench_cmd,
+ serve_comb=serve_comb,
+ bench_comb=bench_comb,
+ sla_comb=sla_comb,
+ sla_variable=sla_variable,
+ base_path=base_path,
+ num_runs=num_runs,
+ dry_run=dry_run,
+ )
+
+ if comb_data is not None:
+ all_data.extend(comb_data)
+
+ if dry_run:
+ return None
+
+ combined_df = pd.DataFrame.from_records(all_data)
+ combined_df.to_csv(output_dir / "summary.csv")
+
+ _plot_throughput_latency_curves(all_data, serve_params, bench_params, output_dir)
+
+ return combined_df
+
+
+def _run_main(
+ serve_cmd: list[str],
+ bench_cmd: list[str],
+ after_bench_cmd: list[str],
+ *,
+ show_stdout: bool,
+ serve_params: list[dict[str, object]],
+ bench_params: list[dict[str, object]],
+ sla_params: list[dict[str, SLACriterionBase]],
+ sla_variable: SLAVariable,
+ output_dir: Path,
+ num_runs: int,
+ dry_run: bool,
+):
+ if sla_params:
+ return run_slas(
+ serve_cmd=serve_cmd,
+ bench_cmd=bench_cmd,
+ after_bench_cmd=after_bench_cmd,
+ show_stdout=show_stdout,
+ serve_params=serve_params,
+ bench_params=bench_params,
+ sla_params=sla_params,
+ sla_variable=sla_variable,
+ output_dir=output_dir,
+ num_runs=num_runs,
+ dry_run=dry_run,
+ )
+
+ return run_combs(
+ serve_cmd=serve_cmd,
+ bench_cmd=bench_cmd,
+ after_bench_cmd=after_bench_cmd,
+ show_stdout=show_stdout,
+ serve_params=serve_params,
+ bench_params=bench_params,
+ output_dir=output_dir,
+ num_runs=num_runs,
+ dry_run=dry_run,
+ )
+
+
+def run_main(
+ serve_cmd: list[str],
+ bench_cmd: list[str],
+ after_bench_cmd: list[str],
+ *,
+ show_stdout: bool,
+ serve_params: list[dict[str, object]],
+ bench_params: list[dict[str, object]],
+ sla_params: list[dict[str, SLACriterionBase]],
+ sla_variable: SLAVariable,
+ output_dir: Path,
+ num_runs: int,
+ dry_run: bool,
+ resume: str | None,
+):
+ timestamp = resume or datetime.now().strftime("%Y%m%d_%H%M%S")
+ output_dir = output_dir / timestamp
+
+ if resume and not output_dir.exists():
+ raise ValueError(f"Cannot resume from non-existent directory ({output_dir})")
+
+ try:
+ return _run_main(
+ serve_cmd=serve_cmd,
+ bench_cmd=bench_cmd,
+ after_bench_cmd=after_bench_cmd,
+ show_stdout=show_stdout,
+ serve_params=serve_params,
+ bench_params=bench_params,
+ sla_params=sla_params,
+ sla_variable=sla_variable,
+ output_dir=output_dir,
+ num_runs=num_runs,
+ dry_run=dry_run,
+ )
+ except BaseException as exc:
+ raise RuntimeError(
+ f"The script was terminated early. Use `--resume {timestamp}` "
+ f"to continue the script from its last checkpoint."
+ ) from exc
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Run vLLM server benchmark on a parameter grid of settings."
+ )
+ parser.add_argument(
+ "--serve-cmd",
+ type=str,
+ required=True,
+ help="The command used to run the server: `vllm serve ...`",
+ )
+ parser.add_argument(
+ "--bench-cmd",
+ type=str,
+ required=True,
+ help="The command used to run the benchmark: `vllm bench serve ...`",
+ )
+ parser.add_argument(
+ "--after-bench-cmd",
+ type=str,
+ default=None,
+ help="After a benchmark run is complete, invoke this command instead of the "
+ "default `ServerWrapper.clear_cache()`.",
+ )
+ parser.add_argument(
+ "--show-stdout",
+ action="store_true",
+ help="If set, logs the standard output of subcommands. "
+ "Useful for debugging but can be quite spammy.",
+ )
+ parser.add_argument(
+ "--serve-params",
+ type=str,
+ default=None,
+ help="Path to JSON file containing a list of parameter combinations "
+ "for the `vllm serve` command. "
+ "If both `serve_params` and `bench_params` are given, "
+ "this script will iterate over their Cartesian product.",
+ )
+ parser.add_argument(
+ "--bench-params",
+ type=str,
+ default=None,
+ help="Path to JSON file containing a list of parameter combinations "
+ "for the `vllm bench serve` command. "
+ "If both `serve_params` and `bench_params` are given, "
+ "this script will iterate over their Cartesian product.",
+ )
+ parser.add_argument(
+ "--sla-params",
+ type=str,
+ default=None,
+ help="Path to JSON file containing a list of SLA constraints to satisfy. "
+ 'Each constraint is expressed in `{"": ""}` format, '
+ 'e.g.: `{"p99_e2el_ms": "<=500"}` means that '
+ "the E2E latency should be less than 500ms 99% of the time. "
+ "Setting this option runs this script in SLA mode, which searches for the "
+ "maximum `sla_variable` that satisfies the constraints for each combination "
+ "of `serve_params`, `bench_params`, and `sla_params`.",
+ )
+ parser.add_argument(
+ "--sla-variable",
+ type=str,
+ choices=get_args(SLAVariable),
+ default="request_rate",
+ help="Whether to tune request rate or maximum concurrency to satisfy "
+ "the SLA constraints.",
+ )
+ parser.add_argument(
+ "-o",
+ "--output-dir",
+ type=str,
+ default="results",
+ help="The directory to which results are written.",
+ )
+ parser.add_argument(
+ "--num-runs",
+ type=int,
+ default=3,
+ help="Number of runs per parameter combination.",
+ )
+ parser.add_argument(
+ "--dry-run",
+ action="store_true",
+ help="If set, prints the commands to run then exits without running them.",
+ )
+ parser.add_argument(
+ "--resume",
+ type=str,
+ default=None,
+ help="Set this to the name of a directory under `output_dir` (which is a "
+ "timestamp) to resume a previous execution of this script, i.e., only run "
+ "parameter combinations for which there are still no output files.",
+ )
+
+ args = parser.parse_args()
+
+ serve_cmd = shlex.split(args.serve_cmd)
+ bench_cmd = shlex.split(args.bench_cmd)
+ after_bench_cmd = (
+ [] if args.after_bench_cmd is None else shlex.split(args.after_bench_cmd)
+ )
+
+ serve_params: list[dict[str, object]]
+ if args.serve_params:
+ with open(args.serve_params, "rb") as f:
+ serve_params = _parse_params(json.load(f))
+ else:
+ # i.e.: run serve_cmd without any modification
+ serve_params = [{}]
+
+ bench_params: list[dict[str, object]]
+ if args.bench_params:
+ with open(args.bench_params, "rb") as f:
+ bench_params = _parse_params(json.load(f))
+ else:
+ # i.e.: run bench_cmd without any modification
+ bench_params = [{}]
+
+ sla_params: list[dict[str, SLACriterionBase]]
+ if args.sla_params:
+ with open(args.sla_params, "rb") as f:
+ sla_params = _parse_sla(json.load(f))
+ else:
+ sla_params = []
+
+ num_runs = args.num_runs
+ if num_runs < 1:
+ raise ValueError("`num_runs` should be at least 1.")
+
+ run_main(
+ serve_cmd=serve_cmd,
+ bench_cmd=bench_cmd,
+ after_bench_cmd=after_bench_cmd,
+ show_stdout=args.show_stdout,
+ serve_params=serve_params,
+ bench_params=bench_params,
+ sla_params=sla_params,
+ sla_variable=args.sla_variable,
+ output_dir=Path(args.output_dir),
+ num_runs=num_runs,
+ dry_run=args.dry_run,
+ resume=args.resume,
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py
index 0ac0355956908..40508eebdcb69 100644
--- a/vllm/entrypoints/openai/api_server.py
+++ b/vllm/entrypoints/openai/api_server.py
@@ -994,6 +994,16 @@ if envs.VLLM_SERVER_DEV_MODE:
await engine_client(raw_request).reset_prefix_cache(device)
return Response(status_code=200)
+ @router.post("/reset_mm_cache")
+ async def reset_mm_cache(raw_request: Request):
+ """
+ Reset the multi-modal cache. Note that we currently do not check if the
+ multi-modal cache is successfully reset in the API server.
+ """
+ logger.info("Resetting multi-modal cache...")
+ await engine_client(raw_request).reset_mm_cache()
+ return Response(status_code=200)
+
@router.post("/sleep")
async def sleep(raw_request: Request):
# get POST params