[Misc][Benchmarking] Add variable request-rate ("ramp-up") to the benchmarking client. (#19423)

Signed-off-by: dtransposed <damian@damian-ml-machine.europe-west3-b.c.jetbrains-grazie.internal>
Co-authored-by: dtransposed <damian@damian-ml-machine.europe-west3-b.c.jetbrains-grazie.internal>
Co-authored-by: Roger Wang <hey@rogerw.me>
This commit is contained in:
d.transposed 2025-06-24 20:41:49 +02:00 committed by GitHub
parent a045b7e89a
commit c635c5f744
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 330 additions and 34 deletions

View File

@ -269,6 +269,21 @@ python3 vllm/benchmarks/benchmark_serving.py \
--num-prompts 10 --num-prompts 10
``` ```
### Running With Ramp-Up Request Rate
The benchmark tool also supports ramping up the request rate over the
duration of the benchmark run. This can be useful for stress testing the
server or finding the maximum throughput that it can handle, given some latency budget.
Two ramp-up strategies are supported:
- `linear`: Increases the request rate linearly from a start value to an end value.
- `exponential`: Increases the request rate exponentially.
The following arguments can be used to control the ramp-up:
- `--ramp-up-strategy`: The ramp-up strategy to use (`linear` or `exponential`).
- `--ramp-up-start-rps`: The request rate at the beginning of the benchmark.
- `--ramp-up-end-rps`: The request rate at the end of the benchmark.
--- ---
## Example - Offline Throughput Benchmark ## Example - Offline Throughput Benchmark

View File

@ -33,7 +33,7 @@ import warnings
from collections.abc import AsyncGenerator, Iterable from collections.abc import AsyncGenerator, Iterable
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from typing import Any, Optional from typing import Any, Literal, Optional
import numpy as np import numpy as np
from tqdm.asyncio import tqdm from tqdm.asyncio import tqdm
@ -107,14 +107,42 @@ class BenchmarkMetrics:
percentiles_e2el_ms: list[tuple[float, float]] percentiles_e2el_ms: list[tuple[float, float]]
def _get_current_request_rate(
ramp_up_strategy: Optional[Literal["linear", "exponential"]],
ramp_up_start_rps: Optional[int],
ramp_up_end_rps: Optional[int],
request_index: int,
total_requests: int,
request_rate: float,
) -> float:
if (
ramp_up_strategy
and ramp_up_start_rps is not None
and ramp_up_end_rps is not None
):
progress = request_index / max(total_requests - 1, 1)
if ramp_up_strategy == "linear":
increase = (ramp_up_end_rps - ramp_up_start_rps) * progress
return ramp_up_start_rps + increase
elif ramp_up_strategy == "exponential":
ratio = ramp_up_end_rps / ramp_up_start_rps
return ramp_up_start_rps * (ratio**progress)
else:
raise ValueError(f"Unknown ramp-up strategy: {ramp_up_strategy}")
return request_rate
async def get_request( async def get_request(
input_requests: list[SampleRequest], input_requests: list[SampleRequest],
request_rate: float, request_rate: float,
burstiness: float = 1.0, burstiness: float = 1.0,
) -> AsyncGenerator[SampleRequest, None]: ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None,
ramp_up_start_rps: Optional[int] = None,
ramp_up_end_rps: Optional[int] = None,
) -> AsyncGenerator[tuple[SampleRequest, float], None]:
""" """
Asynchronously generates requests at a specified rate Asynchronously generates requests at a specified rate
with OPTIONAL burstiness. with OPTIONAL burstiness and OPTIONAL ramp-up strategy.
Args: Args:
input_requests: input_requests:
@ -129,22 +157,44 @@ async def get_request(
A lower burstiness value (0 < burstiness < 1) results A lower burstiness value (0 < burstiness < 1) results
in more bursty requests, while a higher burstiness value in more bursty requests, while a higher burstiness value
(burstiness > 1) results in a more uniform arrival of requests. (burstiness > 1) results in a more uniform arrival of requests.
ramp_up_strategy (optional):
The ramp-up strategy. Can be "linear" or "exponential".
If None, uses constant request rate (specified by request_rate).
ramp_up_start_rps (optional):
The starting request rate for ramp-up.
ramp_up_end_rps (optional):
The ending request rate for ramp-up.
""" """
input_requests: Iterable[SampleRequest] = iter(input_requests)
# Calculate scale parameter theta to maintain the desired request_rate.
assert burstiness > 0, ( assert burstiness > 0, (
f"A positive burstiness factor is expected, but given {burstiness}." f"A positive burstiness factor is expected, but given {burstiness}."
) )
theta = 1.0 / (request_rate * burstiness) # Convert to list to get length for ramp-up calculations
if isinstance(input_requests, Iterable) and not isinstance(input_requests, list):
input_requests = list(input_requests)
total_requests = len(input_requests)
request_index = 0
for request in input_requests: for request in input_requests:
yield request current_request_rate = _get_current_request_rate(
ramp_up_strategy,
ramp_up_start_rps,
ramp_up_end_rps,
request_index,
total_requests,
request_rate,
)
if request_rate == float("inf"): yield request, current_request_rate
request_index += 1
if current_request_rate == float("inf"):
# If the request rate is infinity, then we don't need to wait. # If the request rate is infinity, then we don't need to wait.
continue continue
theta = 1.0 / (current_request_rate * burstiness)
# Sample the request interval from the gamma distribution. # Sample the request interval from the gamma distribution.
# If burstiness is 1, it follows exponential distribution. # If burstiness is 1, it follows exponential distribution.
interval = np.random.gamma(shape=burstiness, scale=theta) interval = np.random.gamma(shape=burstiness, scale=theta)
@ -290,6 +340,9 @@ async def benchmark(
max_concurrency: Optional[int], max_concurrency: Optional[int],
lora_modules: Optional[Iterable[str]], lora_modules: Optional[Iterable[str]],
extra_body: Optional[dict], extra_body: Optional[dict],
ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None,
ramp_up_start_rps: Optional[int] = None,
ramp_up_end_rps: Optional[int] = None,
): ):
if backend in ASYNC_REQUEST_FUNCS: if backend in ASYNC_REQUEST_FUNCS:
request_func = ASYNC_REQUEST_FUNCS[backend] request_func = ASYNC_REQUEST_FUNCS[backend]
@ -353,7 +406,15 @@ async def benchmark(
distribution = "Poisson process" if burstiness == 1.0 else "Gamma distribution" distribution = "Poisson process" if burstiness == 1.0 else "Gamma distribution"
print(f"Traffic request rate: {request_rate}") if ramp_up_strategy is not None:
print(
f"Traffic ramp-up strategy: {ramp_up_strategy}. Will increase "
f"RPS from {ramp_up_start_rps} to {ramp_up_end_rps} RPS over "
"the duration of the benchmark."
)
else:
print(f"Traffic request rate: {request_rate} RPS.")
print(f"Burstiness factor: {burstiness} ({distribution})") print(f"Burstiness factor: {burstiness} ({distribution})")
print(f"Maximum request concurrency: {max_concurrency}") print(f"Maximum request concurrency: {max_concurrency}")
@ -373,7 +434,34 @@ async def benchmark(
benchmark_start_time = time.perf_counter() benchmark_start_time = time.perf_counter()
tasks: list[asyncio.Task] = [] tasks: list[asyncio.Task] = []
async for request in get_request(input_requests, request_rate, burstiness):
rps_change_events = []
last_int_rps = -1
if ramp_up_strategy is not None and ramp_up_start_rps is not None:
last_int_rps = ramp_up_start_rps
rps_change_events.append(
{
"rps": last_int_rps,
"timestamp": datetime.now().isoformat(),
}
)
async for request, current_request_rate in get_request(
input_requests,
request_rate,
burstiness,
ramp_up_strategy,
ramp_up_start_rps,
ramp_up_end_rps,
):
if ramp_up_strategy is not None:
current_int_rps = int(current_request_rate)
if current_int_rps > last_int_rps:
timestamp = datetime.now().isoformat()
for rps_val in range(last_int_rps + 1, current_int_rps + 1):
rps_change_events.append({"rps": rps_val, "timestamp": timestamp})
last_int_rps = current_int_rps
prompt, prompt_len, output_len, mm_content = ( prompt, prompt_len, output_len, mm_content = (
request.prompt, request.prompt,
request.prompt_len, request.prompt_len,
@ -397,11 +485,8 @@ async def benchmark(
ignore_eos=ignore_eos, ignore_eos=ignore_eos,
extra_body=extra_body, extra_body=extra_body,
) )
tasks.append( task = limited_request_func(request_func_input=request_func_input, pbar=pbar)
asyncio.create_task( tasks.append(asyncio.create_task(task))
limited_request_func(request_func_input=request_func_input, pbar=pbar)
)
)
outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks) outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
if profile: if profile:
@ -477,6 +562,9 @@ async def benchmark(
"errors": [output.error for output in outputs], "errors": [output.error for output in outputs],
} }
if rps_change_events:
result["rps_change_events"] = rps_change_events
def process_one_metric( def process_one_metric(
# E.g., "ttft" # E.g., "ttft"
metric_attribute_name: str, metric_attribute_name: str,
@ -610,6 +698,26 @@ def main(args: argparse.Namespace):
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
tokenizer_mode = args.tokenizer_mode tokenizer_mode = args.tokenizer_mode
# Validate ramp-up arguments
if args.ramp_up_strategy is not None:
if args.request_rate != float("inf"):
raise ValueError(
"When using ramp-up, do not specify --request-rate. "
"The request rate will be controlled by ramp-up parameters. "
"Please remove the --request-rate argument."
)
if args.ramp_up_start_rps is None or args.ramp_up_end_rps is None:
raise ValueError(
"When using --ramp-up-strategy, both --ramp-up-start-rps and "
"--ramp-up-end-rps must be specified"
)
if args.ramp_up_start_rps < 0 or args.ramp_up_end_rps < 0:
raise ValueError("Ramp-up start and end RPS must be non-negative")
if args.ramp_up_start_rps > args.ramp_up_end_rps:
raise ValueError("Ramp-up start RPS must be less than end RPS")
if args.ramp_up_strategy == "exponential" and args.ramp_up_start_rps == 0:
raise ValueError("For exponential ramp-up, the start RPS cannot be 0.")
if args.base_url is not None: if args.base_url is not None:
api_url = f"{args.base_url}{args.endpoint}" api_url = f"{args.base_url}{args.endpoint}"
base_url = f"{args.base_url}" base_url = f"{args.base_url}"
@ -802,6 +910,9 @@ def main(args: argparse.Namespace):
max_concurrency=args.max_concurrency, max_concurrency=args.max_concurrency,
lora_modules=args.lora_modules, lora_modules=args.lora_modules,
extra_body=sampling_params, extra_body=sampling_params,
ramp_up_strategy=args.ramp_up_strategy,
ramp_up_start_rps=args.ramp_up_start_rps,
ramp_up_end_rps=args.ramp_up_end_rps,
) )
) )
@ -834,6 +945,11 @@ def main(args: argparse.Namespace):
result_json["burstiness"] = args.burstiness result_json["burstiness"] = args.burstiness
result_json["max_concurrency"] = args.max_concurrency result_json["max_concurrency"] = args.max_concurrency
if args.ramp_up_strategy is not None:
result_json["ramp_up_strategy"] = args.ramp_up_strategy
result_json["ramp_up_start_rps"] = args.ramp_up_start_rps
result_json["ramp_up_end_rps"] = args.ramp_up_end_rps
# Merge with benchmark result # Merge with benchmark result
result_json = {**result_json, **benchmark_result} result_json = {**result_json, **benchmark_result}
@ -859,6 +975,9 @@ def main(args: argparse.Namespace):
if args.max_concurrency is not None if args.max_concurrency is not None
else "" else ""
) )
if args.ramp_up_strategy is not None:
file_name = f"{backend}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
else:
file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
if args.result_filename: if args.result_filename:
file_name = args.result_filename file_name = args.result_filename
@ -1225,6 +1344,31 @@ def create_argument_parser():
"script chooses a LoRA module at random.", "script chooses a LoRA module at random.",
) )
parser.add_argument(
"--ramp-up-strategy",
type=str,
default=None,
choices=["linear", "exponential"],
help="The ramp-up strategy. This would be used to "
"ramp up the request rate from initial RPS to final "
"RPS rate (specified by --ramp-up-start-rps and --ramp-up-end-rps). "
"over the duration of the benchmark.",
)
parser.add_argument(
"--ramp-up-start-rps",
type=int,
default=None,
help="The starting request rate for ramp-up (RPS). "
"Needs to be specified when --ramp-up-strategy is used.",
)
parser.add_argument(
"--ramp-up-end-rps",
type=int,
default=None,
help="The ending request rate for ramp-up (RPS). "
"Needs to be specified when --ramp-up-strategy is used.",
)
return parser return parser

