Support datasets in vllm bench serve and sync with benchmark_[serving,datasets].py (#18566)

This commit is contained in:
Michael Goin 2025-05-27 19:59:48 -04:00 committed by GitHub
parent e0cbad4e30
commit e56f44d9ec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 691 additions and 100 deletions

View File

@ -62,6 +62,7 @@ class SampleRequest:
class BenchmarkDataset(ABC):
DEFAULT_SEED = 0
IS_MULTIMODAL = False
def __init__(
self,
@ -316,13 +317,15 @@ class RandomDataset(BenchmarkDataset):
)
vocab_size = tokenizer.vocab_size
num_special_tokens = tokenizer.num_special_tokens_to_add()
real_input_len = input_len - num_special_tokens
prefix_token_ids = (np.random.randint(
0, vocab_size, size=prefix_len).tolist() if prefix_len > 0 else [])
# New sampling logic: [X * (1 - b), X * (1 + b)]
input_low = int(input_len * (1 - range_ratio))
input_high = int(input_len * (1 + range_ratio))
input_low = int(real_input_len * (1 - range_ratio))
input_high = int(real_input_len * (1 + range_ratio))
output_low = int(output_len * (1 - range_ratio))
output_high = int(output_len * (1 + range_ratio))
@ -345,6 +348,17 @@ class RandomDataset(BenchmarkDataset):
vocab_size).tolist()
token_sequence = prefix_token_ids + inner_seq
prompt = tokenizer.decode(token_sequence)
# After decoding the prompt we have to encode and decode it again.
# This is done because in some cases N consecutive tokens
# give a string tokenized into != N number of tokens.
# For example for GPT2Tokenizer:
# [6880, 6881] -> ['Ġcalls', 'here'] ->
# [1650, 939, 486] -> ['Ġcall', 'sh', 'ere']
# To avoid uncontrolled change of the prompt length,
# the encoded sequence is truncated before being decode again.
re_encoded_sequence = tokenizer.encode(
prompt, add_special_tokens=False)[:input_lens[i]]
prompt = tokenizer.decode(re_encoded_sequence)
total_input_len = prefix_len + int(input_lens[i])
requests.append(
SampleRequest(
@ -637,6 +651,7 @@ class ConversationDataset(HuggingFaceDataset):
SUPPORTED_DATASET_PATHS = {
'lmms-lab/LLaVA-OneVision-Data', 'Aeala/ShareGPT_Vicuna_unfiltered'
}
IS_MULTIMODAL = True
def sample(self,
tokenizer: PreTrainedTokenizerBase,
@ -701,6 +716,7 @@ class VisionArenaDataset(HuggingFaceDataset):
"lmarena-ai/vision-arena-bench-v0.1":
lambda x: x["turns"][0][0]["content"]
}
IS_MULTIMODAL = True
def sample(
self,
@ -784,6 +800,64 @@ class InstructCoderDataset(HuggingFaceDataset):
return sampled_requests
# -----------------------------------------------------------------------------
# MT-Bench Dataset Implementation
# -----------------------------------------------------------------------------
class MTBenchDataset(HuggingFaceDataset):
"""
MT-Bench Dataset.
https://huggingface.co/datasets/philschmid/mt-bench
We create a single turn dataset for MT-Bench.
This is similar to Spec decoding benchmark setup in vLLM
https://github.com/vllm-project/vllm/blob/9d98ab5ec/examples/offline_inference/eagle.py#L14-L18
""" # noqa: E501
DEFAULT_OUTPUT_LEN = 256 # avg len used in SD bench in vLLM
SUPPORTED_DATASET_PATHS = {
"philschmid/mt-bench",
}
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
**kwargs,
) -> list:
output_len = (output_len
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
sampled_requests = []
for item in self.data:
if len(sampled_requests) >= num_requests:
break
prompt = item["turns"][0]
# apply template
prompt = tokenizer.apply_chat_template(
[{
"role": "user",
"content": prompt
}],
add_generation_prompt=True,
tokenize=False,
)
prompt_len = len(tokenizer(prompt).input_ids)
sampled_requests.append(
SampleRequest(
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
))
self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests
# -----------------------------------------------------------------------------
# AIMO Dataset Implementation
# -----------------------------------------------------------------------------
@ -858,18 +932,18 @@ def _format_zeta_prompt(
sample: dict,
original_start_marker: str = "<|editable_region_start|>") -> dict:
"""Format the zeta prompt for the Next Edit Prediction (NEP) dataset.
This function formats examples from the NEP dataset
into prompts and expected outputs. It could be
This function formats examples from the NEP dataset
into prompts and expected outputs. It could be
further extended to support more NEP datasets.
Args:
sample: The dataset sample containing events,
sample: The dataset sample containing events,
inputs, and outputs.
original_start_marker: The marker indicating the
start of the editable region. Defaults to
original_start_marker: The marker indicating the
start of the editable region. Defaults to
"<|editable_region_start|>".
Returns:
A dictionary with the formatted prompts and expected outputs.
"""
@ -919,3 +993,94 @@ class NextEditPredictionDataset(HuggingFaceDataset):
break
self.maybe_oversample_requests(samples, num_requests)
return samples
# -----------------------------------------------------------------------------
# ASR Dataset Implementation
# -----------------------------------------------------------------------------
class ASRDataset(HuggingFaceDataset):
"""
Dataset class for processing a ASR dataset for transcription.
Tested on the following set:
+----------------+----------------------------------------+--------------------------+-----------------------------+
| Dataset | Domain | Speaking Style | hf-subset |
+----------------+----------------------------------------+--------------------------+-----------------------------+
| TED-LIUM | TED talks | Oratory | release1, release2, release3|
| | | | release3-speaker-adaptation |
| VoxPopuli | European Parliament | Oratory | en, de, it, fr, ... |
| LibriSpeech | Audiobook | Narrated | "LIUM/tedlium" |
| GigaSpeech | Audiobook, podcast, YouTube | Narrated, spontaneous | xs, s, m, l, xl, dev, test |
| SPGISpeech | Financial meetings | Oratory, spontaneous | S, M, L, dev, test |
| AMI | Meetings | Spontaneous | ihm, sdm |
+----------------+----------------------------------------+--------------------------+-----------------------------+
""" # noqa: E501
SUPPORTED_DATASET_PATHS = {
"openslr/librispeech_asr",
"facebook/voxpopuli",
"LIUM/tedlium",
"edinburghcstr/ami",
"speechcolab/gigaspeech",
"kensho/spgispeech",
}
DEFAULT_OUTPUT_LEN = 128
IS_MULTIMODAL = True
# TODO Whisper-specific. Abstract interface when more models are supported.
TRANSCRIPTION_PREAMBLE = (
"<|startoftranscript|><|en|><|transcribe|><|notimestamps|>")
skip_long_audios: bool = True
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
output_len: Optional[int] = None,
**kwargs,
) -> list:
try:
import librosa
except ImportError as e:
raise ImportError(
"librosa is required for ASRDataset. Please install it "
"using `pip install librosa`.") from e
output_len = (output_len
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
prompt = ASRDataset.TRANSCRIPTION_PREAMBLE
prompt_len = len(tokenizer(prompt).input_ids)
sampled_requests = []
skipped = 0
for item in self.data:
if len(sampled_requests) >= num_requests:
break
audio = item["audio"]
y, sr = audio["array"], audio["sampling_rate"]
duration_s = librosa.get_duration(y=y, sr=sr)
# Whisper max supported duration
if self.skip_long_audios and duration_s > 30:
skipped += 1
continue
mm_content = {"audio": (y, sr)}
sampled_requests.append(
SampleRequest(
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
multi_modal_data=mm_content,
))
if skipped:
logger.warning(
"%d samples discarded from dataset due to"
" their length being greater than"
" what Whisper supports.",
skipped,
)
self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
"""The request function for API endpoints."""
import io
import json
import os
import sys
@ -24,11 +25,11 @@ class RequestFuncInput:
output_len: int
model: str
model_name: Optional[str] = None
best_of: int = 1
logprobs: Optional[int] = None
extra_body: Optional[dict] = None
multi_modal_content: Optional[dict] = None
ignore_eos: bool = False
language: Optional[str] = None
@dataclass
@ -71,7 +72,7 @@ async def async_request_openai_completions(
if request_func_input.model_name else request_func_input.model,
"prompt": request_func_input.prompt,
"temperature": 0.0,
"best_of": request_func_input.best_of,
"repetition_penalty": 1.0,
"max_tokens": request_func_input.output_len,
"logprobs": request_func_input.logprobs,
"stream": True,
@ -154,7 +155,226 @@ async def async_request_openai_completions(
return output
async def async_request_openai_chat_completions(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
assert api_url.endswith(("chat/completions", "profile")), (
"OpenAI Chat Completions API URL must end with 'chat/completions'.")
async with aiohttp.ClientSession(trust_env=True,
timeout=AIOHTTP_TIMEOUT) as session:
content = [{"type": "text", "text": request_func_input.prompt}]
if request_func_input.multi_modal_content:
content.append(request_func_input.multi_modal_content)
payload = {
"model":
request_func_input.model_name
if request_func_input.model_name else request_func_input.model,
"messages": [
{
"role": "user",
"content": content
},
],
"temperature":
0.0,
"max_completion_tokens":
request_func_input.output_len,
"stream":
True,
"stream_options": {
"include_usage": True,
},
}
if request_func_input.ignore_eos:
payload["ignore_eos"] = request_func_input.ignore_eos
if request_func_input.extra_body:
payload.update(request_func_input.extra_body)
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
}
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
generated_text = ""
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(url=api_url, json=payload,
headers=headers) as response:
if response.status == 200:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
chunk = chunk_bytes.decode("utf-8").removeprefix(
"data: ")
if chunk != "[DONE]":
timestamp = time.perf_counter()
data = json.loads(chunk)
if choices := data.get("choices"):
content = choices[0]["delta"].get("content")
# First token
if ttft == 0.0:
ttft = timestamp - st
output.ttft = ttft
# Decoding phase
else:
output.itl.append(timestamp -
most_recent_timestamp)
generated_text += content or ""
elif usage := data.get("usage"):
output.output_tokens = usage.get(
"completion_tokens")
most_recent_timestamp = timestamp
output.generated_text = generated_text
output.success = True
output.latency = most_recent_timestamp - st
else:
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
if pbar:
pbar.update(1)
return output
async def async_request_openai_audio(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
# Lazy import without PlaceholderModule to avoid vllm dep.
import soundfile
api_url = request_func_input.api_url
assert api_url.endswith(("transcriptions", "translations")), (
"OpenAI Chat Completions API URL must end with 'transcriptions' ")
"or `translations`."
async with aiohttp.ClientSession(trust_env=True,
timeout=AIOHTTP_TIMEOUT) as session:
content = [{"type": "text", "text": request_func_input.prompt}]
payload = {
"model":
request_func_input.model_name
if request_func_input.model_name else request_func_input.model,
"temperature":
0.0,
"max_completion_tokens":
request_func_input.output_len,
"stream":
True,
"language":
"en",
# Flattened due to multipart/form-data
"stream_include_usage":
True,
"stream_continuous_usage_stats":
True,
}
if request_func_input.extra_body:
payload.update(request_func_input.extra_body)
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
}
# Send audio file
def to_bytes(y, sr):
buffer = io.BytesIO()
soundfile.write(buffer, y, sr, format="WAV")
buffer.seek(0)
return buffer
with to_bytes(*request_func_input.multi_modal_content["audio"]) as f:
form = aiohttp.FormData()
form.add_field("file", f, content_type="audio/wav")
for key, value in payload.items():
form.add_field(key, str(value))
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
generated_text = ""
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(url=api_url,
data=form,
headers=headers) as response:
if response.status == 200:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
chunk = chunk_bytes.decode("utf-8").removeprefix(
"data: ")
if chunk != "[DONE]":
timestamp = time.perf_counter()
data = json.loads(chunk)
if choices := data.get("choices"):
content = choices[0]["delta"].get(
"content")
# First token
if ttft == 0.0:
ttft = timestamp - st
output.ttft = ttft
# Decoding phase
else:
output.itl.append(
timestamp - most_recent_timestamp)
generated_text += content or ""
elif usage := data.get("usage"):
output.output_tokens = usage.get(
"completion_tokens")
most_recent_timestamp = timestamp
output.generated_text = generated_text
output.success = True
output.latency = most_recent_timestamp - st
else:
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
if pbar:
pbar.update(1)
return output
# TODO: Add more request functions for different API protocols.
ASYNC_REQUEST_FUNCS = {
"openai-comp": async_request_openai_completions,
"vllm": async_request_openai_completions,
"openai": async_request_openai_completions,
"openai-chat": async_request_openai_chat_completions,
"openai-audio": async_request_openai_audio,
}
OPENAI_COMPATIBLE_BACKENDS = [
k for k, v in ASYNC_REQUEST_FUNCS.items()
if v in (async_request_openai_completions,
async_request_openai_chat_completions)
]

View File

@ -7,7 +7,7 @@ to launch the vLLM OpenAI API server:
On the client side, run:
vllm bench serve \
--endpoint-type <endpoint_type. Default 'openi-comp'> \
--endpoint-type <endpoint_type. Default 'openai'> \
--label <benchmark result label. Default using endpoint_type> \
--model <your_model> \
--dataset-name <dataset_name. Default 'random'> \
@ -22,7 +22,7 @@ import os
import random
import time
import warnings
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Iterable
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Optional
@ -31,7 +31,14 @@ import numpy as np
from tqdm.asyncio import tqdm
from transformers import PreTrainedTokenizerBase
from vllm.benchmarks.datasets import (AIMODataset, ASRDataset, BurstGPTDataset,
ConversationDataset, HuggingFaceDataset,
InstructCoderDataset, MTBenchDataset,
NextEditPredictionDataset, RandomDataset,
SampleRequest, ShareGPTDataset,
SonnetDataset, VisionArenaDataset)
from vllm.benchmarks.endpoint_request_func import (ASYNC_REQUEST_FUNCS,
OPENAI_COMPATIBLE_BACKENDS,
RequestFuncInput,
RequestFuncOutput)
from vllm.benchmarks.utils import (convert_to_pytorch_benchmark_format,
@ -71,53 +78,18 @@ class BenchmarkMetrics:
percentiles_e2el_ms: list[tuple[float, float]]
def sample_random_requests(
prefix_len: int,
input_len: int,
output_len: int,
num_prompts: int,
range_ratio: float,
tokenizer: PreTrainedTokenizerBase,
) -> list[tuple[str, int, int]]:
prefix_token_ids = np.random.randint(0,
tokenizer.vocab_size,
size=prefix_len).tolist()
input_lens = np.random.randint(
int(input_len * range_ratio),
input_len + 1,
size=num_prompts,
)
output_lens = np.random.randint(
int(output_len * range_ratio),
output_len + 1,
size=num_prompts,
)
offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts)
input_requests = []
for i in range(num_prompts):
prompt = tokenizer.decode(prefix_token_ids +
[(offsets[i] + i + j) % tokenizer.vocab_size
for j in range(input_lens[i])])
input_requests.append((prompt, int(prefix_len + input_lens[i]),
int(output_lens[i]), None))
return input_requests
async def get_request(
input_requests: list[tuple[str, int, int]],
input_requests: list[SampleRequest],
request_rate: float,
burstiness: float = 1.0,
) -> AsyncGenerator[tuple[str, int, int], None]:
) -> AsyncGenerator[SampleRequest, None]:
"""
Asynchronously generates requests at a specified rate
with OPTIONAL burstiness.
Args:
input_requests:
A list of input requests, each represented as a tuple.
A list of input requests, each represented as a SampleRequest.
request_rate:
The rate at which requests are generated (requests/s).
burstiness (optional):
@ -129,7 +101,7 @@ async def get_request(
in more bursty requests, while a higher burstiness value
(burstiness > 1) results in a more uniform arrival of requests.
"""
input_requests = iter(input_requests)
input_requests: Iterable[SampleRequest] = iter(input_requests)
# Calculate scale parameter theta to maintain the desired request_rate.
assert burstiness > 0, (
@ -151,7 +123,7 @@ async def get_request(
def calculate_metrics(
input_requests: list[tuple[str, int, int]],
input_requests: list[SampleRequest],
outputs: list[RequestFuncOutput],
dur_s: float,
tokenizer: PreTrainedTokenizerBase,
@ -184,7 +156,7 @@ def calculate_metrics(
if outputs[i].success:
output_len = outputs[i].output_tokens
if output_len is None:
if not output_len:
# We use the tokenizer to count the number of output tokens
# for some serving backends instead of looking at
# len(outputs[i].itl) since multiple output tokens may be
@ -194,7 +166,7 @@ def calculate_metrics(
tokenizer(outputs[i].generated_text,
add_special_tokens=False).input_ids)
actual_output_lens.append(output_len)
total_input += input_requests[i][1]
total_input += input_requests[i].prompt_len
tpot = 0
if output_len > 1:
latency_minus_ttft = outputs[i].latency - outputs[i].ttft
@ -277,19 +249,19 @@ async def benchmark(
model_id: str,
model_name: str,
tokenizer: PreTrainedTokenizerBase,
input_requests: list[tuple[str, int, int]],
input_requests: list[SampleRequest],
logprobs: Optional[int],
best_of: int,
request_rate: float,
burstiness: float,
disable_tqdm: bool,
profile: bool,
selected_percentile_metrics: list[str],
selected_percentiles: list[str],
selected_percentiles: list[float],
ignore_eos: bool,
goodput_config_dict: dict[str, float],
max_concurrency: Optional[int],
lora_modules: Optional[list[str]],
lora_modules: Optional[Iterable[str]],
extra_body: Optional[dict],
):
if endpoint_type in ASYNC_REQUEST_FUNCS:
request_func = ASYNC_REQUEST_FUNCS[endpoint_type]
@ -298,11 +270,13 @@ async def benchmark(
print("Starting initial single prompt test run...")
test_prompt, test_prompt_len, test_output_len, test_mm_content = (
input_requests[0])
if endpoint_type != "openai-chat" and test_mm_content is not None:
# multi-modal benchmark is only available on OpenAI Chat endpoint.
raise ValueError("Multi-modal content is only supported on "
"'openai-chat' endpoint_type.")
input_requests[0].prompt,
input_requests[0].prompt_len,
input_requests[0].expected_output_len,
input_requests[0].multi_modal_data,
)
assert test_mm_content is None or isinstance(test_mm_content, dict)
test_input = RequestFuncInput(
model=model_id,
model_name=model_name,
@ -311,9 +285,9 @@ async def benchmark(
prompt_len=test_prompt_len,
output_len=test_output_len,
logprobs=logprobs,
best_of=best_of,
multi_modal_content=test_mm_content,
ignore_eos=ignore_eos,
extra_body=extra_body,
)
test_output = await request_func(request_func_input=test_input)
@ -338,9 +312,9 @@ async def benchmark(
prompt_len=test_prompt_len,
output_len=test_output_len,
logprobs=logprobs,
best_of=best_of,
multi_modal_content=test_mm_content,
ignore_eos=ignore_eos)
ignore_eos=ignore_eos,
extra_body=extra_body)
profile_output = await request_func(request_func_input=profile_input)
if profile_output.success:
print("Profiler started")
@ -374,7 +348,12 @@ async def benchmark(
benchmark_start_time = time.perf_counter()
tasks: list[asyncio.Task] = []
async for request in get_request(input_requests, request_rate, burstiness):
prompt, prompt_len, output_len, mm_content = request
prompt, prompt_len, output_len, mm_content = (
request.prompt,
request.prompt_len,
request.expected_output_len,
request.multi_modal_data,
)
req_model_id, req_model_name = model_id, model_name
if lora_modules:
req_lora_module = next(lora_modules)
@ -387,9 +366,9 @@ async def benchmark(
prompt_len=prompt_len,
output_len=output_len,
logprobs=logprobs,
best_of=best_of,
multi_modal_content=mm_content,
ignore_eos=ignore_eos)
ignore_eos=ignore_eos,
extra_body=extra_body)
tasks.append(
asyncio.create_task(
limited_request_func(request_func_input=request_func_input,
@ -405,7 +384,6 @@ async def benchmark(
prompt_len=test_prompt_len,
output_len=test_output_len,
logprobs=logprobs,
best_of=best_of,
)
profile_output = await request_func(request_func_input=profile_input)
if profile_output.success:
@ -567,7 +545,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--endpoint-type",
type=str,
default="openai-comp",
default="openai",
choices=list(ASYNC_REQUEST_FUNCS.keys()),
)
parser.add_argument(
@ -596,9 +574,16 @@ def add_cli_args(parser: argparse.ArgumentParser):
"--dataset-name",
type=str,
default="random",
choices=["random"],
choices=["sharegpt", "burstgpt", "sonnet", "random", "hf"],
help="Name of the dataset to benchmark on.",
)
parser.add_argument(
"--dataset-path",
type=str,
default=None,
help="Path to the sharegpt/sonnet dataset. "
"Or the huggingface dataset ID if using HF dataset.",
)
parser.add_argument(
"--max-concurrency",
type=int,
@ -624,13 +609,6 @@ def add_cli_args(parser: argparse.ArgumentParser):
help=
"Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
)
parser.add_argument(
"--best-of",
type=int,
default=1,
help="Generates `best_of` sequences per prompt and "
"returns the best one.",
)
parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument(
"--num-prompts",
@ -691,6 +669,17 @@ def add_cli_args(parser: argparse.ArgumentParser):
action="store_true",
help="Specify to save benchmark results to a json file",
)
parser.add_argument(
"--save-detailed",
action="store_true",
help="When saving the results, whether to include per request "
"information such as response, error, ttfs, tpots, etc.",
)
parser.add_argument(
"--append-result",
action="store_true",
help="Append the benchmark result to the existing json file.",
)
parser.add_argument(
"--metadata",
metavar="KEY=VALUE",
@ -733,6 +722,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
default="99",
help="Comma-separated list of percentiles for selected metrics. "
"To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". "
"Default value is \"99\"."
"Use \"--percentile-metrics\" to select metrics.",
)
parser.add_argument(
@ -745,7 +735,41 @@ def add_cli_args(parser: argparse.ArgumentParser):
"separated by spaces. Allowed request level metric names are "
"\"ttft\", \"tpot\", \"e2el\". For more context on the definition of "
"goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 "
"and the blog: https://hao-ai-lab.github.io/blogs/distserve")
"and the blog: https://hao-ai-lab.github.io/blogs/distserve",
)
# group for dataset specific arguments
sonnet_group = parser.add_argument_group("sonnet dataset options")
sonnet_group.add_argument(
"--sonnet-input-len",
type=int,
default=550,
help=
"Number of input tokens per request, used only for sonnet dataset.",
)
sonnet_group.add_argument(
"--sonnet-output-len",
type=int,
default=150,
help=
"Number of output tokens per request, used only for sonnet dataset.",
)
sonnet_group.add_argument(
"--sonnet-prefix-len",
type=int,
default=200,
help=
"Number of prefix tokens per request, used only for sonnet dataset.",
)
sharegpt_group = parser.add_argument_group("sharegpt dataset options")
sharegpt_group.add_argument(
"--sharegpt-output-len",
type=int,
default=None,
help="Output length for each request. Overrides the output length "
"from the ShareGPT dataset.",
)
random_group = parser.add_argument_group("random dataset options")
random_group.add_argument(
@ -765,9 +789,11 @@ def add_cli_args(parser: argparse.ArgumentParser):
random_group.add_argument(
"--random-range-ratio",
type=float,
default=1.0,
help="Range of sampled ratio of input/output length, "
"used only for random sampling.",
default=0.0,
help="Range ratio for sampling input/output length, "
"used only for random sampling. Must be in the range [0, 1) to define "
"a symmetric sampling range"
"[length * (1 - range_ratio), length * (1 + range_ratio)].",
)
random_group.add_argument(
"--random-prefix-len",
@ -778,6 +804,54 @@ def add_cli_args(parser: argparse.ArgumentParser):
" request is [random-prefix-len, "
" random-prefix-len + random-prefix-len * random-range-ratio).")
hf_group = parser.add_argument_group("hf dataset options")
hf_group.add_argument("--hf-subset",
type=str,
default=None,
help="Subset of the HF dataset.")
hf_group.add_argument("--hf-split",
type=str,
default=None,
help="Split of the HF dataset.")
hf_group.add_argument(
"--hf-output-len",
type=int,
default=None,
help="Output length for each request. Overrides the output lengths "
"from the sampled HF dataset.",
)
sampling_group = parser.add_argument_group("sampling parameters")
sampling_group.add_argument(
"--top-p",
type=float,
default=None,
help="Top-p sampling parameter. Only has effect on "
"openai-compatible backends.",
)
sampling_group.add_argument(
"--top-k",
type=int,
default=None,
help="Top-k sampling parameter. Only has effect on "
"openai-compatible backends.",
)
sampling_group.add_argument(
"--min-p",
type=float,
default=None,
help="Min-p sampling parameter. Only has effect on "
"openai-compatible backends.",
)
sampling_group.add_argument(
"--temperature",
type=float,
default=None,
help="Temperature sampling parameter. Only has effect on "
"openai-compatible backends. If not specified, default to greedy "
"decoding (i.e. temperature==0.0).",
)
parser.add_argument(
'--tokenizer-mode',
type=str,
@ -826,27 +900,142 @@ def main(args: argparse.Namespace):
tokenizer = get_tokenizer(tokenizer_id,
tokenizer_mode=tokenizer_mode,
trust_remote_code=args.trust_remote_code)
# TODO: This should be refactored to use the benchmark_dataset.py
# in later PRs.
if args.dataset_name is None:
raise ValueError(
"Please specify '--dataset-name' and the corresponding "
"'--dataset-path' if required.")
elif args.dataset_name == "random":
input_requests = sample_random_requests(
prefix_len=args.random_prefix_len,
input_len=args.random_input_len,
output_len=args.random_output_len,
num_prompts=args.num_prompts,
range_ratio=args.random_range_ratio,
if args.dataset_name == "sonnet":
dataset = SonnetDataset(dataset_path=args.dataset_path)
# For the "sonnet" dataset, formatting depends on the backend.
if args.backend == "openai-chat":
input_requests = dataset.sample(
num_requests=args.num_prompts,
input_len=args.sonnet_input_len,
output_len=args.sonnet_output_len,
prefix_len=args.sonnet_prefix_len,
tokenizer=tokenizer,
return_prompt_formatted=False,
)
else:
assert tokenizer.chat_template or tokenizer.default_chat_template, (
"Tokenizer/model must have chat template for sonnet dataset.")
input_requests = dataset.sample(
num_requests=args.num_prompts,
input_len=args.sonnet_input_len,
output_len=args.sonnet_output_len,
prefix_len=args.sonnet_prefix_len,
tokenizer=tokenizer,
return_prompt_formatted=True,
)
elif args.dataset_name == "hf":
# all following datasets are implemented from the
# HuggingFaceDataset base class
if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
dataset_class = VisionArenaDataset
args.hf_split = "train"
args.hf_subset = None
elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
dataset_class = InstructCoderDataset
args.hf_split = "train"
elif args.dataset_path in MTBenchDataset.SUPPORTED_DATASET_PATHS:
dataset_class = MTBenchDataset
args.hf_split = "train"
elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
dataset_class = ConversationDataset
args.hf_split = "train"
elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS:
dataset_class = AIMODataset
args.hf_split = "train"
elif args.dataset_path in NextEditPredictionDataset.SUPPORTED_DATASET_PATHS: # noqa: E501
dataset_class = NextEditPredictionDataset
args.hf_split = "train"
elif args.dataset_path in ASRDataset.SUPPORTED_DATASET_PATHS:
dataset_class = ASRDataset
args.hf_split = "train"
else:
supported_datasets = set([
dataset_name for cls in HuggingFaceDataset.__subclasses__()
for dataset_name in cls.SUPPORTED_DATASET_PATHS
])
raise ValueError(
f"Unsupported dataset path: {args.dataset_path}. "
"Huggingface dataset only supports dataset_path"
f" from one of following: {supported_datasets}. "
"Please consider contributing if you would "
"like to add support for additional dataset formats.")
if dataset_class.IS_MULTIMODAL and endpoint_type not in [
"openai-chat",
"openai-audio",
]:
# multi-modal benchmark is only available on OpenAI Chat backend.
raise ValueError(
"Multi-modal content is only supported on 'openai-chat' and "
"'openai-audio' backend.")
input_requests = dataset_class(
dataset_path=args.dataset_path,
dataset_subset=args.hf_subset,
dataset_split=args.hf_split,
random_seed=args.seed,
).sample(
num_requests=args.num_prompts,
tokenizer=tokenizer,
output_len=args.hf_output_len,
)
else:
raise ValueError(f"Unknown dataset: {args.dataset_name}")
# For datasets that follow a similar structure, use a mapping.
dataset_mapping = {
"sharegpt":
lambda: ShareGPTDataset(random_seed=args.seed,
dataset_path=args.dataset_path).sample(
tokenizer=tokenizer,
num_requests=args.num_prompts,
output_len=args.sharegpt_output_len,
),
"burstgpt":
lambda: BurstGPTDataset(random_seed=args.seed,
dataset_path=args.dataset_path).
sample(tokenizer=tokenizer, num_requests=args.num_prompts),
"random":
lambda: RandomDataset(dataset_path=args.dataset_path).sample(
tokenizer=tokenizer,
num_requests=args.num_prompts,
prefix_len=args.random_prefix_len,
input_len=args.random_input_len,
output_len=args.random_output_len,
range_ratio=args.random_range_ratio,
),
}
try:
input_requests = dataset_mapping[args.dataset_name]()
except KeyError as err:
raise ValueError(f"Unknown dataset: {args.dataset_name}") from err
goodput_config_dict = check_goodput_args(args)
# Collect the sampling parameters.
sampling_params = {
k: v
for k, v in {
"top_p": args.top_p,
"top_k": args.top_k,
"min_p": args.min_p,
"temperature": args.temperature,
}.items() if v is not None
}
# Sampling parameters are only supported by openai-compatible backend.
if sampling_params and args.backend not in OPENAI_COMPATIBLE_BACKENDS:
raise ValueError("Sampling parameters are only supported by "
"openai-compatible backends.")
if "temperature" not in sampling_params:
sampling_params["temperature"] = 0.0 # Default to greedy decoding.
# Avoid GC processing "static" data - reduce pause times.
gc.collect()
gc.freeze()
@ -861,7 +1050,6 @@ def main(args: argparse.Namespace):
tokenizer=tokenizer,
input_requests=input_requests,
logprobs=args.logprobs,
best_of=args.best_of,
request_rate=args.request_rate,
burstiness=args.burstiness,
disable_tqdm=args.disable_tqdm,
@ -874,10 +1062,11 @@ def main(args: argparse.Namespace):
goodput_config_dict=goodput_config_dict,
max_concurrency=args.max_concurrency,
lora_modules=args.lora_modules,
extra_body=sampling_params,
))
# Save config and results to json
if args.save_result:
if args.save_result or args.append_result:
result_json: dict[str, Any] = {}
# Setup
@ -887,7 +1076,6 @@ def main(args: argparse.Namespace):
result_json["label"] = label
result_json["model_id"] = model_id
result_json["tokenizer_id"] = tokenizer_id
result_json["best_of"] = args.best_of
result_json["num_prompts"] = args.num_prompts
# Metadata
@ -910,6 +1098,19 @@ def main(args: argparse.Namespace):
# Merge with benchmark result
result_json = {**result_json, **benchmark_result}
if not args.save_detailed:
# Remove fields with too many data points
for field in [
"input_lens",
"output_lens",
"ttfts",
"itls",
"generated_texts",
"errors",
]:
if field in result_json:
del result_json[field]
# Save to file
base_model_id = model_id.split("/")[-1]
max_concurrency_str = (f"-concurrency{args.max_concurrency}"
@ -920,6 +1121,11 @@ def main(args: argparse.Namespace):
file_name = args.result_filename
if args.result_dir:
file_name = os.path.join(args.result_dir, file_name)
with open(file_name, "w", encoding='utf-8') as outfile:
with open(file_name,
mode="a+" if args.append_result else "w",
encoding="utf-8") as outfile:
# Append a newline.
if args.append_result and outfile.tell() != 0:
outfile.write("\n")
json.dump(result_json, outfile)
save_to_pytorch_benchmark_format(args, result_json, file_name)