diff --git a/vllm/benchmarks/datasets.py b/vllm/benchmarks/datasets.py index e586337367b1c..93519b5ba1523 100644 --- a/vllm/benchmarks/datasets.py +++ b/vllm/benchmarks/datasets.py @@ -73,7 +73,7 @@ class SampleRequest: Represents a single inference request for benchmarking. """ - prompt: Union[str, Any] + prompt: Union[str, list[str]] prompt_len: int expected_output_len: int multi_modal_data: Optional[ @@ -409,6 +409,7 @@ class RandomDataset(BenchmarkDataset): range_ratio: float = DEFAULT_RANGE_RATIO, input_len: int = DEFAULT_INPUT_LEN, output_len: int = DEFAULT_OUTPUT_LEN, + batchsize: int = 1, **kwargs, ) -> list[SampleRequest]: @@ -439,6 +440,21 @@ class RandomDataset(BenchmarkDataset): request_id=request_id_prefix + str(i), ) ) + # only used for embeddings benchmark. + if batchsize > 1: + batch_requests = [] + # Create batched requests + for i in range(0, num_requests, batchsize): + batch = requests[i : i + batchsize] + batch_requests.append( + SampleRequest( + prompt=[req.prompt for req in batch], + prompt_len=sum(req.prompt_len for req in batch), + expected_output_len=0, + request_id=request_id_prefix + str(i // batchsize), + ) + ) + requests = batch_requests return requests def get_prefix( @@ -475,8 +491,8 @@ class RandomDataset(BenchmarkDataset): input_high = math.ceil(real_input_len * (1 + range_ratio)) output_low = math.floor(output_len * (1 - range_ratio)) output_high = math.ceil(output_len * (1 + range_ratio)) - # Ensure the lower bound for output length is at least 1 to - # prevent sampling 0 tokens. + # Ensure the lower bound for output length is at least 1 to + # prevent sampling 0 tokens. output_low = max(output_low, 1) if input_low > input_high: @@ -506,7 +522,6 @@ class RandomDataset(BenchmarkDataset): size=num_requests) return input_lens, output_lens, offsets - def generate_token_sequence( self, *, @@ -1105,6 +1120,13 @@ def add_dataset_parser(parser: FlexibleArgumentParser): "context length sampled from [input_len * (1 - range_ratio), " "input_len * (1 + range_ratio)]."), ) + random_group.add_argument( + "--random-batch-size", + type=int, + default=1, + help=("Batch size for random sampling. " + "Only used for embeddings benchmark."), + ) # random multimodal dataset options random_mm_group = parser.add_argument_group( @@ -1196,8 +1218,6 @@ def add_dataset_parser(parser: FlexibleArgumentParser): ), ) - - hf_group = parser.add_argument_group("hf dataset options") hf_group.add_argument("--hf-subset", type=str, @@ -1348,22 +1368,24 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: else: # For datasets that follow a similar structure, use a mapping. dataset_mapping = { - "sharegpt": - lambda: ShareGPTDataset(random_seed=args.seed, - dataset_path=args.dataset_path).sample( - tokenizer=tokenizer, - num_requests=args.num_prompts, - output_len=args.sharegpt_output_len, - request_id_prefix=args.request_id_prefix, - ), - "burstgpt": - lambda: BurstGPTDataset(random_seed=args.seed, - dataset_path=args.dataset_path). - sample(tokenizer=tokenizer, num_requests=args.num_prompts, - request_id_prefix=args.request_id_prefix,), - "random": - lambda: RandomDataset(random_seed=args.seed, - dataset_path=args.dataset_path).sample( + "sharegpt": lambda: ShareGPTDataset( + random_seed=args.seed, dataset_path=args.dataset_path + ).sample( + tokenizer=tokenizer, + num_requests=args.num_prompts, + output_len=args.sharegpt_output_len, + request_id_prefix=args.request_id_prefix, + ), + "burstgpt": lambda: BurstGPTDataset( + random_seed=args.seed, dataset_path=args.dataset_path + ).sample( + tokenizer=tokenizer, + num_requests=args.num_prompts, + request_id_prefix=args.request_id_prefix, + ), + "random": lambda: RandomDataset( + random_seed=args.seed, dataset_path=args.dataset_path + ).sample( tokenizer=tokenizer, num_requests=args.num_prompts, prefix_len=args.random_prefix_len, @@ -1371,6 +1393,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: output_len=args.random_output_len, range_ratio=args.random_range_ratio, request_id_prefix=args.request_id_prefix, + batchsize=args.random_batch_size, ), "random-mm": lambda: RandomMultiModalDataset( diff --git a/vllm/benchmarks/lib/endpoint_request_func.py b/vllm/benchmarks/lib/endpoint_request_func.py index 76beded4d5189..6bb2a497119e9 100644 --- a/vllm/benchmarks/lib/endpoint_request_func.py +++ b/vllm/benchmarks/lib/endpoint_request_func.py @@ -69,8 +69,8 @@ async def async_request_openai_completions( ), "OpenAI Completions API URL must end with 'completions' or 'profile'." payload = { - "model": request_func_input.model_name \ - if request_func_input.model_name else request_func_input.model, + "model": request_func_input.model_name + if request_func_input.model_name else request_func_input.model, "prompt": request_func_input.prompt, "temperature": 0.0, "repetition_penalty": 1.0, @@ -135,7 +135,7 @@ async def async_request_openai_completions( # Decoding phase else: output.itl.append(timestamp - - most_recent_timestamp) + most_recent_timestamp) most_recent_timestamp = timestamp generated_text += text or "" @@ -254,7 +254,7 @@ async def async_request_openai_chat_completions( # Decoding phase else: output.itl.append(timestamp - - most_recent_timestamp) + most_recent_timestamp) generated_text += content or "" elif usage := data.get("usage"): @@ -394,12 +394,61 @@ async def async_request_openai_audio( return output +async def async_request_openai_embeddings( + request_func_input: RequestFuncInput, + session: aiohttp.ClientSession, + pbar: Optional[tqdm] = None, +): + api_url = request_func_input.api_url + assert api_url.endswith( + "embeddings" + ), "OpenAI Embeddings API URL must end with 'embeddings'." + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + } + + payload = { + "model": request_func_input.model, + "input": request_func_input.prompt, + } + + output = RequestFuncOutput() + st = time.perf_counter() + try: + async with session.post( + url=api_url, + headers=headers, + json=payload + ) as response: + if response.status == 200: + output.latency = time.perf_counter() - st + data = await response.json() + output.success = True + output.generated_text = "" + output.prompt_len = data.get( + "usage", {}).get( + "prompt_tokens", 0) + else: + output.success = False + output.error = response.reason or "" + except Exception as e: + output.success = False + output.error = str(e) + + if pbar: + pbar.update(1) + return output + + # TODO: Add more request functions for different API protocols. ASYNC_REQUEST_FUNCS = { "vllm": async_request_openai_completions, "openai": async_request_openai_completions, "openai-chat": async_request_openai_chat_completions, "openai-audio": async_request_openai_audio, + "openai-embeddings": async_request_openai_embeddings, } OPENAI_COMPATIBLE_BACKENDS = [ diff --git a/vllm/benchmarks/serve.py b/vllm/benchmarks/serve.py index 79f2c475cbe5d..abb838316cd31 100644 --- a/vllm/benchmarks/serve.py +++ b/vllm/benchmarks/serve.py @@ -4,7 +4,7 @@ r"""Benchmark online serving throughput. On the server side, run one of the following commands to launch the vLLM OpenAI API server: - vllm serve + vllm serve On the client side, run: vllm bench serve \ @@ -26,6 +26,7 @@ import warnings from collections.abc import AsyncGenerator, Iterable from dataclasses import dataclass from datetime import datetime +from enum import Enum from typing import Any, Literal, Optional import aiohttp @@ -46,6 +47,11 @@ from vllm.transformers_utils.tokenizer import get_tokenizer MILLISECONDS_TO_SECONDS_CONVERSION = 1000 +class TaskType(Enum): + GENERATION = "generation" + EMBEDDING = "embedding" + + @dataclass class BenchmarkMetrics: completed: int @@ -75,6 +81,16 @@ class BenchmarkMetrics: std_e2el_ms: float percentiles_e2el_ms: list[tuple[float, float]] +@dataclass +class EmbedBenchmarkMetrics: + completed: int + total_input: int + request_throughput: float + total_token_throughput :float + mean_e2el_ms: float + std_e2el_ms: float + median_e2el_ms: float + percentiles_e2el_ms: float def _get_current_request_rate( ramp_up_strategy: Optional[Literal["linear", "exponential"]], @@ -146,11 +162,11 @@ async def get_request( delay_ts = [] for request_index, request in enumerate(input_requests): current_request_rate = _get_current_request_rate(ramp_up_strategy, - ramp_up_start_rps, - ramp_up_end_rps, - request_index, - total_requests, - request_rate) + ramp_up_start_rps, + ramp_up_end_rps, + request_index, + total_requests, + request_rate) request_rates.append(current_request_rate) if current_request_rate == float("inf"): delay_ts.append(0) @@ -160,7 +176,7 @@ async def get_request( # Sample the request interval from the gamma distribution. # If burstiness is 1, it follows exponential distribution. delay_ts.append(np.random.gamma(shape=burstiness, scale=theta)) - + # Calculate the cumulative delay time from the first sent out requests. for i in range(1, len(delay_ts)): delay_ts[i] += delay_ts[i - 1] @@ -170,11 +186,11 @@ async def get_request( # logic would re-scale delay time to ensure the final delay_ts # align with target_total_delay_s. # - # NOTE: If we simply accumulate the random delta values - # from the gamma distribution, their sum would have 1-2% gap + # NOTE: If we simply accumulate the random delta values + # from the gamma distribution, their sum would have 1-2% gap # from target_total_delay_s. The purpose of the following logic is to - # close the gap for stablizing the throughput data - # from different random seeds. + # close the gap for stablizing the throughput data + # from different random seeds. target_total_delay_s = total_requests / request_rate normalize_factor = target_total_delay_s / delay_ts[-1] delay_ts = [delay * normalize_factor for delay in delay_ts] @@ -189,6 +205,51 @@ async def get_request( yield request, request_rates[request_index] +def calculate_metrics_for_embeddings( + outputs: list[RequestFuncOutput], + dur_s: float, + selected_percentiles: list[float] +) -> EmbedBenchmarkMetrics: + """Calculate the metrics for the embedding requests. + + Args: + outputs: The outputs of the requests. + dur_s: The duration of the benchmark. + selected_percentiles: The percentiles to select. + + Returns: + The calculated benchmark metrics. + """ + total_input = 0 + completed = 0 + e2els: list[float] = [] + for i in range(len(outputs)): + if outputs[i].success: + e2els.append(outputs[i].latency) + completed += 1 + total_input += outputs[i].prompt_len + + if completed == 0: + warnings.warn( + "All requests failed. This is likely due to a misconfiguration " + "on the benchmark arguments.", + stacklevel=2) + metrics = EmbedBenchmarkMetrics( + completed=completed, + total_input=total_input, + request_throughput=completed / dur_s, + total_token_throughput=total_input / dur_s, + mean_e2el_ms=np.mean(e2els or 0) * 1000, + std_e2el_ms=np.std(e2els or 0) * 1000, + median_e2el_ms=np.median(e2els or 0) * 1000, + percentiles_e2el_ms=[ + (p, np.percentile(e2els or 0, p) * 1000) + for p in selected_percentiles + ], + ) + return metrics + + def calculate_metrics( input_requests: list[SampleRequest], outputs: list[RequestFuncOutput], @@ -334,8 +395,16 @@ async def benchmark( ramp_up_end_rps: Optional[int] = None, ready_check_timeout_sec: int = 600, ): + task_type = ( + TaskType.EMBEDDING + if api_url.endswith("/v1/embeddings") + else TaskType.GENERATION + ) if endpoint_type in ASYNC_REQUEST_FUNCS: - request_func = ASYNC_REQUEST_FUNCS[endpoint_type] + if task_type == TaskType.EMBEDDING: + request_func = ASYNC_REQUEST_FUNCS["openai-embeddings"] + else: + request_func = ASYNC_REQUEST_FUNCS[endpoint_type] else: raise ValueError(f"Unknown endpoint_type: {endpoint_type}") @@ -421,8 +490,8 @@ async def benchmark( if profile_output.success: print("Profiler started") - distribution = ("Poisson process" if burstiness == 1.0 - else "Gamma distribution") + distribution = ("Poisson process" if burstiness == 1.0 + else "Gamma distribution") if ramp_up_strategy is not None: print(f"Traffic ramp-up strategy: {ramp_up_strategy}.") @@ -449,7 +518,7 @@ async def benchmark( session=session, pbar=pbar) async with semaphore: - return await request_func(request_func_input=request_func_input, + return await request_func(request_func_input=request_func_input, session=session, pbar=pbar) @@ -513,14 +582,22 @@ async def benchmark( benchmark_duration = time.perf_counter() - benchmark_start_time - metrics, actual_output_lens = calculate_metrics( - input_requests=input_requests, - outputs=outputs, - dur_s=benchmark_duration, - tokenizer=tokenizer, - selected_percentiles=selected_percentiles, - goodput_config_dict=goodput_config_dict, - ) + if task_type == TaskType.GENERATION: + metrics, actual_output_lens = calculate_metrics( + input_requests=input_requests, + outputs=outputs, + dur_s=benchmark_duration, + tokenizer=tokenizer, + selected_percentiles=selected_percentiles, + goodput_config_dict=goodput_config_dict, + ) + else: + metrics = calculate_metrics_for_embeddings( + outputs=outputs, + dur_s=benchmark_duration, + selected_percentiles=selected_percentiles, + ) + actual_output_lens = 0 print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='=')) print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) @@ -529,39 +606,55 @@ async def benchmark( max_concurrency)) if request_rate != float('inf'): print("{:<40} {:<10.2f}".format("Request rate configured (RPS):", - request_rate )) + request_rate)) print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration)) print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) - print("{:<40} {:<10}".format("Total generated tokens:", - metrics.total_output)) + if isinstance(metrics, BenchmarkMetrics): + print("{:<40} {:<10}".format( + "Total generated tokens:", metrics.total_output)) print("{:<40} {:<10.2f}".format("Request throughput (req/s):", metrics.request_throughput)) if goodput_config_dict: print("{:<40} {:<10.2f}".format("Request goodput (req/s):", metrics.request_goodput)) - print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", - metrics.output_throughput)) + if isinstance(metrics, BenchmarkMetrics): + print( + "{:<40} {:<10.2f}".format( + "Output token throughput (tok/s):", metrics.output_throughput + ) + ) print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):", metrics.total_token_throughput)) - result = { - "duration": benchmark_duration, - "completed": metrics.completed, - "total_input_tokens": metrics.total_input, - "total_output_tokens": metrics.total_output, - "request_throughput": metrics.request_throughput, - "request_goodput": - metrics.request_goodput if goodput_config_dict else None, - "output_throughput": metrics.output_throughput, - "total_token_throughput": metrics.total_token_throughput, - "input_lens": [output.prompt_len for output in outputs], - "output_lens": actual_output_lens, - "ttfts": [output.ttft for output in outputs], - "itls": [output.itl for output in outputs], - "generated_texts": [output.generated_text for output in outputs], - "errors": [output.error for output in outputs], - } + if isinstance(metrics, BenchmarkMetrics): + result = { + "duration": benchmark_duration, + "completed": metrics.completed, + "total_input_tokens": metrics.total_input, + "total_output_tokens": metrics.total_output, + "request_throughput": metrics.request_throughput, + "request_goodput": + metrics.request_goodput if goodput_config_dict else None, + "output_throughput": metrics.output_throughput, + "total_token_throughput": metrics.total_token_throughput, + "input_lens": [output.prompt_len for output in outputs], + "output_lens": actual_output_lens, + "ttfts": [output.ttft for output in outputs], + "itls": [output.itl for output in outputs], + "generated_texts": [output.generated_text for output in outputs], + "errors": [output.error for output in outputs], + } + else: + result = { + "duration": benchmark_duration, + "completed": metrics.completed, + "total_input_tokens": metrics.total_input, + "request_throughput": metrics.request_throughput, + "total_token_throughput": metrics.total_token_throughput, + "input_lens": [output.prompt_len for output in outputs], + "errors": [output.error for output in outputs], + } if rps_change_events: result["rps_change_events"] = rps_change_events @@ -598,10 +691,11 @@ async def benchmark( value)) result[f"p{p_word}_{metric_attribute_name}_ms"] = value - process_one_metric("ttft", "TTFT", "Time to First Token") - process_one_metric("tpot", "TPOT", - "Time per Output Token (excl. 1st token)") - process_one_metric("itl", "ITL", "Inter-token Latency") + if task_type == TaskType.GENERATION: + process_one_metric("ttft", "TTFT", "Time to First Token") + process_one_metric( + "tpot", "TPOT", "Time per Output Token (excl. 1st token)") + process_one_metric("itl", "ITL", "Inter-token Latency") process_one_metric("e2el", "E2EL", "End-to-end Latency") print("=" * 50) @@ -732,7 +826,8 @@ def add_cli_args(parser: argparse.ArgumentParser): "initiated, this argument will control how many are actually allowed " "to execute at a time. This means that when used in combination, the " "actual request rate may be lower than specified with --request-rate, " - "if the server is not processing requests fast enough to keep up.") + "if the server is not processing requests fast enough to keep up.", + ) parser.add_argument( "--model", @@ -743,8 +838,7 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--tokenizer", type=str, - help= - "Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 + help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 ) parser.add_argument("--use-beam-search", action="store_true") parser.add_argument( @@ -968,6 +1062,7 @@ def add_cli_args(parser: argparse.ArgumentParser): def main(args: argparse.Namespace) -> dict[str, Any]: return asyncio.run(main_async(args)) + async def main_async(args: argparse.Namespace) -> dict[str, Any]: print(args) random.seed(args.seed) @@ -1046,32 +1141,32 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: gc.freeze() benchmark_result = await benchmark( - endpoint_type=args.endpoint_type, - api_url=api_url, - base_url=base_url, - model_id=model_id, - model_name=model_name, - tokenizer=tokenizer, - input_requests=input_requests, - logprobs=args.logprobs, - request_rate=args.request_rate, - burstiness=args.burstiness, - disable_tqdm=args.disable_tqdm, - profile=args.profile, - selected_percentile_metrics=args.percentile_metrics.split(","), - selected_percentiles=[ - float(p) for p in args.metric_percentiles.split(",") - ], - ignore_eos=args.ignore_eos, - goodput_config_dict=goodput_config_dict, - max_concurrency=args.max_concurrency, - lora_modules=args.lora_modules, - extra_body=sampling_params, - ramp_up_strategy=args.ramp_up_strategy, - ramp_up_start_rps=args.ramp_up_start_rps, - ramp_up_end_rps=args.ramp_up_end_rps, - ready_check_timeout_sec=args.ready_check_timeout_sec, - ) + endpoint_type=args.endpoint_type, + api_url=api_url, + base_url=base_url, + model_id=model_id, + model_name=model_name, + tokenizer=tokenizer, + input_requests=input_requests, + logprobs=args.logprobs, + request_rate=args.request_rate, + burstiness=args.burstiness, + disable_tqdm=args.disable_tqdm, + profile=args.profile, + selected_percentile_metrics=args.percentile_metrics.split(","), + selected_percentiles=[ + float(p) for p in args.metric_percentiles.split(",") + ], + ignore_eos=args.ignore_eos, + goodput_config_dict=goodput_config_dict, + max_concurrency=args.max_concurrency, + lora_modules=args.lora_modules, + extra_body=sampling_params, + ramp_up_strategy=args.ramp_up_strategy, + ramp_up_start_rps=args.ramp_up_start_rps, + ramp_up_end_rps=args.ramp_up_end_rps, + ready_check_timeout_sec=args.ready_check_timeout_sec, + ) # Save config and results to json result_json: dict[str, Any] = {} @@ -1098,7 +1193,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: # Traffic result_json["request_rate"] = (args.request_rate if args.request_rate - < float("inf") else "inf") + < float("inf") else "inf") result_json["burstiness"] = args.burstiness result_json["max_concurrency"] = args.max_concurrency @@ -1132,7 +1227,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: if args.max_concurrency is not None else "") label = label or endpoint_type if args.ramp_up_strategy is not None: - file_name = f"{label}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa + file_name = f"{label}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa else: file_name = f"{label}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa if args.result_filename: @@ -1149,4 +1244,4 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: json.dump(result_json, outfile) save_to_pytorch_benchmark_format(args, result_json, file_name) - return result_json \ No newline at end of file + return result_json