mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 05:44:59 +08:00
[Frontend] Add /v1/audio/translations OpenAI API endpoint (#19615)
Signed-off-by: Roger Wang <ywang@roblox.com> Signed-off-by: NickLucche <nlucches@redhat.com> Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
parent
8359f4c8d8
commit
e795d723ed
@ -57,6 +57,8 @@ We currently support the following OpenAI APIs:
|
||||
- Only applicable to [embedding models](../models/pooling_models.md) (`--task embed`).
|
||||
- [Transcriptions API][transcriptions-api] (`/v1/audio/transcriptions`)
|
||||
- Only applicable to Automatic Speech Recognition (ASR) models (OpenAI Whisper) (`--task generate`).
|
||||
- [Translation API][translations-api] (`/v1/audio/translations`)
|
||||
- Only applicable to Automatic Speech Recognition (ASR) models (OpenAI Whisper) (`--task generate`).
|
||||
|
||||
In addition, we have the following custom APIs:
|
||||
|
||||
@ -374,6 +376,34 @@ The following extra parameters are supported:
|
||||
```python
|
||||
--8<-- "vllm/entrypoints/openai/protocol.py:transcription-extra-params"
|
||||
```
|
||||
|
||||
[](){ #translations-api }
|
||||
|
||||
### Translations API
|
||||
|
||||
Our Translation API is compatible with [OpenAI's Translations API](https://platform.openai.com/docs/api-reference/audio/createTranslation);
|
||||
you can use the [official OpenAI Python client](https://github.com/openai/openai-python) to interact with it.
|
||||
Whisper models can translate audio from one of the 55 non-English supported languages into English.
|
||||
Please mind that the popular `openai/whisper-large-v3-turbo` model does not support translating.
|
||||
|
||||
!!! note
|
||||
To use the Translation API, please install with extra audio dependencies using `pip install vllm[audio]`.
|
||||
|
||||
Code example: <gh-file:examples/online_serving/openai_translation_client.py>
|
||||
|
||||
#### Extra Parameters
|
||||
|
||||
The following [sampling parameters][sampling-params] are supported.
|
||||
|
||||
```python
|
||||
--8<-- "vllm/entrypoints/openai/protocol.py:translation-sampling-params"
|
||||
```
|
||||
|
||||
The following extra parameters are supported:
|
||||
|
||||
```python
|
||||
--8<-- "vllm/entrypoints/openai/protocol.py:translation-extra-params"
|
||||
```
|
||||
|
||||
[](){ #tokenizer-api }
|
||||
|
||||
|
||||
@ -26,23 +26,12 @@ from openai import OpenAI
|
||||
|
||||
from vllm.assets.audio import AudioAsset
|
||||
|
||||
mary_had_lamb = AudioAsset("mary_had_lamb").get_local_path()
|
||||
winning_call = AudioAsset("winning_call").get_local_path()
|
||||
|
||||
# Modify OpenAI's API key and API base to use vLLM's API server.
|
||||
openai_api_key = "EMPTY"
|
||||
openai_api_base = "http://localhost:8000/v1"
|
||||
client = OpenAI(
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_api_base,
|
||||
)
|
||||
|
||||
|
||||
def sync_openai():
|
||||
def sync_openai(audio_path: str, client: OpenAI):
|
||||
"""
|
||||
Perform synchronous transcription using OpenAI-compatible API.
|
||||
"""
|
||||
with open(str(mary_had_lamb), "rb") as f:
|
||||
with open(audio_path, "rb") as f:
|
||||
transcription = client.audio.transcriptions.create(
|
||||
file=f,
|
||||
model="openai/whisper-large-v3",
|
||||
@ -58,8 +47,7 @@ def sync_openai():
|
||||
print("transcription result:", transcription.text)
|
||||
|
||||
|
||||
# OpenAI Transcription API client does not support streaming.
|
||||
async def stream_openai_response():
|
||||
async def stream_openai_response(audio_path: str, base_url: str, api_key: str):
|
||||
"""
|
||||
Perform streaming transcription using vLLM's raw HTTP streaming API.
|
||||
"""
|
||||
@ -68,11 +56,12 @@ async def stream_openai_response():
|
||||
"stream": True,
|
||||
"model": "openai/whisper-large-v3",
|
||||
}
|
||||
url = openai_api_base + "/audio/transcriptions"
|
||||
headers = {"Authorization": f"Bearer {openai_api_key}"}
|
||||
url = base_url + "/audio/transcriptions"
|
||||
headers = {"Authorization": f"Bearer {api_key}"}
|
||||
print("transcription result:", end=" ")
|
||||
# OpenAI Transcription API client does not support streaming.
|
||||
async with httpx.AsyncClient() as client:
|
||||
with open(str(winning_call), "rb") as f:
|
||||
with open(audio_path, "rb") as f:
|
||||
async with client.stream(
|
||||
"POST", url, files={"file": f}, data=data, headers=headers
|
||||
) as response:
|
||||
@ -93,10 +82,20 @@ async def stream_openai_response():
|
||||
|
||||
|
||||
def main():
|
||||
sync_openai()
|
||||
mary_had_lamb = str(AudioAsset("mary_had_lamb").get_local_path())
|
||||
winning_call = str(AudioAsset("winning_call").get_local_path())
|
||||
|
||||
# Modify OpenAI's API key and API base to use vLLM's API server.
|
||||
openai_api_key = "EMPTY"
|
||||
openai_api_base = "http://localhost:8000/v1"
|
||||
client = OpenAI(
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_api_base,
|
||||
)
|
||||
|
||||
sync_openai(mary_had_lamb, client)
|
||||
# Run the asynchronous function
|
||||
asyncio.run(stream_openai_response())
|
||||
asyncio.run(stream_openai_response(winning_call, openai_api_base, openai_api_key))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
75
examples/online_serving/openai_translation_client.py
Normal file
75
examples/online_serving/openai_translation_client.py
Normal file
@ -0,0 +1,75 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
import httpx
|
||||
from openai import OpenAI
|
||||
|
||||
from vllm.assets.audio import AudioAsset
|
||||
|
||||
|
||||
def sync_openai(audio_path: str, client: OpenAI):
|
||||
with open(audio_path, "rb") as f:
|
||||
translation = client.audio.translations.create(
|
||||
file=f,
|
||||
model="openai/whisper-large-v3",
|
||||
response_format="json",
|
||||
temperature=0.0,
|
||||
# Additional params not provided by OpenAI API.
|
||||
extra_body=dict(
|
||||
language="it",
|
||||
seed=4419,
|
||||
repetition_penalty=1.3,
|
||||
),
|
||||
)
|
||||
print("translation result:", translation.text)
|
||||
|
||||
|
||||
async def stream_openai_response(audio_path: str, base_url: str, api_key: str):
|
||||
data = {
|
||||
"language": "it",
|
||||
"stream": True,
|
||||
"model": "openai/whisper-large-v3",
|
||||
}
|
||||
url = base_url + "/audio/translations"
|
||||
headers = {"Authorization": f"Bearer {api_key}"}
|
||||
print("translation result:", end=" ")
|
||||
# OpenAI translation API client does not support streaming.
|
||||
async with httpx.AsyncClient() as client:
|
||||
with open(audio_path, "rb") as f:
|
||||
async with client.stream(
|
||||
"POST", url, files={"file": f}, data=data, headers=headers
|
||||
) as response:
|
||||
async for line in response.aiter_lines():
|
||||
# Each line is a JSON object prefixed with 'data: '
|
||||
if line:
|
||||
if line.startswith("data: "):
|
||||
line = line[len("data: ") :]
|
||||
# Last chunk, stream ends
|
||||
if line.strip() == "[DONE]":
|
||||
break
|
||||
# Parse the JSON response
|
||||
chunk = json.loads(line)
|
||||
# Extract and print the content
|
||||
content = chunk["choices"][0].get("delta", {}).get("content")
|
||||
print(content, end="")
|
||||
|
||||
|
||||
def main():
|
||||
foscolo = str(AudioAsset("azacinto_foscolo").get_local_path())
|
||||
|
||||
# Modify OpenAI's API key and API base to use vLLM's API server.
|
||||
openai_api_key = "EMPTY"
|
||||
openai_api_base = "http://localhost:8000/v1"
|
||||
client = OpenAI(
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_api_base,
|
||||
)
|
||||
sync_openai(foscolo, client)
|
||||
# Run the asynchronous function
|
||||
asyncio.run(stream_openai_response(foscolo, openai_api_base, openai_api_key))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -82,6 +82,8 @@ async def test_long_audio_request(mary_had_lamb):
|
||||
|
||||
mary_had_lamb.seek(0)
|
||||
audio, sr = librosa.load(mary_had_lamb)
|
||||
# Add small silence after each audio for repeatability in the split process
|
||||
audio = np.pad(audio, (0, 1600))
|
||||
repeated_audio = np.tile(audio, 10)
|
||||
# Repeated audio to buffer
|
||||
buffer = io.BytesIO()
|
||||
|
||||
172
tests/entrypoints/openai/test_translation_validation.py
Normal file
172
tests/entrypoints/openai/test_translation_validation.py
Normal file
@ -0,0 +1,172 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import io
|
||||
# imports for guided decoding tests
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import pytest
|
||||
import soundfile as sf
|
||||
from openai._base_client import AsyncAPIClient
|
||||
|
||||
from vllm.assets.audio import AudioAsset
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def foscolo():
|
||||
# Test translation it->en
|
||||
path = AudioAsset('azacinto_foscolo').get_local_path()
|
||||
with open(str(path), "rb") as f:
|
||||
yield f
|
||||
|
||||
|
||||
# NOTE: (NickLucche) the large-v3-turbo model was not trained on translation!
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_audio(foscolo):
|
||||
model_name = "openai/whisper-small"
|
||||
server_args = ["--enforce-eager"]
|
||||
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
||||
client = remote_server.get_async_client()
|
||||
translation = await client.audio.translations.create(
|
||||
model=model_name,
|
||||
file=foscolo,
|
||||
response_format="text",
|
||||
# TODO remove once language detection is implemented
|
||||
extra_body=dict(language="it"),
|
||||
temperature=0.0)
|
||||
out = json.loads(translation)['text'].strip()
|
||||
assert "Nor will I ever touch the sacred" in out
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audio_prompt(foscolo):
|
||||
model_name = "openai/whisper-small"
|
||||
server_args = ["--enforce-eager"]
|
||||
# Condition whisper on starting text
|
||||
prompt = "Nor have I ever"
|
||||
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
||||
client = remote_server.get_async_client()
|
||||
transcription = await client.audio.translations.create(
|
||||
model=model_name,
|
||||
file=foscolo,
|
||||
prompt=prompt,
|
||||
extra_body=dict(language="it"),
|
||||
response_format="text",
|
||||
temperature=0.0)
|
||||
out = json.loads(transcription)['text']
|
||||
assert "Nor will I ever touch the sacred" not in out
|
||||
assert prompt not in out
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_asr_model(foscolo):
|
||||
# text to text model
|
||||
model_name = "JackFram/llama-68m"
|
||||
server_args = ["--enforce-eager"]
|
||||
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
||||
client = remote_server.get_async_client()
|
||||
res = await client.audio.translations.create(model=model_name,
|
||||
file=foscolo,
|
||||
temperature=0.0)
|
||||
assert res.code == 400 and not res.text
|
||||
assert res.message == "The model does not support Translations API"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_response(foscolo):
|
||||
model_name = "openai/whisper-small"
|
||||
server_args = ["--enforce-eager"]
|
||||
translation = ""
|
||||
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
||||
client = remote_server.get_async_client()
|
||||
res_no_stream = await client.audio.translations.create(
|
||||
model=model_name,
|
||||
file=foscolo,
|
||||
response_format="json",
|
||||
extra_body=dict(language="it"),
|
||||
temperature=0.0)
|
||||
# Unfortunately this only works when the openai client is patched
|
||||
# to use streaming mode, not exposed in the translation api.
|
||||
original_post = AsyncAPIClient.post
|
||||
|
||||
async def post_with_stream(*args, **kwargs):
|
||||
kwargs['stream'] = True
|
||||
return await original_post(*args, **kwargs)
|
||||
|
||||
with patch.object(AsyncAPIClient, "post", new=post_with_stream):
|
||||
client = remote_server.get_async_client()
|
||||
res = await client.audio.translations.create(model=model_name,
|
||||
file=foscolo,
|
||||
temperature=0.0,
|
||||
extra_body=dict(
|
||||
stream=True,
|
||||
language="it"))
|
||||
# Reconstruct from chunks and validate
|
||||
async for chunk in res:
|
||||
# just a chunk
|
||||
text = chunk.choices[0]['delta']['content']
|
||||
translation += text
|
||||
|
||||
assert translation == res_no_stream.text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_options(foscolo):
|
||||
model_name = "openai/whisper-small"
|
||||
server_args = ["--enforce-eager"]
|
||||
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
||||
original_post = AsyncAPIClient.post
|
||||
|
||||
async def post_with_stream(*args, **kwargs):
|
||||
kwargs['stream'] = True
|
||||
return await original_post(*args, **kwargs)
|
||||
|
||||
with patch.object(AsyncAPIClient, "post", new=post_with_stream):
|
||||
client = remote_server.get_async_client()
|
||||
res = await client.audio.translations.create(
|
||||
model=model_name,
|
||||
file=foscolo,
|
||||
temperature=0.0,
|
||||
extra_body=dict(language="it",
|
||||
stream=True,
|
||||
stream_include_usage=True,
|
||||
stream_continuous_usage_stats=True))
|
||||
final = False
|
||||
continuous = True
|
||||
async for chunk in res:
|
||||
if not len(chunk.choices):
|
||||
# final usage sent
|
||||
final = True
|
||||
else:
|
||||
continuous = continuous and hasattr(chunk, 'usage')
|
||||
assert final and continuous
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_long_audio_request(foscolo):
|
||||
model_name = "openai/whisper-small"
|
||||
server_args = ["--enforce-eager"]
|
||||
|
||||
foscolo.seek(0)
|
||||
audio, sr = librosa.load(foscolo)
|
||||
repeated_audio = np.tile(audio, 2)
|
||||
# Repeated audio to buffer
|
||||
buffer = io.BytesIO()
|
||||
sf.write(buffer, repeated_audio, sr, format='WAV')
|
||||
buffer.seek(0)
|
||||
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
||||
client = remote_server.get_async_client()
|
||||
translation = await client.audio.translations.create(
|
||||
model=model_name,
|
||||
file=buffer,
|
||||
extra_body=dict(language="it"),
|
||||
response_format="text",
|
||||
temperature=0.0)
|
||||
out = json.loads(translation)['text'].strip().lower()
|
||||
# TODO investigate higher model uncertainty in for longer translations.
|
||||
assert out.count("nor will i ever") == 2
|
||||
@ -73,6 +73,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
TokenizeResponse,
|
||||
TranscriptionRequest,
|
||||
TranscriptionResponse,
|
||||
TranslationRequest,
|
||||
TranslationResponse,
|
||||
UnloadLoRAAdapterRequest)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
@ -88,7 +90,7 @@ from vllm.entrypoints.openai.serving_score import ServingScores
|
||||
from vllm.entrypoints.openai.serving_tokenization import (
|
||||
OpenAIServingTokenization)
|
||||
from vllm.entrypoints.openai.serving_transcription import (
|
||||
OpenAIServingTranscription)
|
||||
OpenAIServingTranscription, OpenAIServingTranslation)
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
||||
from vllm.entrypoints.utils import (cli_env_setup, load_aware_call,
|
||||
with_cancellation)
|
||||
@ -401,6 +403,10 @@ def transcription(request: Request) -> OpenAIServingTranscription:
|
||||
return request.app.state.openai_serving_transcription
|
||||
|
||||
|
||||
def translation(request: Request) -> OpenAIServingTranslation:
|
||||
return request.app.state.openai_serving_translation
|
||||
|
||||
|
||||
def engine_client(request: Request) -> EngineClient:
|
||||
return request.app.state.engine_client
|
||||
|
||||
@ -774,6 +780,47 @@ async def create_transcriptions(raw_request: Request,
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.post("/v1/audio/translations",
|
||||
responses={
|
||||
HTTPStatus.OK.value: {
|
||||
"content": {
|
||||
"text/event-stream": {}
|
||||
}
|
||||
},
|
||||
HTTPStatus.BAD_REQUEST.value: {
|
||||
"model": ErrorResponse
|
||||
},
|
||||
HTTPStatus.UNPROCESSABLE_ENTITY.value: {
|
||||
"model": ErrorResponse
|
||||
},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {
|
||||
"model": ErrorResponse
|
||||
},
|
||||
})
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_translations(request: Annotated[TranslationRequest,
|
||||
Form()],
|
||||
raw_request: Request):
|
||||
handler = translation(raw_request)
|
||||
if handler is None:
|
||||
return base(raw_request).create_error_response(
|
||||
message="The model does not support Translations API")
|
||||
|
||||
audio_data = await request.file.read()
|
||||
generator = await handler.create_translation(audio_data, request,
|
||||
raw_request)
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
|
||||
elif isinstance(generator, TranslationResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.post("/rerank",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
@ -1248,6 +1295,12 @@ async def init_app_state(
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
) if model_config.runner_type == "transcription" else None
|
||||
state.openai_serving_translation = OpenAIServingTranslation(
|
||||
engine_client,
|
||||
model_config,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
) if model_config.runner_type == "transcription" else None
|
||||
state.task = model_config.task
|
||||
|
||||
state.enable_server_load_tracking = args.enable_server_load_tracking
|
||||
|
||||
@ -1947,3 +1947,190 @@ class TranscriptionResponseVerbose(OpenAIBaseModel):
|
||||
|
||||
words: Optional[list[TranscriptionWord]] = None
|
||||
"""Extracted words and their corresponding timestamps."""
|
||||
|
||||
|
||||
class TranslationResponseStreamChoice(OpenAIBaseModel):
|
||||
delta: DeltaMessage
|
||||
finish_reason: Optional[str] = None
|
||||
stop_reason: Optional[Union[int, str]] = None
|
||||
|
||||
|
||||
class TranslationStreamResponse(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"trsl-{random_uuid()}")
|
||||
object: Literal["translation.chunk"] = "translation.chunk"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: list[TranslationResponseStreamChoice]
|
||||
usage: Optional[UsageInfo] = Field(default=None)
|
||||
|
||||
|
||||
class TranslationRequest(OpenAIBaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
# https://platform.openai.com/docs/api-reference/audio/createTranslation
|
||||
|
||||
file: UploadFile
|
||||
"""
|
||||
The audio file object (not file name) to translate, in one of these
|
||||
formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
|
||||
"""
|
||||
|
||||
model: Optional[str] = None
|
||||
"""ID of the model to use.
|
||||
"""
|
||||
|
||||
prompt: str = Field(default="")
|
||||
"""An optional text to guide the model's style or continue a previous audio
|
||||
segment.
|
||||
|
||||
The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting)
|
||||
should match the audio language.
|
||||
"""
|
||||
|
||||
response_format: AudioResponseFormat = Field(default="json")
|
||||
"""
|
||||
The format of the output, in one of these options: `json`, `text`, `srt`,
|
||||
`verbose_json`, or `vtt`.
|
||||
"""
|
||||
|
||||
# TODO support additional sampling parameters
|
||||
# --8<-- [start:translation-sampling-params]
|
||||
temperature: float = Field(default=0.0)
|
||||
"""The sampling temperature, between 0 and 1.
|
||||
|
||||
Higher values like 0.8 will make the output more random, while lower values
|
||||
like 0.2 will make it more focused / deterministic. If set to 0, the model
|
||||
will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
|
||||
to automatically increase the temperature until certain thresholds are hit.
|
||||
"""
|
||||
# --8<-- [end:translation-sampling-params]
|
||||
|
||||
# --8<-- [start:translation-extra-params]
|
||||
language: Optional[str] = None
|
||||
"""The language of the input audio we translate from.
|
||||
|
||||
Supplying the input language in
|
||||
[ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format
|
||||
will improve accuracy.
|
||||
"""
|
||||
|
||||
stream: Optional[bool] = False
|
||||
"""Custom field not present in the original OpenAI definition. When set,
|
||||
it will enable output to be streamed in a similar fashion as the Chat
|
||||
Completion endpoint.
|
||||
"""
|
||||
# Flattened stream option to simplify form data.
|
||||
stream_include_usage: Optional[bool] = False
|
||||
stream_continuous_usage_stats: Optional[bool] = False
|
||||
# --8<-- [end:translation-extra-params]
|
||||
|
||||
# Default sampling parameters for translation requests.
|
||||
_DEFAULT_SAMPLING_PARAMS: dict = {
|
||||
"temperature": 0,
|
||||
}
|
||||
|
||||
def to_sampling_params(
|
||||
self,
|
||||
default_max_tokens: int,
|
||||
default_sampling_params: Optional[dict] = None) -> SamplingParams:
|
||||
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
|
||||
max_tokens = default_max_tokens
|
||||
|
||||
if default_sampling_params is None:
|
||||
default_sampling_params = {}
|
||||
# Default parameters
|
||||
if (temperature := self.temperature) is None:
|
||||
temperature = default_sampling_params.get(
|
||||
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
|
||||
|
||||
return SamplingParams.from_optional(temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
output_kind=RequestOutputKind.DELTA
|
||||
if self.stream \
|
||||
else RequestOutputKind.FINAL_ONLY)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_stream_options(cls, data):
|
||||
stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"]
|
||||
stream = data.get("stream", False)
|
||||
if any(bool(data.get(so, False)) for so in stream_opts) and not stream:
|
||||
raise ValueError(
|
||||
"Stream options can only be defined when `stream=True`.")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
# Translation response objects
|
||||
class TranslationResponse(OpenAIBaseModel):
|
||||
text: str
|
||||
"""The translated text."""
|
||||
|
||||
|
||||
class TranslationWord(OpenAIBaseModel):
|
||||
end: float
|
||||
"""End time of the word in seconds."""
|
||||
|
||||
start: float
|
||||
"""Start time of the word in seconds."""
|
||||
|
||||
word: str
|
||||
"""The text content of the word."""
|
||||
|
||||
|
||||
class TranslationSegment(OpenAIBaseModel):
|
||||
id: int
|
||||
"""Unique identifier of the segment."""
|
||||
|
||||
avg_logprob: float
|
||||
"""Average logprob of the segment.
|
||||
|
||||
If the value is lower than -1, consider the logprobs failed.
|
||||
"""
|
||||
|
||||
compression_ratio: float
|
||||
"""Compression ratio of the segment.
|
||||
|
||||
If the value is greater than 2.4, consider the compression failed.
|
||||
"""
|
||||
|
||||
end: float
|
||||
"""End time of the segment in seconds."""
|
||||
|
||||
no_speech_prob: float
|
||||
"""Probability of no speech in the segment.
|
||||
|
||||
If the value is higher than 1.0 and the `avg_logprob` is below -1, consider
|
||||
this segment silent.
|
||||
"""
|
||||
|
||||
seek: int
|
||||
"""Seek offset of the segment."""
|
||||
|
||||
start: float
|
||||
"""Start time of the segment in seconds."""
|
||||
|
||||
temperature: float
|
||||
"""Temperature parameter used for generating the segment."""
|
||||
|
||||
text: str
|
||||
"""Text content of the segment."""
|
||||
|
||||
tokens: list[int]
|
||||
"""Array of token IDs for the text content."""
|
||||
|
||||
|
||||
class TranslationResponseVerbose(OpenAIBaseModel):
|
||||
duration: str
|
||||
"""The duration of the input audio."""
|
||||
|
||||
language: str
|
||||
"""The language of the input audio."""
|
||||
|
||||
text: str
|
||||
"""The translated text."""
|
||||
|
||||
segments: Optional[list[TranslationSegment]] = None
|
||||
"""Segments of the translated text and their corresponding details."""
|
||||
|
||||
words: Optional[list[TranslationWord]] = None
|
||||
"""Extracted words and their corresponding timestamps."""
|
||||
|
||||
@ -58,7 +58,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
TokenizeCompletionRequest,
|
||||
TokenizeResponse,
|
||||
TranscriptionRequest,
|
||||
TranscriptionResponse)
|
||||
TranscriptionResponse,
|
||||
TranslationRequest)
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParser
|
||||
# yapf: enable
|
||||
@ -89,9 +90,8 @@ CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest,
|
||||
|
||||
ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest,
|
||||
TokenizeChatRequest]
|
||||
|
||||
AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest,
|
||||
TranscriptionRequest]
|
||||
SpeechToTextRequest = Union[TranscriptionRequest, TranslationRequest]
|
||||
AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest, SpeechToTextRequest]
|
||||
|
||||
AnyResponse = Union[
|
||||
CompletionResponse,
|
||||
|
||||
@ -1,155 +1,28 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
import io
|
||||
import math
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from math import ceil
|
||||
from typing import Final, Optional, Union, cast
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
DeltaMessage, ErrorResponse, RequestResponseMetadata, TranscriptionRequest,
|
||||
ErrorResponse, RequestResponseMetadata, TranscriptionRequest,
|
||||
TranscriptionResponse, TranscriptionResponseStreamChoice,
|
||||
TranscriptionStreamResponse, UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
TranscriptionStreamResponse, TranslationRequest, TranslationResponse,
|
||||
TranslationResponseStreamChoice, TranslationStreamResponse)
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.entrypoints.openai.speech_to_text import OpenAISpeechToText
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.transformers_utils.processor import cached_get_processor
|
||||
from vllm.utils import PlaceholderModule
|
||||
|
||||
try:
|
||||
import librosa
|
||||
except ImportError:
|
||||
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# From https://platform.openai.com/docs/guides/speech-to-text/supported-languages#supported-languages
|
||||
# TODO these configs should live somewhere with the model so we can support
|
||||
# additional ones
|
||||
|
||||
ISO639_1_SUPPORTED_LANGS = {
|
||||
"af": "Afrikaans",
|
||||
"ar": "Arabic",
|
||||
"hy": "Armenian",
|
||||
"az": "Azerbaijani",
|
||||
"be": "Belarusian",
|
||||
"bs": "Bosnian",
|
||||
"bg": "Bulgarian",
|
||||
"ca": "Catalan",
|
||||
"zh": "Chinese",
|
||||
"hr": "Croatian",
|
||||
"cs": "Czech",
|
||||
"da": "Danish",
|
||||
"nl": "Dutch",
|
||||
"en": "English",
|
||||
"et": "Estonian",
|
||||
"fi": "Finnish",
|
||||
"fr": "French",
|
||||
"gl": "Galician",
|
||||
"de": "German",
|
||||
"el": "Greek",
|
||||
"he": "Hebrew",
|
||||
"hi": "Hindi",
|
||||
"hu": "Hungarian",
|
||||
"is": "Icelandic",
|
||||
"id": "Indonesian",
|
||||
"it": "Italian",
|
||||
"ja": "Japanese",
|
||||
"kn": "Kannada",
|
||||
"kk": "Kazakh",
|
||||
"ko": "Korean",
|
||||
"lv": "Latvian",
|
||||
"lt": "Lithuanian",
|
||||
"mk": "Macedonian",
|
||||
"ms": "Malay",
|
||||
"mr": "Marathi",
|
||||
"mi": "Maori",
|
||||
"ne": "Nepali",
|
||||
"no": "Norwegian",
|
||||
"fa": "Persian",
|
||||
"pl": "Polish",
|
||||
"pt": "Portuguese",
|
||||
"ro": "Romanian",
|
||||
"ru": "Russian",
|
||||
"sr": "Serbian",
|
||||
"sk": "Slovak",
|
||||
"sl": "Slovenian",
|
||||
"es": "Spanish",
|
||||
"sw": "Swahili",
|
||||
"sv": "Swedish",
|
||||
"tl": "Tagalog",
|
||||
"ta": "Tamil",
|
||||
"th": "Thai",
|
||||
"tr": "Turkish",
|
||||
"uk": "Ukrainian",
|
||||
"ur": "Urdu",
|
||||
"vi": "Vietnamese",
|
||||
"cy": "Welsh"
|
||||
}
|
||||
ISO639_1_OTHER_LANGS = {
|
||||
"lo": "Lao",
|
||||
"jw": "Javanese",
|
||||
"tk": "Turkmen",
|
||||
"yi": "Yiddish",
|
||||
"so": "Somali",
|
||||
"bn": "Bengali",
|
||||
"nn": "Norwegian Nynorsk",
|
||||
"si": "Sinhala",
|
||||
"yo": "Yoruba",
|
||||
"sa": "Sanskrit",
|
||||
"mi": "Māori",
|
||||
"fo": "Faroese", # codespell:ignore
|
||||
"mt": "Maltese",
|
||||
"tg": "Tajik",
|
||||
"mg": "Malagasy",
|
||||
"haw": "Hawaiian",
|
||||
"km": "Khmer",
|
||||
"br": "Breton",
|
||||
"ps": "Pashto",
|
||||
"ln": "Lingala",
|
||||
"la": "Latin",
|
||||
"ml": "Malayalam",
|
||||
"sq": "Albanian",
|
||||
"su": "Sundanese",
|
||||
"eu": "Basque",
|
||||
"ka": "Georgian",
|
||||
"uz": "Uzbek",
|
||||
"sn": "Shona",
|
||||
"ht": "Haitian",
|
||||
"as": "Assamese",
|
||||
"mn": "Mongolian",
|
||||
"te": "Telugu",
|
||||
"pa": "Panjabi",
|
||||
"tt": "Tatar",
|
||||
"gu": "Gujarati",
|
||||
"oc": "Occitan",
|
||||
"ha": "Hausa",
|
||||
"ba": "Bashkir",
|
||||
"my": "Burmese",
|
||||
"sd": "Sindhi",
|
||||
"am": "Amharic",
|
||||
"lb": "Luxembourgish",
|
||||
"bo": "Tibetan"
|
||||
}
|
||||
|
||||
# As per https://platform.openai.com/docs/guides/speech-to-text#overview.
|
||||
# TODO configurable
|
||||
MAX_AUDIO_CLIP_FILESIZE_MB = 25
|
||||
OVERLAP_CHUNK_SECOND = 1
|
||||
MIN_ENERGY_WINDOW_SIZE = 1600 # 1600 ~ 100ms for 16000 Hz audio
|
||||
|
||||
|
||||
class OpenAIServingTranscription(OpenAIServing):
|
||||
class OpenAIServingTranscription(OpenAISpeechToText):
|
||||
"""Handles transcription requests."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -164,70 +37,9 @@ class OpenAIServingTranscription(OpenAIServing):
|
||||
model_config=model_config,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids)
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids,
|
||||
task_type="transcribe")
|
||||
|
||||
self.default_sampling_params = (
|
||||
self.model_config.get_diff_sampling_param())
|
||||
processor = cached_get_processor(model_config.model)
|
||||
self.max_audio_clip_s = processor.feature_extractor.chunk_length
|
||||
self.model_sr = processor.feature_extractor.sampling_rate
|
||||
self.hop_length = processor.feature_extractor.hop_length
|
||||
|
||||
if self.default_sampling_params:
|
||||
logger.info(
|
||||
"Overwriting default completion sampling param with: %s",
|
||||
self.default_sampling_params)
|
||||
|
||||
async def _preprocess_transcription(
|
||||
self,
|
||||
request: TranscriptionRequest,
|
||||
audio_data: bytes,
|
||||
) -> tuple[list[PromptType], float]:
|
||||
# Validate request
|
||||
# TODO language should be optional and can be guessed.
|
||||
# For now we default to en. See
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520
|
||||
lang_token = f"<|{request.language}|>" if request.language else "<|en|>"
|
||||
if request.language:
|
||||
if request.language in ISO639_1_SUPPORTED_LANGS:
|
||||
pass
|
||||
elif request.language in ISO639_1_OTHER_LANGS:
|
||||
logger.warning(
|
||||
"The selected language %s has limited accuracy with"
|
||||
" reported WER>=0.5. Results may be less accurate "
|
||||
"for this choice.", request.language)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported language: {request.language}."
|
||||
"Language should be one of:" +
|
||||
f" {list(ISO639_1_SUPPORTED_LANGS.values())}" +
|
||||
f"or {list(ISO639_1_OTHER_LANGS.values())}")
|
||||
|
||||
if len(audio_data) / 1024**2 > MAX_AUDIO_CLIP_FILESIZE_MB:
|
||||
raise ValueError("Maximum file size exceeded.")
|
||||
|
||||
with io.BytesIO(audio_data) as bytes_:
|
||||
y, sr = librosa.load(bytes_)
|
||||
|
||||
duration = librosa.get_duration(y=y, sr=sr)
|
||||
chunks = [y] if duration < 30 else self._split_audio(y, sr)
|
||||
prompts = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
prompt = {
|
||||
"encoder_prompt": {
|
||||
"prompt": "",
|
||||
"multi_modal_data": {
|
||||
"audio": (chunk, sr),
|
||||
},
|
||||
},
|
||||
"decoder_prompt":
|
||||
f"<|startoftranscript|>{lang_token}<|transcribe|><|notimestamps|>{request.prompt}"
|
||||
if i == 0 else ""
|
||||
}
|
||||
prompts.append(cast(PromptType, prompt))
|
||||
return prompts, duration
|
||||
|
||||
# TODO (varun) : Make verbose response work !
|
||||
async def create_transcription(
|
||||
self, audio_data: bytes, request: TranscriptionRequest,
|
||||
raw_request: Request
|
||||
@ -238,250 +50,83 @@ class OpenAIServingTranscription(OpenAIServing):
|
||||
See https://platform.openai.com/docs/api-reference/audio/createTranscription
|
||||
for the API specification. This API mimics the OpenAI transcription API.
|
||||
"""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
# If the engine is dead, raise the engine's DEAD_ERROR.
|
||||
# This is required for the streaming case, where we return a
|
||||
# success status before we actually start generating text :).
|
||||
if self.engine_client.errored:
|
||||
raise self.engine_client.dead_error
|
||||
|
||||
if request.response_format not in ['text', 'json']:
|
||||
return self.create_error_response(
|
||||
"Currently only support response_format `text` or `json`")
|
||||
|
||||
request_id = f"trsc-{self._base_request_id(raw_request)}"
|
||||
|
||||
request_metadata = RequestResponseMetadata(request_id=request_id)
|
||||
if raw_request:
|
||||
raw_request.state.request_metadata = request_metadata
|
||||
|
||||
try:
|
||||
(
|
||||
lora_request,
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request)
|
||||
|
||||
if lora_request:
|
||||
return self.create_error_response(
|
||||
"Currently do not support LoRA for Transcription.")
|
||||
if prompt_adapter_request:
|
||||
return self.create_error_response(
|
||||
"Currently do not support PromptAdapter for Transcription."
|
||||
)
|
||||
|
||||
prompts, duration_s = await self._preprocess_transcription(
|
||||
request=request,
|
||||
audio_data=audio_data,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
list_result_generator: Optional[list[AsyncGenerator[RequestOutput,
|
||||
None]]] = None
|
||||
try:
|
||||
# Unlike most decoder-only models, whisper generation length is not
|
||||
# constrained by the size of the input audio, which is mapped to a
|
||||
# fixed-size log-mel-spectogram.
|
||||
default_max_tokens = self.model_config.max_model_len
|
||||
sampling_params = request.to_sampling_params(
|
||||
default_max_tokens, self.default_sampling_params)
|
||||
|
||||
self._log_inputs(
|
||||
request_id,
|
||||
prompts[0]['decoder_prompt'], # type: ignore
|
||||
params=sampling_params,
|
||||
lora_request=None,
|
||||
prompt_adapter_request=None)
|
||||
|
||||
list_result_generator = [
|
||||
self.engine_client.generate(
|
||||
prompt,
|
||||
sampling_params,
|
||||
request_id,
|
||||
) for prompt in prompts
|
||||
]
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
if request.stream:
|
||||
return self.transcription_stream_generator(request,
|
||||
list_result_generator,
|
||||
request_id,
|
||||
request_metadata,
|
||||
duration_s)
|
||||
# Non-streaming response.
|
||||
try:
|
||||
assert list_result_generator is not None
|
||||
text = ""
|
||||
for result_generator in list_result_generator:
|
||||
async for op in result_generator:
|
||||
text += op.outputs[0].text
|
||||
return TranscriptionResponse(text=text)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
return await self._create_speech_to_text(
|
||||
audio_data=audio_data,
|
||||
request=request,
|
||||
raw_request=raw_request,
|
||||
response_class=TranscriptionResponse,
|
||||
stream_generator_method=self.transcription_stream_generator,
|
||||
)
|
||||
|
||||
async def transcription_stream_generator(
|
||||
self, request: TranscriptionRequest,
|
||||
list_result_generator: list[AsyncGenerator[RequestOutput, None]],
|
||||
result_generator: list[AsyncGenerator[RequestOutput, None]],
|
||||
request_id: str, request_metadata: RequestResponseMetadata,
|
||||
audio_duration_s: float) -> AsyncGenerator[str, None]:
|
||||
created_time = int(time.time())
|
||||
model_name = request.model
|
||||
chunk_object_type: Final = "transcription.chunk"
|
||||
generator = self._speech_to_text_stream_generator(
|
||||
request=request,
|
||||
list_result_generator=result_generator,
|
||||
request_id=request_id,
|
||||
request_metadata=request_metadata,
|
||||
audio_duration_s=audio_duration_s,
|
||||
chunk_object_type="transcription.chunk",
|
||||
response_stream_choice_class=TranscriptionResponseStreamChoice,
|
||||
stream_response_class=TranscriptionStreamResponse,
|
||||
)
|
||||
async for chunk in generator:
|
||||
yield chunk
|
||||
|
||||
completion_tokens = 0
|
||||
num_prompt_tokens = 0
|
||||
|
||||
include_usage = request.stream_include_usage \
|
||||
if request.stream_include_usage else False
|
||||
include_continuous_usage = request.stream_continuous_usage_stats\
|
||||
if include_usage and request.stream_continuous_usage_stats\
|
||||
else False
|
||||
class OpenAIServingTranslation(OpenAISpeechToText):
|
||||
"""Handles translation requests."""
|
||||
|
||||
try:
|
||||
for result_generator in list_result_generator:
|
||||
async for res in result_generator:
|
||||
# On first result.
|
||||
if res.prompt_token_ids is not None:
|
||||
# Do not account the 4-tokens `<|startoftranscript|>..`
|
||||
# Could be negative when language token
|
||||
# is not specified.
|
||||
num_prompt_tokens = max(
|
||||
len(res.prompt_token_ids) - 4, 0)
|
||||
# NOTE(NickLucche) user can't pass encoder
|
||||
# prompts directly at least not to Whisper.
|
||||
# One indicator of the encoder amount of processing
|
||||
# is the log-mel spectogram length.
|
||||
num_prompt_tokens += ceil(
|
||||
audio_duration_s * self.model_sr / self.hop_length)
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
):
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids,
|
||||
task_type="translate")
|
||||
|
||||
# We need to do it here, because if there are exceptions in
|
||||
# the result_generator, it needs to be sent as the FIRST
|
||||
# response (by the try...catch).
|
||||
async def create_translation(
|
||||
self, audio_data: bytes, request: TranslationRequest,
|
||||
raw_request: Request
|
||||
) -> Union[TranslationResponse, AsyncGenerator[str, None], ErrorResponse]:
|
||||
"""Translation API similar to OpenAI's API.
|
||||
|
||||
# Just one output (n=1) supported.
|
||||
assert len(res.outputs) == 1
|
||||
output = res.outputs[0]
|
||||
|
||||
delta_message = DeltaMessage(content=output.text)
|
||||
completion_tokens += len(output.token_ids)
|
||||
|
||||
if output.finish_reason is None:
|
||||
# Still generating, send delta update.
|
||||
choice_data = TranscriptionResponseStreamChoice(
|
||||
delta=delta_message)
|
||||
else:
|
||||
# Model is finished generating.
|
||||
choice_data = TranscriptionResponseStreamChoice(
|
||||
delta=delta_message,
|
||||
finish_reason=output.finish_reason,
|
||||
stop_reason=output.stop_reason)
|
||||
|
||||
chunk = TranscriptionStreamResponse(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[choice_data],
|
||||
model=model_name)
|
||||
|
||||
# handle usage stats if requested & if continuous
|
||||
if include_continuous_usage:
|
||||
chunk.usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=num_prompt_tokens + completion_tokens,
|
||||
)
|
||||
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
# Once the final token is handled, if stream_options.include_usage
|
||||
# is sent, send the usage.
|
||||
if include_usage:
|
||||
final_usage = UsageInfo(prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=num_prompt_tokens +
|
||||
completion_tokens)
|
||||
|
||||
final_usage_chunk = TranscriptionStreamResponse(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[],
|
||||
model=model_name,
|
||||
usage=final_usage)
|
||||
final_usage_data = (final_usage_chunk.model_dump_json(
|
||||
exclude_unset=True, exclude_none=True))
|
||||
yield f"data: {final_usage_data}\n\n"
|
||||
|
||||
# report to FastAPI middleware aggregate usage across all choices
|
||||
request_metadata.final_usage_info = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=num_prompt_tokens + completion_tokens)
|
||||
|
||||
except Exception as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
logger.exception("Error in chat completion stream generator.")
|
||||
data = self.create_streaming_error_response(str(e))
|
||||
yield f"data: {data}\n\n"
|
||||
# Send the final done message after all response.n are finished
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
def _split_audio(self, audio_data: np.ndarray,
|
||||
sample_rate: int) -> list[np.ndarray]:
|
||||
chunk_size = sample_rate * self.max_audio_clip_s
|
||||
overlap_size = sample_rate * OVERLAP_CHUNK_SECOND
|
||||
chunks = []
|
||||
i = 0
|
||||
while i < audio_data.shape[-1]:
|
||||
if i + chunk_size >= audio_data.shape[-1]:
|
||||
# handle last chunk
|
||||
chunks.append(audio_data[..., i:])
|
||||
break
|
||||
|
||||
# Find the best split point in the overlap region
|
||||
search_start = i + chunk_size - overlap_size
|
||||
search_end = min(i + chunk_size, audio_data.shape[-1])
|
||||
split_point = self._find_split_point(audio_data, search_start,
|
||||
search_end)
|
||||
|
||||
# Extract chunk up to the split point
|
||||
chunks.append(audio_data[..., i:split_point])
|
||||
i = split_point
|
||||
return chunks
|
||||
|
||||
def _find_split_point(self, wav: np.ndarray, start_idx: int,
|
||||
end_idx: int) -> int:
|
||||
"""Find the best point to split audio by
|
||||
looking for silence or low amplitude.
|
||||
Args:
|
||||
wav: Audio tensor [1, T]
|
||||
start_idx: Start index of search region
|
||||
end_idx: End index of search region
|
||||
Returns:
|
||||
Index of best splitting point
|
||||
See https://platform.openai.com/docs/api-reference/audio/createTranslation
|
||||
for the API specification. This API mimics the OpenAI translation API.
|
||||
"""
|
||||
segment = wav[start_idx:end_idx]
|
||||
return await self._create_speech_to_text(
|
||||
audio_data=audio_data,
|
||||
request=request,
|
||||
raw_request=raw_request,
|
||||
response_class=TranslationResponse,
|
||||
stream_generator_method=self.translation_stream_generator,
|
||||
)
|
||||
|
||||
# Calculate RMS energy in small windows
|
||||
min_energy = math.inf
|
||||
quietest_idx = 0
|
||||
for i in range(0,
|
||||
len(segment) - MIN_ENERGY_WINDOW_SIZE,
|
||||
MIN_ENERGY_WINDOW_SIZE):
|
||||
window = segment[i:i + MIN_ENERGY_WINDOW_SIZE]
|
||||
energy = (window**2).mean()**0.5
|
||||
if energy < min_energy:
|
||||
quietest_idx = i + start_idx
|
||||
min_energy = energy
|
||||
return quietest_idx
|
||||
async def translation_stream_generator(
|
||||
self, request: TranslationRequest,
|
||||
result_generator: list[AsyncGenerator[RequestOutput, None]],
|
||||
request_id: str, request_metadata: RequestResponseMetadata,
|
||||
audio_duration_s: float) -> AsyncGenerator[str, None]:
|
||||
generator = self._speech_to_text_stream_generator(
|
||||
request=request,
|
||||
list_result_generator=result_generator,
|
||||
request_id=request_id,
|
||||
request_metadata=request_metadata,
|
||||
audio_duration_s=audio_duration_s,
|
||||
chunk_object_type="translation.chunk",
|
||||
response_stream_choice_class=TranslationResponseStreamChoice,
|
||||
stream_response_class=TranslationStreamResponse,
|
||||
)
|
||||
async for chunk in generator:
|
||||
yield chunk
|
||||
|
||||
503
vllm/entrypoints/openai/speech_to_text.py
Normal file
503
vllm/entrypoints/openai/speech_to_text.py
Normal file
@ -0,0 +1,503 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
import io
|
||||
import math
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from math import ceil
|
||||
from typing import Callable, Literal, Optional, TypeVar, Union, cast
|
||||
|
||||
import numpy as np
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
DeltaMessage, ErrorResponse, RequestResponseMetadata,
|
||||
TranscriptionResponse, TranscriptionResponseStreamChoice,
|
||||
TranscriptionStreamResponse, TranslationResponse,
|
||||
TranslationResponseStreamChoice, TranslationStreamResponse, UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
|
||||
SpeechToTextRequest)
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.transformers_utils.processor import cached_get_processor
|
||||
from vllm.utils import PlaceholderModule
|
||||
|
||||
try:
|
||||
import librosa
|
||||
except ImportError:
|
||||
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
|
||||
|
||||
SpeechToTextResponse = Union[TranscriptionResponse, TranslationResponse]
|
||||
T = TypeVar("T", bound=SpeechToTextResponse)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# From https://platform.openai.com/docs/guides/speech-to-text/supported-languages
|
||||
# TODO these configs should live somewhere with the model so we can support
|
||||
# additional ones
|
||||
|
||||
ISO639_1_SUPPORTED_LANGS = {
|
||||
"af": "Afrikaans",
|
||||
"ar": "Arabic",
|
||||
"hy": "Armenian",
|
||||
"az": "Azerbaijani",
|
||||
"be": "Belarusian",
|
||||
"bs": "Bosnian",
|
||||
"bg": "Bulgarian",
|
||||
"ca": "Catalan",
|
||||
"zh": "Chinese",
|
||||
"hr": "Croatian",
|
||||
"cs": "Czech",
|
||||
"da": "Danish",
|
||||
"nl": "Dutch",
|
||||
"en": "English",
|
||||
"et": "Estonian",
|
||||
"fi": "Finnish",
|
||||
"fr": "French",
|
||||
"gl": "Galician",
|
||||
"de": "German",
|
||||
"el": "Greek",
|
||||
"he": "Hebrew",
|
||||
"hi": "Hindi",
|
||||
"hu": "Hungarian",
|
||||
"is": "Icelandic",
|
||||
"id": "Indonesian",
|
||||
"it": "Italian",
|
||||
"ja": "Japanese",
|
||||
"kn": "Kannada",
|
||||
"kk": "Kazakh",
|
||||
"ko": "Korean",
|
||||
"lv": "Latvian",
|
||||
"lt": "Lithuanian",
|
||||
"mk": "Macedonian",
|
||||
"ms": "Malay",
|
||||
"mr": "Marathi",
|
||||
"mi": "Maori",
|
||||
"ne": "Nepali",
|
||||
"no": "Norwegian",
|
||||
"fa": "Persian",
|
||||
"pl": "Polish",
|
||||
"pt": "Portuguese",
|
||||
"ro": "Romanian",
|
||||
"ru": "Russian",
|
||||
"sr": "Serbian",
|
||||
"sk": "Slovak",
|
||||
"sl": "Slovenian",
|
||||
"es": "Spanish",
|
||||
"sw": "Swahili",
|
||||
"sv": "Swedish",
|
||||
"tl": "Tagalog",
|
||||
"ta": "Tamil",
|
||||
"th": "Thai",
|
||||
"tr": "Turkish",
|
||||
"uk": "Ukrainian",
|
||||
"ur": "Urdu",
|
||||
"vi": "Vietnamese",
|
||||
"cy": "Welsh"
|
||||
}
|
||||
ISO639_1_OTHER_LANGS = {
|
||||
"lo": "Lao",
|
||||
"jw": "Javanese",
|
||||
"tk": "Turkmen",
|
||||
"yi": "Yiddish",
|
||||
"so": "Somali",
|
||||
"bn": "Bengali",
|
||||
"nn": "Norwegian Nynorsk",
|
||||
"si": "Sinhala",
|
||||
"yo": "Yoruba",
|
||||
"sa": "Sanskrit",
|
||||
"mi": "Māori",
|
||||
"fo": "Faroese", # codespell:ignore
|
||||
"mt": "Maltese",
|
||||
"tg": "Tajik",
|
||||
"mg": "Malagasy",
|
||||
"haw": "Hawaiian",
|
||||
"km": "Khmer",
|
||||
"br": "Breton",
|
||||
"ps": "Pashto",
|
||||
"ln": "Lingala",
|
||||
"la": "Latin",
|
||||
"ml": "Malayalam",
|
||||
"sq": "Albanian",
|
||||
"su": "Sundanese",
|
||||
"eu": "Basque",
|
||||
"ka": "Georgian",
|
||||
"uz": "Uzbek",
|
||||
"sn": "Shona",
|
||||
"ht": "Haitian",
|
||||
"as": "Assamese",
|
||||
"mn": "Mongolian",
|
||||
"te": "Telugu",
|
||||
"pa": "Panjabi",
|
||||
"tt": "Tatar",
|
||||
"gu": "Gujarati",
|
||||
"oc": "Occitan",
|
||||
"ha": "Hausa",
|
||||
"ba": "Bashkir",
|
||||
"my": "Burmese",
|
||||
"sd": "Sindhi",
|
||||
"am": "Amharic",
|
||||
"lb": "Luxembourgish",
|
||||
"bo": "Tibetan"
|
||||
}
|
||||
|
||||
# As per https://platform.openai.com/docs/guides/speech-to-text#overview.
|
||||
# TODO configurable
|
||||
MAX_AUDIO_CLIP_FILESIZE_MB = 25
|
||||
OVERLAP_CHUNK_SECOND = 1
|
||||
MIN_ENERGY_WINDOW_SIZE = 1600 # 1600 ~ 100ms for 16000 Hz audio
|
||||
|
||||
|
||||
class OpenAISpeechToText(OpenAIServing):
|
||||
"""Base class for speech-to-text operations like transcription and
|
||||
translation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
task_type: Literal["transcribe", "translate"] = "transcribe",
|
||||
):
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids)
|
||||
|
||||
self.default_sampling_params = (
|
||||
self.model_config.get_diff_sampling_param())
|
||||
processor = cached_get_processor(model_config.model)
|
||||
self.max_audio_clip_s = processor.feature_extractor.chunk_length
|
||||
self.model_sr = processor.feature_extractor.sampling_rate
|
||||
self.hop_length = processor.feature_extractor.hop_length
|
||||
self.task_type = task_type
|
||||
|
||||
if self.default_sampling_params:
|
||||
logger.info(
|
||||
"Overwriting default completion sampling param with: %s",
|
||||
self.default_sampling_params)
|
||||
|
||||
async def _preprocess_speech_to_text(
|
||||
self,
|
||||
request: SpeechToTextRequest,
|
||||
audio_data: bytes,
|
||||
) -> tuple[list[PromptType], float]:
|
||||
# Validate request
|
||||
# TODO language should be optional and can be guessed.
|
||||
# For now we default to en. See
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520
|
||||
lang_token = f"<|{request.language}|>" if request.language else "<|en|>"
|
||||
if request.language:
|
||||
if request.language in ISO639_1_SUPPORTED_LANGS:
|
||||
pass
|
||||
elif request.language in ISO639_1_OTHER_LANGS:
|
||||
logger.warning(
|
||||
"The selected language %s has limited accuracy with"
|
||||
" reported WER>=0.5. Results may be less accurate "
|
||||
"for this choice.", request.language)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported language: {request.language}."
|
||||
"Language should be one of:" +
|
||||
f" {list(ISO639_1_SUPPORTED_LANGS.values())}" +
|
||||
f"or {list(ISO639_1_OTHER_LANGS.values())}")
|
||||
|
||||
if len(audio_data) / 1024**2 > MAX_AUDIO_CLIP_FILESIZE_MB:
|
||||
raise ValueError("Maximum file size exceeded.")
|
||||
|
||||
with io.BytesIO(audio_data) as bytes_:
|
||||
# NOTE resample to model SR here for efficiency. This is also a
|
||||
# pre-requisite for chunking, as it assumes Whisper SR.
|
||||
y, sr = librosa.load(bytes_, sr=self.model_sr)
|
||||
|
||||
duration = librosa.get_duration(y=y, sr=sr)
|
||||
chunks = [y] if duration < 30 else self._split_audio(y, int(sr))
|
||||
prompts = []
|
||||
for chunk in chunks:
|
||||
prompt = {
|
||||
"encoder_prompt": {
|
||||
"prompt": "",
|
||||
"multi_modal_data": {
|
||||
"audio": (chunk, sr),
|
||||
},
|
||||
},
|
||||
"decoder_prompt":
|
||||
(f"<|startoftranscript|>{lang_token}"
|
||||
f"<|{self.task_type}|><|notimestamps|>{request.prompt}")
|
||||
}
|
||||
prompts.append(cast(PromptType, prompt))
|
||||
return prompts, duration
|
||||
|
||||
async def _create_speech_to_text(
|
||||
self,
|
||||
audio_data: bytes,
|
||||
request: SpeechToTextRequest,
|
||||
raw_request: Request,
|
||||
response_class: type[T],
|
||||
stream_generator_method: Callable[..., AsyncGenerator[str, None]],
|
||||
) -> Union[T, AsyncGenerator[str, None], ErrorResponse]:
|
||||
"""Base method for speech-to-text operations like transcription and
|
||||
translation."""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
# If the engine is dead, raise the engine's DEAD_ERROR.
|
||||
# This is required for the streaming case, where we return a
|
||||
# success status before we actually start generating text :).
|
||||
if self.engine_client.errored:
|
||||
raise self.engine_client.dead_error
|
||||
|
||||
if request.response_format not in ['text', 'json']:
|
||||
return self.create_error_response(
|
||||
"Currently only support response_format `text` or `json`")
|
||||
|
||||
request_id = f"{self.task_type}-{self._base_request_id(raw_request)}"
|
||||
|
||||
request_metadata = RequestResponseMetadata(request_id=request_id)
|
||||
if raw_request:
|
||||
raw_request.state.request_metadata = request_metadata
|
||||
|
||||
try:
|
||||
(
|
||||
lora_request,
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request)
|
||||
|
||||
if lora_request:
|
||||
return self.create_error_response(
|
||||
"Currently do not support LoRA for "
|
||||
f"{self.task_type.title()}.")
|
||||
if prompt_adapter_request:
|
||||
return self.create_error_response(
|
||||
f"Currently do not support PromptAdapter for "
|
||||
f"{self.task_type.title()}.")
|
||||
|
||||
prompts, duration_s = await self._preprocess_speech_to_text(
|
||||
request=request,
|
||||
audio_data=audio_data,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
list_result_generator: Optional[list[AsyncGenerator[RequestOutput,
|
||||
None]]] = None
|
||||
try:
|
||||
# Unlike most decoder-only models, whisper generation length is not
|
||||
# constrained by the size of the input audio, which is mapped to a
|
||||
# fixed-size log-mel-spectogram.
|
||||
default_max_tokens = self.model_config.max_model_len
|
||||
sampling_params = request.to_sampling_params(
|
||||
default_max_tokens, self.default_sampling_params)
|
||||
|
||||
self._log_inputs(
|
||||
request_id,
|
||||
prompts[0]['decoder_prompt'], # type: ignore
|
||||
params=sampling_params,
|
||||
lora_request=None,
|
||||
prompt_adapter_request=None)
|
||||
|
||||
list_result_generator = [
|
||||
self.engine_client.generate(
|
||||
prompt,
|
||||
sampling_params,
|
||||
request_id,
|
||||
) for prompt in prompts
|
||||
]
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
if request.stream:
|
||||
return stream_generator_method(request, list_result_generator,
|
||||
request_id, request_metadata,
|
||||
duration_s)
|
||||
# Non-streaming response.
|
||||
try:
|
||||
assert list_result_generator is not None
|
||||
text = ""
|
||||
for result_generator in list_result_generator:
|
||||
async for op in result_generator:
|
||||
text += op.outputs[0].text
|
||||
return cast(T, response_class(text=text))
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
async def _speech_to_text_stream_generator(
|
||||
self,
|
||||
request: SpeechToTextRequest,
|
||||
list_result_generator: list[AsyncGenerator[RequestOutput, None]],
|
||||
request_id: str,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
audio_duration_s: float,
|
||||
chunk_object_type: Literal["translation.chunk", "transcription.chunk"],
|
||||
response_stream_choice_class: Union[
|
||||
type[TranscriptionResponseStreamChoice],
|
||||
type[TranslationResponseStreamChoice]],
|
||||
stream_response_class: Union[type[TranscriptionStreamResponse],
|
||||
type[TranslationStreamResponse]],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
created_time = int(time.time())
|
||||
model_name = request.model
|
||||
|
||||
completion_tokens = 0
|
||||
num_prompt_tokens = 0
|
||||
|
||||
include_usage = request.stream_include_usage \
|
||||
if request.stream_include_usage else False
|
||||
include_continuous_usage = request.stream_continuous_usage_stats\
|
||||
if include_usage and request.stream_continuous_usage_stats\
|
||||
else False
|
||||
|
||||
try:
|
||||
for result_generator in list_result_generator:
|
||||
async for res in result_generator:
|
||||
# On first result.
|
||||
if res.prompt_token_ids is not None:
|
||||
# Do not account the 4-tokens `<|startoftranscript|>..`
|
||||
# Could be negative when language token
|
||||
# is not specified.
|
||||
num_prompt_tokens = max(
|
||||
len(res.prompt_token_ids) - 4, 0)
|
||||
# NOTE(NickLucche) user can't pass encoder
|
||||
# prompts directly at least not to Whisper.
|
||||
# One indicator of the encoder amount of processing
|
||||
# is the log-mel spectogram length.
|
||||
num_prompt_tokens += ceil(
|
||||
audio_duration_s * self.model_sr / self.hop_length)
|
||||
|
||||
# We need to do it here, because if there are exceptions in
|
||||
# the result_generator, it needs to be sent as the FIRST
|
||||
# response (by the try...catch).
|
||||
|
||||
# Just one output (n=1) supported.
|
||||
assert len(res.outputs) == 1
|
||||
output = res.outputs[0]
|
||||
|
||||
delta_message = DeltaMessage(content=output.text)
|
||||
completion_tokens += len(output.token_ids)
|
||||
|
||||
if output.finish_reason is None:
|
||||
# Still generating, send delta update.
|
||||
choice_data = response_stream_choice_class(
|
||||
delta=delta_message)
|
||||
else:
|
||||
# Model is finished generating.
|
||||
choice_data = response_stream_choice_class(
|
||||
delta=delta_message,
|
||||
finish_reason=output.finish_reason,
|
||||
stop_reason=output.stop_reason)
|
||||
|
||||
chunk = stream_response_class(id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[choice_data],
|
||||
model=model_name)
|
||||
|
||||
# handle usage stats if requested & if continuous
|
||||
if include_continuous_usage:
|
||||
chunk.usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=num_prompt_tokens + completion_tokens,
|
||||
)
|
||||
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
# Once the final token is handled, if stream_options.include_usage
|
||||
# is sent, send the usage.
|
||||
if include_usage:
|
||||
final_usage = UsageInfo(prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=num_prompt_tokens +
|
||||
completion_tokens)
|
||||
|
||||
final_usage_chunk = stream_response_class(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[],
|
||||
model=model_name,
|
||||
usage=final_usage)
|
||||
final_usage_data = (final_usage_chunk.model_dump_json(
|
||||
exclude_unset=True, exclude_none=True))
|
||||
yield f"data: {final_usage_data}\n\n"
|
||||
|
||||
# report to FastAPI middleware aggregate usage across all choices
|
||||
request_metadata.final_usage_info = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=num_prompt_tokens + completion_tokens)
|
||||
|
||||
except Exception as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
logger.exception("Error in %s stream generator.", self.task_type)
|
||||
data = self.create_streaming_error_response(str(e))
|
||||
yield f"data: {data}\n\n"
|
||||
# Send the final done message after all response.n are finished
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
def _split_audio(self, audio_data: np.ndarray,
|
||||
sample_rate: int) -> list[np.ndarray]:
|
||||
chunk_size = sample_rate * self.max_audio_clip_s
|
||||
overlap_size = sample_rate * OVERLAP_CHUNK_SECOND
|
||||
chunks = []
|
||||
i = 0
|
||||
while i < audio_data.shape[-1]:
|
||||
if i + chunk_size >= audio_data.shape[-1]:
|
||||
# handle last chunk
|
||||
chunks.append(audio_data[..., i:])
|
||||
break
|
||||
|
||||
# Find the best split point in the overlap region
|
||||
search_start = i + chunk_size - overlap_size
|
||||
search_end = min(i + chunk_size, audio_data.shape[-1])
|
||||
split_point = self._find_split_point(audio_data, search_start,
|
||||
search_end)
|
||||
|
||||
# Extract chunk up to the split point
|
||||
chunks.append(audio_data[..., i:split_point])
|
||||
i = split_point
|
||||
return chunks
|
||||
|
||||
def _find_split_point(self, wav: np.ndarray, start_idx: int,
|
||||
end_idx: int) -> int:
|
||||
"""Find the best point to split audio by
|
||||
looking for silence or low amplitude.
|
||||
Args:
|
||||
wav: Audio tensor [1, T]
|
||||
start_idx: Start index of search region
|
||||
end_idx: End index of search region
|
||||
Returns:
|
||||
Index of best splitting point
|
||||
"""
|
||||
segment = wav[start_idx:end_idx]
|
||||
|
||||
# Calculate RMS energy in small windows
|
||||
min_energy = math.inf
|
||||
quietest_idx = 0
|
||||
for i in range(0,
|
||||
len(segment) - MIN_ENERGY_WINDOW_SIZE,
|
||||
MIN_ENERGY_WINDOW_SIZE):
|
||||
window = segment[i:i + MIN_ENERGY_WINDOW_SIZE]
|
||||
energy = (window**2).mean()**0.5
|
||||
if energy < min_energy:
|
||||
quietest_idx = i + start_idx
|
||||
min_energy = energy
|
||||
return quietest_idx
|
||||
Loading…
x
Reference in New Issue
Block a user