mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:15:42 +08:00
[Misc] Benchmarks for audio models (#16505)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
parent
2ef0dc53b8
commit
9d4ca19d50
@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
@ -32,6 +33,7 @@ class RequestFuncInput:
|
||||
extra_body: Optional[dict] = None
|
||||
multi_modal_content: Optional[dict] = None
|
||||
ignore_eos: bool = False
|
||||
language: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -436,6 +438,110 @@ async def async_request_openai_chat_completions(
|
||||
return output
|
||||
|
||||
|
||||
async def async_request_openai_audio(
|
||||
request_func_input: RequestFuncInput,
|
||||
pbar: Optional[tqdm] = None,
|
||||
) -> RequestFuncOutput:
|
||||
# Lazy import without PlaceholderModule to avoid vllm dep.
|
||||
import soundfile
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith(
|
||||
("transcriptions", "translations"
|
||||
)), "OpenAI Chat Completions API URL must end with 'transcriptions' "
|
||||
"or `translations`."
|
||||
|
||||
async with aiohttp.ClientSession(trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT) as session:
|
||||
content = [{"type": "text", "text": request_func_input.prompt}]
|
||||
payload = {
|
||||
"model": request_func_input.model_name \
|
||||
if request_func_input.model_name else request_func_input.model,
|
||||
"temperature": 0.0,
|
||||
"max_completion_tokens": request_func_input.output_len,
|
||||
"stream": True,
|
||||
"language": "en",
|
||||
# Flattened due to multipart/form-data
|
||||
"stream_include_usage": True,
|
||||
"stream_continuous_usage_stats": True
|
||||
}
|
||||
if request_func_input.extra_body:
|
||||
payload.update(request_func_input.extra_body)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
}
|
||||
|
||||
# Send audio file
|
||||
def to_bytes(y, sr):
|
||||
buffer = io.BytesIO()
|
||||
soundfile.write(buffer, y, sr, format="WAV")
|
||||
buffer.seek(0)
|
||||
return buffer
|
||||
|
||||
with to_bytes(*request_func_input.multi_modal_content['audio']) as f:
|
||||
form = aiohttp.FormData()
|
||||
form.add_field('file', f, content_type='audio/wav')
|
||||
for key, value in payload.items():
|
||||
form.add_field(key, str(value))
|
||||
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
|
||||
generated_text = ""
|
||||
ttft = 0.0
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(url=api_url,
|
||||
data=form,
|
||||
headers=headers) as response:
|
||||
if response.status == 200:
|
||||
async for chunk_bytes in response.content:
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
|
||||
chunk = chunk_bytes.decode("utf-8").removeprefix(
|
||||
"data: ")
|
||||
if chunk != "[DONE]":
|
||||
timestamp = time.perf_counter()
|
||||
data = json.loads(chunk)
|
||||
|
||||
if choices := data.get("choices"):
|
||||
content = choices[0]["delta"].get(
|
||||
"content")
|
||||
# First token
|
||||
if ttft == 0.0:
|
||||
ttft = timestamp - st
|
||||
output.ttft = ttft
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(
|
||||
timestamp - most_recent_timestamp)
|
||||
|
||||
generated_text += content or ""
|
||||
elif usage := data.get("usage"):
|
||||
output.output_tokens = usage.get(
|
||||
"completion_tokens")
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
|
||||
output.generated_text = generated_text
|
||||
output.success = True
|
||||
output.latency = most_recent_timestamp - st
|
||||
else:
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
except Exception:
|
||||
output.success = False
|
||||
exc_info = sys.exc_info()
|
||||
output.error = "".join(traceback.format_exception(*exc_info))
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
return output
|
||||
|
||||
|
||||
def get_model(pretrained_model_name_or_path: str) -> str:
|
||||
if os.getenv('VLLM_USE_MODELSCOPE', 'False').lower() == 'true':
|
||||
from modelscope import snapshot_download
|
||||
@ -493,6 +599,7 @@ ASYNC_REQUEST_FUNCS = {
|
||||
"deepspeed-mii": async_request_deepspeed_mii,
|
||||
"openai": async_request_openai_completions,
|
||||
"openai-chat": async_request_openai_chat_completions,
|
||||
"openai-audio": async_request_openai_audio,
|
||||
"tensorrt-llm": async_request_trt_llm,
|
||||
"scalellm": async_request_openai_completions,
|
||||
"sglang": async_request_openai_completions,
|
||||
|
||||
@ -64,6 +64,7 @@ class SampleRequest:
|
||||
|
||||
class BenchmarkDataset(ABC):
|
||||
DEFAULT_SEED = 0
|
||||
IS_MULTIMODAL = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -621,6 +622,7 @@ class ConversationDataset(HuggingFaceDataset):
|
||||
SUPPORTED_DATASET_PATHS = {
|
||||
'lmms-lab/LLaVA-OneVision-Data', 'Aeala/ShareGPT_Vicuna_unfiltered'
|
||||
}
|
||||
IS_MULTIMODAL = True
|
||||
|
||||
def sample(self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
@ -685,6 +687,7 @@ class VisionArenaDataset(HuggingFaceDataset):
|
||||
"lmarena-ai/vision-arena-bench-v0.1":
|
||||
lambda x: x["turns"][0][0]["content"]
|
||||
}
|
||||
IS_MULTIMODAL = True
|
||||
|
||||
def sample(
|
||||
self,
|
||||
@ -815,3 +818,80 @@ class AIMODataset(HuggingFaceDataset):
|
||||
))
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||
return sampled_requests
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# ASR Dataset Implementation
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ASRDataset(HuggingFaceDataset):
|
||||
"""
|
||||
Dataset class for processing a ASR dataset for transcription.
|
||||
Tested on the following set:
|
||||
|
||||
+----------------+----------------------------------------+--------------------------+-----------------------------+
|
||||
| Dataset | Domain | Speaking Style | hf-subset |
|
||||
+----------------+----------------------------------------+--------------------------+-----------------------------+
|
||||
| TED-LIUM | TED talks | Oratory | release1, release2, release3|
|
||||
| | | | release3-speaker-adaptation |
|
||||
| VoxPopuli | European Parliament | Oratory | en, de, it, fr, ... |
|
||||
| LibriSpeech | Audiobook | Narrated | "LIUM/tedlium" |
|
||||
| GigaSpeech | Audiobook, podcast, YouTube | Narrated, spontaneous | xs, s, m, l, xl, dev, test |
|
||||
| SPGISpeech | Financial meetings | Oratory, spontaneous | S, M, L, dev, test |
|
||||
| AMI | Meetings | Spontaneous | ihm, sdm |
|
||||
+----------------+----------------------------------------+--------------------------+-----------------------------+
|
||||
|
||||
""" # noqa: E501
|
||||
SUPPORTED_DATASET_PATHS = {
|
||||
"openslr/librispeech_asr", "facebook/voxpopuli", "LIUM/tedlium",
|
||||
"edinburghcstr/ami", "speechcolab/gigaspeech", "kensho/spgispeech"
|
||||
}
|
||||
|
||||
DEFAULT_OUTPUT_LEN = 128
|
||||
IS_MULTIMODAL = True
|
||||
|
||||
# TODO Whisper-specific. Abstract interface when more models are supported.
|
||||
TRANSCRIPTION_PREAMBLE = "<|startoftranscript|><|en|><|transcribe|>"\
|
||||
"<|notimestamps|>"
|
||||
skip_long_audios: bool = True
|
||||
|
||||
def sample(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
output_len: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> list:
|
||||
import librosa
|
||||
output_len = (output_len
|
||||
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
|
||||
prompt = ASRDataset.TRANSCRIPTION_PREAMBLE
|
||||
prompt_len = len(tokenizer(prompt).input_ids)
|
||||
sampled_requests = []
|
||||
skipped = 0
|
||||
for item in self.data:
|
||||
if len(sampled_requests) >= num_requests:
|
||||
break
|
||||
audio = item["audio"]
|
||||
y, sr = audio["array"], audio["sampling_rate"]
|
||||
duration_s = librosa.get_duration(y=y, sr=sr)
|
||||
# Whisper max supported duration
|
||||
if self.skip_long_audios and duration_s > 30:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
mm_content = {"audio": (y, sr)}
|
||||
sampled_requests.append(
|
||||
SampleRequest(
|
||||
prompt=prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
multi_modal_data=mm_content,
|
||||
))
|
||||
if skipped:
|
||||
logger.warning("%d samples discarded from dataset due to" \
|
||||
" their length being greater than" \
|
||||
" what Whisper supports.", skipped)
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||
return sampled_requests
|
||||
|
||||
@ -50,7 +50,7 @@ try:
|
||||
except ImportError:
|
||||
from argparse import ArgumentParser as FlexibleArgumentParser
|
||||
|
||||
from benchmark_dataset import (AIMODataset, BurstGPTDataset,
|
||||
from benchmark_dataset import (AIMODataset, ASRDataset, BurstGPTDataset,
|
||||
ConversationDataset, HuggingFaceDataset,
|
||||
InstructCoderDataset, RandomDataset,
|
||||
SampleRequest, ShareGPTDataset, SonnetDataset,
|
||||
@ -274,10 +274,6 @@ async def benchmark(
|
||||
input_requests[0].expected_output_len, \
|
||||
input_requests[0].multi_modal_data
|
||||
|
||||
if backend != "openai-chat" and test_mm_content is not None:
|
||||
# multi-modal benchmark is only available on OpenAI Chat backend.
|
||||
raise ValueError(
|
||||
"Multi-modal content is only supported on 'openai-chat' backend.")
|
||||
assert test_mm_content is None or isinstance(test_mm_content, dict)
|
||||
test_input = RequestFuncInput(
|
||||
model=model_id,
|
||||
@ -604,6 +600,9 @@ def main(args: argparse.Namespace):
|
||||
elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_class = AIMODataset
|
||||
args.hf_split = "train"
|
||||
elif args.dataset_path in ASRDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_class = ASRDataset
|
||||
args.hf_split = "train"
|
||||
else:
|
||||
supported_datasets = set([
|
||||
dataset_name for cls in HuggingFaceDataset.__subclasses__()
|
||||
@ -615,6 +614,13 @@ def main(args: argparse.Namespace):
|
||||
f" from one of following: {supported_datasets}. "
|
||||
"Please consider contributing if you would "
|
||||
"like to add support for additional dataset formats.")
|
||||
|
||||
if (dataset_class.IS_MULTIMODAL and backend not in \
|
||||
["openai-chat", "openai-audio"]):
|
||||
# multi-modal benchmark is only available on OpenAI Chat backend.
|
||||
raise ValueError(
|
||||
"Multi-modal content is only supported on 'openai-chat' and " \
|
||||
"'openai-audio' backend.")
|
||||
input_requests = dataset_class(
|
||||
dataset_path=args.dataset_path,
|
||||
dataset_subset=args.hf_subset,
|
||||
|
||||
@ -150,6 +150,7 @@ def test_wer_correctness(model_name,
|
||||
expected_wer,
|
||||
n_examples=-1,
|
||||
max_concurrent_request=None):
|
||||
# TODO refactor to use `ASRDataset`
|
||||
with RemoteOpenAIServer(model_name, ['--enforce-eager']) as remote_server:
|
||||
dataset = load_hf_dataset(dataset_repo)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user