From e56f44d9ec6f8f07f0d2c7936eea9bb2c0212bf2 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Tue, 27 May 2025 19:59:48 -0400 Subject: [PATCH] Support datasets in `vllm bench serve` and sync with benchmark_[serving,datasets].py (#18566) --- vllm/benchmarks/datasets.py | 185 ++++++++++- vllm/benchmarks/endpoint_request_func.py | 226 +++++++++++++- vllm/benchmarks/serve.py | 380 +++++++++++++++++------ 3 files changed, 691 insertions(+), 100 deletions(-) diff --git a/vllm/benchmarks/datasets.py b/vllm/benchmarks/datasets.py index 74a9b2b03391..712e83528f12 100644 --- a/vllm/benchmarks/datasets.py +++ b/vllm/benchmarks/datasets.py @@ -62,6 +62,7 @@ class SampleRequest: class BenchmarkDataset(ABC): DEFAULT_SEED = 0 + IS_MULTIMODAL = False def __init__( self, @@ -316,13 +317,15 @@ class RandomDataset(BenchmarkDataset): ) vocab_size = tokenizer.vocab_size + num_special_tokens = tokenizer.num_special_tokens_to_add() + real_input_len = input_len - num_special_tokens prefix_token_ids = (np.random.randint( 0, vocab_size, size=prefix_len).tolist() if prefix_len > 0 else []) # New sampling logic: [X * (1 - b), X * (1 + b)] - input_low = int(input_len * (1 - range_ratio)) - input_high = int(input_len * (1 + range_ratio)) + input_low = int(real_input_len * (1 - range_ratio)) + input_high = int(real_input_len * (1 + range_ratio)) output_low = int(output_len * (1 - range_ratio)) output_high = int(output_len * (1 + range_ratio)) @@ -345,6 +348,17 @@ class RandomDataset(BenchmarkDataset): vocab_size).tolist() token_sequence = prefix_token_ids + inner_seq prompt = tokenizer.decode(token_sequence) + # After decoding the prompt we have to encode and decode it again. + # This is done because in some cases N consecutive tokens + # give a string tokenized into != N number of tokens. + # For example for GPT2Tokenizer: + # [6880, 6881] -> ['Ġcalls', 'here'] -> + # [1650, 939, 486] -> ['Ġcall', 'sh', 'ere'] + # To avoid uncontrolled change of the prompt length, + # the encoded sequence is truncated before being decode again. + re_encoded_sequence = tokenizer.encode( + prompt, add_special_tokens=False)[:input_lens[i]] + prompt = tokenizer.decode(re_encoded_sequence) total_input_len = prefix_len + int(input_lens[i]) requests.append( SampleRequest( @@ -637,6 +651,7 @@ class ConversationDataset(HuggingFaceDataset): SUPPORTED_DATASET_PATHS = { 'lmms-lab/LLaVA-OneVision-Data', 'Aeala/ShareGPT_Vicuna_unfiltered' } + IS_MULTIMODAL = True def sample(self, tokenizer: PreTrainedTokenizerBase, @@ -701,6 +716,7 @@ class VisionArenaDataset(HuggingFaceDataset): "lmarena-ai/vision-arena-bench-v0.1": lambda x: x["turns"][0][0]["content"] } + IS_MULTIMODAL = True def sample( self, @@ -784,6 +800,64 @@ class InstructCoderDataset(HuggingFaceDataset): return sampled_requests +# ----------------------------------------------------------------------------- +# MT-Bench Dataset Implementation +# ----------------------------------------------------------------------------- + + +class MTBenchDataset(HuggingFaceDataset): + """ + MT-Bench Dataset. + https://huggingface.co/datasets/philschmid/mt-bench + + We create a single turn dataset for MT-Bench. + This is similar to Spec decoding benchmark setup in vLLM + https://github.com/vllm-project/vllm/blob/9d98ab5ec/examples/offline_inference/eagle.py#L14-L18 + """ # noqa: E501 + + DEFAULT_OUTPUT_LEN = 256 # avg len used in SD bench in vLLM + SUPPORTED_DATASET_PATHS = { + "philschmid/mt-bench", + } + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + **kwargs, + ) -> list: + output_len = (output_len + if output_len is not None else self.DEFAULT_OUTPUT_LEN) + sampled_requests = [] + + for item in self.data: + if len(sampled_requests) >= num_requests: + break + prompt = item["turns"][0] + + # apply template + prompt = tokenizer.apply_chat_template( + [{ + "role": "user", + "content": prompt + }], + add_generation_prompt=True, + tokenize=False, + ) + + prompt_len = len(tokenizer(prompt).input_ids) + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + )) + self.maybe_oversample_requests(sampled_requests, num_requests) + return sampled_requests + + # ----------------------------------------------------------------------------- # AIMO Dataset Implementation # ----------------------------------------------------------------------------- @@ -858,18 +932,18 @@ def _format_zeta_prompt( sample: dict, original_start_marker: str = "<|editable_region_start|>") -> dict: """Format the zeta prompt for the Next Edit Prediction (NEP) dataset. - - This function formats examples from the NEP dataset - into prompts and expected outputs. It could be + + This function formats examples from the NEP dataset + into prompts and expected outputs. It could be further extended to support more NEP datasets. - + Args: - sample: The dataset sample containing events, + sample: The dataset sample containing events, inputs, and outputs. - original_start_marker: The marker indicating the - start of the editable region. Defaults to + original_start_marker: The marker indicating the + start of the editable region. Defaults to "<|editable_region_start|>". - + Returns: A dictionary with the formatted prompts and expected outputs. """ @@ -919,3 +993,94 @@ class NextEditPredictionDataset(HuggingFaceDataset): break self.maybe_oversample_requests(samples, num_requests) return samples + + +# ----------------------------------------------------------------------------- +# ASR Dataset Implementation +# ----------------------------------------------------------------------------- + + +class ASRDataset(HuggingFaceDataset): + """ + Dataset class for processing a ASR dataset for transcription. + Tested on the following set: + + +----------------+----------------------------------------+--------------------------+-----------------------------+ + | Dataset | Domain | Speaking Style | hf-subset | + +----------------+----------------------------------------+--------------------------+-----------------------------+ + | TED-LIUM | TED talks | Oratory | release1, release2, release3| + | | | | release3-speaker-adaptation | + | VoxPopuli | European Parliament | Oratory | en, de, it, fr, ... | + | LibriSpeech | Audiobook | Narrated | "LIUM/tedlium" | + | GigaSpeech | Audiobook, podcast, YouTube | Narrated, spontaneous | xs, s, m, l, xl, dev, test | + | SPGISpeech | Financial meetings | Oratory, spontaneous | S, M, L, dev, test | + | AMI | Meetings | Spontaneous | ihm, sdm | + +----------------+----------------------------------------+--------------------------+-----------------------------+ + + """ # noqa: E501 + + SUPPORTED_DATASET_PATHS = { + "openslr/librispeech_asr", + "facebook/voxpopuli", + "LIUM/tedlium", + "edinburghcstr/ami", + "speechcolab/gigaspeech", + "kensho/spgispeech", + } + + DEFAULT_OUTPUT_LEN = 128 + IS_MULTIMODAL = True + + # TODO Whisper-specific. Abstract interface when more models are supported. + TRANSCRIPTION_PREAMBLE = ( + "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>") + skip_long_audios: bool = True + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + **kwargs, + ) -> list: + try: + import librosa + except ImportError as e: + raise ImportError( + "librosa is required for ASRDataset. Please install it " + "using `pip install librosa`.") from e + + output_len = (output_len + if output_len is not None else self.DEFAULT_OUTPUT_LEN) + prompt = ASRDataset.TRANSCRIPTION_PREAMBLE + prompt_len = len(tokenizer(prompt).input_ids) + sampled_requests = [] + skipped = 0 + for item in self.data: + if len(sampled_requests) >= num_requests: + break + audio = item["audio"] + y, sr = audio["array"], audio["sampling_rate"] + duration_s = librosa.get_duration(y=y, sr=sr) + # Whisper max supported duration + if self.skip_long_audios and duration_s > 30: + skipped += 1 + continue + + mm_content = {"audio": (y, sr)} + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + multi_modal_data=mm_content, + )) + if skipped: + logger.warning( + "%d samples discarded from dataset due to" + " their length being greater than" + " what Whisper supports.", + skipped, + ) + self.maybe_oversample_requests(sampled_requests, num_requests) + return sampled_requests diff --git a/vllm/benchmarks/endpoint_request_func.py b/vllm/benchmarks/endpoint_request_func.py index 32767a896070..a28630d50f26 100644 --- a/vllm/benchmarks/endpoint_request_func.py +++ b/vllm/benchmarks/endpoint_request_func.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 """The request function for API endpoints.""" +import io import json import os import sys @@ -24,11 +25,11 @@ class RequestFuncInput: output_len: int model: str model_name: Optional[str] = None - best_of: int = 1 logprobs: Optional[int] = None extra_body: Optional[dict] = None multi_modal_content: Optional[dict] = None ignore_eos: bool = False + language: Optional[str] = None @dataclass @@ -71,7 +72,7 @@ async def async_request_openai_completions( if request_func_input.model_name else request_func_input.model, "prompt": request_func_input.prompt, "temperature": 0.0, - "best_of": request_func_input.best_of, + "repetition_penalty": 1.0, "max_tokens": request_func_input.output_len, "logprobs": request_func_input.logprobs, "stream": True, @@ -154,7 +155,226 @@ async def async_request_openai_completions( return output +async def async_request_openai_chat_completions( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + api_url = request_func_input.api_url + assert api_url.endswith(("chat/completions", "profile")), ( + "OpenAI Chat Completions API URL must end with 'chat/completions'.") + + async with aiohttp.ClientSession(trust_env=True, + timeout=AIOHTTP_TIMEOUT) as session: + content = [{"type": "text", "text": request_func_input.prompt}] + if request_func_input.multi_modal_content: + content.append(request_func_input.multi_modal_content) + payload = { + "model": + request_func_input.model_name + if request_func_input.model_name else request_func_input.model, + "messages": [ + { + "role": "user", + "content": content + }, + ], + "temperature": + 0.0, + "max_completion_tokens": + request_func_input.output_len, + "stream": + True, + "stream_options": { + "include_usage": True, + }, + } + if request_func_input.ignore_eos: + payload["ignore_eos"] = request_func_input.ignore_eos + if request_func_input.extra_body: + payload.update(request_func_input.extra_body) + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + } + + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + generated_text = "" + ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post(url=api_url, json=payload, + headers=headers) as response: + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = chunk_bytes.decode("utf-8").removeprefix( + "data: ") + if chunk != "[DONE]": + timestamp = time.perf_counter() + data = json.loads(chunk) + + if choices := data.get("choices"): + content = choices[0]["delta"].get("content") + # First token + if ttft == 0.0: + ttft = timestamp - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append(timestamp - + most_recent_timestamp) + + generated_text += content or "" + elif usage := data.get("usage"): + output.output_tokens = usage.get( + "completion_tokens") + + most_recent_timestamp = timestamp + + output.generated_text = generated_text + output.success = True + output.latency = most_recent_timestamp - st + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + +async def async_request_openai_audio( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + # Lazy import without PlaceholderModule to avoid vllm dep. + import soundfile + + api_url = request_func_input.api_url + assert api_url.endswith(("transcriptions", "translations")), ( + "OpenAI Chat Completions API URL must end with 'transcriptions' ") + "or `translations`." + + async with aiohttp.ClientSession(trust_env=True, + timeout=AIOHTTP_TIMEOUT) as session: + content = [{"type": "text", "text": request_func_input.prompt}] + payload = { + "model": + request_func_input.model_name + if request_func_input.model_name else request_func_input.model, + "temperature": + 0.0, + "max_completion_tokens": + request_func_input.output_len, + "stream": + True, + "language": + "en", + # Flattened due to multipart/form-data + "stream_include_usage": + True, + "stream_continuous_usage_stats": + True, + } + if request_func_input.extra_body: + payload.update(request_func_input.extra_body) + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + } + + # Send audio file + def to_bytes(y, sr): + buffer = io.BytesIO() + soundfile.write(buffer, y, sr, format="WAV") + buffer.seek(0) + return buffer + + with to_bytes(*request_func_input.multi_modal_content["audio"]) as f: + form = aiohttp.FormData() + form.add_field("file", f, content_type="audio/wav") + for key, value in payload.items(): + form.add_field(key, str(value)) + + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + generated_text = "" + ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post(url=api_url, + data=form, + headers=headers) as response: + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = chunk_bytes.decode("utf-8").removeprefix( + "data: ") + if chunk != "[DONE]": + timestamp = time.perf_counter() + data = json.loads(chunk) + + if choices := data.get("choices"): + content = choices[0]["delta"].get( + "content") + # First token + if ttft == 0.0: + ttft = timestamp - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append( + timestamp - most_recent_timestamp) + + generated_text += content or "" + elif usage := data.get("usage"): + output.output_tokens = usage.get( + "completion_tokens") + + most_recent_timestamp = timestamp + + output.generated_text = generated_text + output.success = True + output.latency = most_recent_timestamp - st + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + # TODO: Add more request functions for different API protocols. ASYNC_REQUEST_FUNCS = { - "openai-comp": async_request_openai_completions, + "vllm": async_request_openai_completions, + "openai": async_request_openai_completions, + "openai-chat": async_request_openai_chat_completions, + "openai-audio": async_request_openai_audio, } + +OPENAI_COMPATIBLE_BACKENDS = [ + k for k, v in ASYNC_REQUEST_FUNCS.items() + if v in (async_request_openai_completions, + async_request_openai_chat_completions) +] diff --git a/vllm/benchmarks/serve.py b/vllm/benchmarks/serve.py index dc0ec3219486..040815e879f0 100644 --- a/vllm/benchmarks/serve.py +++ b/vllm/benchmarks/serve.py @@ -7,7 +7,7 @@ to launch the vLLM OpenAI API server: On the client side, run: vllm bench serve \ - --endpoint-type \ + --endpoint-type \ --label \ --model \ --dataset-name \ @@ -22,7 +22,7 @@ import os import random import time import warnings -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, Iterable from dataclasses import dataclass from datetime import datetime from typing import Any, Optional @@ -31,7 +31,14 @@ import numpy as np from tqdm.asyncio import tqdm from transformers import PreTrainedTokenizerBase +from vllm.benchmarks.datasets import (AIMODataset, ASRDataset, BurstGPTDataset, + ConversationDataset, HuggingFaceDataset, + InstructCoderDataset, MTBenchDataset, + NextEditPredictionDataset, RandomDataset, + SampleRequest, ShareGPTDataset, + SonnetDataset, VisionArenaDataset) from vllm.benchmarks.endpoint_request_func import (ASYNC_REQUEST_FUNCS, + OPENAI_COMPATIBLE_BACKENDS, RequestFuncInput, RequestFuncOutput) from vllm.benchmarks.utils import (convert_to_pytorch_benchmark_format, @@ -71,53 +78,18 @@ class BenchmarkMetrics: percentiles_e2el_ms: list[tuple[float, float]] -def sample_random_requests( - prefix_len: int, - input_len: int, - output_len: int, - num_prompts: int, - range_ratio: float, - tokenizer: PreTrainedTokenizerBase, -) -> list[tuple[str, int, int]]: - prefix_token_ids = np.random.randint(0, - tokenizer.vocab_size, - size=prefix_len).tolist() - - input_lens = np.random.randint( - int(input_len * range_ratio), - input_len + 1, - size=num_prompts, - ) - output_lens = np.random.randint( - int(output_len * range_ratio), - output_len + 1, - size=num_prompts, - ) - offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts) - input_requests = [] - for i in range(num_prompts): - prompt = tokenizer.decode(prefix_token_ids + - [(offsets[i] + i + j) % tokenizer.vocab_size - for j in range(input_lens[i])]) - - input_requests.append((prompt, int(prefix_len + input_lens[i]), - int(output_lens[i]), None)) - - return input_requests - - async def get_request( - input_requests: list[tuple[str, int, int]], + input_requests: list[SampleRequest], request_rate: float, burstiness: float = 1.0, -) -> AsyncGenerator[tuple[str, int, int], None]: +) -> AsyncGenerator[SampleRequest, None]: """ Asynchronously generates requests at a specified rate with OPTIONAL burstiness. Args: input_requests: - A list of input requests, each represented as a tuple. + A list of input requests, each represented as a SampleRequest. request_rate: The rate at which requests are generated (requests/s). burstiness (optional): @@ -129,7 +101,7 @@ async def get_request( in more bursty requests, while a higher burstiness value (burstiness > 1) results in a more uniform arrival of requests. """ - input_requests = iter(input_requests) + input_requests: Iterable[SampleRequest] = iter(input_requests) # Calculate scale parameter theta to maintain the desired request_rate. assert burstiness > 0, ( @@ -151,7 +123,7 @@ async def get_request( def calculate_metrics( - input_requests: list[tuple[str, int, int]], + input_requests: list[SampleRequest], outputs: list[RequestFuncOutput], dur_s: float, tokenizer: PreTrainedTokenizerBase, @@ -184,7 +156,7 @@ def calculate_metrics( if outputs[i].success: output_len = outputs[i].output_tokens - if output_len is None: + if not output_len: # We use the tokenizer to count the number of output tokens # for some serving backends instead of looking at # len(outputs[i].itl) since multiple output tokens may be @@ -194,7 +166,7 @@ def calculate_metrics( tokenizer(outputs[i].generated_text, add_special_tokens=False).input_ids) actual_output_lens.append(output_len) - total_input += input_requests[i][1] + total_input += input_requests[i].prompt_len tpot = 0 if output_len > 1: latency_minus_ttft = outputs[i].latency - outputs[i].ttft @@ -277,19 +249,19 @@ async def benchmark( model_id: str, model_name: str, tokenizer: PreTrainedTokenizerBase, - input_requests: list[tuple[str, int, int]], + input_requests: list[SampleRequest], logprobs: Optional[int], - best_of: int, request_rate: float, burstiness: float, disable_tqdm: bool, profile: bool, selected_percentile_metrics: list[str], - selected_percentiles: list[str], + selected_percentiles: list[float], ignore_eos: bool, goodput_config_dict: dict[str, float], max_concurrency: Optional[int], - lora_modules: Optional[list[str]], + lora_modules: Optional[Iterable[str]], + extra_body: Optional[dict], ): if endpoint_type in ASYNC_REQUEST_FUNCS: request_func = ASYNC_REQUEST_FUNCS[endpoint_type] @@ -298,11 +270,13 @@ async def benchmark( print("Starting initial single prompt test run...") test_prompt, test_prompt_len, test_output_len, test_mm_content = ( - input_requests[0]) - if endpoint_type != "openai-chat" and test_mm_content is not None: - # multi-modal benchmark is only available on OpenAI Chat endpoint. - raise ValueError("Multi-modal content is only supported on " - "'openai-chat' endpoint_type.") + input_requests[0].prompt, + input_requests[0].prompt_len, + input_requests[0].expected_output_len, + input_requests[0].multi_modal_data, + ) + + assert test_mm_content is None or isinstance(test_mm_content, dict) test_input = RequestFuncInput( model=model_id, model_name=model_name, @@ -311,9 +285,9 @@ async def benchmark( prompt_len=test_prompt_len, output_len=test_output_len, logprobs=logprobs, - best_of=best_of, multi_modal_content=test_mm_content, ignore_eos=ignore_eos, + extra_body=extra_body, ) test_output = await request_func(request_func_input=test_input) @@ -338,9 +312,9 @@ async def benchmark( prompt_len=test_prompt_len, output_len=test_output_len, logprobs=logprobs, - best_of=best_of, multi_modal_content=test_mm_content, - ignore_eos=ignore_eos) + ignore_eos=ignore_eos, + extra_body=extra_body) profile_output = await request_func(request_func_input=profile_input) if profile_output.success: print("Profiler started") @@ -374,7 +348,12 @@ async def benchmark( benchmark_start_time = time.perf_counter() tasks: list[asyncio.Task] = [] async for request in get_request(input_requests, request_rate, burstiness): - prompt, prompt_len, output_len, mm_content = request + prompt, prompt_len, output_len, mm_content = ( + request.prompt, + request.prompt_len, + request.expected_output_len, + request.multi_modal_data, + ) req_model_id, req_model_name = model_id, model_name if lora_modules: req_lora_module = next(lora_modules) @@ -387,9 +366,9 @@ async def benchmark( prompt_len=prompt_len, output_len=output_len, logprobs=logprobs, - best_of=best_of, multi_modal_content=mm_content, - ignore_eos=ignore_eos) + ignore_eos=ignore_eos, + extra_body=extra_body) tasks.append( asyncio.create_task( limited_request_func(request_func_input=request_func_input, @@ -405,7 +384,6 @@ async def benchmark( prompt_len=test_prompt_len, output_len=test_output_len, logprobs=logprobs, - best_of=best_of, ) profile_output = await request_func(request_func_input=profile_input) if profile_output.success: @@ -567,7 +545,7 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--endpoint-type", type=str, - default="openai-comp", + default="openai", choices=list(ASYNC_REQUEST_FUNCS.keys()), ) parser.add_argument( @@ -596,9 +574,16 @@ def add_cli_args(parser: argparse.ArgumentParser): "--dataset-name", type=str, default="random", - choices=["random"], + choices=["sharegpt", "burstgpt", "sonnet", "random", "hf"], help="Name of the dataset to benchmark on.", ) + parser.add_argument( + "--dataset-path", + type=str, + default=None, + help="Path to the sharegpt/sonnet dataset. " + "Or the huggingface dataset ID if using HF dataset.", + ) parser.add_argument( "--max-concurrency", type=int, @@ -624,13 +609,6 @@ def add_cli_args(parser: argparse.ArgumentParser): help= "Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 ) - parser.add_argument( - "--best-of", - type=int, - default=1, - help="Generates `best_of` sequences per prompt and " - "returns the best one.", - ) parser.add_argument("--use-beam-search", action="store_true") parser.add_argument( "--num-prompts", @@ -691,6 +669,17 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Specify to save benchmark results to a json file", ) + parser.add_argument( + "--save-detailed", + action="store_true", + help="When saving the results, whether to include per request " + "information such as response, error, ttfs, tpots, etc.", + ) + parser.add_argument( + "--append-result", + action="store_true", + help="Append the benchmark result to the existing json file.", + ) parser.add_argument( "--metadata", metavar="KEY=VALUE", @@ -733,6 +722,7 @@ def add_cli_args(parser: argparse.ArgumentParser): default="99", help="Comma-separated list of percentiles for selected metrics. " "To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". " + "Default value is \"99\"." "Use \"--percentile-metrics\" to select metrics.", ) parser.add_argument( @@ -745,7 +735,41 @@ def add_cli_args(parser: argparse.ArgumentParser): "separated by spaces. Allowed request level metric names are " "\"ttft\", \"tpot\", \"e2el\". For more context on the definition of " "goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 " - "and the blog: https://hao-ai-lab.github.io/blogs/distserve") + "and the blog: https://hao-ai-lab.github.io/blogs/distserve", + ) + + # group for dataset specific arguments + sonnet_group = parser.add_argument_group("sonnet dataset options") + sonnet_group.add_argument( + "--sonnet-input-len", + type=int, + default=550, + help= + "Number of input tokens per request, used only for sonnet dataset.", + ) + sonnet_group.add_argument( + "--sonnet-output-len", + type=int, + default=150, + help= + "Number of output tokens per request, used only for sonnet dataset.", + ) + sonnet_group.add_argument( + "--sonnet-prefix-len", + type=int, + default=200, + help= + "Number of prefix tokens per request, used only for sonnet dataset.", + ) + + sharegpt_group = parser.add_argument_group("sharegpt dataset options") + sharegpt_group.add_argument( + "--sharegpt-output-len", + type=int, + default=None, + help="Output length for each request. Overrides the output length " + "from the ShareGPT dataset.", + ) random_group = parser.add_argument_group("random dataset options") random_group.add_argument( @@ -765,9 +789,11 @@ def add_cli_args(parser: argparse.ArgumentParser): random_group.add_argument( "--random-range-ratio", type=float, - default=1.0, - help="Range of sampled ratio of input/output length, " - "used only for random sampling.", + default=0.0, + help="Range ratio for sampling input/output length, " + "used only for random sampling. Must be in the range [0, 1) to define " + "a symmetric sampling range" + "[length * (1 - range_ratio), length * (1 + range_ratio)].", ) random_group.add_argument( "--random-prefix-len", @@ -778,6 +804,54 @@ def add_cli_args(parser: argparse.ArgumentParser): " request is [random-prefix-len, " " random-prefix-len + random-prefix-len * random-range-ratio).") + hf_group = parser.add_argument_group("hf dataset options") + hf_group.add_argument("--hf-subset", + type=str, + default=None, + help="Subset of the HF dataset.") + hf_group.add_argument("--hf-split", + type=str, + default=None, + help="Split of the HF dataset.") + hf_group.add_argument( + "--hf-output-len", + type=int, + default=None, + help="Output length for each request. Overrides the output lengths " + "from the sampled HF dataset.", + ) + + sampling_group = parser.add_argument_group("sampling parameters") + sampling_group.add_argument( + "--top-p", + type=float, + default=None, + help="Top-p sampling parameter. Only has effect on " + "openai-compatible backends.", + ) + sampling_group.add_argument( + "--top-k", + type=int, + default=None, + help="Top-k sampling parameter. Only has effect on " + "openai-compatible backends.", + ) + sampling_group.add_argument( + "--min-p", + type=float, + default=None, + help="Min-p sampling parameter. Only has effect on " + "openai-compatible backends.", + ) + sampling_group.add_argument( + "--temperature", + type=float, + default=None, + help="Temperature sampling parameter. Only has effect on " + "openai-compatible backends. If not specified, default to greedy " + "decoding (i.e. temperature==0.0).", + ) + parser.add_argument( '--tokenizer-mode', type=str, @@ -826,27 +900,142 @@ def main(args: argparse.Namespace): tokenizer = get_tokenizer(tokenizer_id, tokenizer_mode=tokenizer_mode, trust_remote_code=args.trust_remote_code) - # TODO: This should be refactored to use the benchmark_dataset.py - # in later PRs. + if args.dataset_name is None: raise ValueError( "Please specify '--dataset-name' and the corresponding " "'--dataset-path' if required.") - elif args.dataset_name == "random": - input_requests = sample_random_requests( - prefix_len=args.random_prefix_len, - input_len=args.random_input_len, - output_len=args.random_output_len, - num_prompts=args.num_prompts, - range_ratio=args.random_range_ratio, + + if args.dataset_name == "sonnet": + dataset = SonnetDataset(dataset_path=args.dataset_path) + # For the "sonnet" dataset, formatting depends on the backend. + if args.backend == "openai-chat": + input_requests = dataset.sample( + num_requests=args.num_prompts, + input_len=args.sonnet_input_len, + output_len=args.sonnet_output_len, + prefix_len=args.sonnet_prefix_len, + tokenizer=tokenizer, + return_prompt_formatted=False, + ) + else: + assert tokenizer.chat_template or tokenizer.default_chat_template, ( + "Tokenizer/model must have chat template for sonnet dataset.") + input_requests = dataset.sample( + num_requests=args.num_prompts, + input_len=args.sonnet_input_len, + output_len=args.sonnet_output_len, + prefix_len=args.sonnet_prefix_len, + tokenizer=tokenizer, + return_prompt_formatted=True, + ) + + elif args.dataset_name == "hf": + # all following datasets are implemented from the + # HuggingFaceDataset base class + if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS: + dataset_class = VisionArenaDataset + args.hf_split = "train" + args.hf_subset = None + elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS: + dataset_class = InstructCoderDataset + args.hf_split = "train" + elif args.dataset_path in MTBenchDataset.SUPPORTED_DATASET_PATHS: + dataset_class = MTBenchDataset + args.hf_split = "train" + elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS: + dataset_class = ConversationDataset + args.hf_split = "train" + elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS: + dataset_class = AIMODataset + args.hf_split = "train" + elif args.dataset_path in NextEditPredictionDataset.SUPPORTED_DATASET_PATHS: # noqa: E501 + dataset_class = NextEditPredictionDataset + args.hf_split = "train" + elif args.dataset_path in ASRDataset.SUPPORTED_DATASET_PATHS: + dataset_class = ASRDataset + args.hf_split = "train" + else: + supported_datasets = set([ + dataset_name for cls in HuggingFaceDataset.__subclasses__() + for dataset_name in cls.SUPPORTED_DATASET_PATHS + ]) + raise ValueError( + f"Unsupported dataset path: {args.dataset_path}. " + "Huggingface dataset only supports dataset_path" + f" from one of following: {supported_datasets}. " + "Please consider contributing if you would " + "like to add support for additional dataset formats.") + + if dataset_class.IS_MULTIMODAL and endpoint_type not in [ + "openai-chat", + "openai-audio", + ]: + # multi-modal benchmark is only available on OpenAI Chat backend. + raise ValueError( + "Multi-modal content is only supported on 'openai-chat' and " + "'openai-audio' backend.") + input_requests = dataset_class( + dataset_path=args.dataset_path, + dataset_subset=args.hf_subset, + dataset_split=args.hf_split, + random_seed=args.seed, + ).sample( + num_requests=args.num_prompts, tokenizer=tokenizer, + output_len=args.hf_output_len, ) else: - raise ValueError(f"Unknown dataset: {args.dataset_name}") + # For datasets that follow a similar structure, use a mapping. + dataset_mapping = { + "sharegpt": + lambda: ShareGPTDataset(random_seed=args.seed, + dataset_path=args.dataset_path).sample( + tokenizer=tokenizer, + num_requests=args.num_prompts, + output_len=args.sharegpt_output_len, + ), + "burstgpt": + lambda: BurstGPTDataset(random_seed=args.seed, + dataset_path=args.dataset_path). + sample(tokenizer=tokenizer, num_requests=args.num_prompts), + "random": + lambda: RandomDataset(dataset_path=args.dataset_path).sample( + tokenizer=tokenizer, + num_requests=args.num_prompts, + prefix_len=args.random_prefix_len, + input_len=args.random_input_len, + output_len=args.random_output_len, + range_ratio=args.random_range_ratio, + ), + } + try: + input_requests = dataset_mapping[args.dataset_name]() + except KeyError as err: + raise ValueError(f"Unknown dataset: {args.dataset_name}") from err goodput_config_dict = check_goodput_args(args) + # Collect the sampling parameters. + sampling_params = { + k: v + for k, v in { + "top_p": args.top_p, + "top_k": args.top_k, + "min_p": args.min_p, + "temperature": args.temperature, + }.items() if v is not None + } + + # Sampling parameters are only supported by openai-compatible backend. + if sampling_params and args.backend not in OPENAI_COMPATIBLE_BACKENDS: + raise ValueError("Sampling parameters are only supported by " + "openai-compatible backends.") + + if "temperature" not in sampling_params: + sampling_params["temperature"] = 0.0 # Default to greedy decoding. + # Avoid GC processing "static" data - reduce pause times. gc.collect() gc.freeze() @@ -861,7 +1050,6 @@ def main(args: argparse.Namespace): tokenizer=tokenizer, input_requests=input_requests, logprobs=args.logprobs, - best_of=args.best_of, request_rate=args.request_rate, burstiness=args.burstiness, disable_tqdm=args.disable_tqdm, @@ -874,10 +1062,11 @@ def main(args: argparse.Namespace): goodput_config_dict=goodput_config_dict, max_concurrency=args.max_concurrency, lora_modules=args.lora_modules, + extra_body=sampling_params, )) # Save config and results to json - if args.save_result: + if args.save_result or args.append_result: result_json: dict[str, Any] = {} # Setup @@ -887,7 +1076,6 @@ def main(args: argparse.Namespace): result_json["label"] = label result_json["model_id"] = model_id result_json["tokenizer_id"] = tokenizer_id - result_json["best_of"] = args.best_of result_json["num_prompts"] = args.num_prompts # Metadata @@ -910,6 +1098,19 @@ def main(args: argparse.Namespace): # Merge with benchmark result result_json = {**result_json, **benchmark_result} + if not args.save_detailed: + # Remove fields with too many data points + for field in [ + "input_lens", + "output_lens", + "ttfts", + "itls", + "generated_texts", + "errors", + ]: + if field in result_json: + del result_json[field] + # Save to file base_model_id = model_id.split("/")[-1] max_concurrency_str = (f"-concurrency{args.max_concurrency}" @@ -920,6 +1121,11 @@ def main(args: argparse.Namespace): file_name = args.result_filename if args.result_dir: file_name = os.path.join(args.result_dir, file_name) - with open(file_name, "w", encoding='utf-8') as outfile: + with open(file_name, + mode="a+" if args.append_result else "w", + encoding="utf-8") as outfile: + # Append a newline. + if args.append_result and outfile.tell() != 0: + outfile.write("\n") json.dump(result_json, outfile) save_to_pytorch_benchmark_format(args, result_json, file_name)