mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 06:45:01 +08:00
Support datasets in vllm bench serve and sync with benchmark_[serving,datasets].py (#18566)
This commit is contained in:
parent
e0cbad4e30
commit
e56f44d9ec
@ -62,6 +62,7 @@ class SampleRequest:
|
||||
|
||||
class BenchmarkDataset(ABC):
|
||||
DEFAULT_SEED = 0
|
||||
IS_MULTIMODAL = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -316,13 +317,15 @@ class RandomDataset(BenchmarkDataset):
|
||||
)
|
||||
|
||||
vocab_size = tokenizer.vocab_size
|
||||
num_special_tokens = tokenizer.num_special_tokens_to_add()
|
||||
real_input_len = input_len - num_special_tokens
|
||||
|
||||
prefix_token_ids = (np.random.randint(
|
||||
0, vocab_size, size=prefix_len).tolist() if prefix_len > 0 else [])
|
||||
|
||||
# New sampling logic: [X * (1 - b), X * (1 + b)]
|
||||
input_low = int(input_len * (1 - range_ratio))
|
||||
input_high = int(input_len * (1 + range_ratio))
|
||||
input_low = int(real_input_len * (1 - range_ratio))
|
||||
input_high = int(real_input_len * (1 + range_ratio))
|
||||
output_low = int(output_len * (1 - range_ratio))
|
||||
output_high = int(output_len * (1 + range_ratio))
|
||||
|
||||
@ -345,6 +348,17 @@ class RandomDataset(BenchmarkDataset):
|
||||
vocab_size).tolist()
|
||||
token_sequence = prefix_token_ids + inner_seq
|
||||
prompt = tokenizer.decode(token_sequence)
|
||||
# After decoding the prompt we have to encode and decode it again.
|
||||
# This is done because in some cases N consecutive tokens
|
||||
# give a string tokenized into != N number of tokens.
|
||||
# For example for GPT2Tokenizer:
|
||||
# [6880, 6881] -> ['Ġcalls', 'here'] ->
|
||||
# [1650, 939, 486] -> ['Ġcall', 'sh', 'ere']
|
||||
# To avoid uncontrolled change of the prompt length,
|
||||
# the encoded sequence is truncated before being decode again.
|
||||
re_encoded_sequence = tokenizer.encode(
|
||||
prompt, add_special_tokens=False)[:input_lens[i]]
|
||||
prompt = tokenizer.decode(re_encoded_sequence)
|
||||
total_input_len = prefix_len + int(input_lens[i])
|
||||
requests.append(
|
||||
SampleRequest(
|
||||
@ -637,6 +651,7 @@ class ConversationDataset(HuggingFaceDataset):
|
||||
SUPPORTED_DATASET_PATHS = {
|
||||
'lmms-lab/LLaVA-OneVision-Data', 'Aeala/ShareGPT_Vicuna_unfiltered'
|
||||
}
|
||||
IS_MULTIMODAL = True
|
||||
|
||||
def sample(self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
@ -701,6 +716,7 @@ class VisionArenaDataset(HuggingFaceDataset):
|
||||
"lmarena-ai/vision-arena-bench-v0.1":
|
||||
lambda x: x["turns"][0][0]["content"]
|
||||
}
|
||||
IS_MULTIMODAL = True
|
||||
|
||||
def sample(
|
||||
self,
|
||||
@ -784,6 +800,64 @@ class InstructCoderDataset(HuggingFaceDataset):
|
||||
return sampled_requests
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# MT-Bench Dataset Implementation
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class MTBenchDataset(HuggingFaceDataset):
|
||||
"""
|
||||
MT-Bench Dataset.
|
||||
https://huggingface.co/datasets/philschmid/mt-bench
|
||||
|
||||
We create a single turn dataset for MT-Bench.
|
||||
This is similar to Spec decoding benchmark setup in vLLM
|
||||
https://github.com/vllm-project/vllm/blob/9d98ab5ec/examples/offline_inference/eagle.py#L14-L18
|
||||
""" # noqa: E501
|
||||
|
||||
DEFAULT_OUTPUT_LEN = 256 # avg len used in SD bench in vLLM
|
||||
SUPPORTED_DATASET_PATHS = {
|
||||
"philschmid/mt-bench",
|
||||
}
|
||||
|
||||
def sample(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
output_len: Optional[int] = None,
|
||||
enable_multimodal_chat: bool = False,
|
||||
**kwargs,
|
||||
) -> list:
|
||||
output_len = (output_len
|
||||
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
|
||||
sampled_requests = []
|
||||
|
||||
for item in self.data:
|
||||
if len(sampled_requests) >= num_requests:
|
||||
break
|
||||
prompt = item["turns"][0]
|
||||
|
||||
# apply template
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
[{
|
||||
"role": "user",
|
||||
"content": prompt
|
||||
}],
|
||||
add_generation_prompt=True,
|
||||
tokenize=False,
|
||||
)
|
||||
|
||||
prompt_len = len(tokenizer(prompt).input_ids)
|
||||
sampled_requests.append(
|
||||
SampleRequest(
|
||||
prompt=prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
))
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||
return sampled_requests
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# AIMO Dataset Implementation
|
||||
# -----------------------------------------------------------------------------
|
||||
@ -858,18 +932,18 @@ def _format_zeta_prompt(
|
||||
sample: dict,
|
||||
original_start_marker: str = "<|editable_region_start|>") -> dict:
|
||||
"""Format the zeta prompt for the Next Edit Prediction (NEP) dataset.
|
||||
|
||||
This function formats examples from the NEP dataset
|
||||
into prompts and expected outputs. It could be
|
||||
|
||||
This function formats examples from the NEP dataset
|
||||
into prompts and expected outputs. It could be
|
||||
further extended to support more NEP datasets.
|
||||
|
||||
|
||||
Args:
|
||||
sample: The dataset sample containing events,
|
||||
sample: The dataset sample containing events,
|
||||
inputs, and outputs.
|
||||
original_start_marker: The marker indicating the
|
||||
start of the editable region. Defaults to
|
||||
original_start_marker: The marker indicating the
|
||||
start of the editable region. Defaults to
|
||||
"<|editable_region_start|>".
|
||||
|
||||
|
||||
Returns:
|
||||
A dictionary with the formatted prompts and expected outputs.
|
||||
"""
|
||||
@ -919,3 +993,94 @@ class NextEditPredictionDataset(HuggingFaceDataset):
|
||||
break
|
||||
self.maybe_oversample_requests(samples, num_requests)
|
||||
return samples
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# ASR Dataset Implementation
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ASRDataset(HuggingFaceDataset):
|
||||
"""
|
||||
Dataset class for processing a ASR dataset for transcription.
|
||||
Tested on the following set:
|
||||
|
||||
+----------------+----------------------------------------+--------------------------+-----------------------------+
|
||||
| Dataset | Domain | Speaking Style | hf-subset |
|
||||
+----------------+----------------------------------------+--------------------------+-----------------------------+
|
||||
| TED-LIUM | TED talks | Oratory | release1, release2, release3|
|
||||
| | | | release3-speaker-adaptation |
|
||||
| VoxPopuli | European Parliament | Oratory | en, de, it, fr, ... |
|
||||
| LibriSpeech | Audiobook | Narrated | "LIUM/tedlium" |
|
||||
| GigaSpeech | Audiobook, podcast, YouTube | Narrated, spontaneous | xs, s, m, l, xl, dev, test |
|
||||
| SPGISpeech | Financial meetings | Oratory, spontaneous | S, M, L, dev, test |
|
||||
| AMI | Meetings | Spontaneous | ihm, sdm |
|
||||
+----------------+----------------------------------------+--------------------------+-----------------------------+
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
SUPPORTED_DATASET_PATHS = {
|
||||
"openslr/librispeech_asr",
|
||||
"facebook/voxpopuli",
|
||||
"LIUM/tedlium",
|
||||
"edinburghcstr/ami",
|
||||
"speechcolab/gigaspeech",
|
||||
"kensho/spgispeech",
|
||||
}
|
||||
|
||||
DEFAULT_OUTPUT_LEN = 128
|
||||
IS_MULTIMODAL = True
|
||||
|
||||
# TODO Whisper-specific. Abstract interface when more models are supported.
|
||||
TRANSCRIPTION_PREAMBLE = (
|
||||
"<|startoftranscript|><|en|><|transcribe|><|notimestamps|>")
|
||||
skip_long_audios: bool = True
|
||||
|
||||
def sample(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
output_len: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> list:
|
||||
try:
|
||||
import librosa
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"librosa is required for ASRDataset. Please install it "
|
||||
"using `pip install librosa`.") from e
|
||||
|
||||
output_len = (output_len
|
||||
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
|
||||
prompt = ASRDataset.TRANSCRIPTION_PREAMBLE
|
||||
prompt_len = len(tokenizer(prompt).input_ids)
|
||||
sampled_requests = []
|
||||
skipped = 0
|
||||
for item in self.data:
|
||||
if len(sampled_requests) >= num_requests:
|
||||
break
|
||||
audio = item["audio"]
|
||||
y, sr = audio["array"], audio["sampling_rate"]
|
||||
duration_s = librosa.get_duration(y=y, sr=sr)
|
||||
# Whisper max supported duration
|
||||
if self.skip_long_audios and duration_s > 30:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
mm_content = {"audio": (y, sr)}
|
||||
sampled_requests.append(
|
||||
SampleRequest(
|
||||
prompt=prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
multi_modal_data=mm_content,
|
||||
))
|
||||
if skipped:
|
||||
logger.warning(
|
||||
"%d samples discarded from dataset due to"
|
||||
" their length being greater than"
|
||||
" what Whisper supports.",
|
||||
skipped,
|
||||
)
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||
return sampled_requests
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""The request function for API endpoints."""
|
||||
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
@ -24,11 +25,11 @@ class RequestFuncInput:
|
||||
output_len: int
|
||||
model: str
|
||||
model_name: Optional[str] = None
|
||||
best_of: int = 1
|
||||
logprobs: Optional[int] = None
|
||||
extra_body: Optional[dict] = None
|
||||
multi_modal_content: Optional[dict] = None
|
||||
ignore_eos: bool = False
|
||||
language: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -71,7 +72,7 @@ async def async_request_openai_completions(
|
||||
if request_func_input.model_name else request_func_input.model,
|
||||
"prompt": request_func_input.prompt,
|
||||
"temperature": 0.0,
|
||||
"best_of": request_func_input.best_of,
|
||||
"repetition_penalty": 1.0,
|
||||
"max_tokens": request_func_input.output_len,
|
||||
"logprobs": request_func_input.logprobs,
|
||||
"stream": True,
|
||||
@ -154,7 +155,226 @@ async def async_request_openai_completions(
|
||||
return output
|
||||
|
||||
|
||||
async def async_request_openai_chat_completions(
|
||||
request_func_input: RequestFuncInput,
|
||||
pbar: Optional[tqdm] = None,
|
||||
) -> RequestFuncOutput:
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith(("chat/completions", "profile")), (
|
||||
"OpenAI Chat Completions API URL must end with 'chat/completions'.")
|
||||
|
||||
async with aiohttp.ClientSession(trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT) as session:
|
||||
content = [{"type": "text", "text": request_func_input.prompt}]
|
||||
if request_func_input.multi_modal_content:
|
||||
content.append(request_func_input.multi_modal_content)
|
||||
payload = {
|
||||
"model":
|
||||
request_func_input.model_name
|
||||
if request_func_input.model_name else request_func_input.model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": content
|
||||
},
|
||||
],
|
||||
"temperature":
|
||||
0.0,
|
||||
"max_completion_tokens":
|
||||
request_func_input.output_len,
|
||||
"stream":
|
||||
True,
|
||||
"stream_options": {
|
||||
"include_usage": True,
|
||||
},
|
||||
}
|
||||
if request_func_input.ignore_eos:
|
||||
payload["ignore_eos"] = request_func_input.ignore_eos
|
||||
if request_func_input.extra_body:
|
||||
payload.update(request_func_input.extra_body)
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
}
|
||||
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
|
||||
generated_text = ""
|
||||
ttft = 0.0
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(url=api_url, json=payload,
|
||||
headers=headers) as response:
|
||||
if response.status == 200:
|
||||
async for chunk_bytes in response.content:
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
|
||||
chunk = chunk_bytes.decode("utf-8").removeprefix(
|
||||
"data: ")
|
||||
if chunk != "[DONE]":
|
||||
timestamp = time.perf_counter()
|
||||
data = json.loads(chunk)
|
||||
|
||||
if choices := data.get("choices"):
|
||||
content = choices[0]["delta"].get("content")
|
||||
# First token
|
||||
if ttft == 0.0:
|
||||
ttft = timestamp - st
|
||||
output.ttft = ttft
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(timestamp -
|
||||
most_recent_timestamp)
|
||||
|
||||
generated_text += content or ""
|
||||
elif usage := data.get("usage"):
|
||||
output.output_tokens = usage.get(
|
||||
"completion_tokens")
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
|
||||
output.generated_text = generated_text
|
||||
output.success = True
|
||||
output.latency = most_recent_timestamp - st
|
||||
else:
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
except Exception:
|
||||
output.success = False
|
||||
exc_info = sys.exc_info()
|
||||
output.error = "".join(traceback.format_exception(*exc_info))
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
return output
|
||||
|
||||
|
||||
async def async_request_openai_audio(
|
||||
request_func_input: RequestFuncInput,
|
||||
pbar: Optional[tqdm] = None,
|
||||
) -> RequestFuncOutput:
|
||||
# Lazy import without PlaceholderModule to avoid vllm dep.
|
||||
import soundfile
|
||||
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith(("transcriptions", "translations")), (
|
||||
"OpenAI Chat Completions API URL must end with 'transcriptions' ")
|
||||
"or `translations`."
|
||||
|
||||
async with aiohttp.ClientSession(trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT) as session:
|
||||
content = [{"type": "text", "text": request_func_input.prompt}]
|
||||
payload = {
|
||||
"model":
|
||||
request_func_input.model_name
|
||||
if request_func_input.model_name else request_func_input.model,
|
||||
"temperature":
|
||||
0.0,
|
||||
"max_completion_tokens":
|
||||
request_func_input.output_len,
|
||||
"stream":
|
||||
True,
|
||||
"language":
|
||||
"en",
|
||||
# Flattened due to multipart/form-data
|
||||
"stream_include_usage":
|
||||
True,
|
||||
"stream_continuous_usage_stats":
|
||||
True,
|
||||
}
|
||||
if request_func_input.extra_body:
|
||||
payload.update(request_func_input.extra_body)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
}
|
||||
|
||||
# Send audio file
|
||||
def to_bytes(y, sr):
|
||||
buffer = io.BytesIO()
|
||||
soundfile.write(buffer, y, sr, format="WAV")
|
||||
buffer.seek(0)
|
||||
return buffer
|
||||
|
||||
with to_bytes(*request_func_input.multi_modal_content["audio"]) as f:
|
||||
form = aiohttp.FormData()
|
||||
form.add_field("file", f, content_type="audio/wav")
|
||||
for key, value in payload.items():
|
||||
form.add_field(key, str(value))
|
||||
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
|
||||
generated_text = ""
|
||||
ttft = 0.0
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(url=api_url,
|
||||
data=form,
|
||||
headers=headers) as response:
|
||||
if response.status == 200:
|
||||
async for chunk_bytes in response.content:
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
|
||||
chunk = chunk_bytes.decode("utf-8").removeprefix(
|
||||
"data: ")
|
||||
if chunk != "[DONE]":
|
||||
timestamp = time.perf_counter()
|
||||
data = json.loads(chunk)
|
||||
|
||||
if choices := data.get("choices"):
|
||||
content = choices[0]["delta"].get(
|
||||
"content")
|
||||
# First token
|
||||
if ttft == 0.0:
|
||||
ttft = timestamp - st
|
||||
output.ttft = ttft
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(
|
||||
timestamp - most_recent_timestamp)
|
||||
|
||||
generated_text += content or ""
|
||||
elif usage := data.get("usage"):
|
||||
output.output_tokens = usage.get(
|
||||
"completion_tokens")
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
|
||||
output.generated_text = generated_text
|
||||
output.success = True
|
||||
output.latency = most_recent_timestamp - st
|
||||
else:
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
except Exception:
|
||||
output.success = False
|
||||
exc_info = sys.exc_info()
|
||||
output.error = "".join(traceback.format_exception(*exc_info))
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
return output
|
||||
|
||||
|
||||
# TODO: Add more request functions for different API protocols.
|
||||
ASYNC_REQUEST_FUNCS = {
|
||||
"openai-comp": async_request_openai_completions,
|
||||
"vllm": async_request_openai_completions,
|
||||
"openai": async_request_openai_completions,
|
||||
"openai-chat": async_request_openai_chat_completions,
|
||||
"openai-audio": async_request_openai_audio,
|
||||
}
|
||||
|
||||
OPENAI_COMPATIBLE_BACKENDS = [
|
||||
k for k, v in ASYNC_REQUEST_FUNCS.items()
|
||||
if v in (async_request_openai_completions,
|
||||
async_request_openai_chat_completions)
|
||||
]
|
||||
|
||||
@ -7,7 +7,7 @@ to launch the vLLM OpenAI API server:
|
||||
|
||||
On the client side, run:
|
||||
vllm bench serve \
|
||||
--endpoint-type <endpoint_type. Default 'openi-comp'> \
|
||||
--endpoint-type <endpoint_type. Default 'openai'> \
|
||||
--label <benchmark result label. Default using endpoint_type> \
|
||||
--model <your_model> \
|
||||
--dataset-name <dataset_name. Default 'random'> \
|
||||
@ -22,7 +22,7 @@ import os
|
||||
import random
|
||||
import time
|
||||
import warnings
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import AsyncGenerator, Iterable
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
@ -31,7 +31,14 @@ import numpy as np
|
||||
from tqdm.asyncio import tqdm
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.benchmarks.datasets import (AIMODataset, ASRDataset, BurstGPTDataset,
|
||||
ConversationDataset, HuggingFaceDataset,
|
||||
InstructCoderDataset, MTBenchDataset,
|
||||
NextEditPredictionDataset, RandomDataset,
|
||||
SampleRequest, ShareGPTDataset,
|
||||
SonnetDataset, VisionArenaDataset)
|
||||
from vllm.benchmarks.endpoint_request_func import (ASYNC_REQUEST_FUNCS,
|
||||
OPENAI_COMPATIBLE_BACKENDS,
|
||||
RequestFuncInput,
|
||||
RequestFuncOutput)
|
||||
from vllm.benchmarks.utils import (convert_to_pytorch_benchmark_format,
|
||||
@ -71,53 +78,18 @@ class BenchmarkMetrics:
|
||||
percentiles_e2el_ms: list[tuple[float, float]]
|
||||
|
||||
|
||||
def sample_random_requests(
|
||||
prefix_len: int,
|
||||
input_len: int,
|
||||
output_len: int,
|
||||
num_prompts: int,
|
||||
range_ratio: float,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
) -> list[tuple[str, int, int]]:
|
||||
prefix_token_ids = np.random.randint(0,
|
||||
tokenizer.vocab_size,
|
||||
size=prefix_len).tolist()
|
||||
|
||||
input_lens = np.random.randint(
|
||||
int(input_len * range_ratio),
|
||||
input_len + 1,
|
||||
size=num_prompts,
|
||||
)
|
||||
output_lens = np.random.randint(
|
||||
int(output_len * range_ratio),
|
||||
output_len + 1,
|
||||
size=num_prompts,
|
||||
)
|
||||
offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts)
|
||||
input_requests = []
|
||||
for i in range(num_prompts):
|
||||
prompt = tokenizer.decode(prefix_token_ids +
|
||||
[(offsets[i] + i + j) % tokenizer.vocab_size
|
||||
for j in range(input_lens[i])])
|
||||
|
||||
input_requests.append((prompt, int(prefix_len + input_lens[i]),
|
||||
int(output_lens[i]), None))
|
||||
|
||||
return input_requests
|
||||
|
||||
|
||||
async def get_request(
|
||||
input_requests: list[tuple[str, int, int]],
|
||||
input_requests: list[SampleRequest],
|
||||
request_rate: float,
|
||||
burstiness: float = 1.0,
|
||||
) -> AsyncGenerator[tuple[str, int, int], None]:
|
||||
) -> AsyncGenerator[SampleRequest, None]:
|
||||
"""
|
||||
Asynchronously generates requests at a specified rate
|
||||
with OPTIONAL burstiness.
|
||||
|
||||
Args:
|
||||
input_requests:
|
||||
A list of input requests, each represented as a tuple.
|
||||
A list of input requests, each represented as a SampleRequest.
|
||||
request_rate:
|
||||
The rate at which requests are generated (requests/s).
|
||||
burstiness (optional):
|
||||
@ -129,7 +101,7 @@ async def get_request(
|
||||
in more bursty requests, while a higher burstiness value
|
||||
(burstiness > 1) results in a more uniform arrival of requests.
|
||||
"""
|
||||
input_requests = iter(input_requests)
|
||||
input_requests: Iterable[SampleRequest] = iter(input_requests)
|
||||
|
||||
# Calculate scale parameter theta to maintain the desired request_rate.
|
||||
assert burstiness > 0, (
|
||||
@ -151,7 +123,7 @@ async def get_request(
|
||||
|
||||
|
||||
def calculate_metrics(
|
||||
input_requests: list[tuple[str, int, int]],
|
||||
input_requests: list[SampleRequest],
|
||||
outputs: list[RequestFuncOutput],
|
||||
dur_s: float,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
@ -184,7 +156,7 @@ def calculate_metrics(
|
||||
if outputs[i].success:
|
||||
output_len = outputs[i].output_tokens
|
||||
|
||||
if output_len is None:
|
||||
if not output_len:
|
||||
# We use the tokenizer to count the number of output tokens
|
||||
# for some serving backends instead of looking at
|
||||
# len(outputs[i].itl) since multiple output tokens may be
|
||||
@ -194,7 +166,7 @@ def calculate_metrics(
|
||||
tokenizer(outputs[i].generated_text,
|
||||
add_special_tokens=False).input_ids)
|
||||
actual_output_lens.append(output_len)
|
||||
total_input += input_requests[i][1]
|
||||
total_input += input_requests[i].prompt_len
|
||||
tpot = 0
|
||||
if output_len > 1:
|
||||
latency_minus_ttft = outputs[i].latency - outputs[i].ttft
|
||||
@ -277,19 +249,19 @@ async def benchmark(
|
||||
model_id: str,
|
||||
model_name: str,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
input_requests: list[tuple[str, int, int]],
|
||||
input_requests: list[SampleRequest],
|
||||
logprobs: Optional[int],
|
||||
best_of: int,
|
||||
request_rate: float,
|
||||
burstiness: float,
|
||||
disable_tqdm: bool,
|
||||
profile: bool,
|
||||
selected_percentile_metrics: list[str],
|
||||
selected_percentiles: list[str],
|
||||
selected_percentiles: list[float],
|
||||
ignore_eos: bool,
|
||||
goodput_config_dict: dict[str, float],
|
||||
max_concurrency: Optional[int],
|
||||
lora_modules: Optional[list[str]],
|
||||
lora_modules: Optional[Iterable[str]],
|
||||
extra_body: Optional[dict],
|
||||
):
|
||||
if endpoint_type in ASYNC_REQUEST_FUNCS:
|
||||
request_func = ASYNC_REQUEST_FUNCS[endpoint_type]
|
||||
@ -298,11 +270,13 @@ async def benchmark(
|
||||
|
||||
print("Starting initial single prompt test run...")
|
||||
test_prompt, test_prompt_len, test_output_len, test_mm_content = (
|
||||
input_requests[0])
|
||||
if endpoint_type != "openai-chat" and test_mm_content is not None:
|
||||
# multi-modal benchmark is only available on OpenAI Chat endpoint.
|
||||
raise ValueError("Multi-modal content is only supported on "
|
||||
"'openai-chat' endpoint_type.")
|
||||
input_requests[0].prompt,
|
||||
input_requests[0].prompt_len,
|
||||
input_requests[0].expected_output_len,
|
||||
input_requests[0].multi_modal_data,
|
||||
)
|
||||
|
||||
assert test_mm_content is None or isinstance(test_mm_content, dict)
|
||||
test_input = RequestFuncInput(
|
||||
model=model_id,
|
||||
model_name=model_name,
|
||||
@ -311,9 +285,9 @@ async def benchmark(
|
||||
prompt_len=test_prompt_len,
|
||||
output_len=test_output_len,
|
||||
logprobs=logprobs,
|
||||
best_of=best_of,
|
||||
multi_modal_content=test_mm_content,
|
||||
ignore_eos=ignore_eos,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
|
||||
test_output = await request_func(request_func_input=test_input)
|
||||
@ -338,9 +312,9 @@ async def benchmark(
|
||||
prompt_len=test_prompt_len,
|
||||
output_len=test_output_len,
|
||||
logprobs=logprobs,
|
||||
best_of=best_of,
|
||||
multi_modal_content=test_mm_content,
|
||||
ignore_eos=ignore_eos)
|
||||
ignore_eos=ignore_eos,
|
||||
extra_body=extra_body)
|
||||
profile_output = await request_func(request_func_input=profile_input)
|
||||
if profile_output.success:
|
||||
print("Profiler started")
|
||||
@ -374,7 +348,12 @@ async def benchmark(
|
||||
benchmark_start_time = time.perf_counter()
|
||||
tasks: list[asyncio.Task] = []
|
||||
async for request in get_request(input_requests, request_rate, burstiness):
|
||||
prompt, prompt_len, output_len, mm_content = request
|
||||
prompt, prompt_len, output_len, mm_content = (
|
||||
request.prompt,
|
||||
request.prompt_len,
|
||||
request.expected_output_len,
|
||||
request.multi_modal_data,
|
||||
)
|
||||
req_model_id, req_model_name = model_id, model_name
|
||||
if lora_modules:
|
||||
req_lora_module = next(lora_modules)
|
||||
@ -387,9 +366,9 @@ async def benchmark(
|
||||
prompt_len=prompt_len,
|
||||
output_len=output_len,
|
||||
logprobs=logprobs,
|
||||
best_of=best_of,
|
||||
multi_modal_content=mm_content,
|
||||
ignore_eos=ignore_eos)
|
||||
ignore_eos=ignore_eos,
|
||||
extra_body=extra_body)
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
limited_request_func(request_func_input=request_func_input,
|
||||
@ -405,7 +384,6 @@ async def benchmark(
|
||||
prompt_len=test_prompt_len,
|
||||
output_len=test_output_len,
|
||||
logprobs=logprobs,
|
||||
best_of=best_of,
|
||||
)
|
||||
profile_output = await request_func(request_func_input=profile_input)
|
||||
if profile_output.success:
|
||||
@ -567,7 +545,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--endpoint-type",
|
||||
type=str,
|
||||
default="openai-comp",
|
||||
default="openai",
|
||||
choices=list(ASYNC_REQUEST_FUNCS.keys()),
|
||||
)
|
||||
parser.add_argument(
|
||||
@ -596,9 +574,16 @@ def add_cli_args(parser: argparse.ArgumentParser):
|
||||
"--dataset-name",
|
||||
type=str,
|
||||
default="random",
|
||||
choices=["random"],
|
||||
choices=["sharegpt", "burstgpt", "sonnet", "random", "hf"],
|
||||
help="Name of the dataset to benchmark on.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the sharegpt/sonnet dataset. "
|
||||
"Or the huggingface dataset ID if using HF dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-concurrency",
|
||||
type=int,
|
||||
@ -624,13 +609,6 @@ def add_cli_args(parser: argparse.ArgumentParser):
|
||||
help=
|
||||
"Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
|
||||
)
|
||||
parser.add_argument(
|
||||
"--best-of",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Generates `best_of` sequences per prompt and "
|
||||
"returns the best one.",
|
||||
)
|
||||
parser.add_argument("--use-beam-search", action="store_true")
|
||||
parser.add_argument(
|
||||
"--num-prompts",
|
||||
@ -691,6 +669,17 @@ def add_cli_args(parser: argparse.ArgumentParser):
|
||||
action="store_true",
|
||||
help="Specify to save benchmark results to a json file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-detailed",
|
||||
action="store_true",
|
||||
help="When saving the results, whether to include per request "
|
||||
"information such as response, error, ttfs, tpots, etc.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--append-result",
|
||||
action="store_true",
|
||||
help="Append the benchmark result to the existing json file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metadata",
|
||||
metavar="KEY=VALUE",
|
||||
@ -733,6 +722,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
|
||||
default="99",
|
||||
help="Comma-separated list of percentiles for selected metrics. "
|
||||
"To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". "
|
||||
"Default value is \"99\"."
|
||||
"Use \"--percentile-metrics\" to select metrics.",
|
||||
)
|
||||
parser.add_argument(
|
||||
@ -745,7 +735,41 @@ def add_cli_args(parser: argparse.ArgumentParser):
|
||||
"separated by spaces. Allowed request level metric names are "
|
||||
"\"ttft\", \"tpot\", \"e2el\". For more context on the definition of "
|
||||
"goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 "
|
||||
"and the blog: https://hao-ai-lab.github.io/blogs/distserve")
|
||||
"and the blog: https://hao-ai-lab.github.io/blogs/distserve",
|
||||
)
|
||||
|
||||
# group for dataset specific arguments
|
||||
sonnet_group = parser.add_argument_group("sonnet dataset options")
|
||||
sonnet_group.add_argument(
|
||||
"--sonnet-input-len",
|
||||
type=int,
|
||||
default=550,
|
||||
help=
|
||||
"Number of input tokens per request, used only for sonnet dataset.",
|
||||
)
|
||||
sonnet_group.add_argument(
|
||||
"--sonnet-output-len",
|
||||
type=int,
|
||||
default=150,
|
||||
help=
|
||||
"Number of output tokens per request, used only for sonnet dataset.",
|
||||
)
|
||||
sonnet_group.add_argument(
|
||||
"--sonnet-prefix-len",
|
||||
type=int,
|
||||
default=200,
|
||||
help=
|
||||
"Number of prefix tokens per request, used only for sonnet dataset.",
|
||||
)
|
||||
|
||||
sharegpt_group = parser.add_argument_group("sharegpt dataset options")
|
||||
sharegpt_group.add_argument(
|
||||
"--sharegpt-output-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Output length for each request. Overrides the output length "
|
||||
"from the ShareGPT dataset.",
|
||||
)
|
||||
|
||||
random_group = parser.add_argument_group("random dataset options")
|
||||
random_group.add_argument(
|
||||
@ -765,9 +789,11 @@ def add_cli_args(parser: argparse.ArgumentParser):
|
||||
random_group.add_argument(
|
||||
"--random-range-ratio",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Range of sampled ratio of input/output length, "
|
||||
"used only for random sampling.",
|
||||
default=0.0,
|
||||
help="Range ratio for sampling input/output length, "
|
||||
"used only for random sampling. Must be in the range [0, 1) to define "
|
||||
"a symmetric sampling range"
|
||||
"[length * (1 - range_ratio), length * (1 + range_ratio)].",
|
||||
)
|
||||
random_group.add_argument(
|
||||
"--random-prefix-len",
|
||||
@ -778,6 +804,54 @@ def add_cli_args(parser: argparse.ArgumentParser):
|
||||
" request is [random-prefix-len, "
|
||||
" random-prefix-len + random-prefix-len * random-range-ratio).")
|
||||
|
||||
hf_group = parser.add_argument_group("hf dataset options")
|
||||
hf_group.add_argument("--hf-subset",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Subset of the HF dataset.")
|
||||
hf_group.add_argument("--hf-split",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Split of the HF dataset.")
|
||||
hf_group.add_argument(
|
||||
"--hf-output-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Output length for each request. Overrides the output lengths "
|
||||
"from the sampled HF dataset.",
|
||||
)
|
||||
|
||||
sampling_group = parser.add_argument_group("sampling parameters")
|
||||
sampling_group.add_argument(
|
||||
"--top-p",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Top-p sampling parameter. Only has effect on "
|
||||
"openai-compatible backends.",
|
||||
)
|
||||
sampling_group.add_argument(
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Top-k sampling parameter. Only has effect on "
|
||||
"openai-compatible backends.",
|
||||
)
|
||||
sampling_group.add_argument(
|
||||
"--min-p",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Min-p sampling parameter. Only has effect on "
|
||||
"openai-compatible backends.",
|
||||
)
|
||||
sampling_group.add_argument(
|
||||
"--temperature",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Temperature sampling parameter. Only has effect on "
|
||||
"openai-compatible backends. If not specified, default to greedy "
|
||||
"decoding (i.e. temperature==0.0).",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--tokenizer-mode',
|
||||
type=str,
|
||||
@ -826,27 +900,142 @@ def main(args: argparse.Namespace):
|
||||
tokenizer = get_tokenizer(tokenizer_id,
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
trust_remote_code=args.trust_remote_code)
|
||||
# TODO: This should be refactored to use the benchmark_dataset.py
|
||||
# in later PRs.
|
||||
|
||||
if args.dataset_name is None:
|
||||
raise ValueError(
|
||||
"Please specify '--dataset-name' and the corresponding "
|
||||
"'--dataset-path' if required.")
|
||||
elif args.dataset_name == "random":
|
||||
input_requests = sample_random_requests(
|
||||
prefix_len=args.random_prefix_len,
|
||||
input_len=args.random_input_len,
|
||||
output_len=args.random_output_len,
|
||||
num_prompts=args.num_prompts,
|
||||
range_ratio=args.random_range_ratio,
|
||||
|
||||
if args.dataset_name == "sonnet":
|
||||
dataset = SonnetDataset(dataset_path=args.dataset_path)
|
||||
# For the "sonnet" dataset, formatting depends on the backend.
|
||||
if args.backend == "openai-chat":
|
||||
input_requests = dataset.sample(
|
||||
num_requests=args.num_prompts,
|
||||
input_len=args.sonnet_input_len,
|
||||
output_len=args.sonnet_output_len,
|
||||
prefix_len=args.sonnet_prefix_len,
|
||||
tokenizer=tokenizer,
|
||||
return_prompt_formatted=False,
|
||||
)
|
||||
else:
|
||||
assert tokenizer.chat_template or tokenizer.default_chat_template, (
|
||||
"Tokenizer/model must have chat template for sonnet dataset.")
|
||||
input_requests = dataset.sample(
|
||||
num_requests=args.num_prompts,
|
||||
input_len=args.sonnet_input_len,
|
||||
output_len=args.sonnet_output_len,
|
||||
prefix_len=args.sonnet_prefix_len,
|
||||
tokenizer=tokenizer,
|
||||
return_prompt_formatted=True,
|
||||
)
|
||||
|
||||
elif args.dataset_name == "hf":
|
||||
# all following datasets are implemented from the
|
||||
# HuggingFaceDataset base class
|
||||
if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_class = VisionArenaDataset
|
||||
args.hf_split = "train"
|
||||
args.hf_subset = None
|
||||
elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_class = InstructCoderDataset
|
||||
args.hf_split = "train"
|
||||
elif args.dataset_path in MTBenchDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_class = MTBenchDataset
|
||||
args.hf_split = "train"
|
||||
elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_class = ConversationDataset
|
||||
args.hf_split = "train"
|
||||
elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_class = AIMODataset
|
||||
args.hf_split = "train"
|
||||
elif args.dataset_path in NextEditPredictionDataset.SUPPORTED_DATASET_PATHS: # noqa: E501
|
||||
dataset_class = NextEditPredictionDataset
|
||||
args.hf_split = "train"
|
||||
elif args.dataset_path in ASRDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_class = ASRDataset
|
||||
args.hf_split = "train"
|
||||
else:
|
||||
supported_datasets = set([
|
||||
dataset_name for cls in HuggingFaceDataset.__subclasses__()
|
||||
for dataset_name in cls.SUPPORTED_DATASET_PATHS
|
||||
])
|
||||
raise ValueError(
|
||||
f"Unsupported dataset path: {args.dataset_path}. "
|
||||
"Huggingface dataset only supports dataset_path"
|
||||
f" from one of following: {supported_datasets}. "
|
||||
"Please consider contributing if you would "
|
||||
"like to add support for additional dataset formats.")
|
||||
|
||||
if dataset_class.IS_MULTIMODAL and endpoint_type not in [
|
||||
"openai-chat",
|
||||
"openai-audio",
|
||||
]:
|
||||
# multi-modal benchmark is only available on OpenAI Chat backend.
|
||||
raise ValueError(
|
||||
"Multi-modal content is only supported on 'openai-chat' and "
|
||||
"'openai-audio' backend.")
|
||||
input_requests = dataset_class(
|
||||
dataset_path=args.dataset_path,
|
||||
dataset_subset=args.hf_subset,
|
||||
dataset_split=args.hf_split,
|
||||
random_seed=args.seed,
|
||||
).sample(
|
||||
num_requests=args.num_prompts,
|
||||
tokenizer=tokenizer,
|
||||
output_len=args.hf_output_len,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset: {args.dataset_name}")
|
||||
# For datasets that follow a similar structure, use a mapping.
|
||||
dataset_mapping = {
|
||||
"sharegpt":
|
||||
lambda: ShareGPTDataset(random_seed=args.seed,
|
||||
dataset_path=args.dataset_path).sample(
|
||||
tokenizer=tokenizer,
|
||||
num_requests=args.num_prompts,
|
||||
output_len=args.sharegpt_output_len,
|
||||
),
|
||||
"burstgpt":
|
||||
lambda: BurstGPTDataset(random_seed=args.seed,
|
||||
dataset_path=args.dataset_path).
|
||||
sample(tokenizer=tokenizer, num_requests=args.num_prompts),
|
||||
"random":
|
||||
lambda: RandomDataset(dataset_path=args.dataset_path).sample(
|
||||
tokenizer=tokenizer,
|
||||
num_requests=args.num_prompts,
|
||||
prefix_len=args.random_prefix_len,
|
||||
input_len=args.random_input_len,
|
||||
output_len=args.random_output_len,
|
||||
range_ratio=args.random_range_ratio,
|
||||
),
|
||||
}
|
||||
|
||||
try:
|
||||
input_requests = dataset_mapping[args.dataset_name]()
|
||||
except KeyError as err:
|
||||
raise ValueError(f"Unknown dataset: {args.dataset_name}") from err
|
||||
goodput_config_dict = check_goodput_args(args)
|
||||
|
||||
# Collect the sampling parameters.
|
||||
sampling_params = {
|
||||
k: v
|
||||
for k, v in {
|
||||
"top_p": args.top_p,
|
||||
"top_k": args.top_k,
|
||||
"min_p": args.min_p,
|
||||
"temperature": args.temperature,
|
||||
}.items() if v is not None
|
||||
}
|
||||
|
||||
# Sampling parameters are only supported by openai-compatible backend.
|
||||
if sampling_params and args.backend not in OPENAI_COMPATIBLE_BACKENDS:
|
||||
raise ValueError("Sampling parameters are only supported by "
|
||||
"openai-compatible backends.")
|
||||
|
||||
if "temperature" not in sampling_params:
|
||||
sampling_params["temperature"] = 0.0 # Default to greedy decoding.
|
||||
|
||||
# Avoid GC processing "static" data - reduce pause times.
|
||||
gc.collect()
|
||||
gc.freeze()
|
||||
@ -861,7 +1050,6 @@ def main(args: argparse.Namespace):
|
||||
tokenizer=tokenizer,
|
||||
input_requests=input_requests,
|
||||
logprobs=args.logprobs,
|
||||
best_of=args.best_of,
|
||||
request_rate=args.request_rate,
|
||||
burstiness=args.burstiness,
|
||||
disable_tqdm=args.disable_tqdm,
|
||||
@ -874,10 +1062,11 @@ def main(args: argparse.Namespace):
|
||||
goodput_config_dict=goodput_config_dict,
|
||||
max_concurrency=args.max_concurrency,
|
||||
lora_modules=args.lora_modules,
|
||||
extra_body=sampling_params,
|
||||
))
|
||||
|
||||
# Save config and results to json
|
||||
if args.save_result:
|
||||
if args.save_result or args.append_result:
|
||||
result_json: dict[str, Any] = {}
|
||||
|
||||
# Setup
|
||||
@ -887,7 +1076,6 @@ def main(args: argparse.Namespace):
|
||||
result_json["label"] = label
|
||||
result_json["model_id"] = model_id
|
||||
result_json["tokenizer_id"] = tokenizer_id
|
||||
result_json["best_of"] = args.best_of
|
||||
result_json["num_prompts"] = args.num_prompts
|
||||
|
||||
# Metadata
|
||||
@ -910,6 +1098,19 @@ def main(args: argparse.Namespace):
|
||||
# Merge with benchmark result
|
||||
result_json = {**result_json, **benchmark_result}
|
||||
|
||||
if not args.save_detailed:
|
||||
# Remove fields with too many data points
|
||||
for field in [
|
||||
"input_lens",
|
||||
"output_lens",
|
||||
"ttfts",
|
||||
"itls",
|
||||
"generated_texts",
|
||||
"errors",
|
||||
]:
|
||||
if field in result_json:
|
||||
del result_json[field]
|
||||
|
||||
# Save to file
|
||||
base_model_id = model_id.split("/")[-1]
|
||||
max_concurrency_str = (f"-concurrency{args.max_concurrency}"
|
||||
@ -920,6 +1121,11 @@ def main(args: argparse.Namespace):
|
||||
file_name = args.result_filename
|
||||
if args.result_dir:
|
||||
file_name = os.path.join(args.result_dir, file_name)
|
||||
with open(file_name, "w", encoding='utf-8') as outfile:
|
||||
with open(file_name,
|
||||
mode="a+" if args.append_result else "w",
|
||||
encoding="utf-8") as outfile:
|
||||
# Append a newline.
|
||||
if args.append_result and outfile.tell() != 0:
|
||||
outfile.write("\n")
|
||||
json.dump(result_json, outfile)
|
||||
save_to_pytorch_benchmark_format(args, result_json, file_name)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user