[Benchmarks] add benchmark for embedding models (#23000)

Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
This commit is contained in:
Jiangyun Zhu 2025-08-26 14:57:08 +08:00 committed by GitHub
parent 7d67a9d9f9
commit 3ecbb14b81
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 274 additions and 107 deletions

View File

@ -73,7 +73,7 @@ class SampleRequest:
Represents a single inference request for benchmarking. Represents a single inference request for benchmarking.
""" """
prompt: Union[str, Any] prompt: Union[str, list[str]]
prompt_len: int prompt_len: int
expected_output_len: int expected_output_len: int
multi_modal_data: Optional[ multi_modal_data: Optional[
@ -409,6 +409,7 @@ class RandomDataset(BenchmarkDataset):
range_ratio: float = DEFAULT_RANGE_RATIO, range_ratio: float = DEFAULT_RANGE_RATIO,
input_len: int = DEFAULT_INPUT_LEN, input_len: int = DEFAULT_INPUT_LEN,
output_len: int = DEFAULT_OUTPUT_LEN, output_len: int = DEFAULT_OUTPUT_LEN,
batchsize: int = 1,
**kwargs, **kwargs,
) -> list[SampleRequest]: ) -> list[SampleRequest]:
@ -439,6 +440,21 @@ class RandomDataset(BenchmarkDataset):
request_id=request_id_prefix + str(i), request_id=request_id_prefix + str(i),
) )
) )
# only used for embeddings benchmark.
if batchsize > 1:
batch_requests = []
# Create batched requests
for i in range(0, num_requests, batchsize):
batch = requests[i : i + batchsize]
batch_requests.append(
SampleRequest(
prompt=[req.prompt for req in batch],
prompt_len=sum(req.prompt_len for req in batch),
expected_output_len=0,
request_id=request_id_prefix + str(i // batchsize),
)
)
requests = batch_requests
return requests return requests
def get_prefix( def get_prefix(
@ -475,8 +491,8 @@ class RandomDataset(BenchmarkDataset):
input_high = math.ceil(real_input_len * (1 + range_ratio)) input_high = math.ceil(real_input_len * (1 + range_ratio))
output_low = math.floor(output_len * (1 - range_ratio)) output_low = math.floor(output_len * (1 - range_ratio))
output_high = math.ceil(output_len * (1 + range_ratio)) output_high = math.ceil(output_len * (1 + range_ratio))
# Ensure the lower bound for output length is at least 1 to # Ensure the lower bound for output length is at least 1 to
# prevent sampling 0 tokens. # prevent sampling 0 tokens.
output_low = max(output_low, 1) output_low = max(output_low, 1)
if input_low > input_high: if input_low > input_high:
@ -506,7 +522,6 @@ class RandomDataset(BenchmarkDataset):
size=num_requests) size=num_requests)
return input_lens, output_lens, offsets return input_lens, output_lens, offsets
def generate_token_sequence( def generate_token_sequence(
self, self,
*, *,
@ -1105,6 +1120,13 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
"context length sampled from [input_len * (1 - range_ratio), " "context length sampled from [input_len * (1 - range_ratio), "
"input_len * (1 + range_ratio)]."), "input_len * (1 + range_ratio)]."),
) )
random_group.add_argument(
"--random-batch-size",
type=int,
default=1,
help=("Batch size for random sampling. "
"Only used for embeddings benchmark."),
)
# random multimodal dataset options # random multimodal dataset options
random_mm_group = parser.add_argument_group( random_mm_group = parser.add_argument_group(
@ -1196,8 +1218,6 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
), ),
) )
hf_group = parser.add_argument_group("hf dataset options") hf_group = parser.add_argument_group("hf dataset options")
hf_group.add_argument("--hf-subset", hf_group.add_argument("--hf-subset",
type=str, type=str,
@ -1348,22 +1368,24 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
else: else:
# For datasets that follow a similar structure, use a mapping. # For datasets that follow a similar structure, use a mapping.
dataset_mapping = { dataset_mapping = {
"sharegpt": "sharegpt": lambda: ShareGPTDataset(
lambda: ShareGPTDataset(random_seed=args.seed, random_seed=args.seed, dataset_path=args.dataset_path
dataset_path=args.dataset_path).sample( ).sample(
tokenizer=tokenizer, tokenizer=tokenizer,
num_requests=args.num_prompts, num_requests=args.num_prompts,
output_len=args.sharegpt_output_len, output_len=args.sharegpt_output_len,
request_id_prefix=args.request_id_prefix, request_id_prefix=args.request_id_prefix,
), ),
"burstgpt": "burstgpt": lambda: BurstGPTDataset(
lambda: BurstGPTDataset(random_seed=args.seed, random_seed=args.seed, dataset_path=args.dataset_path
dataset_path=args.dataset_path). ).sample(
sample(tokenizer=tokenizer, num_requests=args.num_prompts, tokenizer=tokenizer,
request_id_prefix=args.request_id_prefix,), num_requests=args.num_prompts,
"random": request_id_prefix=args.request_id_prefix,
lambda: RandomDataset(random_seed=args.seed, ),
dataset_path=args.dataset_path).sample( "random": lambda: RandomDataset(
random_seed=args.seed, dataset_path=args.dataset_path
).sample(
tokenizer=tokenizer, tokenizer=tokenizer,
num_requests=args.num_prompts, num_requests=args.num_prompts,
prefix_len=args.random_prefix_len, prefix_len=args.random_prefix_len,
@ -1371,6 +1393,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
output_len=args.random_output_len, output_len=args.random_output_len,
range_ratio=args.random_range_ratio, range_ratio=args.random_range_ratio,
request_id_prefix=args.request_id_prefix, request_id_prefix=args.request_id_prefix,
batchsize=args.random_batch_size,
), ),
"random-mm": "random-mm":
lambda: RandomMultiModalDataset( lambda: RandomMultiModalDataset(

View File

@ -69,8 +69,8 @@ async def async_request_openai_completions(
), "OpenAI Completions API URL must end with 'completions' or 'profile'." ), "OpenAI Completions API URL must end with 'completions' or 'profile'."
payload = { payload = {
"model": request_func_input.model_name \ "model": request_func_input.model_name
if request_func_input.model_name else request_func_input.model, if request_func_input.model_name else request_func_input.model,
"prompt": request_func_input.prompt, "prompt": request_func_input.prompt,
"temperature": 0.0, "temperature": 0.0,
"repetition_penalty": 1.0, "repetition_penalty": 1.0,
@ -135,7 +135,7 @@ async def async_request_openai_completions(
# Decoding phase # Decoding phase
else: else:
output.itl.append(timestamp - output.itl.append(timestamp -
most_recent_timestamp) most_recent_timestamp)
most_recent_timestamp = timestamp most_recent_timestamp = timestamp
generated_text += text or "" generated_text += text or ""
@ -254,7 +254,7 @@ async def async_request_openai_chat_completions(
# Decoding phase # Decoding phase
else: else:
output.itl.append(timestamp - output.itl.append(timestamp -
most_recent_timestamp) most_recent_timestamp)
generated_text += content or "" generated_text += content or ""
elif usage := data.get("usage"): elif usage := data.get("usage"):
@ -394,12 +394,61 @@ async def async_request_openai_audio(
return output return output
async def async_request_openai_embeddings(
request_func_input: RequestFuncInput,
session: aiohttp.ClientSession,
pbar: Optional[tqdm] = None,
):
api_url = request_func_input.api_url
assert api_url.endswith(
"embeddings"
), "OpenAI Embeddings API URL must end with 'embeddings'."
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
}
payload = {
"model": request_func_input.model,
"input": request_func_input.prompt,
}
output = RequestFuncOutput()
st = time.perf_counter()
try:
async with session.post(
url=api_url,
headers=headers,
json=payload
) as response:
if response.status == 200:
output.latency = time.perf_counter() - st
data = await response.json()
output.success = True
output.generated_text = ""
output.prompt_len = data.get(
"usage", {}).get(
"prompt_tokens", 0)
else:
output.success = False
output.error = response.reason or ""
except Exception as e:
output.success = False
output.error = str(e)
if pbar:
pbar.update(1)
return output
# TODO: Add more request functions for different API protocols. # TODO: Add more request functions for different API protocols.
ASYNC_REQUEST_FUNCS = { ASYNC_REQUEST_FUNCS = {
"vllm": async_request_openai_completions, "vllm": async_request_openai_completions,
"openai": async_request_openai_completions, "openai": async_request_openai_completions,
"openai-chat": async_request_openai_chat_completions, "openai-chat": async_request_openai_chat_completions,
"openai-audio": async_request_openai_audio, "openai-audio": async_request_openai_audio,
"openai-embeddings": async_request_openai_embeddings,
} }
OPENAI_COMPATIBLE_BACKENDS = [ OPENAI_COMPATIBLE_BACKENDS = [

View File

@ -4,7 +4,7 @@ r"""Benchmark online serving throughput.
On the server side, run one of the following commands On the server side, run one of the following commands
to launch the vLLM OpenAI API server: to launch the vLLM OpenAI API server:
vllm serve <your_model> <engine arguments> vllm serve <your_model> <engine arguments>
On the client side, run: On the client side, run:
vllm bench serve \ vllm bench serve \
@ -26,6 +26,7 @@ import warnings
from collections.abc import AsyncGenerator, Iterable from collections.abc import AsyncGenerator, Iterable
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from enum import Enum
from typing import Any, Literal, Optional from typing import Any, Literal, Optional
import aiohttp import aiohttp
@ -46,6 +47,11 @@ from vllm.transformers_utils.tokenizer import get_tokenizer
MILLISECONDS_TO_SECONDS_CONVERSION = 1000 MILLISECONDS_TO_SECONDS_CONVERSION = 1000
class TaskType(Enum):
GENERATION = "generation"
EMBEDDING = "embedding"
@dataclass @dataclass
class BenchmarkMetrics: class BenchmarkMetrics:
completed: int completed: int
@ -75,6 +81,16 @@ class BenchmarkMetrics:
std_e2el_ms: float std_e2el_ms: float
percentiles_e2el_ms: list[tuple[float, float]] percentiles_e2el_ms: list[tuple[float, float]]
@dataclass
class EmbedBenchmarkMetrics:
completed: int
total_input: int
request_throughput: float
total_token_throughput :float
mean_e2el_ms: float
std_e2el_ms: float
median_e2el_ms: float
percentiles_e2el_ms: float
def _get_current_request_rate( def _get_current_request_rate(
ramp_up_strategy: Optional[Literal["linear", "exponential"]], ramp_up_strategy: Optional[Literal["linear", "exponential"]],
@ -146,11 +162,11 @@ async def get_request(
delay_ts = [] delay_ts = []
for request_index, request in enumerate(input_requests): for request_index, request in enumerate(input_requests):
current_request_rate = _get_current_request_rate(ramp_up_strategy, current_request_rate = _get_current_request_rate(ramp_up_strategy,
ramp_up_start_rps, ramp_up_start_rps,
ramp_up_end_rps, ramp_up_end_rps,
request_index, request_index,
total_requests, total_requests,
request_rate) request_rate)
request_rates.append(current_request_rate) request_rates.append(current_request_rate)
if current_request_rate == float("inf"): if current_request_rate == float("inf"):
delay_ts.append(0) delay_ts.append(0)
@ -160,7 +176,7 @@ async def get_request(
# Sample the request interval from the gamma distribution. # Sample the request interval from the gamma distribution.
# If burstiness is 1, it follows exponential distribution. # If burstiness is 1, it follows exponential distribution.
delay_ts.append(np.random.gamma(shape=burstiness, scale=theta)) delay_ts.append(np.random.gamma(shape=burstiness, scale=theta))
# Calculate the cumulative delay time from the first sent out requests. # Calculate the cumulative delay time from the first sent out requests.
for i in range(1, len(delay_ts)): for i in range(1, len(delay_ts)):
delay_ts[i] += delay_ts[i - 1] delay_ts[i] += delay_ts[i - 1]
@ -170,11 +186,11 @@ async def get_request(
# logic would re-scale delay time to ensure the final delay_ts # logic would re-scale delay time to ensure the final delay_ts
# align with target_total_delay_s. # align with target_total_delay_s.
# #
# NOTE: If we simply accumulate the random delta values # NOTE: If we simply accumulate the random delta values
# from the gamma distribution, their sum would have 1-2% gap # from the gamma distribution, their sum would have 1-2% gap
# from target_total_delay_s. The purpose of the following logic is to # from target_total_delay_s. The purpose of the following logic is to
# close the gap for stablizing the throughput data # close the gap for stablizing the throughput data
# from different random seeds. # from different random seeds.
target_total_delay_s = total_requests / request_rate target_total_delay_s = total_requests / request_rate
normalize_factor = target_total_delay_s / delay_ts[-1] normalize_factor = target_total_delay_s / delay_ts[-1]
delay_ts = [delay * normalize_factor for delay in delay_ts] delay_ts = [delay * normalize_factor for delay in delay_ts]
@ -189,6 +205,51 @@ async def get_request(
yield request, request_rates[request_index] yield request, request_rates[request_index]
def calculate_metrics_for_embeddings(
outputs: list[RequestFuncOutput],
dur_s: float,
selected_percentiles: list[float]
) -> EmbedBenchmarkMetrics:
"""Calculate the metrics for the embedding requests.
Args:
outputs: The outputs of the requests.
dur_s: The duration of the benchmark.
selected_percentiles: The percentiles to select.
Returns:
The calculated benchmark metrics.
"""
total_input = 0
completed = 0
e2els: list[float] = []
for i in range(len(outputs)):
if outputs[i].success:
e2els.append(outputs[i].latency)
completed += 1
total_input += outputs[i].prompt_len
if completed == 0:
warnings.warn(
"All requests failed. This is likely due to a misconfiguration "
"on the benchmark arguments.",
stacklevel=2)
metrics = EmbedBenchmarkMetrics(
completed=completed,
total_input=total_input,
request_throughput=completed / dur_s,
total_token_throughput=total_input / dur_s,
mean_e2el_ms=np.mean(e2els or 0) * 1000,
std_e2el_ms=np.std(e2els or 0) * 1000,
median_e2el_ms=np.median(e2els or 0) * 1000,
percentiles_e2el_ms=[
(p, np.percentile(e2els or 0, p) * 1000)
for p in selected_percentiles
],
)
return metrics
def calculate_metrics( def calculate_metrics(
input_requests: list[SampleRequest], input_requests: list[SampleRequest],
outputs: list[RequestFuncOutput], outputs: list[RequestFuncOutput],
@ -334,8 +395,16 @@ async def benchmark(
ramp_up_end_rps: Optional[int] = None, ramp_up_end_rps: Optional[int] = None,
ready_check_timeout_sec: int = 600, ready_check_timeout_sec: int = 600,
): ):
task_type = (
TaskType.EMBEDDING
if api_url.endswith("/v1/embeddings")
else TaskType.GENERATION
)
if endpoint_type in ASYNC_REQUEST_FUNCS: if endpoint_type in ASYNC_REQUEST_FUNCS:
request_func = ASYNC_REQUEST_FUNCS[endpoint_type] if task_type == TaskType.EMBEDDING:
request_func = ASYNC_REQUEST_FUNCS["openai-embeddings"]
else:
request_func = ASYNC_REQUEST_FUNCS[endpoint_type]
else: else:
raise ValueError(f"Unknown endpoint_type: {endpoint_type}") raise ValueError(f"Unknown endpoint_type: {endpoint_type}")
@ -421,8 +490,8 @@ async def benchmark(
if profile_output.success: if profile_output.success:
print("Profiler started") print("Profiler started")
distribution = ("Poisson process" if burstiness == 1.0 distribution = ("Poisson process" if burstiness == 1.0
else "Gamma distribution") else "Gamma distribution")
if ramp_up_strategy is not None: if ramp_up_strategy is not None:
print(f"Traffic ramp-up strategy: {ramp_up_strategy}.") print(f"Traffic ramp-up strategy: {ramp_up_strategy}.")
@ -449,7 +518,7 @@ async def benchmark(
session=session, session=session,
pbar=pbar) pbar=pbar)
async with semaphore: async with semaphore:
return await request_func(request_func_input=request_func_input, return await request_func(request_func_input=request_func_input,
session=session, session=session,
pbar=pbar) pbar=pbar)
@ -513,14 +582,22 @@ async def benchmark(
benchmark_duration = time.perf_counter() - benchmark_start_time benchmark_duration = time.perf_counter() - benchmark_start_time
metrics, actual_output_lens = calculate_metrics( if task_type == TaskType.GENERATION:
input_requests=input_requests, metrics, actual_output_lens = calculate_metrics(
outputs=outputs, input_requests=input_requests,
dur_s=benchmark_duration, outputs=outputs,
tokenizer=tokenizer, dur_s=benchmark_duration,
selected_percentiles=selected_percentiles, tokenizer=tokenizer,
goodput_config_dict=goodput_config_dict, selected_percentiles=selected_percentiles,
) goodput_config_dict=goodput_config_dict,
)
else:
metrics = calculate_metrics_for_embeddings(
outputs=outputs,
dur_s=benchmark_duration,
selected_percentiles=selected_percentiles,
)
actual_output_lens = 0
print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='=')) print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='='))
print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
@ -529,39 +606,55 @@ async def benchmark(
max_concurrency)) max_concurrency))
if request_rate != float('inf'): if request_rate != float('inf'):
print("{:<40} {:<10.2f}".format("Request rate configured (RPS):", print("{:<40} {:<10.2f}".format("Request rate configured (RPS):",
request_rate )) request_rate))
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", print("{:<40} {:<10.2f}".format("Benchmark duration (s):",
benchmark_duration)) benchmark_duration))
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
print("{:<40} {:<10}".format("Total generated tokens:", if isinstance(metrics, BenchmarkMetrics):
metrics.total_output)) print("{:<40} {:<10}".format(
"Total generated tokens:", metrics.total_output))
print("{:<40} {:<10.2f}".format("Request throughput (req/s):", print("{:<40} {:<10.2f}".format("Request throughput (req/s):",
metrics.request_throughput)) metrics.request_throughput))
if goodput_config_dict: if goodput_config_dict:
print("{:<40} {:<10.2f}".format("Request goodput (req/s):", print("{:<40} {:<10.2f}".format("Request goodput (req/s):",
metrics.request_goodput)) metrics.request_goodput))
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", if isinstance(metrics, BenchmarkMetrics):
metrics.output_throughput)) print(
"{:<40} {:<10.2f}".format(
"Output token throughput (tok/s):", metrics.output_throughput
)
)
print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):", print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):",
metrics.total_token_throughput)) metrics.total_token_throughput))
result = { if isinstance(metrics, BenchmarkMetrics):
"duration": benchmark_duration, result = {
"completed": metrics.completed, "duration": benchmark_duration,
"total_input_tokens": metrics.total_input, "completed": metrics.completed,
"total_output_tokens": metrics.total_output, "total_input_tokens": metrics.total_input,
"request_throughput": metrics.request_throughput, "total_output_tokens": metrics.total_output,
"request_goodput": "request_throughput": metrics.request_throughput,
metrics.request_goodput if goodput_config_dict else None, "request_goodput":
"output_throughput": metrics.output_throughput, metrics.request_goodput if goodput_config_dict else None,
"total_token_throughput": metrics.total_token_throughput, "output_throughput": metrics.output_throughput,
"input_lens": [output.prompt_len for output in outputs], "total_token_throughput": metrics.total_token_throughput,
"output_lens": actual_output_lens, "input_lens": [output.prompt_len for output in outputs],
"ttfts": [output.ttft for output in outputs], "output_lens": actual_output_lens,
"itls": [output.itl for output in outputs], "ttfts": [output.ttft for output in outputs],
"generated_texts": [output.generated_text for output in outputs], "itls": [output.itl for output in outputs],
"errors": [output.error for output in outputs], "generated_texts": [output.generated_text for output in outputs],
} "errors": [output.error for output in outputs],
}
else:
result = {
"duration": benchmark_duration,
"completed": metrics.completed,
"total_input_tokens": metrics.total_input,
"request_throughput": metrics.request_throughput,
"total_token_throughput": metrics.total_token_throughput,
"input_lens": [output.prompt_len for output in outputs],
"errors": [output.error for output in outputs],
}
if rps_change_events: if rps_change_events:
result["rps_change_events"] = rps_change_events result["rps_change_events"] = rps_change_events
@ -598,10 +691,11 @@ async def benchmark(
value)) value))
result[f"p{p_word}_{metric_attribute_name}_ms"] = value result[f"p{p_word}_{metric_attribute_name}_ms"] = value
process_one_metric("ttft", "TTFT", "Time to First Token") if task_type == TaskType.GENERATION:
process_one_metric("tpot", "TPOT", process_one_metric("ttft", "TTFT", "Time to First Token")
"Time per Output Token (excl. 1st token)") process_one_metric(
process_one_metric("itl", "ITL", "Inter-token Latency") "tpot", "TPOT", "Time per Output Token (excl. 1st token)")
process_one_metric("itl", "ITL", "Inter-token Latency")
process_one_metric("e2el", "E2EL", "End-to-end Latency") process_one_metric("e2el", "E2EL", "End-to-end Latency")
print("=" * 50) print("=" * 50)
@ -732,7 +826,8 @@ def add_cli_args(parser: argparse.ArgumentParser):
"initiated, this argument will control how many are actually allowed " "initiated, this argument will control how many are actually allowed "
"to execute at a time. This means that when used in combination, the " "to execute at a time. This means that when used in combination, the "
"actual request rate may be lower than specified with --request-rate, " "actual request rate may be lower than specified with --request-rate, "
"if the server is not processing requests fast enough to keep up.") "if the server is not processing requests fast enough to keep up.",
)
parser.add_argument( parser.add_argument(
"--model", "--model",
@ -743,8 +838,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--tokenizer", "--tokenizer",
type=str, type=str,
help= help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
"Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
) )
parser.add_argument("--use-beam-search", action="store_true") parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument( parser.add_argument(
@ -968,6 +1062,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
def main(args: argparse.Namespace) -> dict[str, Any]: def main(args: argparse.Namespace) -> dict[str, Any]:
return asyncio.run(main_async(args)) return asyncio.run(main_async(args))
async def main_async(args: argparse.Namespace) -> dict[str, Any]: async def main_async(args: argparse.Namespace) -> dict[str, Any]:
print(args) print(args)
random.seed(args.seed) random.seed(args.seed)
@ -1046,32 +1141,32 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
gc.freeze() gc.freeze()
benchmark_result = await benchmark( benchmark_result = await benchmark(
endpoint_type=args.endpoint_type, endpoint_type=args.endpoint_type,
api_url=api_url, api_url=api_url,
base_url=base_url, base_url=base_url,
model_id=model_id, model_id=model_id,
model_name=model_name, model_name=model_name,
tokenizer=tokenizer, tokenizer=tokenizer,
input_requests=input_requests, input_requests=input_requests,
logprobs=args.logprobs, logprobs=args.logprobs,
request_rate=args.request_rate, request_rate=args.request_rate,
burstiness=args.burstiness, burstiness=args.burstiness,
disable_tqdm=args.disable_tqdm, disable_tqdm=args.disable_tqdm,
profile=args.profile, profile=args.profile,
selected_percentile_metrics=args.percentile_metrics.split(","), selected_percentile_metrics=args.percentile_metrics.split(","),
selected_percentiles=[ selected_percentiles=[
float(p) for p in args.metric_percentiles.split(",") float(p) for p in args.metric_percentiles.split(",")
], ],
ignore_eos=args.ignore_eos, ignore_eos=args.ignore_eos,
goodput_config_dict=goodput_config_dict, goodput_config_dict=goodput_config_dict,
max_concurrency=args.max_concurrency, max_concurrency=args.max_concurrency,
lora_modules=args.lora_modules, lora_modules=args.lora_modules,
extra_body=sampling_params, extra_body=sampling_params,
ramp_up_strategy=args.ramp_up_strategy, ramp_up_strategy=args.ramp_up_strategy,
ramp_up_start_rps=args.ramp_up_start_rps, ramp_up_start_rps=args.ramp_up_start_rps,
ramp_up_end_rps=args.ramp_up_end_rps, ramp_up_end_rps=args.ramp_up_end_rps,
ready_check_timeout_sec=args.ready_check_timeout_sec, ready_check_timeout_sec=args.ready_check_timeout_sec,
) )
# Save config and results to json # Save config and results to json
result_json: dict[str, Any] = {} result_json: dict[str, Any] = {}
@ -1098,7 +1193,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
# Traffic # Traffic
result_json["request_rate"] = (args.request_rate if args.request_rate result_json["request_rate"] = (args.request_rate if args.request_rate
< float("inf") else "inf") < float("inf") else "inf")
result_json["burstiness"] = args.burstiness result_json["burstiness"] = args.burstiness
result_json["max_concurrency"] = args.max_concurrency result_json["max_concurrency"] = args.max_concurrency
@ -1132,7 +1227,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
if args.max_concurrency is not None else "") if args.max_concurrency is not None else "")
label = label or endpoint_type label = label or endpoint_type
if args.ramp_up_strategy is not None: if args.ramp_up_strategy is not None:
file_name = f"{label}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa file_name = f"{label}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
else: else:
file_name = f"{label}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa file_name = f"{label}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
if args.result_filename: if args.result_filename:
@ -1149,4 +1244,4 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
json.dump(result_json, outfile) json.dump(result_json, outfile)
save_to_pytorch_benchmark_format(args, result_json, file_name) save_to_pytorch_benchmark_format(args, result_json, file_name)
return result_json return result_json