mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:06:10 +08:00
[Misc][Benchmarking] Add variable request-rate ("ramp-up") to the benchmarking client. (#19423)
Signed-off-by: dtransposed <damian@damian-ml-machine.europe-west3-b.c.jetbrains-grazie.internal> Co-authored-by: dtransposed <damian@damian-ml-machine.europe-west3-b.c.jetbrains-grazie.internal> Co-authored-by: Roger Wang <hey@rogerw.me>
This commit is contained in:
parent
a045b7e89a
commit
c635c5f744
@ -269,6 +269,21 @@ python3 vllm/benchmarks/benchmark_serving.py \
|
||||
--num-prompts 10
|
||||
```
|
||||
|
||||
### Running With Ramp-Up Request Rate
|
||||
|
||||
The benchmark tool also supports ramping up the request rate over the
|
||||
duration of the benchmark run. This can be useful for stress testing the
|
||||
server or finding the maximum throughput that it can handle, given some latency budget.
|
||||
|
||||
Two ramp-up strategies are supported:
|
||||
- `linear`: Increases the request rate linearly from a start value to an end value.
|
||||
- `exponential`: Increases the request rate exponentially.
|
||||
|
||||
The following arguments can be used to control the ramp-up:
|
||||
- `--ramp-up-strategy`: The ramp-up strategy to use (`linear` or `exponential`).
|
||||
- `--ramp-up-start-rps`: The request rate at the beginning of the benchmark.
|
||||
- `--ramp-up-end-rps`: The request rate at the end of the benchmark.
|
||||
|
||||
---
|
||||
## Example - Offline Throughput Benchmark
|
||||
|
||||
|
||||
@ -33,7 +33,7 @@ import warnings
|
||||
from collections.abc import AsyncGenerator, Iterable
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
import numpy as np
|
||||
from tqdm.asyncio import tqdm
|
||||
@ -107,14 +107,42 @@ class BenchmarkMetrics:
|
||||
percentiles_e2el_ms: list[tuple[float, float]]
|
||||
|
||||
|
||||
def _get_current_request_rate(
|
||||
ramp_up_strategy: Optional[Literal["linear", "exponential"]],
|
||||
ramp_up_start_rps: Optional[int],
|
||||
ramp_up_end_rps: Optional[int],
|
||||
request_index: int,
|
||||
total_requests: int,
|
||||
request_rate: float,
|
||||
) -> float:
|
||||
if (
|
||||
ramp_up_strategy
|
||||
and ramp_up_start_rps is not None
|
||||
and ramp_up_end_rps is not None
|
||||
):
|
||||
progress = request_index / max(total_requests - 1, 1)
|
||||
if ramp_up_strategy == "linear":
|
||||
increase = (ramp_up_end_rps - ramp_up_start_rps) * progress
|
||||
return ramp_up_start_rps + increase
|
||||
elif ramp_up_strategy == "exponential":
|
||||
ratio = ramp_up_end_rps / ramp_up_start_rps
|
||||
return ramp_up_start_rps * (ratio**progress)
|
||||
else:
|
||||
raise ValueError(f"Unknown ramp-up strategy: {ramp_up_strategy}")
|
||||
return request_rate
|
||||
|
||||
|
||||
async def get_request(
|
||||
input_requests: list[SampleRequest],
|
||||
request_rate: float,
|
||||
burstiness: float = 1.0,
|
||||
) -> AsyncGenerator[SampleRequest, None]:
|
||||
ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None,
|
||||
ramp_up_start_rps: Optional[int] = None,
|
||||
ramp_up_end_rps: Optional[int] = None,
|
||||
) -> AsyncGenerator[tuple[SampleRequest, float], None]:
|
||||
"""
|
||||
Asynchronously generates requests at a specified rate
|
||||
with OPTIONAL burstiness.
|
||||
with OPTIONAL burstiness and OPTIONAL ramp-up strategy.
|
||||
|
||||
Args:
|
||||
input_requests:
|
||||
@ -129,22 +157,44 @@ async def get_request(
|
||||
A lower burstiness value (0 < burstiness < 1) results
|
||||
in more bursty requests, while a higher burstiness value
|
||||
(burstiness > 1) results in a more uniform arrival of requests.
|
||||
ramp_up_strategy (optional):
|
||||
The ramp-up strategy. Can be "linear" or "exponential".
|
||||
If None, uses constant request rate (specified by request_rate).
|
||||
ramp_up_start_rps (optional):
|
||||
The starting request rate for ramp-up.
|
||||
ramp_up_end_rps (optional):
|
||||
The ending request rate for ramp-up.
|
||||
"""
|
||||
input_requests: Iterable[SampleRequest] = iter(input_requests)
|
||||
|
||||
# Calculate scale parameter theta to maintain the desired request_rate.
|
||||
assert burstiness > 0, (
|
||||
f"A positive burstiness factor is expected, but given {burstiness}."
|
||||
)
|
||||
theta = 1.0 / (request_rate * burstiness)
|
||||
# Convert to list to get length for ramp-up calculations
|
||||
if isinstance(input_requests, Iterable) and not isinstance(input_requests, list):
|
||||
input_requests = list(input_requests)
|
||||
|
||||
total_requests = len(input_requests)
|
||||
request_index = 0
|
||||
|
||||
for request in input_requests:
|
||||
yield request
|
||||
current_request_rate = _get_current_request_rate(
|
||||
ramp_up_strategy,
|
||||
ramp_up_start_rps,
|
||||
ramp_up_end_rps,
|
||||
request_index,
|
||||
total_requests,
|
||||
request_rate,
|
||||
)
|
||||
|
||||
if request_rate == float("inf"):
|
||||
yield request, current_request_rate
|
||||
|
||||
request_index += 1
|
||||
|
||||
if current_request_rate == float("inf"):
|
||||
# If the request rate is infinity, then we don't need to wait.
|
||||
continue
|
||||
|
||||
theta = 1.0 / (current_request_rate * burstiness)
|
||||
|
||||
# Sample the request interval from the gamma distribution.
|
||||
# If burstiness is 1, it follows exponential distribution.
|
||||
interval = np.random.gamma(shape=burstiness, scale=theta)
|
||||
@ -290,6 +340,9 @@ async def benchmark(
|
||||
max_concurrency: Optional[int],
|
||||
lora_modules: Optional[Iterable[str]],
|
||||
extra_body: Optional[dict],
|
||||
ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None,
|
||||
ramp_up_start_rps: Optional[int] = None,
|
||||
ramp_up_end_rps: Optional[int] = None,
|
||||
):
|
||||
if backend in ASYNC_REQUEST_FUNCS:
|
||||
request_func = ASYNC_REQUEST_FUNCS[backend]
|
||||
@ -353,7 +406,15 @@ async def benchmark(
|
||||
|
||||
distribution = "Poisson process" if burstiness == 1.0 else "Gamma distribution"
|
||||
|
||||
print(f"Traffic request rate: {request_rate}")
|
||||
if ramp_up_strategy is not None:
|
||||
print(
|
||||
f"Traffic ramp-up strategy: {ramp_up_strategy}. Will increase "
|
||||
f"RPS from {ramp_up_start_rps} to {ramp_up_end_rps} RPS over "
|
||||
"the duration of the benchmark."
|
||||
)
|
||||
else:
|
||||
print(f"Traffic request rate: {request_rate} RPS.")
|
||||
|
||||
print(f"Burstiness factor: {burstiness} ({distribution})")
|
||||
print(f"Maximum request concurrency: {max_concurrency}")
|
||||
|
||||
@ -373,7 +434,34 @@ async def benchmark(
|
||||
|
||||
benchmark_start_time = time.perf_counter()
|
||||
tasks: list[asyncio.Task] = []
|
||||
async for request in get_request(input_requests, request_rate, burstiness):
|
||||
|
||||
rps_change_events = []
|
||||
last_int_rps = -1
|
||||
if ramp_up_strategy is not None and ramp_up_start_rps is not None:
|
||||
last_int_rps = ramp_up_start_rps
|
||||
rps_change_events.append(
|
||||
{
|
||||
"rps": last_int_rps,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
)
|
||||
|
||||
async for request, current_request_rate in get_request(
|
||||
input_requests,
|
||||
request_rate,
|
||||
burstiness,
|
||||
ramp_up_strategy,
|
||||
ramp_up_start_rps,
|
||||
ramp_up_end_rps,
|
||||
):
|
||||
if ramp_up_strategy is not None:
|
||||
current_int_rps = int(current_request_rate)
|
||||
if current_int_rps > last_int_rps:
|
||||
timestamp = datetime.now().isoformat()
|
||||
for rps_val in range(last_int_rps + 1, current_int_rps + 1):
|
||||
rps_change_events.append({"rps": rps_val, "timestamp": timestamp})
|
||||
last_int_rps = current_int_rps
|
||||
|
||||
prompt, prompt_len, output_len, mm_content = (
|
||||
request.prompt,
|
||||
request.prompt_len,
|
||||
@ -397,11 +485,8 @@ async def benchmark(
|
||||
ignore_eos=ignore_eos,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
limited_request_func(request_func_input=request_func_input, pbar=pbar)
|
||||
)
|
||||
)
|
||||
task = limited_request_func(request_func_input=request_func_input, pbar=pbar)
|
||||
tasks.append(asyncio.create_task(task))
|
||||
outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
|
||||
|
||||
if profile:
|
||||
@ -477,6 +562,9 @@ async def benchmark(
|
||||
"errors": [output.error for output in outputs],
|
||||
}
|
||||
|
||||
if rps_change_events:
|
||||
result["rps_change_events"] = rps_change_events
|
||||
|
||||
def process_one_metric(
|
||||
# E.g., "ttft"
|
||||
metric_attribute_name: str,
|
||||
@ -610,6 +698,26 @@ def main(args: argparse.Namespace):
|
||||
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
|
||||
tokenizer_mode = args.tokenizer_mode
|
||||
|
||||
# Validate ramp-up arguments
|
||||
if args.ramp_up_strategy is not None:
|
||||
if args.request_rate != float("inf"):
|
||||
raise ValueError(
|
||||
"When using ramp-up, do not specify --request-rate. "
|
||||
"The request rate will be controlled by ramp-up parameters. "
|
||||
"Please remove the --request-rate argument."
|
||||
)
|
||||
if args.ramp_up_start_rps is None or args.ramp_up_end_rps is None:
|
||||
raise ValueError(
|
||||
"When using --ramp-up-strategy, both --ramp-up-start-rps and "
|
||||
"--ramp-up-end-rps must be specified"
|
||||
)
|
||||
if args.ramp_up_start_rps < 0 or args.ramp_up_end_rps < 0:
|
||||
raise ValueError("Ramp-up start and end RPS must be non-negative")
|
||||
if args.ramp_up_start_rps > args.ramp_up_end_rps:
|
||||
raise ValueError("Ramp-up start RPS must be less than end RPS")
|
||||
if args.ramp_up_strategy == "exponential" and args.ramp_up_start_rps == 0:
|
||||
raise ValueError("For exponential ramp-up, the start RPS cannot be 0.")
|
||||
|
||||
if args.base_url is not None:
|
||||
api_url = f"{args.base_url}{args.endpoint}"
|
||||
base_url = f"{args.base_url}"
|
||||
@ -802,6 +910,9 @@ def main(args: argparse.Namespace):
|
||||
max_concurrency=args.max_concurrency,
|
||||
lora_modules=args.lora_modules,
|
||||
extra_body=sampling_params,
|
||||
ramp_up_strategy=args.ramp_up_strategy,
|
||||
ramp_up_start_rps=args.ramp_up_start_rps,
|
||||
ramp_up_end_rps=args.ramp_up_end_rps,
|
||||
)
|
||||
)
|
||||
|
||||
@ -834,6 +945,11 @@ def main(args: argparse.Namespace):
|
||||
result_json["burstiness"] = args.burstiness
|
||||
result_json["max_concurrency"] = args.max_concurrency
|
||||
|
||||
if args.ramp_up_strategy is not None:
|
||||
result_json["ramp_up_strategy"] = args.ramp_up_strategy
|
||||
result_json["ramp_up_start_rps"] = args.ramp_up_start_rps
|
||||
result_json["ramp_up_end_rps"] = args.ramp_up_end_rps
|
||||
|
||||
# Merge with benchmark result
|
||||
result_json = {**result_json, **benchmark_result}
|
||||
|
||||
@ -859,6 +975,9 @@ def main(args: argparse.Namespace):
|
||||
if args.max_concurrency is not None
|
||||
else ""
|
||||
)
|
||||
if args.ramp_up_strategy is not None:
|
||||
file_name = f"{backend}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
|
||||
else:
|
||||
file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
|
||||
if args.result_filename:
|
||||
file_name = args.result_filename
|
||||
@ -1225,6 +1344,31 @@ def create_argument_parser():
|
||||
"script chooses a LoRA module at random.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ramp-up-strategy",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["linear", "exponential"],
|
||||
help="The ramp-up strategy. This would be used to "
|
||||
"ramp up the request rate from initial RPS to final "
|
||||
"RPS rate (specified by --ramp-up-start-rps and --ramp-up-end-rps). "
|
||||
"over the duration of the benchmark.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ramp-up-start-rps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="The starting request rate for ramp-up (RPS). "
|
||||
"Needs to be specified when --ramp-up-strategy is used.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ramp-up-end-rps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="The ending request rate for ramp-up (RPS). "
|
||||
"Needs to be specified when --ramp-up-strategy is used.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
|
||||
@ -26,7 +26,7 @@ import warnings
|
||||
from collections.abc import AsyncGenerator, Iterable
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
import numpy as np
|
||||
from tqdm.asyncio import tqdm
|
||||
@ -75,14 +75,39 @@ class BenchmarkMetrics:
|
||||
percentiles_e2el_ms: list[tuple[float, float]]
|
||||
|
||||
|
||||
def _get_current_request_rate(
|
||||
ramp_up_strategy: Optional[Literal["linear", "exponential"]],
|
||||
ramp_up_start_rps: Optional[int],
|
||||
ramp_up_end_rps: Optional[int],
|
||||
request_index: int,
|
||||
total_requests: int,
|
||||
request_rate: float,
|
||||
) -> float:
|
||||
if (ramp_up_strategy and ramp_up_start_rps is not None
|
||||
and ramp_up_end_rps is not None):
|
||||
progress = request_index / max(total_requests - 1, 1)
|
||||
if ramp_up_strategy == "linear":
|
||||
increase = (ramp_up_end_rps - ramp_up_start_rps) * progress
|
||||
return ramp_up_start_rps + increase
|
||||
elif ramp_up_strategy == "exponential":
|
||||
ratio = ramp_up_end_rps / ramp_up_start_rps
|
||||
return ramp_up_start_rps * (ratio**progress)
|
||||
else:
|
||||
raise ValueError(f"Unknown ramp-up strategy: {ramp_up_strategy}")
|
||||
return request_rate
|
||||
|
||||
|
||||
async def get_request(
|
||||
input_requests: list[SampleRequest],
|
||||
request_rate: float,
|
||||
burstiness: float = 1.0,
|
||||
) -> AsyncGenerator[SampleRequest, None]:
|
||||
ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None,
|
||||
ramp_up_start_rps: Optional[int] = None,
|
||||
ramp_up_end_rps: Optional[int] = None,
|
||||
) -> AsyncGenerator[tuple[SampleRequest, float], None]:
|
||||
"""
|
||||
Asynchronously generates requests at a specified rate
|
||||
with OPTIONAL burstiness.
|
||||
with OPTIONAL burstiness and OPTIONAL ramp-up strategy.
|
||||
|
||||
Args:
|
||||
input_requests:
|
||||
@ -97,21 +122,42 @@ async def get_request(
|
||||
A lower burstiness value (0 < burstiness < 1) results
|
||||
in more bursty requests, while a higher burstiness value
|
||||
(burstiness > 1) results in a more uniform arrival of requests.
|
||||
ramp_up_strategy (optional):
|
||||
The ramp-up strategy. Can be "linear" or "exponential".
|
||||
If None, uses constant request rate (specified by request_rate).
|
||||
ramp_up_start_rps (optional):
|
||||
The starting request rate for ramp-up.
|
||||
ramp_up_end_rps (optional):
|
||||
The ending request rate for ramp-up.
|
||||
"""
|
||||
input_requests: Iterable[SampleRequest] = iter(input_requests)
|
||||
|
||||
# Calculate scale parameter theta to maintain the desired request_rate.
|
||||
assert burstiness > 0, (
|
||||
f"A positive burstiness factor is expected, but given {burstiness}.")
|
||||
theta = 1.0 / (request_rate * burstiness)
|
||||
# Convert to list to get length for ramp-up calculations
|
||||
if isinstance(input_requests, Iterable) and not isinstance(
|
||||
input_requests, list):
|
||||
input_requests = list(input_requests)
|
||||
|
||||
total_requests = len(input_requests)
|
||||
request_index = 0
|
||||
|
||||
for request in input_requests:
|
||||
yield request
|
||||
current_request_rate = _get_current_request_rate(ramp_up_strategy,
|
||||
ramp_up_start_rps,
|
||||
ramp_up_end_rps,
|
||||
request_index,
|
||||
total_requests,
|
||||
request_rate)
|
||||
|
||||
if request_rate == float("inf"):
|
||||
yield request, current_request_rate
|
||||
|
||||
request_index += 1
|
||||
|
||||
if current_request_rate == float("inf"):
|
||||
# If the request rate is infinity, then we don't need to wait.
|
||||
continue
|
||||
|
||||
theta = 1.0 / (current_request_rate * burstiness)
|
||||
|
||||
# Sample the request interval from the gamma distribution.
|
||||
# If burstiness is 1, it follows exponential distribution.
|
||||
interval = np.random.gamma(shape=burstiness, scale=theta)
|
||||
@ -259,6 +305,9 @@ async def benchmark(
|
||||
max_concurrency: Optional[int],
|
||||
lora_modules: Optional[Iterable[str]],
|
||||
extra_body: Optional[dict],
|
||||
ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None,
|
||||
ramp_up_start_rps: Optional[int] = None,
|
||||
ramp_up_end_rps: Optional[int] = None,
|
||||
):
|
||||
if endpoint_type in ASYNC_REQUEST_FUNCS:
|
||||
request_func = ASYNC_REQUEST_FUNCS[endpoint_type]
|
||||
@ -316,12 +365,16 @@ async def benchmark(
|
||||
if profile_output.success:
|
||||
print("Profiler started")
|
||||
|
||||
if burstiness == 1.0:
|
||||
distribution = "Poisson process"
|
||||
else:
|
||||
distribution = "Gamma distribution"
|
||||
distribution = ("Poisson process" if burstiness == 1.0
|
||||
else "Gamma distribution")
|
||||
|
||||
if ramp_up_strategy is not None:
|
||||
print(f"Traffic ramp-up strategy: {ramp_up_strategy}.")
|
||||
print(f"Will increase RPS from {ramp_up_start_rps} to "
|
||||
f"{ramp_up_end_rps} RPS over the duration of the benchmark.")
|
||||
else:
|
||||
print(f"Traffic request rate: {request_rate}")
|
||||
|
||||
print(f"Burstiness factor: {burstiness} ({distribution})")
|
||||
print(f"Maximum request concurrency: {max_concurrency}")
|
||||
|
||||
@ -344,7 +397,29 @@ async def benchmark(
|
||||
|
||||
benchmark_start_time = time.perf_counter()
|
||||
tasks: list[asyncio.Task] = []
|
||||
async for request in get_request(input_requests, request_rate, burstiness):
|
||||
|
||||
rps_change_events = []
|
||||
last_int_rps = -1
|
||||
if ramp_up_strategy is not None and ramp_up_start_rps is not None:
|
||||
last_int_rps = ramp_up_start_rps
|
||||
rps_change_events.append({
|
||||
"rps": last_int_rps,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
})
|
||||
|
||||
async for request, current_request_rate in get_request(
|
||||
input_requests, request_rate, burstiness, ramp_up_strategy,
|
||||
ramp_up_start_rps, ramp_up_end_rps):
|
||||
if ramp_up_strategy is not None:
|
||||
current_int_rps = int(current_request_rate)
|
||||
if current_int_rps > last_int_rps:
|
||||
timestamp = datetime.now().isoformat()
|
||||
for rps_val in range(last_int_rps + 1, current_int_rps + 1):
|
||||
rps_change_events.append({
|
||||
"rps": rps_val,
|
||||
"timestamp": timestamp
|
||||
})
|
||||
last_int_rps = current_int_rps
|
||||
prompt, prompt_len, output_len, mm_content = (
|
||||
request.prompt,
|
||||
request.prompt_len,
|
||||
@ -435,6 +510,9 @@ async def benchmark(
|
||||
"errors": [output.error for output in outputs],
|
||||
}
|
||||
|
||||
if rps_change_events:
|
||||
result["rps_change_events"] = rps_change_events
|
||||
|
||||
def process_one_metric(
|
||||
# E.g., "ttft"
|
||||
metric_attribute_name: str,
|
||||
@ -771,12 +849,60 @@ def add_cli_args(parser: argparse.ArgumentParser):
|
||||
"launching the server. For each request, the "
|
||||
"script chooses a LoRA module at random.")
|
||||
|
||||
parser.add_argument(
|
||||
"--ramp-up-strategy",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["linear", "exponential"],
|
||||
help="The ramp-up strategy. This would be used to "
|
||||
"ramp up the request rate from initial RPS to final "
|
||||
"RPS rate (specified by --ramp-up-start-rps and "
|
||||
"--ramp-up-end-rps.) over the duration of the benchmark."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ramp-up-start-rps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="The starting request rate for ramp-up (RPS). "
|
||||
"Needs to be specified when --ramp-up-strategy is used.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ramp-up-end-rps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="The ending request rate for ramp-up (RPS). "
|
||||
"Needs to be specified when --ramp-up-strategy is used.",
|
||||
)
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
print(args)
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
|
||||
# Validate ramp-up arguments
|
||||
if args.ramp_up_strategy is not None:
|
||||
if args.request_rate != float("inf"):
|
||||
raise ValueError(
|
||||
"When using ramp-up, do not specify --request-rate. "
|
||||
"The request rate will be controlled by ramp-up parameters. "
|
||||
"Please remove the --request-rate argument."
|
||||
)
|
||||
if args.ramp_up_start_rps is None or args.ramp_up_end_rps is None:
|
||||
raise ValueError(
|
||||
"When using --ramp-up-strategy, both --ramp-up-start-rps and "
|
||||
"--ramp-up-end-rps must be specified"
|
||||
)
|
||||
if args.ramp_up_start_rps < 0 or args.ramp_up_end_rps < 0:
|
||||
raise ValueError("Ramp-up start and end RPS must be non-negative")
|
||||
if args.ramp_up_start_rps > args.ramp_up_end_rps:
|
||||
raise ValueError("Ramp-up start RPS must be less than end RPS")
|
||||
if (args.ramp_up_strategy == "exponential"
|
||||
and args.ramp_up_start_rps == 0):
|
||||
raise ValueError(
|
||||
"For exponential ramp-up, the start RPS cannot be 0.")
|
||||
|
||||
endpoint_type = args.endpoint_type
|
||||
label = args.label
|
||||
model_id = args.model
|
||||
model_name = args.served_model_name
|
||||
@ -849,6 +975,9 @@ def main(args: argparse.Namespace):
|
||||
max_concurrency=args.max_concurrency,
|
||||
lora_modules=args.lora_modules,
|
||||
extra_body=sampling_params,
|
||||
ramp_up_strategy=args.ramp_up_strategy,
|
||||
ramp_up_start_rps=args.ramp_up_start_rps,
|
||||
ramp_up_end_rps=args.ramp_up_end_rps,
|
||||
))
|
||||
|
||||
# Save config and results to json
|
||||
@ -881,6 +1010,11 @@ def main(args: argparse.Namespace):
|
||||
result_json["burstiness"] = args.burstiness
|
||||
result_json["max_concurrency"] = args.max_concurrency
|
||||
|
||||
if args.ramp_up_strategy is not None:
|
||||
result_json["ramp_up_strategy"] = args.ramp_up_strategy
|
||||
result_json["ramp_up_start_rps"] = args.ramp_up_start_rps
|
||||
result_json["ramp_up_end_rps"] = args.ramp_up_end_rps
|
||||
|
||||
# Merge with benchmark result
|
||||
result_json = {**result_json, **benchmark_result}
|
||||
|
||||
@ -903,8 +1037,11 @@ def main(args: argparse.Namespace):
|
||||
base_model_id = model_id.split("/")[-1]
|
||||
max_concurrency_str = (f"-concurrency{args.max_concurrency}"
|
||||
if args.max_concurrency is not None else "")
|
||||
label = label or args.endpoint_type
|
||||
file_name = f"{label}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" #noqa
|
||||
label = label or endpoint_type
|
||||
if args.ramp_up_strategy is not None:
|
||||
file_name = f"{label}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
|
||||
else:
|
||||
file_name = f"{label}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
|
||||
if args.result_filename:
|
||||
file_name = args.result_filename
|
||||
if args.result_dir:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user