mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:05:01 +08:00
[Misc] Benchmarks for audio models (#16505)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
parent
2ef0dc53b8
commit
9d4ca19d50
@ -1,5 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import io
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
@ -32,6 +33,7 @@ class RequestFuncInput:
|
|||||||
extra_body: Optional[dict] = None
|
extra_body: Optional[dict] = None
|
||||||
multi_modal_content: Optional[dict] = None
|
multi_modal_content: Optional[dict] = None
|
||||||
ignore_eos: bool = False
|
ignore_eos: bool = False
|
||||||
|
language: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -436,6 +438,110 @@ async def async_request_openai_chat_completions(
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
async def async_request_openai_audio(
|
||||||
|
request_func_input: RequestFuncInput,
|
||||||
|
pbar: Optional[tqdm] = None,
|
||||||
|
) -> RequestFuncOutput:
|
||||||
|
# Lazy import without PlaceholderModule to avoid vllm dep.
|
||||||
|
import soundfile
|
||||||
|
api_url = request_func_input.api_url
|
||||||
|
assert api_url.endswith(
|
||||||
|
("transcriptions", "translations"
|
||||||
|
)), "OpenAI Chat Completions API URL must end with 'transcriptions' "
|
||||||
|
"or `translations`."
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession(trust_env=True,
|
||||||
|
timeout=AIOHTTP_TIMEOUT) as session:
|
||||||
|
content = [{"type": "text", "text": request_func_input.prompt}]
|
||||||
|
payload = {
|
||||||
|
"model": request_func_input.model_name \
|
||||||
|
if request_func_input.model_name else request_func_input.model,
|
||||||
|
"temperature": 0.0,
|
||||||
|
"max_completion_tokens": request_func_input.output_len,
|
||||||
|
"stream": True,
|
||||||
|
"language": "en",
|
||||||
|
# Flattened due to multipart/form-data
|
||||||
|
"stream_include_usage": True,
|
||||||
|
"stream_continuous_usage_stats": True
|
||||||
|
}
|
||||||
|
if request_func_input.extra_body:
|
||||||
|
payload.update(request_func_input.extra_body)
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Send audio file
|
||||||
|
def to_bytes(y, sr):
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
soundfile.write(buffer, y, sr, format="WAV")
|
||||||
|
buffer.seek(0)
|
||||||
|
return buffer
|
||||||
|
|
||||||
|
with to_bytes(*request_func_input.multi_modal_content['audio']) as f:
|
||||||
|
form = aiohttp.FormData()
|
||||||
|
form.add_field('file', f, content_type='audio/wav')
|
||||||
|
for key, value in payload.items():
|
||||||
|
form.add_field(key, str(value))
|
||||||
|
|
||||||
|
output = RequestFuncOutput()
|
||||||
|
output.prompt_len = request_func_input.prompt_len
|
||||||
|
|
||||||
|
generated_text = ""
|
||||||
|
ttft = 0.0
|
||||||
|
st = time.perf_counter()
|
||||||
|
most_recent_timestamp = st
|
||||||
|
try:
|
||||||
|
async with session.post(url=api_url,
|
||||||
|
data=form,
|
||||||
|
headers=headers) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
async for chunk_bytes in response.content:
|
||||||
|
chunk_bytes = chunk_bytes.strip()
|
||||||
|
if not chunk_bytes:
|
||||||
|
continue
|
||||||
|
|
||||||
|
chunk = chunk_bytes.decode("utf-8").removeprefix(
|
||||||
|
"data: ")
|
||||||
|
if chunk != "[DONE]":
|
||||||
|
timestamp = time.perf_counter()
|
||||||
|
data = json.loads(chunk)
|
||||||
|
|
||||||
|
if choices := data.get("choices"):
|
||||||
|
content = choices[0]["delta"].get(
|
||||||
|
"content")
|
||||||
|
# First token
|
||||||
|
if ttft == 0.0:
|
||||||
|
ttft = timestamp - st
|
||||||
|
output.ttft = ttft
|
||||||
|
|
||||||
|
# Decoding phase
|
||||||
|
else:
|
||||||
|
output.itl.append(
|
||||||
|
timestamp - most_recent_timestamp)
|
||||||
|
|
||||||
|
generated_text += content or ""
|
||||||
|
elif usage := data.get("usage"):
|
||||||
|
output.output_tokens = usage.get(
|
||||||
|
"completion_tokens")
|
||||||
|
|
||||||
|
most_recent_timestamp = timestamp
|
||||||
|
|
||||||
|
output.generated_text = generated_text
|
||||||
|
output.success = True
|
||||||
|
output.latency = most_recent_timestamp - st
|
||||||
|
else:
|
||||||
|
output.error = response.reason or ""
|
||||||
|
output.success = False
|
||||||
|
except Exception:
|
||||||
|
output.success = False
|
||||||
|
exc_info = sys.exc_info()
|
||||||
|
output.error = "".join(traceback.format_exception(*exc_info))
|
||||||
|
|
||||||
|
if pbar:
|
||||||
|
pbar.update(1)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
def get_model(pretrained_model_name_or_path: str) -> str:
|
def get_model(pretrained_model_name_or_path: str) -> str:
|
||||||
if os.getenv('VLLM_USE_MODELSCOPE', 'False').lower() == 'true':
|
if os.getenv('VLLM_USE_MODELSCOPE', 'False').lower() == 'true':
|
||||||
from modelscope import snapshot_download
|
from modelscope import snapshot_download
|
||||||
@ -493,6 +599,7 @@ ASYNC_REQUEST_FUNCS = {
|
|||||||
"deepspeed-mii": async_request_deepspeed_mii,
|
"deepspeed-mii": async_request_deepspeed_mii,
|
||||||
"openai": async_request_openai_completions,
|
"openai": async_request_openai_completions,
|
||||||
"openai-chat": async_request_openai_chat_completions,
|
"openai-chat": async_request_openai_chat_completions,
|
||||||
|
"openai-audio": async_request_openai_audio,
|
||||||
"tensorrt-llm": async_request_trt_llm,
|
"tensorrt-llm": async_request_trt_llm,
|
||||||
"scalellm": async_request_openai_completions,
|
"scalellm": async_request_openai_completions,
|
||||||
"sglang": async_request_openai_completions,
|
"sglang": async_request_openai_completions,
|
||||||
|
|||||||
@ -64,6 +64,7 @@ class SampleRequest:
|
|||||||
|
|
||||||
class BenchmarkDataset(ABC):
|
class BenchmarkDataset(ABC):
|
||||||
DEFAULT_SEED = 0
|
DEFAULT_SEED = 0
|
||||||
|
IS_MULTIMODAL = False
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -621,6 +622,7 @@ class ConversationDataset(HuggingFaceDataset):
|
|||||||
SUPPORTED_DATASET_PATHS = {
|
SUPPORTED_DATASET_PATHS = {
|
||||||
'lmms-lab/LLaVA-OneVision-Data', 'Aeala/ShareGPT_Vicuna_unfiltered'
|
'lmms-lab/LLaVA-OneVision-Data', 'Aeala/ShareGPT_Vicuna_unfiltered'
|
||||||
}
|
}
|
||||||
|
IS_MULTIMODAL = True
|
||||||
|
|
||||||
def sample(self,
|
def sample(self,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
@ -685,6 +687,7 @@ class VisionArenaDataset(HuggingFaceDataset):
|
|||||||
"lmarena-ai/vision-arena-bench-v0.1":
|
"lmarena-ai/vision-arena-bench-v0.1":
|
||||||
lambda x: x["turns"][0][0]["content"]
|
lambda x: x["turns"][0][0]["content"]
|
||||||
}
|
}
|
||||||
|
IS_MULTIMODAL = True
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
@ -815,3 +818,80 @@ class AIMODataset(HuggingFaceDataset):
|
|||||||
))
|
))
|
||||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||||
return sampled_requests
|
return sampled_requests
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# ASR Dataset Implementation
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class ASRDataset(HuggingFaceDataset):
|
||||||
|
"""
|
||||||
|
Dataset class for processing a ASR dataset for transcription.
|
||||||
|
Tested on the following set:
|
||||||
|
|
||||||
|
+----------------+----------------------------------------+--------------------------+-----------------------------+
|
||||||
|
| Dataset | Domain | Speaking Style | hf-subset |
|
||||||
|
+----------------+----------------------------------------+--------------------------+-----------------------------+
|
||||||
|
| TED-LIUM | TED talks | Oratory | release1, release2, release3|
|
||||||
|
| | | | release3-speaker-adaptation |
|
||||||
|
| VoxPopuli | European Parliament | Oratory | en, de, it, fr, ... |
|
||||||
|
| LibriSpeech | Audiobook | Narrated | "LIUM/tedlium" |
|
||||||
|
| GigaSpeech | Audiobook, podcast, YouTube | Narrated, spontaneous | xs, s, m, l, xl, dev, test |
|
||||||
|
| SPGISpeech | Financial meetings | Oratory, spontaneous | S, M, L, dev, test |
|
||||||
|
| AMI | Meetings | Spontaneous | ihm, sdm |
|
||||||
|
+----------------+----------------------------------------+--------------------------+-----------------------------+
|
||||||
|
|
||||||
|
""" # noqa: E501
|
||||||
|
SUPPORTED_DATASET_PATHS = {
|
||||||
|
"openslr/librispeech_asr", "facebook/voxpopuli", "LIUM/tedlium",
|
||||||
|
"edinburghcstr/ami", "speechcolab/gigaspeech", "kensho/spgispeech"
|
||||||
|
}
|
||||||
|
|
||||||
|
DEFAULT_OUTPUT_LEN = 128
|
||||||
|
IS_MULTIMODAL = True
|
||||||
|
|
||||||
|
# TODO Whisper-specific. Abstract interface when more models are supported.
|
||||||
|
TRANSCRIPTION_PREAMBLE = "<|startoftranscript|><|en|><|transcribe|>"\
|
||||||
|
"<|notimestamps|>"
|
||||||
|
skip_long_audios: bool = True
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
num_requests: int,
|
||||||
|
output_len: Optional[int] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> list:
|
||||||
|
import librosa
|
||||||
|
output_len = (output_len
|
||||||
|
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
|
||||||
|
prompt = ASRDataset.TRANSCRIPTION_PREAMBLE
|
||||||
|
prompt_len = len(tokenizer(prompt).input_ids)
|
||||||
|
sampled_requests = []
|
||||||
|
skipped = 0
|
||||||
|
for item in self.data:
|
||||||
|
if len(sampled_requests) >= num_requests:
|
||||||
|
break
|
||||||
|
audio = item["audio"]
|
||||||
|
y, sr = audio["array"], audio["sampling_rate"]
|
||||||
|
duration_s = librosa.get_duration(y=y, sr=sr)
|
||||||
|
# Whisper max supported duration
|
||||||
|
if self.skip_long_audios and duration_s > 30:
|
||||||
|
skipped += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
mm_content = {"audio": (y, sr)}
|
||||||
|
sampled_requests.append(
|
||||||
|
SampleRequest(
|
||||||
|
prompt=prompt,
|
||||||
|
prompt_len=prompt_len,
|
||||||
|
expected_output_len=output_len,
|
||||||
|
multi_modal_data=mm_content,
|
||||||
|
))
|
||||||
|
if skipped:
|
||||||
|
logger.warning("%d samples discarded from dataset due to" \
|
||||||
|
" their length being greater than" \
|
||||||
|
" what Whisper supports.", skipped)
|
||||||
|
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||||
|
return sampled_requests
|
||||||
|
|||||||
@ -50,7 +50,7 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
from argparse import ArgumentParser as FlexibleArgumentParser
|
from argparse import ArgumentParser as FlexibleArgumentParser
|
||||||
|
|
||||||
from benchmark_dataset import (AIMODataset, BurstGPTDataset,
|
from benchmark_dataset import (AIMODataset, ASRDataset, BurstGPTDataset,
|
||||||
ConversationDataset, HuggingFaceDataset,
|
ConversationDataset, HuggingFaceDataset,
|
||||||
InstructCoderDataset, RandomDataset,
|
InstructCoderDataset, RandomDataset,
|
||||||
SampleRequest, ShareGPTDataset, SonnetDataset,
|
SampleRequest, ShareGPTDataset, SonnetDataset,
|
||||||
@ -274,10 +274,6 @@ async def benchmark(
|
|||||||
input_requests[0].expected_output_len, \
|
input_requests[0].expected_output_len, \
|
||||||
input_requests[0].multi_modal_data
|
input_requests[0].multi_modal_data
|
||||||
|
|
||||||
if backend != "openai-chat" and test_mm_content is not None:
|
|
||||||
# multi-modal benchmark is only available on OpenAI Chat backend.
|
|
||||||
raise ValueError(
|
|
||||||
"Multi-modal content is only supported on 'openai-chat' backend.")
|
|
||||||
assert test_mm_content is None or isinstance(test_mm_content, dict)
|
assert test_mm_content is None or isinstance(test_mm_content, dict)
|
||||||
test_input = RequestFuncInput(
|
test_input = RequestFuncInput(
|
||||||
model=model_id,
|
model=model_id,
|
||||||
@ -604,6 +600,9 @@ def main(args: argparse.Namespace):
|
|||||||
elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS:
|
elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS:
|
||||||
dataset_class = AIMODataset
|
dataset_class = AIMODataset
|
||||||
args.hf_split = "train"
|
args.hf_split = "train"
|
||||||
|
elif args.dataset_path in ASRDataset.SUPPORTED_DATASET_PATHS:
|
||||||
|
dataset_class = ASRDataset
|
||||||
|
args.hf_split = "train"
|
||||||
else:
|
else:
|
||||||
supported_datasets = set([
|
supported_datasets = set([
|
||||||
dataset_name for cls in HuggingFaceDataset.__subclasses__()
|
dataset_name for cls in HuggingFaceDataset.__subclasses__()
|
||||||
@ -615,6 +614,13 @@ def main(args: argparse.Namespace):
|
|||||||
f" from one of following: {supported_datasets}. "
|
f" from one of following: {supported_datasets}. "
|
||||||
"Please consider contributing if you would "
|
"Please consider contributing if you would "
|
||||||
"like to add support for additional dataset formats.")
|
"like to add support for additional dataset formats.")
|
||||||
|
|
||||||
|
if (dataset_class.IS_MULTIMODAL and backend not in \
|
||||||
|
["openai-chat", "openai-audio"]):
|
||||||
|
# multi-modal benchmark is only available on OpenAI Chat backend.
|
||||||
|
raise ValueError(
|
||||||
|
"Multi-modal content is only supported on 'openai-chat' and " \
|
||||||
|
"'openai-audio' backend.")
|
||||||
input_requests = dataset_class(
|
input_requests = dataset_class(
|
||||||
dataset_path=args.dataset_path,
|
dataset_path=args.dataset_path,
|
||||||
dataset_subset=args.hf_subset,
|
dataset_subset=args.hf_subset,
|
||||||
|
|||||||
@ -150,6 +150,7 @@ def test_wer_correctness(model_name,
|
|||||||
expected_wer,
|
expected_wer,
|
||||||
n_examples=-1,
|
n_examples=-1,
|
||||||
max_concurrent_request=None):
|
max_concurrent_request=None):
|
||||||
|
# TODO refactor to use `ASRDataset`
|
||||||
with RemoteOpenAIServer(model_name, ['--enforce-eager']) as remote_server:
|
with RemoteOpenAIServer(model_name, ['--enforce-eager']) as remote_server:
|
||||||
dataset = load_hf_dataset(dataset_repo)
|
dataset = load_hf_dataset(dataset_repo)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user