mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:54:56 +08:00
[Frontend] add chunking audio for > 30s audio (#19597)
Signed-off-by: nguyenhoangthuan99 <thuanhppro12@gmail.com>
This commit is contained in:
parent
07334959d8
commit
ede5c4ebdf
@ -74,19 +74,29 @@ async def test_bad_requests(mary_had_lamb):
|
||||
language="hh",
|
||||
temperature=0.0)
|
||||
|
||||
# Expect audio too long: repeat the timeseries
|
||||
mary_had_lamb.seek(0)
|
||||
audio, sr = librosa.load(mary_had_lamb)
|
||||
repeated_audio = np.tile(audio, 10)
|
||||
# Repeated audio to buffer
|
||||
buffer = io.BytesIO()
|
||||
sf.write(buffer, repeated_audio, sr, format='WAV')
|
||||
buffer.seek(0)
|
||||
with pytest.raises(openai.BadRequestError):
|
||||
await client.audio.transcriptions.create(model=model_name,
|
||||
file=buffer,
|
||||
language="en",
|
||||
temperature=0.0)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_long_audio_request(mary_had_lamb):
|
||||
model_name = "openai/whisper-large-v3-turbo"
|
||||
server_args = ["--enforce-eager"]
|
||||
|
||||
mary_had_lamb.seek(0)
|
||||
audio, sr = librosa.load(mary_had_lamb)
|
||||
repeated_audio = np.tile(audio, 10)
|
||||
# Repeated audio to buffer
|
||||
buffer = io.BytesIO()
|
||||
sf.write(buffer, repeated_audio, sr, format='WAV')
|
||||
buffer.seek(0)
|
||||
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
||||
client = remote_server.get_async_client()
|
||||
transcription = await client.audio.transcriptions.create(
|
||||
model=model_name,
|
||||
file=buffer,
|
||||
language="en",
|
||||
response_format="text",
|
||||
temperature=0.0)
|
||||
out = json.loads(transcription)['text']
|
||||
assert out.count("Mary had a little lamb") == 10
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@ -2,11 +2,13 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
import io
|
||||
import math
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from math import ceil
|
||||
from typing import Final, Optional, Union, cast
|
||||
|
||||
import numpy as np
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
@ -143,6 +145,8 @@ ISO639_1_OTHER_LANGS = {
|
||||
# As per https://platform.openai.com/docs/guides/speech-to-text#overview.
|
||||
# TODO configurable
|
||||
MAX_AUDIO_CLIP_FILESIZE_MB = 25
|
||||
OVERLAP_CHUNK_SECOND = 1
|
||||
MIN_ENERGY_WINDOW_SIZE = 1600 # 1600 ~ 100ms for 16000 Hz audio
|
||||
|
||||
|
||||
class OpenAIServingTranscription(OpenAIServing):
|
||||
@ -178,7 +182,7 @@ class OpenAIServingTranscription(OpenAIServing):
|
||||
self,
|
||||
request: TranscriptionRequest,
|
||||
audio_data: bytes,
|
||||
) -> tuple[PromptType, float]:
|
||||
) -> tuple[list[PromptType], float]:
|
||||
# Validate request
|
||||
# TODO language should be optional and can be guessed.
|
||||
# For now we default to en. See
|
||||
@ -206,22 +210,22 @@ class OpenAIServingTranscription(OpenAIServing):
|
||||
y, sr = librosa.load(bytes_)
|
||||
|
||||
duration = librosa.get_duration(y=y, sr=sr)
|
||||
if duration > self.max_audio_clip_s:
|
||||
raise ValueError(
|
||||
f"Maximum clip duration ({self.max_audio_clip_s}s) "
|
||||
"exceeded.")
|
||||
|
||||
prompt = {
|
||||
"encoder_prompt": {
|
||||
"prompt": "",
|
||||
"multi_modal_data": {
|
||||
"audio": (y, sr),
|
||||
chunks = [y] if duration < 30 else self._split_audio(y, sr)
|
||||
prompts = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
prompt = {
|
||||
"encoder_prompt": {
|
||||
"prompt": "",
|
||||
"multi_modal_data": {
|
||||
"audio": (chunk, sr),
|
||||
},
|
||||
},
|
||||
},
|
||||
"decoder_prompt":
|
||||
f"<|startoftranscript|>{lang_token}<|transcribe|><|notimestamps|>{request.prompt}"
|
||||
}
|
||||
return cast(PromptType, prompt), duration
|
||||
"decoder_prompt":
|
||||
f"<|startoftranscript|>{lang_token}<|transcribe|><|notimestamps|>{request.prompt}"
|
||||
if i == 0 else ""
|
||||
}
|
||||
prompts.append(cast(PromptType, prompt))
|
||||
return prompts, duration
|
||||
|
||||
# TODO (varun) : Make verbose response work !
|
||||
async def create_transcription(
|
||||
@ -268,7 +272,7 @@ class OpenAIServingTranscription(OpenAIServing):
|
||||
"Currently do not support PromptAdapter for Transcription."
|
||||
)
|
||||
|
||||
prompt, duration_s = await self._preprocess_transcription(
|
||||
prompts, duration_s = await self._preprocess_transcription(
|
||||
request=request,
|
||||
audio_data=audio_data,
|
||||
)
|
||||
@ -277,7 +281,8 @@ class OpenAIServingTranscription(OpenAIServing):
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
result_generator: Optional[AsyncGenerator[RequestOutput, None]] = None
|
||||
list_result_generator: Optional[list[AsyncGenerator[RequestOutput,
|
||||
None]]] = None
|
||||
try:
|
||||
# Unlike most decoder-only models, whisper generation length is not
|
||||
# constrained by the size of the input audio, which is mapped to a
|
||||
@ -288,32 +293,36 @@ class OpenAIServingTranscription(OpenAIServing):
|
||||
|
||||
self._log_inputs(
|
||||
request_id,
|
||||
prompt['decoder_prompt'], # type: ignore
|
||||
prompts[0]['decoder_prompt'], # type: ignore
|
||||
params=sampling_params,
|
||||
lora_request=None,
|
||||
prompt_adapter_request=None)
|
||||
|
||||
result_generator = self.engine_client.generate(
|
||||
prompt,
|
||||
sampling_params,
|
||||
request_id,
|
||||
)
|
||||
list_result_generator = [
|
||||
self.engine_client.generate(
|
||||
prompt,
|
||||
sampling_params,
|
||||
request_id,
|
||||
) for prompt in prompts
|
||||
]
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
if request.stream:
|
||||
return self.transcription_stream_generator(request,
|
||||
result_generator,
|
||||
list_result_generator,
|
||||
request_id,
|
||||
request_metadata,
|
||||
duration_s)
|
||||
# Non-streaming response.
|
||||
try:
|
||||
assert result_generator is not None
|
||||
async for op in result_generator:
|
||||
result = op
|
||||
return TranscriptionResponse(text=result.outputs[0].text)
|
||||
assert list_result_generator is not None
|
||||
text = ""
|
||||
for result_generator in list_result_generator:
|
||||
async for op in result_generator:
|
||||
text += op.outputs[0].text
|
||||
return TranscriptionResponse(text=text)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
@ -322,7 +331,7 @@ class OpenAIServingTranscription(OpenAIServing):
|
||||
|
||||
async def transcription_stream_generator(
|
||||
self, request: TranscriptionRequest,
|
||||
result_generator: AsyncGenerator[RequestOutput, None],
|
||||
list_result_generator: list[AsyncGenerator[RequestOutput, None]],
|
||||
request_id: str, request_metadata: RequestResponseMetadata,
|
||||
audio_duration_s: float) -> AsyncGenerator[str, None]:
|
||||
created_time = int(time.time())
|
||||
@ -335,60 +344,65 @@ class OpenAIServingTranscription(OpenAIServing):
|
||||
include_usage = request.stream_include_usage \
|
||||
if request.stream_include_usage else False
|
||||
include_continuous_usage = request.stream_continuous_usage_stats\
|
||||
if include_usage and request.stream_continuous_usage_stats\
|
||||
else False
|
||||
if include_usage and request.stream_continuous_usage_stats\
|
||||
else False
|
||||
|
||||
try:
|
||||
async for res in result_generator:
|
||||
# On first result.
|
||||
if res.prompt_token_ids is not None:
|
||||
# Do not account the 4-tokens `<|startoftranscript|>..`
|
||||
# Could be negative when language token is not specified.
|
||||
num_prompt_tokens = max(len(res.prompt_token_ids) - 4, 0)
|
||||
# NOTE(NickLucche) user can't pass encoder prompts directly
|
||||
# at least not to Whisper. One indicator of the encoder
|
||||
# amount of processing is the log-mel spectogram length.
|
||||
num_prompt_tokens += ceil(audio_duration_s *
|
||||
self.model_sr / self.hop_length)
|
||||
for result_generator in list_result_generator:
|
||||
async for res in result_generator:
|
||||
# On first result.
|
||||
if res.prompt_token_ids is not None:
|
||||
# Do not account the 4-tokens `<|startoftranscript|>..`
|
||||
# Could be negative when language token
|
||||
# is not specified.
|
||||
num_prompt_tokens = max(
|
||||
len(res.prompt_token_ids) - 4, 0)
|
||||
# NOTE(NickLucche) user can't pass encoder
|
||||
# prompts directly at least not to Whisper.
|
||||
# One indicator of the encoder amount of processing
|
||||
# is the log-mel spectogram length.
|
||||
num_prompt_tokens += ceil(
|
||||
audio_duration_s * self.model_sr / self.hop_length)
|
||||
|
||||
# We need to do it here, because if there are exceptions in
|
||||
# the result_generator, it needs to be sent as the FIRST
|
||||
# response (by the try...catch).
|
||||
# We need to do it here, because if there are exceptions in
|
||||
# the result_generator, it needs to be sent as the FIRST
|
||||
# response (by the try...catch).
|
||||
|
||||
# Just one output (n=1) supported.
|
||||
assert len(res.outputs) == 1
|
||||
output = res.outputs[0]
|
||||
# Just one output (n=1) supported.
|
||||
assert len(res.outputs) == 1
|
||||
output = res.outputs[0]
|
||||
|
||||
delta_message = DeltaMessage(content=output.text)
|
||||
completion_tokens += len(output.token_ids)
|
||||
delta_message = DeltaMessage(content=output.text)
|
||||
completion_tokens += len(output.token_ids)
|
||||
|
||||
if output.finish_reason is None:
|
||||
# Still generating, send delta update.
|
||||
choice_data = TranscriptionResponseStreamChoice(
|
||||
delta=delta_message)
|
||||
else:
|
||||
# Model is finished generating.
|
||||
choice_data = TranscriptionResponseStreamChoice(
|
||||
delta=delta_message,
|
||||
finish_reason=output.finish_reason,
|
||||
stop_reason=output.stop_reason)
|
||||
if output.finish_reason is None:
|
||||
# Still generating, send delta update.
|
||||
choice_data = TranscriptionResponseStreamChoice(
|
||||
delta=delta_message)
|
||||
else:
|
||||
# Model is finished generating.
|
||||
choice_data = TranscriptionResponseStreamChoice(
|
||||
delta=delta_message,
|
||||
finish_reason=output.finish_reason,
|
||||
stop_reason=output.stop_reason)
|
||||
|
||||
chunk = TranscriptionStreamResponse(id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[choice_data],
|
||||
model=model_name)
|
||||
chunk = TranscriptionStreamResponse(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[choice_data],
|
||||
model=model_name)
|
||||
|
||||
# handle usage stats if requested & if continuous
|
||||
if include_continuous_usage:
|
||||
chunk.usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=num_prompt_tokens + completion_tokens,
|
||||
)
|
||||
# handle usage stats if requested & if continuous
|
||||
if include_continuous_usage:
|
||||
chunk.usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=num_prompt_tokens + completion_tokens,
|
||||
)
|
||||
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
yield f"data: {data}\n\n"
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
# Once the final token is handled, if stream_options.include_usage
|
||||
# is sent, send the usage.
|
||||
@ -422,3 +436,52 @@ class OpenAIServingTranscription(OpenAIServing):
|
||||
yield f"data: {data}\n\n"
|
||||
# Send the final done message after all response.n are finished
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
def _split_audio(self, audio_data: np.ndarray,
|
||||
sample_rate: int) -> list[np.ndarray]:
|
||||
chunk_size = sample_rate * self.max_audio_clip_s
|
||||
overlap_size = sample_rate * OVERLAP_CHUNK_SECOND
|
||||
chunks = []
|
||||
i = 0
|
||||
while i < audio_data.shape[-1]:
|
||||
if i + chunk_size >= audio_data.shape[-1]:
|
||||
# handle last chunk
|
||||
chunks.append(audio_data[..., i:])
|
||||
break
|
||||
|
||||
# Find the best split point in the overlap region
|
||||
search_start = i + chunk_size - overlap_size
|
||||
search_end = min(i + chunk_size, audio_data.shape[-1])
|
||||
split_point = self._find_split_point(audio_data, search_start,
|
||||
search_end)
|
||||
|
||||
# Extract chunk up to the split point
|
||||
chunks.append(audio_data[..., i:split_point])
|
||||
i = split_point
|
||||
return chunks
|
||||
|
||||
def _find_split_point(self, wav: np.ndarray, start_idx: int,
|
||||
end_idx: int) -> int:
|
||||
"""Find the best point to split audio by
|
||||
looking for silence or low amplitude.
|
||||
Args:
|
||||
wav: Audio tensor [1, T]
|
||||
start_idx: Start index of search region
|
||||
end_idx: End index of search region
|
||||
Returns:
|
||||
Index of best splitting point
|
||||
"""
|
||||
segment = wav[start_idx:end_idx]
|
||||
|
||||
# Calculate RMS energy in small windows
|
||||
min_energy = math.inf
|
||||
quietest_idx = 0
|
||||
for i in range(0,
|
||||
len(segment) - MIN_ENERGY_WINDOW_SIZE,
|
||||
MIN_ENERGY_WINDOW_SIZE):
|
||||
window = segment[i:i + MIN_ENERGY_WINDOW_SIZE]
|
||||
energy = (window**2).mean()**0.5
|
||||
if energy < min_energy:
|
||||
quietest_idx = i + start_idx
|
||||
min_energy = energy
|
||||
return quietest_idx
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user