mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 06:35:00 +08:00
[benchmark] add peak throughput metrics and plot (#23867)
Signed-off-by: simon-mo <simon.mo@hey.com>
This commit is contained in:
parent
b7433ca1a4
commit
a904ea78ea
@ -89,6 +89,7 @@ class RequestFuncOutput:
|
|||||||
tpot: float = 0.0 # avg next-token latencies
|
tpot: float = 0.0 # avg next-token latencies
|
||||||
prompt_len: int = 0
|
prompt_len: int = 0
|
||||||
error: str = ""
|
error: str = ""
|
||||||
|
start_time: float = 0.0
|
||||||
|
|
||||||
|
|
||||||
async def async_request_openai_completions(
|
async def async_request_openai_completions(
|
||||||
@ -140,6 +141,7 @@ async def async_request_openai_completions(
|
|||||||
|
|
||||||
generated_text = ""
|
generated_text = ""
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
|
output.start_time = st
|
||||||
most_recent_timestamp = st
|
most_recent_timestamp = st
|
||||||
try:
|
try:
|
||||||
async with session.post(url=api_url, json=payload,
|
async with session.post(url=api_url, json=payload,
|
||||||
@ -272,6 +274,7 @@ async def async_request_openai_chat_completions(
|
|||||||
generated_text = ""
|
generated_text = ""
|
||||||
ttft = 0.0
|
ttft = 0.0
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
|
output.start_time = st
|
||||||
most_recent_timestamp = st
|
most_recent_timestamp = st
|
||||||
try:
|
try:
|
||||||
async with session.post(url=api_url, json=payload,
|
async with session.post(url=api_url, json=payload,
|
||||||
@ -396,6 +399,7 @@ async def async_request_openai_audio(
|
|||||||
generated_text = ""
|
generated_text = ""
|
||||||
ttft = 0.0
|
ttft = 0.0
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
|
output.start_time = st
|
||||||
most_recent_timestamp = st
|
most_recent_timestamp = st
|
||||||
try:
|
try:
|
||||||
async with session.post(url=api_url,
|
async with session.post(url=api_url,
|
||||||
@ -475,6 +479,7 @@ async def async_request_openai_embeddings(
|
|||||||
|
|
||||||
output = RequestFuncOutput()
|
output = RequestFuncOutput()
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
|
output.start_time = st
|
||||||
try:
|
try:
|
||||||
async with session.post(
|
async with session.post(
|
||||||
url=api_url,
|
url=api_url,
|
||||||
|
|||||||
@ -18,9 +18,11 @@ On the client side, run:
|
|||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
import gc
|
import gc
|
||||||
|
import importlib.util
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
import shutil
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
from collections.abc import AsyncGenerator, Iterable
|
from collections.abc import AsyncGenerator, Iterable
|
||||||
@ -46,6 +48,9 @@ from vllm.transformers_utils.tokenizer import get_tokenizer
|
|||||||
|
|
||||||
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
|
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
|
||||||
|
|
||||||
|
TERM_PLOTLIB_AVAILABLE = ((importlib.util.find_spec("termplotlib") is not None)
|
||||||
|
and (shutil.which("gnuplot") is not None))
|
||||||
|
|
||||||
|
|
||||||
class TaskType(Enum):
|
class TaskType(Enum):
|
||||||
GENERATION = "generation"
|
GENERATION = "generation"
|
||||||
@ -80,18 +85,23 @@ class BenchmarkMetrics:
|
|||||||
median_e2el_ms: float
|
median_e2el_ms: float
|
||||||
std_e2el_ms: float
|
std_e2el_ms: float
|
||||||
percentiles_e2el_ms: list[tuple[float, float]]
|
percentiles_e2el_ms: list[tuple[float, float]]
|
||||||
|
# Max output tokens per second and concurrent requests at that peak
|
||||||
|
max_output_tokens_per_s: float
|
||||||
|
max_concurrent_requests: int
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EmbedBenchmarkMetrics:
|
class EmbedBenchmarkMetrics:
|
||||||
completed: int
|
completed: int
|
||||||
total_input: int
|
total_input: int
|
||||||
request_throughput: float
|
request_throughput: float
|
||||||
total_token_throughput :float
|
total_token_throughput: float
|
||||||
mean_e2el_ms: float
|
mean_e2el_ms: float
|
||||||
std_e2el_ms: float
|
std_e2el_ms: float
|
||||||
median_e2el_ms: float
|
median_e2el_ms: float
|
||||||
percentiles_e2el_ms: float
|
percentiles_e2el_ms: float
|
||||||
|
|
||||||
|
|
||||||
def _get_current_request_rate(
|
def _get_current_request_rate(
|
||||||
ramp_up_strategy: Optional[Literal["linear", "exponential"]],
|
ramp_up_strategy: Optional[Literal["linear", "exponential"]],
|
||||||
ramp_up_start_rps: Optional[int],
|
ramp_up_start_rps: Optional[int],
|
||||||
@ -150,8 +160,8 @@ async def get_request(
|
|||||||
assert burstiness > 0, (
|
assert burstiness > 0, (
|
||||||
f"A positive burstiness factor is expected, but given {burstiness}.")
|
f"A positive burstiness factor is expected, but given {burstiness}.")
|
||||||
# Convert to list to get length for ramp-up calculations
|
# Convert to list to get length for ramp-up calculations
|
||||||
if isinstance(input_requests, Iterable) and not isinstance(
|
if isinstance(input_requests,
|
||||||
input_requests, list):
|
Iterable) and not isinstance(input_requests, list):
|
||||||
input_requests = list(input_requests)
|
input_requests = list(input_requests)
|
||||||
|
|
||||||
total_requests = len(input_requests)
|
total_requests = len(input_requests)
|
||||||
@ -161,12 +171,9 @@ async def get_request(
|
|||||||
request_rates = []
|
request_rates = []
|
||||||
delay_ts = []
|
delay_ts = []
|
||||||
for request_index, request in enumerate(input_requests):
|
for request_index, request in enumerate(input_requests):
|
||||||
current_request_rate = _get_current_request_rate(ramp_up_strategy,
|
current_request_rate = _get_current_request_rate(
|
||||||
ramp_up_start_rps,
|
ramp_up_strategy, ramp_up_start_rps, ramp_up_end_rps,
|
||||||
ramp_up_end_rps,
|
request_index, total_requests, request_rate)
|
||||||
request_index,
|
|
||||||
total_requests,
|
|
||||||
request_rate)
|
|
||||||
request_rates.append(current_request_rate)
|
request_rates.append(current_request_rate)
|
||||||
if current_request_rate == float("inf"):
|
if current_request_rate == float("inf"):
|
||||||
delay_ts.append(0)
|
delay_ts.append(0)
|
||||||
@ -206,10 +213,8 @@ async def get_request(
|
|||||||
|
|
||||||
|
|
||||||
def calculate_metrics_for_embeddings(
|
def calculate_metrics_for_embeddings(
|
||||||
outputs: list[RequestFuncOutput],
|
outputs: list[RequestFuncOutput], dur_s: float,
|
||||||
dur_s: float,
|
selected_percentiles: list[float]) -> EmbedBenchmarkMetrics:
|
||||||
selected_percentiles: list[float]
|
|
||||||
) -> EmbedBenchmarkMetrics:
|
|
||||||
"""Calculate the metrics for the embedding requests.
|
"""Calculate the metrics for the embedding requests.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -242,10 +247,8 @@ def calculate_metrics_for_embeddings(
|
|||||||
mean_e2el_ms=np.mean(e2els or 0) * 1000,
|
mean_e2el_ms=np.mean(e2els or 0) * 1000,
|
||||||
std_e2el_ms=np.std(e2els or 0) * 1000,
|
std_e2el_ms=np.std(e2els or 0) * 1000,
|
||||||
median_e2el_ms=np.median(e2els or 0) * 1000,
|
median_e2el_ms=np.median(e2els or 0) * 1000,
|
||||||
percentiles_e2el_ms=[
|
percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000)
|
||||||
(p, np.percentile(e2els or 0, p) * 1000)
|
for p in selected_percentiles],
|
||||||
for p in selected_percentiles
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
@ -336,6 +339,67 @@ def calculate_metrics(
|
|||||||
"All requests failed. This is likely due to a misconfiguration "
|
"All requests failed. This is likely due to a misconfiguration "
|
||||||
"on the benchmark arguments.",
|
"on the benchmark arguments.",
|
||||||
stacklevel=2)
|
stacklevel=2)
|
||||||
|
|
||||||
|
# Calculate max output tokens per second metric
|
||||||
|
max_output_tokens_per_s = 0.0
|
||||||
|
max_concurrent_requests = 0
|
||||||
|
|
||||||
|
# Find the time range across all successful requests
|
||||||
|
successful_outputs = [output for output in outputs if output.success]
|
||||||
|
if successful_outputs:
|
||||||
|
min_start_time = min(output.start_time
|
||||||
|
for output in successful_outputs)
|
||||||
|
max_end_time = max(output.start_time + output.latency
|
||||||
|
for output in successful_outputs)
|
||||||
|
|
||||||
|
# Create second buckets (ceiling to ensure we capture all time)
|
||||||
|
duration_seconds = int(np.ceil(max_end_time - min_start_time)) + 1
|
||||||
|
tokens_per_second = np.zeros(duration_seconds)
|
||||||
|
concurrent_requests_per_second = np.zeros(duration_seconds)
|
||||||
|
|
||||||
|
for i, output in enumerate(successful_outputs):
|
||||||
|
# Calculate token generation timestamp using
|
||||||
|
# start_time, ttft, and itl
|
||||||
|
token_times = [output.start_time + output.ttft]
|
||||||
|
current_time = token_times[0]
|
||||||
|
for itl_value in output.itl:
|
||||||
|
current_time += itl_value
|
||||||
|
token_times.append(current_time)
|
||||||
|
|
||||||
|
# Add tokens to second buckets
|
||||||
|
for token_time in token_times:
|
||||||
|
second_bucket = int(token_time - min_start_time)
|
||||||
|
if 0 <= second_bucket < duration_seconds:
|
||||||
|
tokens_per_second[second_bucket] += 1
|
||||||
|
|
||||||
|
# Track concurrent requests for each second this request was active
|
||||||
|
request_start_second = int(output.start_time - min_start_time)
|
||||||
|
request_end_second = int((output.start_time + output.latency) -
|
||||||
|
min_start_time)
|
||||||
|
|
||||||
|
for second in range(request_start_second, request_end_second + 1):
|
||||||
|
concurrent_requests_per_second[second] += 1
|
||||||
|
|
||||||
|
# Find the maximum tokens per second and corresponding
|
||||||
|
# concurrent requests
|
||||||
|
if len(tokens_per_second) > 0:
|
||||||
|
max_output_tokens_per_s = float(np.max(tokens_per_second))
|
||||||
|
max_concurrent_requests = int(
|
||||||
|
np.max(concurrent_requests_per_second))
|
||||||
|
|
||||||
|
if TERM_PLOTLIB_AVAILABLE:
|
||||||
|
import termplotlib as tpl
|
||||||
|
fig = tpl.figure()
|
||||||
|
fig.plot(np.arange(len(tokens_per_second)),
|
||||||
|
tokens_per_second,
|
||||||
|
title="Output tokens per second")
|
||||||
|
fig.plot(np.arange(len(concurrent_requests_per_second)),
|
||||||
|
concurrent_requests_per_second,
|
||||||
|
title="Concurrent requests per second")
|
||||||
|
fig.show()
|
||||||
|
else:
|
||||||
|
print("tip: install termplotlib and gnuplot to plot the metrics")
|
||||||
|
|
||||||
metrics = BenchmarkMetrics(
|
metrics = BenchmarkMetrics(
|
||||||
completed=completed,
|
completed=completed,
|
||||||
total_input=total_input,
|
total_input=total_input,
|
||||||
@ -365,6 +429,8 @@ def calculate_metrics(
|
|||||||
median_e2el_ms=np.median(e2els or 0) * 1000,
|
median_e2el_ms=np.median(e2els or 0) * 1000,
|
||||||
percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000)
|
percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000)
|
||||||
for p in selected_percentiles],
|
for p in selected_percentiles],
|
||||||
|
max_output_tokens_per_s=max_output_tokens_per_s,
|
||||||
|
max_concurrent_requests=max_concurrent_requests,
|
||||||
)
|
)
|
||||||
|
|
||||||
return metrics, actual_output_lens
|
return metrics, actual_output_lens
|
||||||
@ -396,11 +462,8 @@ async def benchmark(
|
|||||||
ramp_up_end_rps: Optional[int] = None,
|
ramp_up_end_rps: Optional[int] = None,
|
||||||
ready_check_timeout_sec: int = 600,
|
ready_check_timeout_sec: int = 600,
|
||||||
):
|
):
|
||||||
task_type = (
|
task_type = (TaskType.EMBEDDING if api_url.endswith("/v1/embeddings") else
|
||||||
TaskType.EMBEDDING
|
TaskType.GENERATION)
|
||||||
if api_url.endswith("/v1/embeddings")
|
|
||||||
else TaskType.GENERATION
|
|
||||||
)
|
|
||||||
if endpoint_type in ASYNC_REQUEST_FUNCS:
|
if endpoint_type in ASYNC_REQUEST_FUNCS:
|
||||||
if task_type == TaskType.EMBEDDING:
|
if task_type == TaskType.EMBEDDING:
|
||||||
request_func = ASYNC_REQUEST_FUNCS["openai-embeddings"]
|
request_func = ASYNC_REQUEST_FUNCS["openai-embeddings"]
|
||||||
@ -435,14 +498,10 @@ async def benchmark(
|
|||||||
input_requests[0].multi_modal_data,
|
input_requests[0].multi_modal_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert (
|
assert (test_mm_content is None or isinstance(test_mm_content, dict)
|
||||||
test_mm_content is None
|
or (isinstance(test_mm_content, list)
|
||||||
or isinstance(test_mm_content, dict)
|
and all(isinstance(item, dict) for item in test_mm_content))
|
||||||
or (
|
), "multi_modal_data must be a dict or list[dict]"
|
||||||
isinstance(test_mm_content, list)
|
|
||||||
and all(isinstance(item, dict) for item in test_mm_content)
|
|
||||||
)
|
|
||||||
), "multi_modal_data must be a dict or list[dict]"
|
|
||||||
test_input = RequestFuncInput(
|
test_input = RequestFuncInput(
|
||||||
model=model_id,
|
model=model_id,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
@ -488,13 +547,13 @@ async def benchmark(
|
|||||||
ignore_eos=ignore_eos,
|
ignore_eos=ignore_eos,
|
||||||
extra_headers=extra_headers,
|
extra_headers=extra_headers,
|
||||||
extra_body=extra_body)
|
extra_body=extra_body)
|
||||||
profile_output = await request_func(
|
profile_output = await request_func(request_func_input=profile_input,
|
||||||
request_func_input=profile_input, session=session)
|
session=session)
|
||||||
if profile_output.success:
|
if profile_output.success:
|
||||||
print("Profiler started")
|
print("Profiler started")
|
||||||
|
|
||||||
distribution = ("Poisson process" if burstiness == 1.0
|
distribution = ("Poisson process"
|
||||||
else "Gamma distribution")
|
if burstiness == 1.0 else "Gamma distribution")
|
||||||
|
|
||||||
if ramp_up_strategy is not None:
|
if ramp_up_strategy is not None:
|
||||||
print(f"Traffic ramp-up strategy: {ramp_up_strategy}.")
|
print(f"Traffic ramp-up strategy: {ramp_up_strategy}.")
|
||||||
@ -562,18 +621,20 @@ async def benchmark(
|
|||||||
req_lora_module = next(lora_modules)
|
req_lora_module = next(lora_modules)
|
||||||
req_model_id, req_model_name = req_lora_module, req_lora_module
|
req_model_id, req_model_name = req_lora_module, req_lora_module
|
||||||
|
|
||||||
request_func_input = RequestFuncInput(model=req_model_id,
|
request_func_input = RequestFuncInput(
|
||||||
model_name=req_model_name,
|
model=req_model_id,
|
||||||
prompt=prompt,
|
model_name=req_model_name,
|
||||||
api_url=api_url,
|
prompt=prompt,
|
||||||
prompt_len=prompt_len,
|
api_url=api_url,
|
||||||
output_len=output_len,
|
prompt_len=prompt_len,
|
||||||
logprobs=logprobs,
|
output_len=output_len,
|
||||||
multi_modal_content=mm_content,
|
logprobs=logprobs,
|
||||||
ignore_eos=ignore_eos,
|
multi_modal_content=mm_content,
|
||||||
extra_headers=extra_headers,
|
ignore_eos=ignore_eos,
|
||||||
extra_body=extra_body,
|
extra_headers=extra_headers,
|
||||||
request_id=request_id,)
|
extra_body=extra_body,
|
||||||
|
request_id=request_id,
|
||||||
|
)
|
||||||
tasks.append(
|
tasks.append(
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
limited_request_func(request_func_input=request_func_input,
|
limited_request_func(request_func_input=request_func_input,
|
||||||
@ -615,19 +676,21 @@ async def benchmark(
|
|||||||
benchmark_duration))
|
benchmark_duration))
|
||||||
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
|
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
|
||||||
if isinstance(metrics, BenchmarkMetrics):
|
if isinstance(metrics, BenchmarkMetrics):
|
||||||
print("{:<40} {:<10}".format(
|
print("{:<40} {:<10}".format("Total generated tokens:",
|
||||||
"Total generated tokens:", metrics.total_output))
|
metrics.total_output))
|
||||||
print("{:<40} {:<10.2f}".format("Request throughput (req/s):",
|
print("{:<40} {:<10.2f}".format("Request throughput (req/s):",
|
||||||
metrics.request_throughput))
|
metrics.request_throughput))
|
||||||
if goodput_config_dict:
|
if goodput_config_dict:
|
||||||
print("{:<40} {:<10.2f}".format("Request goodput (req/s):",
|
print("{:<40} {:<10.2f}".format("Request goodput (req/s):",
|
||||||
metrics.request_goodput))
|
metrics.request_goodput))
|
||||||
if isinstance(metrics, BenchmarkMetrics):
|
if isinstance(metrics, BenchmarkMetrics):
|
||||||
print(
|
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
|
||||||
"{:<40} {:<10.2f}".format(
|
metrics.output_throughput))
|
||||||
"Output token throughput (tok/s):", metrics.output_throughput
|
print("{:<40} {:<10.2f}".format(
|
||||||
)
|
"Peak output token throughput (tok/s):",
|
||||||
)
|
metrics.max_output_tokens_per_s))
|
||||||
|
print("{:<40} {:<10.2f}".format("Peak concurrent requests:",
|
||||||
|
metrics.max_concurrent_requests))
|
||||||
print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):",
|
print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):",
|
||||||
metrics.total_token_throughput))
|
metrics.total_token_throughput))
|
||||||
|
|
||||||
@ -648,6 +711,8 @@ async def benchmark(
|
|||||||
"itls": [output.itl for output in outputs],
|
"itls": [output.itl for output in outputs],
|
||||||
"generated_texts": [output.generated_text for output in outputs],
|
"generated_texts": [output.generated_text for output in outputs],
|
||||||
"errors": [output.error for output in outputs],
|
"errors": [output.error for output in outputs],
|
||||||
|
"max_output_tokens_per_s": metrics.max_output_tokens_per_s,
|
||||||
|
"max_concurrent_requests": metrics.max_concurrent_requests,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
result = {
|
result = {
|
||||||
@ -697,8 +762,8 @@ async def benchmark(
|
|||||||
|
|
||||||
if task_type == TaskType.GENERATION:
|
if task_type == TaskType.GENERATION:
|
||||||
process_one_metric("ttft", "TTFT", "Time to First Token")
|
process_one_metric("ttft", "TTFT", "Time to First Token")
|
||||||
process_one_metric(
|
process_one_metric("tpot", "TPOT",
|
||||||
"tpot", "TPOT", "Time per Output Token (excl. 1st token)")
|
"Time per Output Token (excl. 1st token)")
|
||||||
process_one_metric("itl", "ITL", "Inter-token Latency")
|
process_one_metric("itl", "ITL", "Inter-token Latency")
|
||||||
process_one_metric("e2el", "E2EL", "End-to-end Latency")
|
process_one_metric("e2el", "E2EL", "End-to-end Latency")
|
||||||
|
|
||||||
@ -714,8 +779,8 @@ async def benchmark(
|
|||||||
output_len=test_output_len,
|
output_len=test_output_len,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
profile_output = await request_func(
|
profile_output = await request_func(request_func_input=profile_input,
|
||||||
request_func_input=profile_input, session=session)
|
session=session)
|
||||||
if profile_output.success:
|
if profile_output.success:
|
||||||
print("Profiler stopped")
|
print("Profiler stopped")
|
||||||
|
|
||||||
@ -851,7 +916,8 @@ def add_cli_args(parser: argparse.ArgumentParser):
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--tokenizer",
|
"--tokenizer",
|
||||||
type=str,
|
type=str,
|
||||||
help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
|
help=
|
||||||
|
"Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
|
||||||
)
|
)
|
||||||
parser.add_argument("--use-beam-search", action="store_true")
|
parser.add_argument("--use-beam-search", action="store_true")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -982,7 +1048,6 @@ def add_cli_args(parser: argparse.ArgumentParser):
|
|||||||
help="Specify the prefix of request id.",
|
help="Specify the prefix of request id.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
sampling_group = parser.add_argument_group("sampling parameters")
|
sampling_group = parser.add_argument_group("sampling parameters")
|
||||||
sampling_group.add_argument(
|
sampling_group.add_argument(
|
||||||
"--top-p",
|
"--top-p",
|
||||||
@ -1047,8 +1112,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
|
|||||||
help="The ramp-up strategy. This would be used to "
|
help="The ramp-up strategy. This would be used to "
|
||||||
"ramp up the request rate from initial RPS to final "
|
"ramp up the request rate from initial RPS to final "
|
||||||
"RPS rate (specified by --ramp-up-start-rps and "
|
"RPS rate (specified by --ramp-up-start-rps and "
|
||||||
"--ramp-up-end-rps.) over the duration of the benchmark."
|
"--ramp-up-end-rps.) over the duration of the benchmark.")
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ramp-up-start-rps",
|
"--ramp-up-start-rps",
|
||||||
type=int,
|
type=int,
|
||||||
@ -1087,13 +1151,11 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"When using ramp-up, do not specify --request-rate. "
|
"When using ramp-up, do not specify --request-rate. "
|
||||||
"The request rate will be controlled by ramp-up parameters. "
|
"The request rate will be controlled by ramp-up parameters. "
|
||||||
"Please remove the --request-rate argument."
|
"Please remove the --request-rate argument.")
|
||||||
)
|
|
||||||
if args.ramp_up_start_rps is None or args.ramp_up_end_rps is None:
|
if args.ramp_up_start_rps is None or args.ramp_up_end_rps is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"When using --ramp-up-strategy, both --ramp-up-start-rps and "
|
"When using --ramp-up-strategy, both --ramp-up-start-rps and "
|
||||||
"--ramp-up-end-rps must be specified"
|
"--ramp-up-end-rps must be specified")
|
||||||
)
|
|
||||||
if args.ramp_up_start_rps < 0 or args.ramp_up_end_rps < 0:
|
if args.ramp_up_start_rps < 0 or args.ramp_up_end_rps < 0:
|
||||||
raise ValueError("Ramp-up start and end RPS must be non-negative")
|
raise ValueError("Ramp-up start and end RPS must be non-negative")
|
||||||
if args.ramp_up_start_rps > args.ramp_up_end_rps:
|
if args.ramp_up_start_rps > args.ramp_up_end_rps:
|
||||||
@ -1127,8 +1189,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
|
|||||||
headers[kvstring[0].strip()] = kvstring[1].strip()
|
headers[kvstring[0].strip()] = kvstring[1].strip()
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Invalid header format. Please use KEY=VALUE format."
|
"Invalid header format. Please use KEY=VALUE format.")
|
||||||
)
|
|
||||||
|
|
||||||
tokenizer = get_tokenizer(tokenizer_id,
|
tokenizer = get_tokenizer(tokenizer_id,
|
||||||
tokenizer_mode=tokenizer_mode,
|
tokenizer_mode=tokenizer_mode,
|
||||||
@ -1215,8 +1276,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
|
|||||||
result_json[kvstring[0].strip()] = kvstring[1].strip()
|
result_json[kvstring[0].strip()] = kvstring[1].strip()
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Invalid metadata format. Please use KEY=VALUE format."
|
"Invalid metadata format. Please use KEY=VALUE format.")
|
||||||
)
|
|
||||||
|
|
||||||
# Traffic
|
# Traffic
|
||||||
result_json["request_rate"] = (args.request_rate if args.request_rate
|
result_json["request_rate"] = (args.request_rate if args.request_rate
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user