mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 04:35:01 +08:00
[Frontend] add chunking audio for > 30s audio (#19597)
Signed-off-by: nguyenhoangthuan99 <thuanhppro12@gmail.com>
This commit is contained in:
parent
07334959d8
commit
ede5c4ebdf
@ -74,7 +74,12 @@ async def test_bad_requests(mary_had_lamb):
|
|||||||
language="hh",
|
language="hh",
|
||||||
temperature=0.0)
|
temperature=0.0)
|
||||||
|
|
||||||
# Expect audio too long: repeat the timeseries
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_long_audio_request(mary_had_lamb):
|
||||||
|
model_name = "openai/whisper-large-v3-turbo"
|
||||||
|
server_args = ["--enforce-eager"]
|
||||||
|
|
||||||
mary_had_lamb.seek(0)
|
mary_had_lamb.seek(0)
|
||||||
audio, sr = librosa.load(mary_had_lamb)
|
audio, sr = librosa.load(mary_had_lamb)
|
||||||
repeated_audio = np.tile(audio, 10)
|
repeated_audio = np.tile(audio, 10)
|
||||||
@ -82,11 +87,16 @@ async def test_bad_requests(mary_had_lamb):
|
|||||||
buffer = io.BytesIO()
|
buffer = io.BytesIO()
|
||||||
sf.write(buffer, repeated_audio, sr, format='WAV')
|
sf.write(buffer, repeated_audio, sr, format='WAV')
|
||||||
buffer.seek(0)
|
buffer.seek(0)
|
||||||
with pytest.raises(openai.BadRequestError):
|
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
||||||
await client.audio.transcriptions.create(model=model_name,
|
client = remote_server.get_async_client()
|
||||||
|
transcription = await client.audio.transcriptions.create(
|
||||||
|
model=model_name,
|
||||||
file=buffer,
|
file=buffer,
|
||||||
language="en",
|
language="en",
|
||||||
|
response_format="text",
|
||||||
temperature=0.0)
|
temperature=0.0)
|
||||||
|
out = json.loads(transcription)['text']
|
||||||
|
assert out.count("Mary had a little lamb") == 10
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
@ -2,11 +2,13 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import asyncio
|
import asyncio
|
||||||
import io
|
import io
|
||||||
|
import math
|
||||||
import time
|
import time
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from math import ceil
|
from math import ceil
|
||||||
from typing import Final, Optional, Union, cast
|
from typing import Final, Optional, Union, cast
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
@ -143,6 +145,8 @@ ISO639_1_OTHER_LANGS = {
|
|||||||
# As per https://platform.openai.com/docs/guides/speech-to-text#overview.
|
# As per https://platform.openai.com/docs/guides/speech-to-text#overview.
|
||||||
# TODO configurable
|
# TODO configurable
|
||||||
MAX_AUDIO_CLIP_FILESIZE_MB = 25
|
MAX_AUDIO_CLIP_FILESIZE_MB = 25
|
||||||
|
OVERLAP_CHUNK_SECOND = 1
|
||||||
|
MIN_ENERGY_WINDOW_SIZE = 1600 # 1600 ~ 100ms for 16000 Hz audio
|
||||||
|
|
||||||
|
|
||||||
class OpenAIServingTranscription(OpenAIServing):
|
class OpenAIServingTranscription(OpenAIServing):
|
||||||
@ -178,7 +182,7 @@ class OpenAIServingTranscription(OpenAIServing):
|
|||||||
self,
|
self,
|
||||||
request: TranscriptionRequest,
|
request: TranscriptionRequest,
|
||||||
audio_data: bytes,
|
audio_data: bytes,
|
||||||
) -> tuple[PromptType, float]:
|
) -> tuple[list[PromptType], float]:
|
||||||
# Validate request
|
# Validate request
|
||||||
# TODO language should be optional and can be guessed.
|
# TODO language should be optional and can be guessed.
|
||||||
# For now we default to en. See
|
# For now we default to en. See
|
||||||
@ -206,22 +210,22 @@ class OpenAIServingTranscription(OpenAIServing):
|
|||||||
y, sr = librosa.load(bytes_)
|
y, sr = librosa.load(bytes_)
|
||||||
|
|
||||||
duration = librosa.get_duration(y=y, sr=sr)
|
duration = librosa.get_duration(y=y, sr=sr)
|
||||||
if duration > self.max_audio_clip_s:
|
chunks = [y] if duration < 30 else self._split_audio(y, sr)
|
||||||
raise ValueError(
|
prompts = []
|
||||||
f"Maximum clip duration ({self.max_audio_clip_s}s) "
|
for i, chunk in enumerate(chunks):
|
||||||
"exceeded.")
|
|
||||||
|
|
||||||
prompt = {
|
prompt = {
|
||||||
"encoder_prompt": {
|
"encoder_prompt": {
|
||||||
"prompt": "",
|
"prompt": "",
|
||||||
"multi_modal_data": {
|
"multi_modal_data": {
|
||||||
"audio": (y, sr),
|
"audio": (chunk, sr),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"decoder_prompt":
|
"decoder_prompt":
|
||||||
f"<|startoftranscript|>{lang_token}<|transcribe|><|notimestamps|>{request.prompt}"
|
f"<|startoftranscript|>{lang_token}<|transcribe|><|notimestamps|>{request.prompt}"
|
||||||
|
if i == 0 else ""
|
||||||
}
|
}
|
||||||
return cast(PromptType, prompt), duration
|
prompts.append(cast(PromptType, prompt))
|
||||||
|
return prompts, duration
|
||||||
|
|
||||||
# TODO (varun) : Make verbose response work !
|
# TODO (varun) : Make verbose response work !
|
||||||
async def create_transcription(
|
async def create_transcription(
|
||||||
@ -268,7 +272,7 @@ class OpenAIServingTranscription(OpenAIServing):
|
|||||||
"Currently do not support PromptAdapter for Transcription."
|
"Currently do not support PromptAdapter for Transcription."
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt, duration_s = await self._preprocess_transcription(
|
prompts, duration_s = await self._preprocess_transcription(
|
||||||
request=request,
|
request=request,
|
||||||
audio_data=audio_data,
|
audio_data=audio_data,
|
||||||
)
|
)
|
||||||
@ -277,7 +281,8 @@ class OpenAIServingTranscription(OpenAIServing):
|
|||||||
logger.exception("Error in preprocessing prompt inputs")
|
logger.exception("Error in preprocessing prompt inputs")
|
||||||
return self.create_error_response(str(e))
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
result_generator: Optional[AsyncGenerator[RequestOutput, None]] = None
|
list_result_generator: Optional[list[AsyncGenerator[RequestOutput,
|
||||||
|
None]]] = None
|
||||||
try:
|
try:
|
||||||
# Unlike most decoder-only models, whisper generation length is not
|
# Unlike most decoder-only models, whisper generation length is not
|
||||||
# constrained by the size of the input audio, which is mapped to a
|
# constrained by the size of the input audio, which is mapped to a
|
||||||
@ -288,32 +293,36 @@ class OpenAIServingTranscription(OpenAIServing):
|
|||||||
|
|
||||||
self._log_inputs(
|
self._log_inputs(
|
||||||
request_id,
|
request_id,
|
||||||
prompt['decoder_prompt'], # type: ignore
|
prompts[0]['decoder_prompt'], # type: ignore
|
||||||
params=sampling_params,
|
params=sampling_params,
|
||||||
lora_request=None,
|
lora_request=None,
|
||||||
prompt_adapter_request=None)
|
prompt_adapter_request=None)
|
||||||
|
|
||||||
result_generator = self.engine_client.generate(
|
list_result_generator = [
|
||||||
|
self.engine_client.generate(
|
||||||
prompt,
|
prompt,
|
||||||
sampling_params,
|
sampling_params,
|
||||||
request_id,
|
request_id,
|
||||||
)
|
) for prompt in prompts
|
||||||
|
]
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
# TODO: Use a vllm-specific Validation Error
|
# TODO: Use a vllm-specific Validation Error
|
||||||
return self.create_error_response(str(e))
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
if request.stream:
|
if request.stream:
|
||||||
return self.transcription_stream_generator(request,
|
return self.transcription_stream_generator(request,
|
||||||
result_generator,
|
list_result_generator,
|
||||||
request_id,
|
request_id,
|
||||||
request_metadata,
|
request_metadata,
|
||||||
duration_s)
|
duration_s)
|
||||||
# Non-streaming response.
|
# Non-streaming response.
|
||||||
try:
|
try:
|
||||||
assert result_generator is not None
|
assert list_result_generator is not None
|
||||||
|
text = ""
|
||||||
|
for result_generator in list_result_generator:
|
||||||
async for op in result_generator:
|
async for op in result_generator:
|
||||||
result = op
|
text += op.outputs[0].text
|
||||||
return TranscriptionResponse(text=result.outputs[0].text)
|
return TranscriptionResponse(text=text)
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
return self.create_error_response("Client disconnected")
|
return self.create_error_response("Client disconnected")
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
@ -322,7 +331,7 @@ class OpenAIServingTranscription(OpenAIServing):
|
|||||||
|
|
||||||
async def transcription_stream_generator(
|
async def transcription_stream_generator(
|
||||||
self, request: TranscriptionRequest,
|
self, request: TranscriptionRequest,
|
||||||
result_generator: AsyncGenerator[RequestOutput, None],
|
list_result_generator: list[AsyncGenerator[RequestOutput, None]],
|
||||||
request_id: str, request_metadata: RequestResponseMetadata,
|
request_id: str, request_metadata: RequestResponseMetadata,
|
||||||
audio_duration_s: float) -> AsyncGenerator[str, None]:
|
audio_duration_s: float) -> AsyncGenerator[str, None]:
|
||||||
created_time = int(time.time())
|
created_time = int(time.time())
|
||||||
@ -339,17 +348,21 @@ class OpenAIServingTranscription(OpenAIServing):
|
|||||||
else False
|
else False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
for result_generator in list_result_generator:
|
||||||
async for res in result_generator:
|
async for res in result_generator:
|
||||||
# On first result.
|
# On first result.
|
||||||
if res.prompt_token_ids is not None:
|
if res.prompt_token_ids is not None:
|
||||||
# Do not account the 4-tokens `<|startoftranscript|>..`
|
# Do not account the 4-tokens `<|startoftranscript|>..`
|
||||||
# Could be negative when language token is not specified.
|
# Could be negative when language token
|
||||||
num_prompt_tokens = max(len(res.prompt_token_ids) - 4, 0)
|
# is not specified.
|
||||||
# NOTE(NickLucche) user can't pass encoder prompts directly
|
num_prompt_tokens = max(
|
||||||
# at least not to Whisper. One indicator of the encoder
|
len(res.prompt_token_ids) - 4, 0)
|
||||||
# amount of processing is the log-mel spectogram length.
|
# NOTE(NickLucche) user can't pass encoder
|
||||||
num_prompt_tokens += ceil(audio_duration_s *
|
# prompts directly at least not to Whisper.
|
||||||
self.model_sr / self.hop_length)
|
# One indicator of the encoder amount of processing
|
||||||
|
# is the log-mel spectogram length.
|
||||||
|
num_prompt_tokens += ceil(
|
||||||
|
audio_duration_s * self.model_sr / self.hop_length)
|
||||||
|
|
||||||
# We need to do it here, because if there are exceptions in
|
# We need to do it here, because if there are exceptions in
|
||||||
# the result_generator, it needs to be sent as the FIRST
|
# the result_generator, it needs to be sent as the FIRST
|
||||||
@ -373,7 +386,8 @@ class OpenAIServingTranscription(OpenAIServing):
|
|||||||
finish_reason=output.finish_reason,
|
finish_reason=output.finish_reason,
|
||||||
stop_reason=output.stop_reason)
|
stop_reason=output.stop_reason)
|
||||||
|
|
||||||
chunk = TranscriptionStreamResponse(id=request_id,
|
chunk = TranscriptionStreamResponse(
|
||||||
|
id=request_id,
|
||||||
object=chunk_object_type,
|
object=chunk_object_type,
|
||||||
created=created_time,
|
created=created_time,
|
||||||
choices=[choice_data],
|
choices=[choice_data],
|
||||||
@ -422,3 +436,52 @@ class OpenAIServingTranscription(OpenAIServing):
|
|||||||
yield f"data: {data}\n\n"
|
yield f"data: {data}\n\n"
|
||||||
# Send the final done message after all response.n are finished
|
# Send the final done message after all response.n are finished
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
def _split_audio(self, audio_data: np.ndarray,
|
||||||
|
sample_rate: int) -> list[np.ndarray]:
|
||||||
|
chunk_size = sample_rate * self.max_audio_clip_s
|
||||||
|
overlap_size = sample_rate * OVERLAP_CHUNK_SECOND
|
||||||
|
chunks = []
|
||||||
|
i = 0
|
||||||
|
while i < audio_data.shape[-1]:
|
||||||
|
if i + chunk_size >= audio_data.shape[-1]:
|
||||||
|
# handle last chunk
|
||||||
|
chunks.append(audio_data[..., i:])
|
||||||
|
break
|
||||||
|
|
||||||
|
# Find the best split point in the overlap region
|
||||||
|
search_start = i + chunk_size - overlap_size
|
||||||
|
search_end = min(i + chunk_size, audio_data.shape[-1])
|
||||||
|
split_point = self._find_split_point(audio_data, search_start,
|
||||||
|
search_end)
|
||||||
|
|
||||||
|
# Extract chunk up to the split point
|
||||||
|
chunks.append(audio_data[..., i:split_point])
|
||||||
|
i = split_point
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
def _find_split_point(self, wav: np.ndarray, start_idx: int,
|
||||||
|
end_idx: int) -> int:
|
||||||
|
"""Find the best point to split audio by
|
||||||
|
looking for silence or low amplitude.
|
||||||
|
Args:
|
||||||
|
wav: Audio tensor [1, T]
|
||||||
|
start_idx: Start index of search region
|
||||||
|
end_idx: End index of search region
|
||||||
|
Returns:
|
||||||
|
Index of best splitting point
|
||||||
|
"""
|
||||||
|
segment = wav[start_idx:end_idx]
|
||||||
|
|
||||||
|
# Calculate RMS energy in small windows
|
||||||
|
min_energy = math.inf
|
||||||
|
quietest_idx = 0
|
||||||
|
for i in range(0,
|
||||||
|
len(segment) - MIN_ENERGY_WINDOW_SIZE,
|
||||||
|
MIN_ENERGY_WINDOW_SIZE):
|
||||||
|
window = segment[i:i + MIN_ENERGY_WINDOW_SIZE]
|
||||||
|
energy = (window**2).mean()**0.5
|
||||||
|
if energy < min_energy:
|
||||||
|
quietest_idx = i + start_idx
|
||||||
|
min_energy = energy
|
||||||
|
return quietest_idx
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user