[Frontend] add chunking audio for > 30s audio (#19597)

Signed-off-by: nguyenhoangthuan99 <thuanhppro12@gmail.com>
This commit is contained in:
nguyenhoangthuan99 2025-06-17 10:34:00 +07:00 committed by GitHub
parent 07334959d8
commit ede5c4ebdf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 161 additions and 88 deletions

View File

@ -74,19 +74,29 @@ async def test_bad_requests(mary_had_lamb):
language="hh", language="hh",
temperature=0.0) temperature=0.0)
# Expect audio too long: repeat the timeseries
mary_had_lamb.seek(0) @pytest.mark.asyncio
audio, sr = librosa.load(mary_had_lamb) async def test_long_audio_request(mary_had_lamb):
repeated_audio = np.tile(audio, 10) model_name = "openai/whisper-large-v3-turbo"
# Repeated audio to buffer server_args = ["--enforce-eager"]
buffer = io.BytesIO()
sf.write(buffer, repeated_audio, sr, format='WAV') mary_had_lamb.seek(0)
buffer.seek(0) audio, sr = librosa.load(mary_had_lamb)
with pytest.raises(openai.BadRequestError): repeated_audio = np.tile(audio, 10)
await client.audio.transcriptions.create(model=model_name, # Repeated audio to buffer
file=buffer, buffer = io.BytesIO()
language="en", sf.write(buffer, repeated_audio, sr, format='WAV')
temperature=0.0) buffer.seek(0)
with RemoteOpenAIServer(model_name, server_args) as remote_server:
client = remote_server.get_async_client()
transcription = await client.audio.transcriptions.create(
model=model_name,
file=buffer,
language="en",
response_format="text",
temperature=0.0)
out = json.loads(transcription)['text']
assert out.count("Mary had a little lamb") == 10
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@ -2,11 +2,13 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio import asyncio
import io import io
import math
import time import time
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from math import ceil from math import ceil
from typing import Final, Optional, Union, cast from typing import Final, Optional, Union, cast
import numpy as np
from fastapi import Request from fastapi import Request
from vllm.config import ModelConfig from vllm.config import ModelConfig
@ -143,6 +145,8 @@ ISO639_1_OTHER_LANGS = {
# As per https://platform.openai.com/docs/guides/speech-to-text#overview. # As per https://platform.openai.com/docs/guides/speech-to-text#overview.
# TODO configurable # TODO configurable
MAX_AUDIO_CLIP_FILESIZE_MB = 25 MAX_AUDIO_CLIP_FILESIZE_MB = 25
OVERLAP_CHUNK_SECOND = 1
MIN_ENERGY_WINDOW_SIZE = 1600 # 1600 ~ 100ms for 16000 Hz audio
class OpenAIServingTranscription(OpenAIServing): class OpenAIServingTranscription(OpenAIServing):
@ -178,7 +182,7 @@ class OpenAIServingTranscription(OpenAIServing):
self, self,
request: TranscriptionRequest, request: TranscriptionRequest,
audio_data: bytes, audio_data: bytes,
) -> tuple[PromptType, float]: ) -> tuple[list[PromptType], float]:
# Validate request # Validate request
# TODO language should be optional and can be guessed. # TODO language should be optional and can be guessed.
# For now we default to en. See # For now we default to en. See
@ -206,22 +210,22 @@ class OpenAIServingTranscription(OpenAIServing):
y, sr = librosa.load(bytes_) y, sr = librosa.load(bytes_)
duration = librosa.get_duration(y=y, sr=sr) duration = librosa.get_duration(y=y, sr=sr)
if duration > self.max_audio_clip_s: chunks = [y] if duration < 30 else self._split_audio(y, sr)
raise ValueError( prompts = []
f"Maximum clip duration ({self.max_audio_clip_s}s) " for i, chunk in enumerate(chunks):
"exceeded.") prompt = {
"encoder_prompt": {
prompt = { "prompt": "",
"encoder_prompt": { "multi_modal_data": {
"prompt": "", "audio": (chunk, sr),
"multi_modal_data": { },
"audio": (y, sr),
}, },
}, "decoder_prompt":
"decoder_prompt": f"<|startoftranscript|>{lang_token}<|transcribe|><|notimestamps|>{request.prompt}"
f"<|startoftranscript|>{lang_token}<|transcribe|><|notimestamps|>{request.prompt}" if i == 0 else ""
} }
return cast(PromptType, prompt), duration prompts.append(cast(PromptType, prompt))
return prompts, duration
# TODO (varun) : Make verbose response work ! # TODO (varun) : Make verbose response work !
async def create_transcription( async def create_transcription(
@ -268,7 +272,7 @@ class OpenAIServingTranscription(OpenAIServing):
"Currently do not support PromptAdapter for Transcription." "Currently do not support PromptAdapter for Transcription."
) )
prompt, duration_s = await self._preprocess_transcription( prompts, duration_s = await self._preprocess_transcription(
request=request, request=request,
audio_data=audio_data, audio_data=audio_data,
) )
@ -277,7 +281,8 @@ class OpenAIServingTranscription(OpenAIServing):
logger.exception("Error in preprocessing prompt inputs") logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e)) return self.create_error_response(str(e))
result_generator: Optional[AsyncGenerator[RequestOutput, None]] = None list_result_generator: Optional[list[AsyncGenerator[RequestOutput,
None]]] = None
try: try:
# Unlike most decoder-only models, whisper generation length is not # Unlike most decoder-only models, whisper generation length is not
# constrained by the size of the input audio, which is mapped to a # constrained by the size of the input audio, which is mapped to a
@ -288,32 +293,36 @@ class OpenAIServingTranscription(OpenAIServing):
self._log_inputs( self._log_inputs(
request_id, request_id,
prompt['decoder_prompt'], # type: ignore prompts[0]['decoder_prompt'], # type: ignore
params=sampling_params, params=sampling_params,
lora_request=None, lora_request=None,
prompt_adapter_request=None) prompt_adapter_request=None)
result_generator = self.engine_client.generate( list_result_generator = [
prompt, self.engine_client.generate(
sampling_params, prompt,
request_id, sampling_params,
) request_id,
) for prompt in prompts
]
except ValueError as e: except ValueError as e:
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e)) return self.create_error_response(str(e))
if request.stream: if request.stream:
return self.transcription_stream_generator(request, return self.transcription_stream_generator(request,
result_generator, list_result_generator,
request_id, request_id,
request_metadata, request_metadata,
duration_s) duration_s)
# Non-streaming response. # Non-streaming response.
try: try:
assert result_generator is not None assert list_result_generator is not None
async for op in result_generator: text = ""
result = op for result_generator in list_result_generator:
return TranscriptionResponse(text=result.outputs[0].text) async for op in result_generator:
text += op.outputs[0].text
return TranscriptionResponse(text=text)
except asyncio.CancelledError: except asyncio.CancelledError:
return self.create_error_response("Client disconnected") return self.create_error_response("Client disconnected")
except ValueError as e: except ValueError as e:
@ -322,7 +331,7 @@ class OpenAIServingTranscription(OpenAIServing):
async def transcription_stream_generator( async def transcription_stream_generator(
self, request: TranscriptionRequest, self, request: TranscriptionRequest,
result_generator: AsyncGenerator[RequestOutput, None], list_result_generator: list[AsyncGenerator[RequestOutput, None]],
request_id: str, request_metadata: RequestResponseMetadata, request_id: str, request_metadata: RequestResponseMetadata,
audio_duration_s: float) -> AsyncGenerator[str, None]: audio_duration_s: float) -> AsyncGenerator[str, None]:
created_time = int(time.time()) created_time = int(time.time())
@ -335,60 +344,65 @@ class OpenAIServingTranscription(OpenAIServing):
include_usage = request.stream_include_usage \ include_usage = request.stream_include_usage \
if request.stream_include_usage else False if request.stream_include_usage else False
include_continuous_usage = request.stream_continuous_usage_stats\ include_continuous_usage = request.stream_continuous_usage_stats\
if include_usage and request.stream_continuous_usage_stats\ if include_usage and request.stream_continuous_usage_stats\
else False else False
try: try:
async for res in result_generator: for result_generator in list_result_generator:
# On first result. async for res in result_generator:
if res.prompt_token_ids is not None: # On first result.
# Do not account the 4-tokens `<|startoftranscript|>..` if res.prompt_token_ids is not None:
# Could be negative when language token is not specified. # Do not account the 4-tokens `<|startoftranscript|>..`
num_prompt_tokens = max(len(res.prompt_token_ids) - 4, 0) # Could be negative when language token
# NOTE(NickLucche) user can't pass encoder prompts directly # is not specified.
# at least not to Whisper. One indicator of the encoder num_prompt_tokens = max(
# amount of processing is the log-mel spectogram length. len(res.prompt_token_ids) - 4, 0)
num_prompt_tokens += ceil(audio_duration_s * # NOTE(NickLucche) user can't pass encoder
self.model_sr / self.hop_length) # prompts directly at least not to Whisper.
# One indicator of the encoder amount of processing
# is the log-mel spectogram length.
num_prompt_tokens += ceil(
audio_duration_s * self.model_sr / self.hop_length)
# We need to do it here, because if there are exceptions in # We need to do it here, because if there are exceptions in
# the result_generator, it needs to be sent as the FIRST # the result_generator, it needs to be sent as the FIRST
# response (by the try...catch). # response (by the try...catch).
# Just one output (n=1) supported. # Just one output (n=1) supported.
assert len(res.outputs) == 1 assert len(res.outputs) == 1
output = res.outputs[0] output = res.outputs[0]
delta_message = DeltaMessage(content=output.text) delta_message = DeltaMessage(content=output.text)
completion_tokens += len(output.token_ids) completion_tokens += len(output.token_ids)
if output.finish_reason is None: if output.finish_reason is None:
# Still generating, send delta update. # Still generating, send delta update.
choice_data = TranscriptionResponseStreamChoice( choice_data = TranscriptionResponseStreamChoice(
delta=delta_message) delta=delta_message)
else: else:
# Model is finished generating. # Model is finished generating.
choice_data = TranscriptionResponseStreamChoice( choice_data = TranscriptionResponseStreamChoice(
delta=delta_message, delta=delta_message,
finish_reason=output.finish_reason, finish_reason=output.finish_reason,
stop_reason=output.stop_reason) stop_reason=output.stop_reason)
chunk = TranscriptionStreamResponse(id=request_id, chunk = TranscriptionStreamResponse(
object=chunk_object_type, id=request_id,
created=created_time, object=chunk_object_type,
choices=[choice_data], created=created_time,
model=model_name) choices=[choice_data],
model=model_name)
# handle usage stats if requested & if continuous # handle usage stats if requested & if continuous
if include_continuous_usage: if include_continuous_usage:
chunk.usage = UsageInfo( chunk.usage = UsageInfo(
prompt_tokens=num_prompt_tokens, prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=num_prompt_tokens + completion_tokens, total_tokens=num_prompt_tokens + completion_tokens,
) )
data = chunk.model_dump_json(exclude_unset=True) data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n" yield f"data: {data}\n\n"
# Once the final token is handled, if stream_options.include_usage # Once the final token is handled, if stream_options.include_usage
# is sent, send the usage. # is sent, send the usage.
@ -422,3 +436,52 @@ class OpenAIServingTranscription(OpenAIServing):
yield f"data: {data}\n\n" yield f"data: {data}\n\n"
# Send the final done message after all response.n are finished # Send the final done message after all response.n are finished
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
def _split_audio(self, audio_data: np.ndarray,
sample_rate: int) -> list[np.ndarray]:
chunk_size = sample_rate * self.max_audio_clip_s
overlap_size = sample_rate * OVERLAP_CHUNK_SECOND
chunks = []
i = 0
while i < audio_data.shape[-1]:
if i + chunk_size >= audio_data.shape[-1]:
# handle last chunk
chunks.append(audio_data[..., i:])
break
# Find the best split point in the overlap region
search_start = i + chunk_size - overlap_size
search_end = min(i + chunk_size, audio_data.shape[-1])
split_point = self._find_split_point(audio_data, search_start,
search_end)
# Extract chunk up to the split point
chunks.append(audio_data[..., i:split_point])
i = split_point
return chunks
def _find_split_point(self, wav: np.ndarray, start_idx: int,
end_idx: int) -> int:
"""Find the best point to split audio by
looking for silence or low amplitude.
Args:
wav: Audio tensor [1, T]
start_idx: Start index of search region
end_idx: End index of search region
Returns:
Index of best splitting point
"""
segment = wav[start_idx:end_idx]
# Calculate RMS energy in small windows
min_energy = math.inf
quietest_idx = 0
for i in range(0,
len(segment) - MIN_ENERGY_WINDOW_SIZE,
MIN_ENERGY_WINDOW_SIZE):
window = segment[i:i + MIN_ENERGY_WINDOW_SIZE]
energy = (window**2).mean()**0.5
if energy < min_energy:
quietest_idx = i + start_idx
min_energy = energy
return quietest_idx