diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index 287d500a81de..efd51c79c37c 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 +import io import json import os import sys @@ -32,6 +33,7 @@ class RequestFuncInput: extra_body: Optional[dict] = None multi_modal_content: Optional[dict] = None ignore_eos: bool = False + language: Optional[str] = None @dataclass @@ -436,6 +438,110 @@ async def async_request_openai_chat_completions( return output +async def async_request_openai_audio( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + # Lazy import without PlaceholderModule to avoid vllm dep. + import soundfile + api_url = request_func_input.api_url + assert api_url.endswith( + ("transcriptions", "translations" + )), "OpenAI Chat Completions API URL must end with 'transcriptions' " + "or `translations`." + + async with aiohttp.ClientSession(trust_env=True, + timeout=AIOHTTP_TIMEOUT) as session: + content = [{"type": "text", "text": request_func_input.prompt}] + payload = { + "model": request_func_input.model_name \ + if request_func_input.model_name else request_func_input.model, + "temperature": 0.0, + "max_completion_tokens": request_func_input.output_len, + "stream": True, + "language": "en", + # Flattened due to multipart/form-data + "stream_include_usage": True, + "stream_continuous_usage_stats": True + } + if request_func_input.extra_body: + payload.update(request_func_input.extra_body) + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + } + + # Send audio file + def to_bytes(y, sr): + buffer = io.BytesIO() + soundfile.write(buffer, y, sr, format="WAV") + buffer.seek(0) + return buffer + + with to_bytes(*request_func_input.multi_modal_content['audio']) as f: + form = aiohttp.FormData() + form.add_field('file', f, content_type='audio/wav') + for key, value in payload.items(): + form.add_field(key, str(value)) + + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + generated_text = "" + ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post(url=api_url, + data=form, + headers=headers) as response: + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = chunk_bytes.decode("utf-8").removeprefix( + "data: ") + if chunk != "[DONE]": + timestamp = time.perf_counter() + data = json.loads(chunk) + + if choices := data.get("choices"): + content = choices[0]["delta"].get( + "content") + # First token + if ttft == 0.0: + ttft = timestamp - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append( + timestamp - most_recent_timestamp) + + generated_text += content or "" + elif usage := data.get("usage"): + output.output_tokens = usage.get( + "completion_tokens") + + most_recent_timestamp = timestamp + + output.generated_text = generated_text + output.success = True + output.latency = most_recent_timestamp - st + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + def get_model(pretrained_model_name_or_path: str) -> str: if os.getenv('VLLM_USE_MODELSCOPE', 'False').lower() == 'true': from modelscope import snapshot_download @@ -493,6 +599,7 @@ ASYNC_REQUEST_FUNCS = { "deepspeed-mii": async_request_deepspeed_mii, "openai": async_request_openai_completions, "openai-chat": async_request_openai_chat_completions, + "openai-audio": async_request_openai_audio, "tensorrt-llm": async_request_trt_llm, "scalellm": async_request_openai_completions, "sglang": async_request_openai_completions, diff --git a/benchmarks/benchmark_dataset.py b/benchmarks/benchmark_dataset.py index 63f174275d47..ccbc6c022f1f 100644 --- a/benchmarks/benchmark_dataset.py +++ b/benchmarks/benchmark_dataset.py @@ -64,6 +64,7 @@ class SampleRequest: class BenchmarkDataset(ABC): DEFAULT_SEED = 0 + IS_MULTIMODAL = False def __init__( self, @@ -621,6 +622,7 @@ class ConversationDataset(HuggingFaceDataset): SUPPORTED_DATASET_PATHS = { 'lmms-lab/LLaVA-OneVision-Data', 'Aeala/ShareGPT_Vicuna_unfiltered' } + IS_MULTIMODAL = True def sample(self, tokenizer: PreTrainedTokenizerBase, @@ -685,6 +687,7 @@ class VisionArenaDataset(HuggingFaceDataset): "lmarena-ai/vision-arena-bench-v0.1": lambda x: x["turns"][0][0]["content"] } + IS_MULTIMODAL = True def sample( self, @@ -815,3 +818,80 @@ class AIMODataset(HuggingFaceDataset): )) self.maybe_oversample_requests(sampled_requests, num_requests) return sampled_requests + + +# ----------------------------------------------------------------------------- +# ASR Dataset Implementation +# ----------------------------------------------------------------------------- + + +class ASRDataset(HuggingFaceDataset): + """ + Dataset class for processing a ASR dataset for transcription. + Tested on the following set: + + +----------------+----------------------------------------+--------------------------+-----------------------------+ + | Dataset | Domain | Speaking Style | hf-subset | + +----------------+----------------------------------------+--------------------------+-----------------------------+ + | TED-LIUM | TED talks | Oratory | release1, release2, release3| + | | | | release3-speaker-adaptation | + | VoxPopuli | European Parliament | Oratory | en, de, it, fr, ... | + | LibriSpeech | Audiobook | Narrated | "LIUM/tedlium" | + | GigaSpeech | Audiobook, podcast, YouTube | Narrated, spontaneous | xs, s, m, l, xl, dev, test | + | SPGISpeech | Financial meetings | Oratory, spontaneous | S, M, L, dev, test | + | AMI | Meetings | Spontaneous | ihm, sdm | + +----------------+----------------------------------------+--------------------------+-----------------------------+ + + """ # noqa: E501 + SUPPORTED_DATASET_PATHS = { + "openslr/librispeech_asr", "facebook/voxpopuli", "LIUM/tedlium", + "edinburghcstr/ami", "speechcolab/gigaspeech", "kensho/spgispeech" + } + + DEFAULT_OUTPUT_LEN = 128 + IS_MULTIMODAL = True + + # TODO Whisper-specific. Abstract interface when more models are supported. + TRANSCRIPTION_PREAMBLE = "<|startoftranscript|><|en|><|transcribe|>"\ + "<|notimestamps|>" + skip_long_audios: bool = True + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + **kwargs, + ) -> list: + import librosa + output_len = (output_len + if output_len is not None else self.DEFAULT_OUTPUT_LEN) + prompt = ASRDataset.TRANSCRIPTION_PREAMBLE + prompt_len = len(tokenizer(prompt).input_ids) + sampled_requests = [] + skipped = 0 + for item in self.data: + if len(sampled_requests) >= num_requests: + break + audio = item["audio"] + y, sr = audio["array"], audio["sampling_rate"] + duration_s = librosa.get_duration(y=y, sr=sr) + # Whisper max supported duration + if self.skip_long_audios and duration_s > 30: + skipped += 1 + continue + + mm_content = {"audio": (y, sr)} + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + multi_modal_data=mm_content, + )) + if skipped: + logger.warning("%d samples discarded from dataset due to" \ + " their length being greater than" \ + " what Whisper supports.", skipped) + self.maybe_oversample_requests(sampled_requests, num_requests) + return sampled_requests diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index b5bd840d8410..bc125f6b73dc 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -50,7 +50,7 @@ try: except ImportError: from argparse import ArgumentParser as FlexibleArgumentParser -from benchmark_dataset import (AIMODataset, BurstGPTDataset, +from benchmark_dataset import (AIMODataset, ASRDataset, BurstGPTDataset, ConversationDataset, HuggingFaceDataset, InstructCoderDataset, RandomDataset, SampleRequest, ShareGPTDataset, SonnetDataset, @@ -274,10 +274,6 @@ async def benchmark( input_requests[0].expected_output_len, \ input_requests[0].multi_modal_data - if backend != "openai-chat" and test_mm_content is not None: - # multi-modal benchmark is only available on OpenAI Chat backend. - raise ValueError( - "Multi-modal content is only supported on 'openai-chat' backend.") assert test_mm_content is None or isinstance(test_mm_content, dict) test_input = RequestFuncInput( model=model_id, @@ -604,6 +600,9 @@ def main(args: argparse.Namespace): elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS: dataset_class = AIMODataset args.hf_split = "train" + elif args.dataset_path in ASRDataset.SUPPORTED_DATASET_PATHS: + dataset_class = ASRDataset + args.hf_split = "train" else: supported_datasets = set([ dataset_name for cls in HuggingFaceDataset.__subclasses__() @@ -615,6 +614,13 @@ def main(args: argparse.Namespace): f" from one of following: {supported_datasets}. " "Please consider contributing if you would " "like to add support for additional dataset formats.") + + if (dataset_class.IS_MULTIMODAL and backend not in \ + ["openai-chat", "openai-audio"]): + # multi-modal benchmark is only available on OpenAI Chat backend. + raise ValueError( + "Multi-modal content is only supported on 'openai-chat' and " \ + "'openai-audio' backend.") input_requests = dataset_class( dataset_path=args.dataset_path, dataset_subset=args.hf_subset, diff --git a/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py b/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py index eca5d184f5d6..642c204b9ff0 100644 --- a/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py +++ b/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py @@ -150,6 +150,7 @@ def test_wer_correctness(model_name, expected_wer, n_examples=-1, max_concurrent_request=None): + # TODO refactor to use `ASRDataset` with RemoteOpenAIServer(model_name, ['--enforce-eager']) as remote_server: dataset = load_hf_dataset(dataset_repo)