[Frontend] add chunking audio for > 30s audio (#19597)

Signed-off-by: nguyenhoangthuan99 <thuanhppro12@gmail.com>
This commit is contained in:
nguyenhoangthuan99 2025-06-17 10:34:00 +07:00 committed by GitHub
parent 07334959d8
commit ede5c4ebdf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 161 additions and 88 deletions

View File

@ -74,19 +74,29 @@ async def test_bad_requests(mary_had_lamb):
language="hh",
temperature=0.0)
# Expect audio too long: repeat the timeseries
mary_had_lamb.seek(0)
audio, sr = librosa.load(mary_had_lamb)
repeated_audio = np.tile(audio, 10)
# Repeated audio to buffer
buffer = io.BytesIO()
sf.write(buffer, repeated_audio, sr, format='WAV')
buffer.seek(0)
with pytest.raises(openai.BadRequestError):
await client.audio.transcriptions.create(model=model_name,
file=buffer,
language="en",
temperature=0.0)
@pytest.mark.asyncio
async def test_long_audio_request(mary_had_lamb):
model_name = "openai/whisper-large-v3-turbo"
server_args = ["--enforce-eager"]
mary_had_lamb.seek(0)
audio, sr = librosa.load(mary_had_lamb)
repeated_audio = np.tile(audio, 10)
# Repeated audio to buffer
buffer = io.BytesIO()
sf.write(buffer, repeated_audio, sr, format='WAV')
buffer.seek(0)
with RemoteOpenAIServer(model_name, server_args) as remote_server:
client = remote_server.get_async_client()
transcription = await client.audio.transcriptions.create(
model=model_name,
file=buffer,
language="en",
response_format="text",
temperature=0.0)
out = json.loads(transcription)['text']
assert out.count("Mary had a little lamb") == 10
@pytest.mark.asyncio

View File

