[benchmark] add peak throughput metrics and plot (#23867)

Signed-off-by: simon-mo <simon.mo@hey.com>
This commit is contained in:
Simon Mo 2025-09-17 22:30:02 -07:00 committed by GitHub
parent b7433ca1a4
commit a904ea78ea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 134 additions and 69 deletions

View File

@ -89,6 +89,7 @@ class RequestFuncOutput:
tpot: float = 0.0 # avg next-token latencies tpot: float = 0.0 # avg next-token latencies
prompt_len: int = 0 prompt_len: int = 0
error: str = "" error: str = ""
start_time: float = 0.0
async def async_request_openai_completions( async def async_request_openai_completions(
@ -140,6 +141,7 @@ async def async_request_openai_completions(
generated_text = "" generated_text = ""
st = time.perf_counter() st = time.perf_counter()
output.start_time = st
most_recent_timestamp = st most_recent_timestamp = st
try: try:
async with session.post(url=api_url, json=payload, async with session.post(url=api_url, json=payload,
@ -272,6 +274,7 @@ async def async_request_openai_chat_completions(
generated_text = "" generated_text = ""
ttft = 0.0 ttft = 0.0
st = time.perf_counter() st = time.perf_counter()
output.start_time = st
most_recent_timestamp = st most_recent_timestamp = st
try: try:
async with session.post(url=api_url, json=payload, async with session.post(url=api_url, json=payload,
@ -396,6 +399,7 @@ async def async_request_openai_audio(
generated_text = "" generated_text = ""
ttft = 0.0 ttft = 0.0
st = time.perf_counter() st = time.perf_counter()
output.start_time = st
most_recent_timestamp = st most_recent_timestamp = st
try: try:
async with session.post(url=api_url, async with session.post(url=api_url,
@ -475,6 +479,7 @@ async def async_request_openai_embeddings(
output = RequestFuncOutput() output = RequestFuncOutput()
st = time.perf_counter() st = time.perf_counter()
output.start_time = st
try: try:
async with session.post( async with session.post(
url=api_url, url=api_url,

View File

@ -18,9 +18,11 @@ On the client side, run:
import argparse import argparse
import asyncio import asyncio
import gc import gc
import importlib.util
import json import json
import os import os
import random import random
import shutil
import time import time
import warnings import warnings
from collections.abc import AsyncGenerator, Iterable from collections.abc import AsyncGenerator, Iterable
@ -46,6 +48,9 @@ from vllm.transformers_utils.tokenizer import get_tokenizer
MILLISECONDS_TO_SECONDS_CONVERSION = 1000 MILLISECONDS_TO_SECONDS_CONVERSION = 1000
TERM_PLOTLIB_AVAILABLE = ((importlib.util.find_spec("termplotlib") is not None)
and (shutil.which("gnuplot") is not None))
class TaskType(Enum): class TaskType(Enum):
GENERATION = "generation" GENERATION = "generation"
@ -80,18 +85,23 @@ class BenchmarkMetrics:
median_e2el_ms: float median_e2el_ms: float
std_e2el_ms: float std_e2el_ms: float
percentiles_e2el_ms: list[tuple[float, float]] percentiles_e2el_ms: list[tuple[float, float]]
# Max output tokens per second and concurrent requests at that peak
max_output_tokens_per_s: float
max_concurrent_requests: int
@dataclass @dataclass
class EmbedBenchmarkMetrics: class EmbedBenchmarkMetrics:
completed: int completed: int
total_input: int total_input: int
request_throughput: float request_throughput: float
total_token_throughput :float total_token_throughput: float
mean_e2el_ms: float mean_e2el_ms: float
std_e2el_ms: float std_e2el_ms: float
median_e2el_ms: float median_e2el_ms: float
percentiles_e2el_ms: float percentiles_e2el_ms: float
def _get_current_request_rate( def _get_current_request_rate(
ramp_up_strategy: Optional[Literal["linear", "exponential"]], ramp_up_strategy: Optional[Literal["linear", "exponential"]],
ramp_up_start_rps: Optional[int], ramp_up_start_rps: Optional[int],
@ -150,8 +160,8 @@ async def get_request(
assert burstiness > 0, ( assert burstiness > 0, (
f"A positive burstiness factor is expected, but given {burstiness}.") f"A positive burstiness factor is expected, but given {burstiness}.")
# Convert to list to get length for ramp-up calculations # Convert to list to get length for ramp-up calculations
if isinstance(input_requests, Iterable) and not isinstance( if isinstance(input_requests,
input_requests, list): Iterable) and not isinstance(input_requests, list):
input_requests = list(input_requests) input_requests = list(input_requests)
total_requests = len(input_requests) total_requests = len(input_requests)
@ -161,12 +171,9 @@ async def get_request(
request_rates = [] request_rates = []
delay_ts = [] delay_ts = []
for request_index, request in enumerate(input_requests): for request_index, request in enumerate(input_requests):
current_request_rate = _get_current_request_rate(ramp_up_strategy, current_request_rate = _get_current_request_rate(
ramp_up_start_rps, ramp_up_strategy, ramp_up_start_rps, ramp_up_end_rps,
ramp_up_end_rps, request_index, total_requests, request_rate)
request_index,
total_requests,
request_rate)
request_rates.append(current_request_rate) request_rates.append(current_request_rate)
if current_request_rate == float("inf"): if current_request_rate == float("inf"):
delay_ts.append(0) delay_ts.append(0)
@ -206,10 +213,8 @@ async def get_request(
def calculate_metrics_for_embeddings( def calculate_metrics_for_embeddings(
outputs: list[RequestFuncOutput], outputs: list[RequestFuncOutput], dur_s: float,
dur_s: float, selected_percentiles: list[float]) -> EmbedBenchmarkMetrics:
selected_percentiles: list[float]
) -> EmbedBenchmarkMetrics:
"""Calculate the metrics for the embedding requests. """Calculate the metrics for the embedding requests.
Args: Args:
@ -242,10 +247,8 @@ def calculate_metrics_for_embeddings(
mean_e2el_ms=np.mean(e2els or 0) * 1000, mean_e2el_ms=np.mean(e2els or 0) * 1000,
std_e2el_ms=np.std(e2els or 0) * 1000, std_e2el_ms=np.std(e2els or 0) * 1000,
median_e2el_ms=np.median(e2els or 0) * 1000, median_e2el_ms=np.median(e2els or 0) * 1000,
percentiles_e2el_ms=[ percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000)
(p, np.percentile(e2els or 0, p) * 1000) for p in selected_percentiles],
for p in selected_percentiles
],
) )
return metrics return metrics
@ -336,6 +339,67 @@ def calculate_metrics(
"All requests failed. This is likely due to a misconfiguration " "All requests failed. This is likely due to a misconfiguration "
"on the benchmark arguments.", "on the benchmark arguments.",
stacklevel=2) stacklevel=2)
# Calculate max output tokens per second metric
max_output_tokens_per_s = 0.0
max_concurrent_requests = 0
# Find the time range across all successful requests
successful_outputs = [output for output in outputs if output.success]
if successful_outputs:
min_start_time = min(output.start_time
for output in successful_outputs)
max_end_time = max(output.start_time + output.latency
for output in successful_outputs)
# Create second buckets (ceiling to ensure we capture all time)
duration_seconds = int(np.ceil(max_end_time - min_start_time)) + 1
tokens_per_second = np.zeros(duration_seconds)
concurrent_requests_per_second = np.zeros(duration_seconds)
for i, output in enumerate(successful_outputs):
# Calculate token generation timestamp using
# start_time, ttft, and itl
token_times = [output.start_time + output.ttft]
current_time = token_times[0]
for itl_value in output.itl:
current_time += itl_value
token_times.append(current_time)
# Add tokens to second buckets
for token_time in token_times:
second_bucket = int(token_time - min_start_time)
if 0 <= second_bucket < duration_seconds:
tokens_per_second[second_bucket] += 1
# Track concurrent requests for each second this request was active
request_start_second = int(output.start_time - min_start_time)
request_end_second = int((output.start_time + output.latency) -
min_start_time)
for second in range(request_start_second, request_end_second + 1):
concurrent_requests_per_second[second] += 1
# Find the maximum tokens per second and corresponding
# concurrent requests
if len(tokens_per_second) > 0:
max_output_tokens_per_s = float(np.max(tokens_per_second))
max_concurrent_requests = int(
np.max(concurrent_requests_per_second))
if TERM_PLOTLIB_AVAILABLE:
import termplotlib as tpl
fig = tpl.figure()
fig.plot(np.arange(len(tokens_per_second)),
tokens_per_second,
title="Output tokens per second")
fig.plot(np.arange(len(concurrent_requests_per_second)),
concurrent_requests_per_second,
title="Concurrent requests per second")
fig.show()
else:
print("tip: install termplotlib and gnuplot to plot the metrics")
metrics = BenchmarkMetrics( metrics = BenchmarkMetrics(
completed=completed, completed=completed,
total_input=total_input, total_input=total_input,
@ -365,6 +429,8 @@ def calculate_metrics(
median_e2el_ms=np.median(e2els or 0) * 1000, median_e2el_ms=np.median(e2els or 0) * 1000,
percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000) percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000)
for p in selected_percentiles], for p in selected_percentiles],
max_output_tokens_per_s=max_output_tokens_per_s,
max_concurrent_requests=max_concurrent_requests,
) )
return metrics, actual_output_lens return metrics, actual_output_lens
@ -396,11 +462,8 @@ async def benchmark(
ramp_up_end_rps: Optional[int] = None, ramp_up_end_rps: Optional[int] = None,
ready_check_timeout_sec: int = 600, ready_check_timeout_sec: int = 600,
): ):
task_type = ( task_type = (TaskType.EMBEDDING if api_url.endswith("/v1/embeddings") else
TaskType.EMBEDDING TaskType.GENERATION)
if api_url.endswith("/v1/embeddings")
else TaskType.GENERATION
)
if endpoint_type in ASYNC_REQUEST_FUNCS: if endpoint_type in ASYNC_REQUEST_FUNCS:
if task_type == TaskType.EMBEDDING: if task_type == TaskType.EMBEDDING:
request_func = ASYNC_REQUEST_FUNCS["openai-embeddings"] request_func = ASYNC_REQUEST_FUNCS["openai-embeddings"]
@ -435,14 +498,10 @@ async def benchmark(
input_requests[0].multi_modal_data, input_requests[0].multi_modal_data,
) )
assert ( assert (test_mm_content is None or isinstance(test_mm_content, dict)
test_mm_content is None or (isinstance(test_mm_content, list)
or isinstance(test_mm_content, dict) and all(isinstance(item, dict) for item in test_mm_content))
or ( ), "multi_modal_data must be a dict or list[dict]"
isinstance(test_mm_content, list)
and all(isinstance(item, dict) for item in test_mm_content)
)
), "multi_modal_data must be a dict or list[dict]"
test_input = RequestFuncInput( test_input = RequestFuncInput(
model=model_id, model=model_id,
model_name=model_name, model_name=model_name,
@ -488,13 +547,13 @@ async def benchmark(
ignore_eos=ignore_eos, ignore_eos=ignore_eos,
extra_headers=extra_headers, extra_headers=extra_headers,
extra_body=extra_body) extra_body=extra_body)
profile_output = await request_func( profile_output = await request_func(request_func_input=profile_input,
request_func_input=profile_input, session=session) session=session)
if profile_output.success: if profile_output.success:
print("Profiler started") print("Profiler started")
distribution = ("Poisson process" if burstiness == 1.0 distribution = ("Poisson process"
else "Gamma distribution") if burstiness == 1.0 else "Gamma distribution")
if ramp_up_strategy is not None: if ramp_up_strategy is not None:
print(f"Traffic ramp-up strategy: {ramp_up_strategy}.") print(f"Traffic ramp-up strategy: {ramp_up_strategy}.")
@ -562,18 +621,20 @@ async def benchmark(
req_lora_module = next(lora_modules) req_lora_module = next(lora_modules)
req_model_id, req_model_name = req_lora_module, req_lora_module req_model_id, req_model_name = req_lora_module, req_lora_module
request_func_input = RequestFuncInput(model=req_model_id, request_func_input = RequestFuncInput(
model_name=req_model_name, model=req_model_id,
prompt=prompt, model_name=req_model_name,
api_url=api_url, prompt=prompt,
prompt_len=prompt_len, api_url=api_url,
output_len=output_len, prompt_len=prompt_len,
logprobs=logprobs, output_len=output_len,
multi_modal_content=mm_content, logprobs=logprobs,
ignore_eos=ignore_eos, multi_modal_content=mm_content,
extra_headers=extra_headers, ignore_eos=ignore_eos,
extra_body=extra_body, extra_headers=extra_headers,
request_id=request_id,) extra_body=extra_body,
request_id=request_id,
)
tasks.append( tasks.append(
asyncio.create_task( asyncio.create_task(
limited_request_func(request_func_input=request_func_input, limited_request_func(request_func_input=request_func_input,
@ -615,19 +676,21 @@ async def benchmark(
benchmark_duration)) benchmark_duration))
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
if isinstance(metrics, BenchmarkMetrics): if isinstance(metrics, BenchmarkMetrics):
print("{:<40} {:<10}".format( print("{:<40} {:<10}".format("Total generated tokens:",
"Total generated tokens:", metrics.total_output)) metrics.total_output))
print("{:<40} {:<10.2f}".format("Request throughput (req/s):", print("{:<40} {:<10.2f}".format("Request throughput (req/s):",
metrics.request_throughput)) metrics.request_throughput))
if goodput_config_dict: if goodput_config_dict:
print("{:<40} {:<10.2f}".format("Request goodput (req/s):", print("{:<40} {:<10.2f}".format("Request goodput (req/s):",
metrics.request_goodput)) metrics.request_goodput))
if isinstance(metrics, BenchmarkMetrics): if isinstance(metrics, BenchmarkMetrics):
print( print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
"{:<40} {:<10.2f}".format( metrics.output_throughput))
"Output token throughput (tok/s):", metrics.output_throughput print("{:<40} {:<10.2f}".format(
) "Peak output token throughput (tok/s):",
) metrics.max_output_tokens_per_s))
print("{:<40} {:<10.2f}".format("Peak concurrent requests:",
metrics.max_concurrent_requests))
print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):", print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):",
metrics.total_token_throughput)) metrics.total_token_throughput))
@ -648,6 +711,8 @@ async def benchmark(
"itls": [output.itl for output in outputs], "itls": [output.itl for output in outputs],
"generated_texts": [output.generated_text for output in outputs], "generated_texts": [output.generated_text for output in outputs],
"errors": [output.error for output in outputs], "errors": [output.error for output in outputs],
"max_output_tokens_per_s": metrics.max_output_tokens_per_s,
"max_concurrent_requests": metrics.max_concurrent_requests,
} }
else: else:
result = { result = {
@ -697,8 +762,8 @@ async def benchmark(
if task_type == TaskType.GENERATION: if task_type == TaskType.GENERATION:
process_one_metric("ttft", "TTFT", "Time to First Token") process_one_metric("ttft", "TTFT", "Time to First Token")
process_one_metric( process_one_metric("tpot", "TPOT",
"tpot", "TPOT", "Time per Output Token (excl. 1st token)") "Time per Output Token (excl. 1st token)")
process_one_metric("itl", "ITL", "Inter-token Latency") process_one_metric("itl", "ITL", "Inter-token Latency")
process_one_metric("e2el", "E2EL", "End-to-end Latency") process_one_metric("e2el", "E2EL", "End-to-end Latency")
@ -714,8 +779,8 @@ async def benchmark(
output_len=test_output_len, output_len=test_output_len,
logprobs=logprobs, logprobs=logprobs,
) )
profile_output = await request_func( profile_output = await request_func(request_func_input=profile_input,
request_func_input=profile_input, session=session) session=session)
if profile_output.success: if profile_output.success:
print("Profiler stopped") print("Profiler stopped")
@ -851,7 +916,8 @@ def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--tokenizer", "--tokenizer",
type=str, type=str,
help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 help=
"Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
) )
parser.add_argument("--use-beam-search", action="store_true") parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument( parser.add_argument(
@ -982,7 +1048,6 @@ def add_cli_args(parser: argparse.ArgumentParser):
help="Specify the prefix of request id.", help="Specify the prefix of request id.",
) )
sampling_group = parser.add_argument_group("sampling parameters") sampling_group = parser.add_argument_group("sampling parameters")
sampling_group.add_argument( sampling_group.add_argument(
"--top-p", "--top-p",
@ -1047,8 +1112,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
help="The ramp-up strategy. This would be used to " help="The ramp-up strategy. This would be used to "
"ramp up the request rate from initial RPS to final " "ramp up the request rate from initial RPS to final "
"RPS rate (specified by --ramp-up-start-rps and " "RPS rate (specified by --ramp-up-start-rps and "
"--ramp-up-end-rps.) over the duration of the benchmark." "--ramp-up-end-rps.) over the duration of the benchmark.")
)
parser.add_argument( parser.add_argument(
"--ramp-up-start-rps", "--ramp-up-start-rps",
type=int, type=int,
@ -1087,13 +1151,11 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
raise ValueError( raise ValueError(
"When using ramp-up, do not specify --request-rate. " "When using ramp-up, do not specify --request-rate. "
"The request rate will be controlled by ramp-up parameters. " "The request rate will be controlled by ramp-up parameters. "
"Please remove the --request-rate argument." "Please remove the --request-rate argument.")
)
if args.ramp_up_start_rps is None or args.ramp_up_end_rps is None: if args.ramp_up_start_rps is None or args.ramp_up_end_rps is None:
raise ValueError( raise ValueError(
"When using --ramp-up-strategy, both --ramp-up-start-rps and " "When using --ramp-up-strategy, both --ramp-up-start-rps and "
"--ramp-up-end-rps must be specified" "--ramp-up-end-rps must be specified")
)
if args.ramp_up_start_rps < 0 or args.ramp_up_end_rps < 0: if args.ramp_up_start_rps < 0 or args.ramp_up_end_rps < 0:
raise ValueError("Ramp-up start and end RPS must be non-negative") raise ValueError("Ramp-up start and end RPS must be non-negative")
if args.ramp_up_start_rps > args.ramp_up_end_rps: if args.ramp_up_start_rps > args.ramp_up_end_rps:
@ -1127,8 +1189,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
headers[kvstring[0].strip()] = kvstring[1].strip() headers[kvstring[0].strip()] = kvstring[1].strip()
else: else:
raise ValueError( raise ValueError(
"Invalid header format. Please use KEY=VALUE format." "Invalid header format. Please use KEY=VALUE format.")
)
tokenizer = get_tokenizer(tokenizer_id, tokenizer = get_tokenizer(tokenizer_id,
tokenizer_mode=tokenizer_mode, tokenizer_mode=tokenizer_mode,
@ -1215,8 +1276,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
result_json[kvstring[0].strip()] = kvstring[1].strip() result_json[kvstring[0].strip()] = kvstring[1].strip()
else: else:
raise ValueError( raise ValueError(
"Invalid metadata format. Please use KEY=VALUE format." "Invalid metadata format. Please use KEY=VALUE format.")
)
# Traffic # Traffic
result_json["request_rate"] = (args.request_rate if args.request_rate result_json["request_rate"] = (args.request_rate if args.request_rate