[CI/Build][Doc] Clean up more docs that point to old bench scripts (#21667)

Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>
This commit is contained in:
Ye (Charlotte) Qi 2025-07-26 21:02:12 -07:00 committed by GitHub
parent 971948b846
commit 01a395e9e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 66 additions and 63 deletions

View File

@ -74,7 +74,7 @@ Here is an example of one test inside `latency-tests.json`:
In this example: In this example:
- The `test_name` attributes is a unique identifier for the test. In `latency-tests.json`, it must start with `latency_`. - The `test_name` attributes is a unique identifier for the test. In `latency-tests.json`, it must start with `latency_`.
- The `parameters` attribute control the command line arguments to be used for `benchmark_latency.py`. Note that please use underline `_` instead of the dash `-` when specifying the command line arguments, and `run-performance-benchmarks.sh` will convert the underline to dash when feeding the arguments to `benchmark_latency.py`. For example, the corresponding command line arguments for `benchmark_latency.py` will be `--model meta-llama/Meta-Llama-3-8B --tensor-parallel-size 1 --load-format dummy --num-iters-warmup 5 --num-iters 15` - The `parameters` attribute control the command line arguments to be used for `vllm bench latency`. Note that please use underline `_` instead of the dash `-` when specifying the command line arguments, and `run-performance-benchmarks.sh` will convert the underline to dash when feeding the arguments to `vllm bench latency`. For example, the corresponding command line arguments for `vllm bench latency` will be `--model meta-llama/Meta-Llama-3-8B --tensor-parallel-size 1 --load-format dummy --num-iters-warmup 5 --num-iters 15`
Note that the performance numbers are highly sensitive to the value of the parameters. Please make sure the parameters are set correctly. Note that the performance numbers are highly sensitive to the value of the parameters. Please make sure the parameters are set correctly.
@ -82,13 +82,13 @@ WARNING: The benchmarking script will save json results by itself, so please do
### Throughput test ### Throughput test
The tests are specified in `throughput-tests.json`. The syntax is similar to `latency-tests.json`, except for that the parameters will be fed forward to `benchmark_throughput.py`. The tests are specified in `throughput-tests.json`. The syntax is similar to `latency-tests.json`, except for that the parameters will be fed forward to `vllm bench throughput`.
The number of this test is also stable -- a slight change on the value of this number might vary the performance numbers by a lot. The number of this test is also stable -- a slight change on the value of this number might vary the performance numbers by a lot.
### Serving test ### Serving test
We test the throughput by using `benchmark_serving.py` with request rate = inf to cover the online serving overhead. The corresponding parameters are in `serving-tests.json`, and here is an example: We test the throughput by using `vllm bench serve` with request rate = inf to cover the online serving overhead. The corresponding parameters are in `serving-tests.json`, and here is an example:
```json ```json
[ [
@ -118,8 +118,8 @@ Inside this example:
- The `test_name` attribute is also a unique identifier for the test. It must start with `serving_`. - The `test_name` attribute is also a unique identifier for the test. It must start with `serving_`.
- The `server-parameters` includes the command line arguments for vLLM server. - The `server-parameters` includes the command line arguments for vLLM server.
- The `client-parameters` includes the command line arguments for `benchmark_serving.py`. - The `client-parameters` includes the command line arguments for `vllm bench serve`.
- The `qps_list` controls the list of qps for test. It will be used to configure the `--request-rate` parameter in `benchmark_serving.py` - The `qps_list` controls the list of qps for test. It will be used to configure the `--request-rate` parameter in `vllm bench serve`
The number of this test is less stable compared to the delay and latency benchmarks (due to randomized sharegpt dataset sampling inside `benchmark_serving.py`), but a large change on this number (e.g. 5% change) still vary the output greatly. The number of this test is less stable compared to the delay and latency benchmarks (due to randomized sharegpt dataset sampling inside `benchmark_serving.py`), but a large change on this number (e.g. 5% change) still vary the output greatly.

View File

@ -100,7 +100,7 @@ if __name__ == "__main__":
raw_result = json.loads(f.read()) raw_result = json.loads(f.read())
if "serving" in str(test_file): if "serving" in str(test_file):
# this result is generated via `benchmark_serving.py` # this result is generated via `vllm bench serve` command
# attach the benchmarking command to raw_result # attach the benchmarking command to raw_result
try: try:
@ -120,7 +120,7 @@ if __name__ == "__main__":
continue continue
elif "latency" in f.name: elif "latency" in f.name:
# this result is generated via `benchmark_latency.py` # this result is generated via `vllm bench latency` command
# attach the benchmarking command to raw_result # attach the benchmarking command to raw_result
try: try:
@ -148,7 +148,7 @@ if __name__ == "__main__":
continue continue
elif "throughput" in f.name: elif "throughput" in f.name:
# this result is generated via `benchmark_throughput.py` # this result is generated via `vllm bench throughput` command
# attach the benchmarking command to raw_result # attach the benchmarking command to raw_result
try: try:

View File

@ -127,7 +127,7 @@ ensure_installed() {
} }
run_serving_tests() { run_serving_tests() {
# run serving tests using `benchmark_serving.py` # run serving tests using `vllm bench serve` command
# $1: a json file specifying serving test cases # $1: a json file specifying serving test cases
local serving_test_file local serving_test_file

View File

@ -165,7 +165,7 @@ upload_to_buildkite() {
} }
run_latency_tests() { run_latency_tests() {
# run latency tests using `benchmark_latency.py` # run latency tests using `vllm bench latency` command
# $1: a json file specifying latency test cases # $1: a json file specifying latency test cases
local latency_test_file local latency_test_file
@ -232,7 +232,7 @@ run_latency_tests() {
} }
run_throughput_tests() { run_throughput_tests() {
# run throughput tests using `benchmark_throughput.py` # run throughput tests using `vllm bench throughput`
# $1: a json file specifying throughput test cases # $1: a json file specifying throughput test cases
local throughput_test_file local throughput_test_file
@ -298,7 +298,7 @@ run_throughput_tests() {
} }
run_serving_tests() { run_serving_tests() {
# run serving tests using `benchmark_serving.py` # run serving tests using `vllm bench serve` command
# $1: a json file specifying serving test cases # $1: a json file specifying serving test cases
local serving_test_file local serving_test_file
@ -448,7 +448,7 @@ main() {
(which jq) || (apt-get update && apt-get -y install jq) (which jq) || (apt-get update && apt-get -y install jq)
(which lsof) || (apt-get update && apt-get install -y lsof) (which lsof) || (apt-get update && apt-get install -y lsof)
# get the current IP address, required by benchmark_serving.py # get the current IP address, required by `vllm bench serve` command
export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') export VLLM_HOST_IP=$(hostname -I | awk '{print $1}')
# turn of the reporting of the status of each request, to clean up the terminal output # turn of the reporting of the status of each request, to clean up the terminal output
export VLLM_LOGGING_LEVEL="WARNING" export VLLM_LOGGING_LEVEL="WARNING"

View File

@ -105,7 +105,7 @@ After the script finishes, you will find the results in a new, timestamped direc
- **Log Files**: The directory (`$BASE/auto-benchmark/YYYY_MM_DD_HH_MM/`) contains detailed logs for each run: - **Log Files**: The directory (`$BASE/auto-benchmark/YYYY_MM_DD_HH_MM/`) contains detailed logs for each run:
- `vllm_log_...txt`: The log output from the vLLM server for each parameter combination. - `vllm_log_...txt`: The log output from the vLLM server for each parameter combination.
- `bm_log_...txt`: The log output from the `benchmark_serving.py` script for each benchmark run. - `bm_log_...txt`: The log output from the `vllm bench serve` command for each benchmark run.
- **Final Result Summary**: A file named `result.txt` is created in the log directory. It contains a summary of each tested combination and concludes with the overall best parameters found. - **Final Result Summary**: A file named `result.txt` is created in the log directory. It contains a summary of each tested combination and concludes with the overall best parameters found.

View File

@ -3,7 +3,7 @@
# benchmark the overhead of disaggregated prefill. # benchmark the overhead of disaggregated prefill.
# methodology: # methodology:
# - send all request to prefill vLLM instance. It will buffer KV cache. # - send all request to prefill vLLM instance. It will buffer KV cache.
# - then send all request to decode instance. # - then send all request to decode instance.
# - The TTFT of decode instance is the overhead. # - The TTFT of decode instance is the overhead.
set -ex set -ex
@ -63,7 +63,7 @@ benchmark() {
--gpu-memory-utilization 0.6 \ --gpu-memory-utilization 0.6 \
--kv-transfer-config \ --kv-transfer-config \
'{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}' & '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}' &
CUDA_VISIBLE_DEVICES=1 python3 \ CUDA_VISIBLE_DEVICES=1 python3 \
-m vllm.entrypoints.openai.api_server \ -m vllm.entrypoints.openai.api_server \
@ -78,38 +78,38 @@ benchmark() {
wait_for_server 8200 wait_for_server 8200
# let the prefill instance finish prefill # let the prefill instance finish prefill
python3 ../benchmark_serving.py \ vllm bench serve \
--backend vllm \ --backend vllm \
--model $model \ --model $model \
--dataset-name $dataset_name \ --dataset-name $dataset_name \
--dataset-path $dataset_path \ --dataset-path $dataset_path \
--sonnet-input-len $input_len \ --sonnet-input-len $input_len \
--sonnet-output-len "$output_len" \ --sonnet-output-len "$output_len" \
--sonnet-prefix-len $prefix_len \ --sonnet-prefix-len $prefix_len \
--num-prompts $num_prompts \ --num-prompts $num_prompts \
--port 8100 \ --port 8100 \
--save-result \ --save-result \
--result-dir $results_folder \ --result-dir $results_folder \
--result-filename disagg_prefill_tp1.json \ --result-filename disagg_prefill_tp1.json \
--request-rate "inf" --request-rate "inf"
# send the request to decode. # send the request to decode.
# The TTFT of this command will be the overhead of disagg prefill impl. # The TTFT of this command will be the overhead of disagg prefill impl.
python3 ../benchmark_serving.py \ vllm bench serve \
--backend vllm \ --backend vllm \
--model $model \ --model $model \
--dataset-name $dataset_name \ --dataset-name $dataset_name \
--dataset-path $dataset_path \ --dataset-path $dataset_path \
--sonnet-input-len $input_len \ --sonnet-input-len $input_len \
--sonnet-output-len "$output_len" \ --sonnet-output-len "$output_len" \
--sonnet-prefix-len $prefix_len \ --sonnet-prefix-len $prefix_len \
--num-prompts $num_prompts \ --num-prompts $num_prompts \
--port 8200 \ --port 8200 \
--save-result \ --save-result \
--result-dir $results_folder \ --result-dir $results_folder \
--result-filename disagg_prefill_tp1_overhead.json \ --result-filename disagg_prefill_tp1_overhead.json \
--request-rate "$qps" --request-rate "$qps"
kill_gpu_processes kill_gpu_processes
} }

View File

@ -60,7 +60,7 @@ launch_chunked_prefill() {
launch_disagg_prefill() { launch_disagg_prefill() {
model="meta-llama/Meta-Llama-3.1-8B-Instruct" model="meta-llama/Meta-Llama-3.1-8B-Instruct"
# disagg prefill # disagg prefill
CUDA_VISIBLE_DEVICES=0 python3 \ CUDA_VISIBLE_DEVICES=0 python3 \
-m vllm.entrypoints.openai.api_server \ -m vllm.entrypoints.openai.api_server \
@ -99,20 +99,20 @@ benchmark() {
output_len=$2 output_len=$2
tag=$3 tag=$3
python3 ../benchmark_serving.py \ vllm bench serve \
--backend vllm \ --backend vllm \
--model $model \ --model $model \
--dataset-name $dataset_name \ --dataset-name $dataset_name \
--dataset-path $dataset_path \ --dataset-path $dataset_path \
--sonnet-input-len $input_len \ --sonnet-input-len $input_len \
--sonnet-output-len "$output_len" \ --sonnet-output-len "$output_len" \
--sonnet-prefix-len $prefix_len \ --sonnet-prefix-len $prefix_len \
--num-prompts $num_prompts \ --num-prompts $num_prompts \
--port 8000 \ --port 8000 \
--save-result \ --save-result \
--result-dir $results_folder \ --result-dir $results_folder \
--result-filename "$tag"-qps-"$qps".json \ --result-filename "$tag"-qps-"$qps".json \
--request-rate "$qps" --request-rate "$qps"
sleep 2 sleep 2
} }

View File

@ -9,10 +9,13 @@ We support tracing vLLM workers using the `torch.profiler` module. You can enabl
The OpenAI server also needs to be started with the `VLLM_TORCH_PROFILER_DIR` environment variable set. The OpenAI server also needs to be started with the `VLLM_TORCH_PROFILER_DIR` environment variable set.
When using `benchmarks/benchmark_serving.py`, you can enable profiling by passing the `--profile` flag. When using `vllm bench serve`, you can enable profiling by passing the `--profile` flag.
Traces can be visualized using <https://ui.perfetto.dev/>. Traces can be visualized using <https://ui.perfetto.dev/>.
!!! tip
You can directly call bench module without installing vllm using `python -m vllm.entrypoints.cli.main bench`.
!!! tip !!! tip
Only send a few requests through vLLM when profiling, as the traces can get quite large. Also, no need to untar the traces, they can be viewed directly. Only send a few requests through vLLM when profiling, as the traces can get quite large. Also, no need to untar the traces, they can be viewed directly.
@ -35,7 +38,7 @@ VLLM_TORCH_PROFILER_DIR=./vllm_profile \
--model meta-llama/Meta-Llama-3-70B --model meta-llama/Meta-Llama-3-70B
``` ```
benchmark_serving.py: vllm bench command:
```bash ```bash
vllm bench serve \ vllm bench serve \
@ -69,7 +72,7 @@ apt install nsight-systems-cli
For basic usage, you can just append `nsys profile -o report.nsys-rep --trace-fork-before-exec=true --cuda-graph-trace=node` before any existing script you would run for offline inference. For basic usage, you can just append `nsys profile -o report.nsys-rep --trace-fork-before-exec=true --cuda-graph-trace=node` before any existing script you would run for offline inference.
The following is an example using the `benchmarks/benchmark_latency.py` script: The following is an example using the `vllm bench latency` script:
```bash ```bash
nsys profile -o report.nsys-rep \ nsys profile -o report.nsys-rep \

View File

@ -28,7 +28,7 @@ Submit some sample requests to the server:
```bash ```bash
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
python3 ../../../benchmarks/benchmark_serving.py \ vllm bench serve \
--model mistralai/Mistral-7B-v0.1 \ --model mistralai/Mistral-7B-v0.1 \
--tokenizer mistralai/Mistral-7B-v0.1 \ --tokenizer mistralai/Mistral-7B-v0.1 \
--endpoint /v1/completions \ --endpoint /v1/completions \