@ -2,11 +2,13 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import io
import math
import time
from collections.abc import AsyncGenerator
from math import ceil
from typing import Final, Optional, Union, cast
import numpy as np
from fastapi import Request
from vllm.config import ModelConfig
@ -143,6 +145,8 @@ ISO639_1_OTHER_LANGS = {
# As per https://platform.openai.com/docs/guides/speech-to-text#overview.
# TODO configurable
MAX_AUDIO_CLIP_FILESIZE_MB = 25
OVERLAP_CHUNK_SECOND = 1
MIN_ENERGY_WINDOW_SIZE = 1600 # 1600 ~ 100ms for 16000 Hz audio
class OpenAIServingTranscription(OpenAIServing):
@ -178,7 +182,7 @@ class OpenAIServingTranscription(OpenAIServing):
self,
request: TranscriptionRequest,
audio_data: bytes,
) -> tuple[PromptType, float]:
) -> tuple[list[PromptType], float]:
# Validate request
# TODO language should be optional and can be guessed.
# For now we default to en. See
@ -206,22 +210,22 @@ class OpenAIServingTranscription(OpenAIServing):
y, sr = librosa.load(bytes_)
duration = librosa.get_duration(y=y, sr=sr)
if duration > self.max_audio_clip_s:
raise ValueError(
f"Maximum clip duration ({self.max_audio_clip_s}s) "
"exceeded.")
prompt = {
"encoder_prompt": {
"prompt": "",
"multi_modal_data": {
"audio": (y, sr),
chunks = [y] if duration < 30 else self._split_audio(y, sr)
prompts = []
for i, chunk in enumerate(chunks):
prompt = {
"encoder_prompt": {
"prompt": "",
"multi_modal_data": {
"audio": (chunk, sr),
},
},
},
"decoder_prompt":
f"<|startoftranscript|>{lang_token}<|transcribe|><|notimestamps|>{request.prompt}"
}
return cast(PromptType, prompt), duration
"decoder_prompt":
f"<|startoftranscript|>{lang_token}<|transcribe|><|notimestamps|>{request.prompt}"
if i == 0 else ""
}
prompts.append(cast(PromptType, prompt))
return prompts, duration
# TODO (varun) : Make verbose response work !
async def create_transcription(
@ -268,7 +272,7 @@ class OpenAIServingTranscription(OpenAIServing):
"Currently do not support PromptAdapter for Transcription."
)
prompt, duration_s = await self._preprocess_transcription(
prompts, duration_s = await self._preprocess_transcription(
request=request,
audio_data=audio_data,
)
@ -277,7 +281,8 @@ class OpenAIServingTranscription(OpenAIServing):
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
result_generator: Optional[AsyncGenerator[RequestOutput, None]] = None
list_result_generator: Optional[list[AsyncGenerator[RequestOutput,
None]]] = None
try:
# Unlike most decoder-only models, whisper generation length is not
# constrained by the size of the input audio, which is mapped to a
@ -288,32 +293,36 @@ class OpenAIServingTranscription(OpenAIServing):
self._log_inputs(
request_id,
prompt['decoder_prompt'], # type: ignore
prompts[0]['decoder_prompt'], # type: ignore
params=sampling_params,
lora_request=None,
prompt_adapter_request=None)
result_generator = self.engine_client.generate(
prompt,
sampling_params,
request_id,
)
list_result_generator = [
self.engine_client.generate(
prompt,
sampling_params,
request_id,
) for prompt in prompts
]
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
if request.stream:
return self.transcription_stream_generator(request,
result_generator,
list_result_generator,
request_id,
request_metadata,
duration_s)
# Non-streaming response.
try:
assert result_generator is not None
async for op in result_generator:
result = op
return TranscriptionResponse(text=result.outputs[0].text)
assert list_result_generator is not None
text = ""
for result_generator in list_result_generator:
async for op in result_generator:
text += op.outputs[0].text
return TranscriptionResponse(text=text)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
@ -322,7 +331,7 @@ class OpenAIServingTranscription(OpenAIServing):
async def transcription_stream_generator(
self, request: TranscriptionRequest,
result_generator: AsyncGenerator[RequestOutput, None],
list_result_generator: list[AsyncGenerator[RequestOutput, None]],
request_id: str, request_metadata: RequestResponseMetadata,
audio_duration_s: float) -> AsyncGenerator[str, None]:
created_time = int(time.time())
@ -335,60 +344,65 @@ class OpenAIServingTranscription(OpenAIServing):
include_usage = request.stream_include_usage \
if request.stream_include_usage else False
include_continuous_usage = request.stream_continuous_usage_stats\
if include_usage and request.stream_continuous_usage_stats\
else False
if include_usage and request.stream_continuous_usage_stats\
else False
try:
async for res in result_generator:
# On first result.
if res.prompt_token_ids is not None:
# Do not account the 4-tokens `<|startoftranscript|>..`
# Could be negative when language token is not specified.
num_prompt_tokens = max(len(res.prompt_token_ids) - 4, 0)
# NOTE(NickLucche) user can't pass encoder prompts directly
# at least not to Whisper. One indicator of the encoder
# amount of processing is the log-mel spectogram length.
num_prompt_tokens += ceil(audio_duration_s *
self.model_sr / self.hop_length)
for result_generator in list_result_generator:
async for res in result_generator:
# On first result.
if res.prompt_token_ids is not None:
# Do not account the 4-tokens `<|startoftranscript|>..`
# Could be negative when language token
# is not specified.
num_prompt_tokens = max(
len(res.prompt_token_ids) - 4, 0)
# NOTE(NickLucche) user can't pass encoder
# prompts directly at least not to Whisper.
# One indicator of the encoder amount of processing
# is the log-mel spectogram length.
num_prompt_tokens += ceil(
audio_duration_s * self.model_sr / self.hop_length)
# We need to do it here, because if there are exceptions in
# the result_generator, it needs to be sent as the FIRST
# response (by the try...catch).
# We need to do it here, because if there are exceptions in
# the result_generator, it needs to be sent as the FIRST
# response (by the try...catch).
# Just one output (n=1) supported.
assert len(res.outputs) == 1
output = res.outputs[0]
# Just one output (n=1) supported.
assert len(res.outputs) == 1
output = res.outputs[0]
delta_message = DeltaMessage(content=output.text)
completion_tokens += len(output.token_ids)
delta_message = DeltaMessage(content=output.text)
completion_tokens += len(output.token_ids)
if output.finish_reason is None:
# Still generating, send delta update.
choice_data = TranscriptionResponseStreamChoice(
delta=delta_message)
else:
# Model is finished generating.
choice_data = TranscriptionResponseStreamChoice(
delta=delta_message,
finish_reason=output.finish_reason,
stop_reason=output.stop_reason)
if output.finish_reason is None:
# Still generating, send delta update.
choice_data = TranscriptionResponseStreamChoice(
delta=delta_message)
else:
# Model is finished generating.
choice_data = TranscriptionResponseStreamChoice(
delta=delta_message,
finish_reason=output.finish_reason,
stop_reason=output.stop_reason)
chunk = TranscriptionStreamResponse(id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)
chunk = TranscriptionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)
# handle usage stats if requested & if continuous
if include_continuous_usage:
chunk.usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=num_prompt_tokens + completion_tokens,
)
# handle usage stats if requested & if continuous
if include_continuous_usage:
chunk.usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=num_prompt_tokens + completion_tokens,
)
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
# Once the final token is handled, if stream_options.include_usage
# is sent, send the usage.
@ -422,3 +436,52 @@ class OpenAIServingTranscription(OpenAIServing):
yield f"data: {data}\n\n"
# Send the final done message after all response.n are finished
yield "data: [DONE]\n\n"
def _split_audio(self, audio_data: np.ndarray,
sample_rate: int) -> list[np.ndarray]:
chunk_size = sample_rate * self.max_audio_clip_s
overlap_size = sample_rate * OVERLAP_CHUNK_SECOND
chunks = []
i = 0
while i < audio_data.shape[-1]:
if i + chunk_size >= audio_data.shape[-1]:
# handle last chunk
chunks.append(audio_data[..., i:])
break
# Find the best split point in the overlap region
search_start = i + chunk_size - overlap_size
search_end = min(i + chunk_size, audio_data.shape[-1])
split_point = self._find_split_point(audio_data, search_start,
search_end)
# Extract chunk up to the split point
chunks.append(audio_data[..., i:split_point])
i = split_point
return chunks
def _find_split_point(self, wav: np.ndarray, start_idx: int,
end_idx: int) -> int:
"""Find the best point to split audio by
looking for silence or low amplitude.
Args:
wav: Audio tensor [1, T]
start_idx: Start index of search region
end_idx: End index of search region
Returns:
Index of best splitting point
"""
segment = wav[start_idx:end_idx]
# Calculate RMS energy in small windows
min_energy = math.inf
quietest_idx = 0
for i in range(0,
len(segment) - MIN_ENERGY_WINDOW_SIZE,
MIN_ENERGY_WINDOW_SIZE):
window = segment[i:i + MIN_ENERGY_WINDOW_SIZE]
energy = (window**2).mean()**0.5
if energy < min_energy:
quietest_idx = i + start_idx
min_energy = energy
return quietest_idx