View File

@ -26,7 +26,7 @@ import warnings
from collections.abc import AsyncGenerator, Iterable from collections.abc import AsyncGenerator, Iterable
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from typing import Any, Optional from typing import Any, Literal, Optional
import numpy as np import numpy as np
from tqdm.asyncio import tqdm from tqdm.asyncio import tqdm
@ -75,14 +75,39 @@ class BenchmarkMetrics:
percentiles_e2el_ms: list[tuple[float, float]] percentiles_e2el_ms: list[tuple[float, float]]
def _get_current_request_rate(
ramp_up_strategy: Optional[Literal["linear", "exponential"]],
ramp_up_start_rps: Optional[int],
ramp_up_end_rps: Optional[int],
request_index: int,
total_requests: int,
request_rate: float,
) -> float:
if (ramp_up_strategy and ramp_up_start_rps is not None
and ramp_up_end_rps is not None):
progress = request_index / max(total_requests - 1, 1)
if ramp_up_strategy == "linear":
increase = (ramp_up_end_rps - ramp_up_start_rps) * progress
return ramp_up_start_rps + increase
elif ramp_up_strategy == "exponential":
ratio = ramp_up_end_rps / ramp_up_start_rps
return ramp_up_start_rps * (ratio**progress)
else:
raise ValueError(f"Unknown ramp-up strategy: {ramp_up_strategy}")
return request_rate
async def get_request( async def get_request(
input_requests: list[SampleRequest], input_requests: list[SampleRequest],
request_rate: float, request_rate: float,
burstiness: float = 1.0, burstiness: float = 1.0,
) -> AsyncGenerator[SampleRequest, None]: ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None,
ramp_up_start_rps: Optional[int] = None,
ramp_up_end_rps: Optional[int] = None,
) -> AsyncGenerator[tuple[SampleRequest, float], None]:
""" """
Asynchronously generates requests at a specified rate Asynchronously generates requests at a specified rate
with OPTIONAL burstiness. with OPTIONAL burstiness and OPTIONAL ramp-up strategy.
Args: Args:
input_requests: input_requests:
@ -97,21 +122,42 @@ async def get_request(
A lower burstiness value (0 < burstiness < 1) results A lower burstiness value (0 < burstiness < 1) results
in more bursty requests, while a higher burstiness value in more bursty requests, while a higher burstiness value
(burstiness > 1) results in a more uniform arrival of requests. (burstiness > 1) results in a more uniform arrival of requests.
ramp_up_strategy (optional):
The ramp-up strategy. Can be "linear" or "exponential".
If None, uses constant request rate (specified by request_rate).
ramp_up_start_rps (optional):
The starting request rate for ramp-up.
ramp_up_end_rps (optional):
The ending request rate for ramp-up.
""" """
input_requests: Iterable[SampleRequest] = iter(input_requests)
# Calculate scale parameter theta to maintain the desired request_rate.
assert burstiness > 0, ( assert burstiness > 0, (
f"A positive burstiness factor is expected, but given {burstiness}.") f"A positive burstiness factor is expected, but given {burstiness}.")
theta = 1.0 / (request_rate * burstiness) # Convert to list to get length for ramp-up calculations
if isinstance(input_requests, Iterable) and not isinstance(
input_requests, list):
input_requests = list(input_requests)
total_requests = len(input_requests)
request_index = 0
for request in input_requests: for request in input_requests:
yield request current_request_rate = _get_current_request_rate(ramp_up_strategy,
ramp_up_start_rps,
ramp_up_end_rps,
request_index,
total_requests,
request_rate)
if request_rate == float("inf"): yield request, current_request_rate
request_index += 1
if current_request_rate == float("inf"):
# If the request rate is infinity, then we don't need to wait. # If the request rate is infinity, then we don't need to wait.
continue continue
theta = 1.0 / (current_request_rate * burstiness)
# Sample the request interval from the gamma distribution. # Sample the request interval from the gamma distribution.
# If burstiness is 1, it follows exponential distribution. # If burstiness is 1, it follows exponential distribution.
interval = np.random.gamma(shape=burstiness, scale=theta) interval = np.random.gamma(shape=burstiness, scale=theta)
@ -259,6 +305,9 @@ async def benchmark(
max_concurrency: Optional[int], max_concurrency: Optional[int],
lora_modules: Optional[Iterable[str]], lora_modules: Optional[Iterable[str]],
extra_body: Optional[dict], extra_body: Optional[dict],
ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None,
ramp_up_start_rps: Optional[int] = None,
ramp_up_end_rps: Optional[int] = None,
): ):
if endpoint_type in ASYNC_REQUEST_FUNCS: if endpoint_type in ASYNC_REQUEST_FUNCS:
request_func = ASYNC_REQUEST_FUNCS[endpoint_type] request_func = ASYNC_REQUEST_FUNCS[endpoint_type]
@ -316,12 +365,16 @@ async def benchmark(
if profile_output.success: if profile_output.success:
print("Profiler started") print("Profiler started")
if burstiness == 1.0: distribution = ("Poisson process" if burstiness == 1.0
distribution = "Poisson process" else "Gamma distribution")
else:
distribution = "Gamma distribution"
if ramp_up_strategy is not None:
print(f"Traffic ramp-up strategy: {ramp_up_strategy}.")
print(f"Will increase RPS from {ramp_up_start_rps} to "
f"{ramp_up_end_rps} RPS over the duration of the benchmark.")
else:
print(f"Traffic request rate: {request_rate}") print(f"Traffic request rate: {request_rate}")
print(f"Burstiness factor: {burstiness} ({distribution})") print(f"Burstiness factor: {burstiness} ({distribution})")
print(f"Maximum request concurrency: {max_concurrency}") print(f"Maximum request concurrency: {max_concurrency}")
@ -344,7 +397,29 @@ async def benchmark(
benchmark_start_time = time.perf_counter() benchmark_start_time = time.perf_counter()
tasks: list[asyncio.Task] = [] tasks: list[asyncio.Task] = []
async for request in get_request(input_requests, request_rate, burstiness):
rps_change_events = []
last_int_rps = -1
if ramp_up_strategy is not None and ramp_up_start_rps is not None:
last_int_rps = ramp_up_start_rps
rps_change_events.append({
"rps": last_int_rps,
"timestamp": datetime.now().isoformat(),
})
async for request, current_request_rate in get_request(
input_requests, request_rate, burstiness, ramp_up_strategy,
ramp_up_start_rps, ramp_up_end_rps):
if ramp_up_strategy is not None:
current_int_rps = int(current_request_rate)
if current_int_rps > last_int_rps:
timestamp = datetime.now().isoformat()
for rps_val in range(last_int_rps + 1, current_int_rps + 1):
rps_change_events.append({
"rps": rps_val,
"timestamp": timestamp
})
last_int_rps = current_int_rps
prompt, prompt_len, output_len, mm_content = ( prompt, prompt_len, output_len, mm_content = (
request.prompt, request.prompt,
request.prompt_len, request.prompt_len,
@ -435,6 +510,9 @@ async def benchmark(
"errors": [output.error for output in outputs], "errors": [output.error for output in outputs],
} }
if rps_change_events:
result["rps_change_events"] = rps_change_events
def process_one_metric( def process_one_metric(
# E.g., "ttft" # E.g., "ttft"
metric_attribute_name: str, metric_attribute_name: str,
@ -771,12 +849,60 @@ def add_cli_args(parser: argparse.ArgumentParser):
"launching the server. For each request, the " "launching the server. For each request, the "
"script chooses a LoRA module at random.") "script chooses a LoRA module at random.")
parser.add_argument(
"--ramp-up-strategy",
type=str,
default=None,
choices=["linear", "exponential"],
help="The ramp-up strategy. This would be used to "
"ramp up the request rate from initial RPS to final "
"RPS rate (specified by --ramp-up-start-rps and "
"--ramp-up-end-rps.) over the duration of the benchmark."
)
parser.add_argument(
"--ramp-up-start-rps",
type=int,
default=None,
help="The starting request rate for ramp-up (RPS). "
"Needs to be specified when --ramp-up-strategy is used.",
)
parser.add_argument(
"--ramp-up-end-rps",
type=int,
default=None,
help="The ending request rate for ramp-up (RPS). "
"Needs to be specified when --ramp-up-strategy is used.",
)
def main(args: argparse.Namespace): def main(args: argparse.Namespace):
print(args) print(args)
random.seed(args.seed) random.seed(args.seed)
np.random.seed(args.seed) np.random.seed(args.seed)
# Validate ramp-up arguments
if args.ramp_up_strategy is not None:
if args.request_rate != float("inf"):
raise ValueError(
"When using ramp-up, do not specify --request-rate. "
"The request rate will be controlled by ramp-up parameters. "
"Please remove the --request-rate argument."
)
if args.ramp_up_start_rps is None or args.ramp_up_end_rps is None:
raise ValueError(
"When using --ramp-up-strategy, both --ramp-up-start-rps and "
"--ramp-up-end-rps must be specified"
)
if args.ramp_up_start_rps < 0 or args.ramp_up_end_rps < 0:
raise ValueError("Ramp-up start and end RPS must be non-negative")
if args.ramp_up_start_rps > args.ramp_up_end_rps:
raise ValueError("Ramp-up start RPS must be less than end RPS")
if (args.ramp_up_strategy == "exponential"
and args.ramp_up_start_rps == 0):
raise ValueError(
"For exponential ramp-up, the start RPS cannot be 0.")
endpoint_type = args.endpoint_type
label = args.label label = args.label
model_id = args.model model_id = args.model
model_name = args.served_model_name model_name = args.served_model_name
@ -849,6 +975,9 @@ def main(args: argparse.Namespace):
max_concurrency=args.max_concurrency, max_concurrency=args.max_concurrency,
lora_modules=args.lora_modules, lora_modules=args.lora_modules,
extra_body=sampling_params, extra_body=sampling_params,
ramp_up_strategy=args.ramp_up_strategy,
ramp_up_start_rps=args.ramp_up_start_rps,
ramp_up_end_rps=args.ramp_up_end_rps,
)) ))
# Save config and results to json # Save config and results to json
@ -881,6 +1010,11 @@ def main(args: argparse.Namespace):
result_json["burstiness"] = args.burstiness result_json["burstiness"] = args.burstiness
result_json["max_concurrency"] = args.max_concurrency result_json["max_concurrency"] = args.max_concurrency
if args.ramp_up_strategy is not None:
result_json["ramp_up_strategy"] = args.ramp_up_strategy
result_json["ramp_up_start_rps"] = args.ramp_up_start_rps
result_json["ramp_up_end_rps"] = args.ramp_up_end_rps
# Merge with benchmark result # Merge with benchmark result
result_json = {**result_json, **benchmark_result} result_json = {**result_json, **benchmark_result}
@ -903,7 +1037,10 @@ def main(args: argparse.Namespace):
base_model_id = model_id.split("/")[-1] base_model_id = model_id.split("/")[-1]
max_concurrency_str = (f"-concurrency{args.max_concurrency}" max_concurrency_str = (f"-concurrency{args.max_concurrency}"
if args.max_concurrency is not None else "") if args.max_concurrency is not None else "")
label = label or args.endpoint_type label = label or endpoint_type
if args.ramp_up_strategy is not None:
file_name = f"{label}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
else:
file_name = f"{label}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa file_name = f"{label}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
if args.result_filename: if args.result_filename:
file_name = args.result_filename file_name = args.result_filename