mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 18:38:38 +08:00
[Frontend] Add /v1/audio/transcriptions OpenAI API endpoint (#12909)
This commit is contained in:
parent
37dfa60037
commit
d84cef76eb
@ -117,7 +117,7 @@ steps:
|
|||||||
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
|
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
|
||||||
- pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process
|
- pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process
|
||||||
- pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process
|
- pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process
|
||||||
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py
|
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/correctness/
|
||||||
- pytest -v -s entrypoints/test_chat_utils.py
|
- pytest -v -s entrypoints/test_chat_utils.py
|
||||||
- pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
|
- pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
|
||||||
|
|
||||||
@ -205,7 +205,7 @@ steps:
|
|||||||
- VLLM_USE_V1=1 pytest -v -s v1/e2e
|
- VLLM_USE_V1=1 pytest -v -s v1/e2e
|
||||||
# Integration test for streaming correctness (requires special branch).
|
# Integration test for streaming correctness (requires special branch).
|
||||||
- pip install -U git+https://github.com/robertgshaw2-neuralmagic/lm-evaluation-harness.git@streaming-api
|
- pip install -U git+https://github.com/robertgshaw2-neuralmagic/lm-evaluation-harness.git@streaming-api
|
||||||
- pytest -v -s entrypoints/openai/test_accuracy.py::test_lm_eval_accuracy_v1_engine
|
- pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine
|
||||||
|
|
||||||
- label: Examples Test # 25min
|
- label: Examples Test # 25min
|
||||||
working_dir: "/vllm-workspace/examples"
|
working_dir: "/vllm-workspace/examples"
|
||||||
@ -339,6 +339,14 @@ steps:
|
|||||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||||
- bash ./run-tests.sh -c configs/models-small.txt -t 1
|
- bash ./run-tests.sh -c configs/models-small.txt -t 1
|
||||||
|
|
||||||
|
- label: OpenAI API correctness
|
||||||
|
source_file_dependencies:
|
||||||
|
- csrc/
|
||||||
|
- vllm/entrypoints/openai/
|
||||||
|
- vllm/model_executor/models/whisper.py
|
||||||
|
commands: # LMEval+Transcription WER check
|
||||||
|
- pytest -s entrypoints/openai/correctness/
|
||||||
|
|
||||||
- label: Encoder Decoder tests # 5min
|
- label: Encoder Decoder tests # 5min
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
|
|||||||
@ -41,6 +41,8 @@ We currently support the following OpenAI APIs:
|
|||||||
- *Note: `parallel_tool_calls` and `user` parameters are ignored.*
|
- *Note: `parallel_tool_calls` and `user` parameters are ignored.*
|
||||||
- [Embeddings API](#embeddings-api) (`/v1/embeddings`)
|
- [Embeddings API](#embeddings-api) (`/v1/embeddings`)
|
||||||
- Only applicable to [embedding models](../models/pooling_models.md) (`--task embed`).
|
- Only applicable to [embedding models](../models/pooling_models.md) (`--task embed`).
|
||||||
|
- [Transcriptions API](#transcriptions-api) (`/v1/audio/transcriptions`)
|
||||||
|
- Only applicable to Automatic Speech Recognition (ASR) models (OpenAI Whisper) (`--task generate`).
|
||||||
|
|
||||||
In addition, we have the following custom APIs:
|
In addition, we have the following custom APIs:
|
||||||
|
|
||||||
@ -296,6 +298,17 @@ For chat-like input (i.e. if `messages` is passed), these extra parameters are s
|
|||||||
:end-before: end-chat-embedding-extra-params
|
:end-before: end-chat-embedding-extra-params
|
||||||
:::
|
:::
|
||||||
|
|
||||||
|
(transcriptions-api)=
|
||||||
|
|
||||||
|
### Transcriptions API
|
||||||
|
|
||||||
|
Our Transcriptions API is compatible with [OpenAI's Transcriptions API](https://platform.openai.com/docs/api-reference/audio/createTranscription);
|
||||||
|
you can use the [official OpenAI Python client](https://github.com/openai/openai-python) to interact with it.
|
||||||
|
|
||||||
|
<!-- TODO: api enforced limits + uploading audios -->
|
||||||
|
|
||||||
|
Code example: <gh-file:examples/online_serving/openai_transcription_client.py>
|
||||||
|
|
||||||
(tokenizer-api)=
|
(tokenizer-api)=
|
||||||
|
|
||||||
### Tokenizer API
|
### Tokenizer API
|
||||||
|
|||||||
23
examples/online_serving/openai_transcription_client.py
Normal file
23
examples/online_serving/openai_transcription_client.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
from vllm.assets.audio import AudioAsset
|
||||||
|
|
||||||
|
mary_had_lamb = AudioAsset('mary_had_lamb').get_local_path()
|
||||||
|
winning_call = AudioAsset('winning_call').get_local_path()
|
||||||
|
|
||||||
|
# Modify OpenAI's API key and API base to use vLLM's API server.
|
||||||
|
openai_api_key = "EMPTY"
|
||||||
|
openai_api_base = "http://localhost:8000/v1"
|
||||||
|
client = OpenAI(
|
||||||
|
api_key=openai_api_key,
|
||||||
|
base_url=openai_api_base,
|
||||||
|
)
|
||||||
|
with open(str(mary_had_lamb), "rb") as f:
|
||||||
|
transcription = client.audio.transcriptions.create(
|
||||||
|
file=f,
|
||||||
|
model="openai/whisper-large-v3",
|
||||||
|
language="en",
|
||||||
|
response_format="text",
|
||||||
|
temperature=0.0)
|
||||||
|
print("transcription result:", transcription)
|
||||||
@ -8,12 +8,11 @@ py-cpuinfo
|
|||||||
transformers >= 4.48.2 # Required for Bamba model and Transformers backend.
|
transformers >= 4.48.2 # Required for Bamba model and Transformers backend.
|
||||||
tokenizers >= 0.19.1 # Required for Llama 3.
|
tokenizers >= 0.19.1 # Required for Llama 3.
|
||||||
protobuf # Required by LlamaTokenizer.
|
protobuf # Required by LlamaTokenizer.
|
||||||
fastapi >= 0.107.0, < 0.113.0; python_version < '3.9'
|
fastapi[standard] >= 0.107.0, < 0.113.0; python_version < '3.9'
|
||||||
fastapi >= 0.107.0, != 0.113.*, != 0.114.0; python_version >= '3.9'
|
fastapi[standard] >= 0.107.0, != 0.113.*, != 0.114.0; python_version >= '3.9'
|
||||||
aiohttp
|
aiohttp
|
||||||
openai >= 1.52.0 # Ensure modern openai package (ensure types module present and max_completion_tokens field support)
|
openai >= 1.52.0 # Ensure modern openai package (ensure types module present and max_completion_tokens field support)
|
||||||
uvicorn[standard]
|
pydantic >= 2.9
|
||||||
pydantic >= 2.9 # Required for fastapi >= 0.113.0
|
|
||||||
prometheus_client >= 0.18.0
|
prometheus_client >= 0.18.0
|
||||||
pillow # Required for image processing
|
pillow # Required for image processing
|
||||||
prometheus-fastapi-instrumentator >= 7.0.0
|
prometheus-fastapi-instrumentator >= 7.0.0
|
||||||
|
|||||||
@ -19,6 +19,7 @@ pqdm
|
|||||||
ray[adag]==2.40.0
|
ray[adag]==2.40.0
|
||||||
sentence-transformers # required for embedding tests
|
sentence-transformers # required for embedding tests
|
||||||
soundfile # required for audio tests
|
soundfile # required for audio tests
|
||||||
|
jiwer # required for audio tests
|
||||||
timm # required for internvl test
|
timm # required for internvl test
|
||||||
torch==2.5.1
|
torch==2.5.1
|
||||||
torchaudio==2.5.1
|
torchaudio==2.5.1
|
||||||
|
|||||||
@ -66,6 +66,7 @@ charset-normalizer==3.4.0
|
|||||||
click==8.1.7
|
click==8.1.7
|
||||||
# via
|
# via
|
||||||
# black
|
# black
|
||||||
|
# jiwer
|
||||||
# nltk
|
# nltk
|
||||||
# ray
|
# ray
|
||||||
colorama==0.4.6
|
colorama==0.4.6
|
||||||
@ -187,6 +188,8 @@ jinja2==3.1.4
|
|||||||
# via
|
# via
|
||||||
# datamodel-code-generator
|
# datamodel-code-generator
|
||||||
# torch
|
# torch
|
||||||
|
jiwer==3.0.5
|
||||||
|
# via -r requirements-test.in
|
||||||
jmespath==1.0.1
|
jmespath==1.0.1
|
||||||
# via
|
# via
|
||||||
# boto3
|
# boto3
|
||||||
@ -470,6 +473,8 @@ pyyaml==6.0.2
|
|||||||
# timm
|
# timm
|
||||||
# transformers
|
# transformers
|
||||||
# vocos
|
# vocos
|
||||||
|
rapidfuzz==3.12.1
|
||||||
|
# via jiwer
|
||||||
ray[adag]==2.40.0
|
ray[adag]==2.40.0
|
||||||
# via -r requirements-test.in
|
# via -r requirements-test.in
|
||||||
redis==5.2.0
|
redis==5.2.0
|
||||||
|
|||||||
0
tests/entrypoints/openai/correctness/__init__.py
Normal file
0
tests/entrypoints/openai/correctness/__init__.py
Normal file
@ -13,7 +13,7 @@ import pytest
|
|||||||
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
from ...utils import RemoteOpenAIServer
|
from ....utils import RemoteOpenAIServer
|
||||||
|
|
||||||
MODEL_NAME = "Qwen/Qwen2-1.5B-Instruct"
|
MODEL_NAME = "Qwen/Qwen2-1.5B-Instruct"
|
||||||
NUM_CONCURRENT = 500
|
NUM_CONCURRENT = 500
|
||||||
@ -0,0 +1,166 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
"""
|
||||||
|
Evaluate Transcription API correctness by computing Word Error Rate (WER)
|
||||||
|
on a given ASR dataset. When provided, it will also compare the WER against
|
||||||
|
a baseline.
|
||||||
|
This simulates real work usage of the API and makes sure that the frontend and
|
||||||
|
AsyncLLMEngine are working correctly.
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import io
|
||||||
|
import time
|
||||||
|
from statistics import mean, median
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import librosa
|
||||||
|
import pytest
|
||||||
|
import soundfile
|
||||||
|
import torch
|
||||||
|
from datasets import load_dataset
|
||||||
|
from evaluate import load
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from ....utils import RemoteOpenAIServer
|
||||||
|
|
||||||
|
|
||||||
|
def to_bytes(y, sr):
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
soundfile.write(buffer, y, sr, format="WAV")
|
||||||
|
buffer.seek(0)
|
||||||
|
return buffer
|
||||||
|
|
||||||
|
|
||||||
|
async def transcribe_audio(client, tokenizer, y, sr):
|
||||||
|
# Send loaded audio directly instead of loading from disk,
|
||||||
|
# dont account for that time though
|
||||||
|
with to_bytes(y, sr) as f:
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
transcription = await client.audio.transcriptions.create(
|
||||||
|
file=f,
|
||||||
|
model=tokenizer.name_or_path,
|
||||||
|
language="en",
|
||||||
|
temperature=0.0,
|
||||||
|
)
|
||||||
|
end_time = time.perf_counter()
|
||||||
|
# NOTE there's no streaming in transcriptions, can't measure ttft
|
||||||
|
latency = end_time - start_time
|
||||||
|
num_output_tokens = len(
|
||||||
|
tokenizer(transcription.text, add_special_tokens=False).input_ids)
|
||||||
|
return latency, num_output_tokens, transcription.text
|
||||||
|
|
||||||
|
|
||||||
|
async def bound_transcribe(model_name, sem, client, audio, reference):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
# Use semaphore to limit concurrent requests.
|
||||||
|
async with sem:
|
||||||
|
result = await transcribe_audio(client, tokenizer, *audio)
|
||||||
|
# Normalize *english* output/reference for evaluation.
|
||||||
|
out = tokenizer.normalize(result[2])
|
||||||
|
ref = tokenizer.normalize(reference)
|
||||||
|
return result[:2] + (out, ref)
|
||||||
|
|
||||||
|
|
||||||
|
async def process_dataset(model, client, data, concurrent_request):
|
||||||
|
sem = asyncio.Semaphore(concurrent_request)
|
||||||
|
|
||||||
|
# Warmup call as the first `librosa.load` server-side is quite slow.
|
||||||
|
audio, sr = data[0]["audio"]["array"], data[0]["audio"]["sampling_rate"]
|
||||||
|
_ = await bound_transcribe(model, sem, client, (audio, sr), "")
|
||||||
|
|
||||||
|
tasks: List[asyncio.Task] = []
|
||||||
|
for sample in data:
|
||||||
|
audio, sr = sample["audio"]["array"], sample["audio"]["sampling_rate"]
|
||||||
|
task = asyncio.create_task(
|
||||||
|
bound_transcribe(model, sem, client, (audio, sr), sample["text"]))
|
||||||
|
tasks.append(task)
|
||||||
|
return await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
|
||||||
|
def print_performance_metrics(results, total_time):
|
||||||
|
latencies = [res[0] for res in results]
|
||||||
|
total_tokens = sum([res[1] for res in results])
|
||||||
|
|
||||||
|
total = len(results)
|
||||||
|
print(f"Total Requests: {total}")
|
||||||
|
print(f"Successful Requests: {len(latencies)}")
|
||||||
|
print(f"Average Latency: {mean(latencies):.4f} seconds")
|
||||||
|
print(f"Median Latency: {median(latencies):.4f} seconds")
|
||||||
|
perc = sorted(latencies)[int(len(latencies) * 0.95) - 1]
|
||||||
|
print(f"95th Percentile Latency: {perc:.4f} seconds")
|
||||||
|
# Throughput
|
||||||
|
req_throughput = len(latencies) / total_time
|
||||||
|
print(f"Estimated req_Throughput: {req_throughput:.2f} requests/s")
|
||||||
|
throughput = total_tokens / total_time
|
||||||
|
print(f"Estimated Throughput: {throughput:.2f} tok/s")
|
||||||
|
|
||||||
|
|
||||||
|
def add_duration(sample):
|
||||||
|
y, sr = sample['audio']["array"], sample['audio']["sampling_rate"]
|
||||||
|
sample['duration_ms'] = librosa.get_duration(y=y, sr=sr) * 1000
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
def load_hf_dataset(dataset_repo: str, split='validation', **hf_kwargs):
|
||||||
|
## Load and filter the dataset
|
||||||
|
dataset = load_dataset(dataset_repo, split=split, **hf_kwargs)
|
||||||
|
if 'duration_ms' not in dataset[0]:
|
||||||
|
# compute duration to filter
|
||||||
|
dataset = dataset.map(add_duration)
|
||||||
|
|
||||||
|
# Whisper max supported duration
|
||||||
|
dataset = dataset.filter(lambda example: example['duration_ms'] < 30000)
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
def run_evaluation(model: str,
|
||||||
|
client,
|
||||||
|
dataset,
|
||||||
|
max_concurrent_reqs: int,
|
||||||
|
n_examples: int = -1,
|
||||||
|
print_metrics: bool = True):
|
||||||
|
if n_examples > 0:
|
||||||
|
dataset = dataset.select(range(n_examples))
|
||||||
|
start = time.perf_counter()
|
||||||
|
results = asyncio.run(
|
||||||
|
process_dataset(model, client, dataset, max_concurrent_reqs))
|
||||||
|
end = time.perf_counter()
|
||||||
|
total_time = end - start
|
||||||
|
print(f"Total Test Time: {total_time:.4f} seconds")
|
||||||
|
if print_metrics:
|
||||||
|
print_performance_metrics(results, total_time)
|
||||||
|
# Compute WER
|
||||||
|
predictions = [res[2] for res in results]
|
||||||
|
references = [res[3] for res in results]
|
||||||
|
wer = load("wer")
|
||||||
|
wer_score = 100 * wer.compute(references=references,
|
||||||
|
predictions=predictions)
|
||||||
|
print("WER:", wer_score)
|
||||||
|
return wer_score
|
||||||
|
|
||||||
|
|
||||||
|
# alternatives "openai/whisper-large-v2", "openai/whisper-large-v3-turbo"..
|
||||||
|
@pytest.mark.parametrize("model_name", ["openai/whisper-large-v3"])
|
||||||
|
# Original dataset is 20GB+ in size, hence we use a pre-filtered slice.
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"dataset_repo", ["D4nt3/esb-datasets-earnings22-validation-tiny-filtered"])
|
||||||
|
# NOTE: Expected WER measured with equivalent hf.transformers args:
|
||||||
|
# whisper-large-v3 + esb-datasets-earnings22-validation-tiny-filtered.
|
||||||
|
@pytest.mark.parametrize("expected_wer", [12.744980])
|
||||||
|
def test_wer_correctness(model_name,
|
||||||
|
dataset_repo,
|
||||||
|
expected_wer,
|
||||||
|
n_examples=-1,
|
||||||
|
max_concurrent_request=None):
|
||||||
|
with RemoteOpenAIServer(model_name, ['--enforce-eager']) as remote_server:
|
||||||
|
dataset = load_hf_dataset(dataset_repo)
|
||||||
|
|
||||||
|
if not max_concurrent_request:
|
||||||
|
# No max concurrency
|
||||||
|
max_concurrent_request = n_examples if n_examples > 0\
|
||||||
|
else len(dataset)
|
||||||
|
|
||||||
|
client = remote_server.get_async_client()
|
||||||
|
wer = run_evaluation(model_name, client, dataset,
|
||||||
|
max_concurrent_request, n_examples)
|
||||||
|
if expected_wer:
|
||||||
|
torch.testing.assert_close(wer, expected_wer, atol=1e-1, rtol=1e-2)
|
||||||
122
tests/entrypoints/openai/test_transcription_validation.py
Normal file
122
tests/entrypoints/openai/test_transcription_validation.py
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
# imports for guided decoding tests
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
|
||||||
|
import librosa
|
||||||
|
import numpy as np
|
||||||
|
import openai
|
||||||
|
import pytest
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
|
from vllm.assets.audio import AudioAsset
|
||||||
|
|
||||||
|
from ...utils import RemoteOpenAIServer
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mary_had_lamb():
|
||||||
|
path = AudioAsset('mary_had_lamb').get_local_path()
|
||||||
|
with open(str(path), "rb") as f:
|
||||||
|
yield f
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def winning_call():
|
||||||
|
path = AudioAsset('winning_call').get_local_path()
|
||||||
|
with open(str(path), "rb") as f:
|
||||||
|
yield f
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_basic_audio(mary_had_lamb):
|
||||||
|
model_name = "openai/whisper-large-v3-turbo"
|
||||||
|
server_args = ["--enforce-eager"]
|
||||||
|
# Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb.
|
||||||
|
prompt = "THE FIRST WORDS I SPOKE"
|
||||||
|
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
||||||
|
client = remote_server.get_async_client()
|
||||||
|
transcription = await client.audio.transcriptions.create(
|
||||||
|
model=model_name,
|
||||||
|
file=mary_had_lamb,
|
||||||
|
language="en",
|
||||||
|
response_format="text",
|
||||||
|
temperature=0.0)
|
||||||
|
out = json.loads(transcription)['text']
|
||||||
|
assert "Mary had a little lamb," in out
|
||||||
|
# This should "force" whisper to continue prompt in all caps
|
||||||
|
transcription_wprompt = await client.audio.transcriptions.create(
|
||||||
|
model=model_name,
|
||||||
|
file=mary_had_lamb,
|
||||||
|
language="en",
|
||||||
|
response_format="text",
|
||||||
|
prompt=prompt,
|
||||||
|
temperature=0.0)
|
||||||
|
out_capital = json.loads(transcription_wprompt)['text']
|
||||||
|
assert prompt not in out_capital
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bad_requests(mary_had_lamb):
|
||||||
|
model_name = "openai/whisper-small"
|
||||||
|
server_args = ["--enforce-eager"]
|
||||||
|
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
||||||
|
client = remote_server.get_async_client()
|
||||||
|
|
||||||
|
# invalid language
|
||||||
|
with pytest.raises(openai.BadRequestError):
|
||||||
|
await client.audio.transcriptions.create(model=model_name,
|
||||||
|
file=mary_had_lamb,
|
||||||
|
language="hh",
|
||||||
|
temperature=0.0)
|
||||||
|
|
||||||
|
# Expect audio too long: repeat the timeseries
|
||||||
|
mary_had_lamb.seek(0)
|
||||||
|
audio, sr = librosa.load(mary_had_lamb)
|
||||||
|
repeated_audio = np.tile(audio, 10)
|
||||||
|
# Repeated audio to buffer
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
sf.write(buffer, repeated_audio, sr, format='WAV')
|
||||||
|
buffer.seek(0)
|
||||||
|
with pytest.raises(openai.BadRequestError):
|
||||||
|
await client.audio.transcriptions.create(model=model_name,
|
||||||
|
file=buffer,
|
||||||
|
language="en",
|
||||||
|
temperature=0.0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_non_asr_model(winning_call):
|
||||||
|
# text to text model
|
||||||
|
model_name = "JackFram/llama-68m"
|
||||||
|
server_args = ["--enforce-eager"]
|
||||||
|
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
||||||
|
client = remote_server.get_async_client()
|
||||||
|
res = await client.audio.transcriptions.create(model=model_name,
|
||||||
|
file=winning_call,
|
||||||
|
language="en",
|
||||||
|
temperature=0.0)
|
||||||
|
assert res.code == 400 and not res.text
|
||||||
|
assert res.message == "The model does not support Transcriptions API"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_completion_endpoints():
|
||||||
|
# text to text model
|
||||||
|
model_name = "openai/whisper-small"
|
||||||
|
server_args = ["--enforce-eager"]
|
||||||
|
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
||||||
|
client = remote_server.get_async_client()
|
||||||
|
res = await client.chat.completions.create(
|
||||||
|
model=model_name,
|
||||||
|
messages=[{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful assistant."
|
||||||
|
}])
|
||||||
|
assert res.code == 400
|
||||||
|
assert res.message == "The model does not support Chat Completions API"
|
||||||
|
|
||||||
|
res = await client.completions.create(model=model_name, prompt="Hello")
|
||||||
|
assert res.code == 400
|
||||||
|
assert res.message == "The model does not support Completions API"
|
||||||
@ -17,6 +17,7 @@ from vllm.platforms import current_platform
|
|||||||
("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify"),
|
("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify"),
|
||||||
("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "score"),
|
("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "score"),
|
||||||
("Qwen/Qwen2.5-Math-RM-72B", "pooling", "reward"),
|
("Qwen/Qwen2.5-Math-RM-72B", "pooling", "reward"),
|
||||||
|
("openai/whisper-small", "transcription", "transcription"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_auto_task(model_id, expected_runner_type, expected_task):
|
def test_auto_task(model_id, expected_runner_type, expected_task):
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
from urllib.parse import urljoin
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
@ -28,6 +29,10 @@ class AudioAsset:
|
|||||||
s3_prefix=ASSET_DIR)
|
s3_prefix=ASSET_DIR)
|
||||||
return librosa.load(audio_path, sr=None)
|
return librosa.load(audio_path, sr=None)
|
||||||
|
|
||||||
|
def get_local_path(self) -> Path:
|
||||||
|
return get_vllm_public_assets(filename=f"{self.name}.ogg",
|
||||||
|
s3_prefix=ASSET_DIR)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def url(self) -> str:
|
def url(self) -> str:
|
||||||
return urljoin(VLLM_S3_BUCKET_URL, f"{ASSET_DIR}/{self.name}.ogg")
|
return urljoin(VLLM_S3_BUCKET_URL, f"{ASSET_DIR}/{self.name}.ogg")
|
||||||
|
|||||||
@ -54,17 +54,18 @@ _POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
|
|||||||
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120
|
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120
|
||||||
|
|
||||||
TaskOption = Literal["auto", "generate", "embedding", "embed", "classify",
|
TaskOption = Literal["auto", "generate", "embedding", "embed", "classify",
|
||||||
"score", "reward"]
|
"score", "reward", "transcription"]
|
||||||
|
|
||||||
_ResolvedTask = Literal["generate", "embed", "classify", "score", "reward",
|
_ResolvedTask = Literal["generate", "embed", "classify", "score", "reward",
|
||||||
"draft"]
|
"draft", "transcription"]
|
||||||
|
|
||||||
RunnerType = Literal["generate", "pooling", "draft"]
|
RunnerType = Literal["generate", "pooling", "draft", "transcription"]
|
||||||
|
|
||||||
_RUNNER_TASKS: Dict[RunnerType, List[_ResolvedTask]] = {
|
_RUNNER_TASKS: Dict[RunnerType, List[_ResolvedTask]] = {
|
||||||
"generate": ["generate"],
|
"generate": ["generate"],
|
||||||
"pooling": ["embed", "classify", "score", "reward"],
|
"pooling": ["embed", "classify", "score", "reward"],
|
||||||
"draft": ["draft"],
|
"draft": ["draft"],
|
||||||
|
"transcription": ["transcription"],
|
||||||
}
|
}
|
||||||
|
|
||||||
_TASK_RUNNER: Dict[_ResolvedTask, RunnerType] = {
|
_TASK_RUNNER: Dict[_ResolvedTask, RunnerType] = {
|
||||||
@ -484,6 +485,8 @@ class ModelConfig:
|
|||||||
return "embed"
|
return "embed"
|
||||||
if ModelRegistry.is_cross_encoder_model(architectures):
|
if ModelRegistry.is_cross_encoder_model(architectures):
|
||||||
return "score"
|
return "score"
|
||||||
|
if ModelRegistry.is_transcription_model(architectures):
|
||||||
|
return "transcription"
|
||||||
|
|
||||||
suffix_to_preferred_task: List[Tuple[str, _ResolvedTask]] = [
|
suffix_to_preferred_task: List[Tuple[str, _ResolvedTask]] = [
|
||||||
# Other models follow this pattern
|
# Other models follow this pattern
|
||||||
@ -516,6 +519,8 @@ class ModelConfig:
|
|||||||
runner_support: Dict[RunnerType, bool] = {
|
runner_support: Dict[RunnerType, bool] = {
|
||||||
# NOTE: Listed from highest to lowest priority,
|
# NOTE: Listed from highest to lowest priority,
|
||||||
# in case the model supports multiple of them
|
# in case the model supports multiple of them
|
||||||
|
"transcription":
|
||||||
|
ModelRegistry.is_transcription_model(architectures),
|
||||||
"generate": ModelRegistry.is_text_generation_model(architectures),
|
"generate": ModelRegistry.is_text_generation_model(architectures),
|
||||||
"pooling": ModelRegistry.is_pooling_model(architectures),
|
"pooling": ModelRegistry.is_pooling_model(architectures),
|
||||||
}
|
}
|
||||||
|
|||||||
@ -16,10 +16,10 @@ from argparse import Namespace
|
|||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import AsyncIterator, Dict, Optional, Set, Tuple, Union
|
from typing import Annotated, AsyncIterator, Dict, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import uvloop
|
import uvloop
|
||||||
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request
|
from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||||
@ -61,6 +61,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
|||||||
ScoreRequest, ScoreResponse,
|
ScoreRequest, ScoreResponse,
|
||||||
TokenizeRequest,
|
TokenizeRequest,
|
||||||
TokenizeResponse,
|
TokenizeResponse,
|
||||||
|
TranscriptionRequest,
|
||||||
|
TranscriptionResponse,
|
||||||
UnloadLoraAdapterRequest)
|
UnloadLoraAdapterRequest)
|
||||||
from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager
|
from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
@ -75,6 +77,8 @@ from vllm.entrypoints.openai.serving_rerank import JinaAIServingRerank
|
|||||||
from vllm.entrypoints.openai.serving_score import OpenAIServingScores
|
from vllm.entrypoints.openai.serving_score import OpenAIServingScores
|
||||||
from vllm.entrypoints.openai.serving_tokenization import (
|
from vllm.entrypoints.openai.serving_tokenization import (
|
||||||
OpenAIServingTokenization)
|
OpenAIServingTokenization)
|
||||||
|
from vllm.entrypoints.openai.serving_transcription import (
|
||||||
|
OpenAIServingTranscription)
|
||||||
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
||||||
from vllm.entrypoints.utils import with_cancellation
|
from vllm.entrypoints.utils import with_cancellation
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -327,6 +331,10 @@ def tokenization(request: Request) -> OpenAIServingTokenization:
|
|||||||
return request.app.state.openai_serving_tokenization
|
return request.app.state.openai_serving_tokenization
|
||||||
|
|
||||||
|
|
||||||
|
def transcription(request: Request) -> OpenAIServingTranscription:
|
||||||
|
return request.app.state.openai_serving_transcription
|
||||||
|
|
||||||
|
|
||||||
def engine_client(request: Request) -> EngineClient:
|
def engine_client(request: Request) -> EngineClient:
|
||||||
return request.app.state.engine_client
|
return request.app.state.engine_client
|
||||||
|
|
||||||
@ -520,6 +528,31 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request):
|
|||||||
return await create_score(request, raw_request)
|
return await create_score(request, raw_request)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/v1/audio/transcriptions")
|
||||||
|
@with_cancellation
|
||||||
|
async def create_transcriptions(request: Annotated[TranscriptionRequest,
|
||||||
|
Form()],
|
||||||
|
raw_request: Request):
|
||||||
|
|
||||||
|
handler = transcription(raw_request)
|
||||||
|
if handler is None:
|
||||||
|
return base(raw_request).create_error_response(
|
||||||
|
message="The model does not support Transcriptions API")
|
||||||
|
|
||||||
|
audio_data = await request.file.read()
|
||||||
|
generator = await handler.create_transcription(audio_data, request,
|
||||||
|
raw_request)
|
||||||
|
|
||||||
|
if isinstance(generator, ErrorResponse):
|
||||||
|
return JSONResponse(content=generator.model_dump(),
|
||||||
|
status_code=generator.code)
|
||||||
|
|
||||||
|
elif isinstance(generator, TranscriptionResponse):
|
||||||
|
return JSONResponse(content=generator.model_dump())
|
||||||
|
|
||||||
|
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||||
|
|
||||||
|
|
||||||
@router.post("/rerank", dependencies=[Depends(validate_json_request)])
|
@router.post("/rerank", dependencies=[Depends(validate_json_request)])
|
||||||
@with_cancellation
|
@with_cancellation
|
||||||
async def do_rerank(request: RerankRequest, raw_request: Request):
|
async def do_rerank(request: RerankRequest, raw_request: Request):
|
||||||
@ -832,6 +865,12 @@ async def init_app_state(
|
|||||||
chat_template=resolved_chat_template,
|
chat_template=resolved_chat_template,
|
||||||
chat_template_content_format=args.chat_template_content_format,
|
chat_template_content_format=args.chat_template_content_format,
|
||||||
)
|
)
|
||||||
|
state.openai_serving_transcription = OpenAIServingTranscription(
|
||||||
|
engine_client,
|
||||||
|
model_config,
|
||||||
|
state.openai_serving_models,
|
||||||
|
request_logger=request_logger,
|
||||||
|
) if model_config.runner_type == "transcription" else None
|
||||||
state.task = model_config.task
|
state.task = model_config.task
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -8,9 +8,10 @@ from argparse import Namespace
|
|||||||
from typing import Any, ClassVar, Dict, List, Literal, Optional, Set, Union
|
from typing import Any, ClassVar, Dict, List, Literal, Optional, Set, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from fastapi import UploadFile
|
||||||
from pydantic import (BaseModel, ConfigDict, Field, TypeAdapter,
|
from pydantic import (BaseModel, ConfigDict, Field, TypeAdapter,
|
||||||
ValidationInfo, field_validator, model_validator)
|
ValidationInfo, field_validator, model_validator)
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated, TypeAlias
|
||||||
|
|
||||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -1426,3 +1427,163 @@ class LoadLoraAdapterRequest(BaseModel):
|
|||||||
class UnloadLoraAdapterRequest(BaseModel):
|
class UnloadLoraAdapterRequest(BaseModel):
|
||||||
lora_name: str
|
lora_name: str
|
||||||
lora_int_id: Optional[int] = Field(default=None)
|
lora_int_id: Optional[int] = Field(default=None)
|
||||||
|
|
||||||
|
|
||||||
|
## Protocols for Audio
|
||||||
|
AudioResponseFormat: TypeAlias = Literal["json", "text", "srt", "verbose_json",
|
||||||
|
"vtt"]
|
||||||
|
|
||||||
|
|
||||||
|
class TranscriptionRequest(OpenAIBaseModel):
|
||||||
|
# Ordered by official OpenAI API documentation
|
||||||
|
#https://platform.openai.com/docs/api-reference/audio/createTranscription
|
||||||
|
|
||||||
|
file: UploadFile
|
||||||
|
"""
|
||||||
|
The audio file object (not file name) to transcribe, in one of these
|
||||||
|
formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model: str
|
||||||
|
"""ID of the model to use.
|
||||||
|
"""
|
||||||
|
|
||||||
|
language: Optional[str] = None
|
||||||
|
"""The language of the input audio.
|
||||||
|
|
||||||
|
Supplying the input language in
|
||||||
|
[ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format
|
||||||
|
will improve accuracy and latency.
|
||||||
|
"""
|
||||||
|
|
||||||
|
prompt: str = Field(default="")
|
||||||
|
"""An optional text to guide the model's style or continue a previous audio
|
||||||
|
segment.
|
||||||
|
|
||||||
|
The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting)
|
||||||
|
should match the audio language.
|
||||||
|
"""
|
||||||
|
|
||||||
|
response_format: AudioResponseFormat = Field(default="json")
|
||||||
|
"""
|
||||||
|
The format of the output, in one of these options: `json`, `text`, `srt`,
|
||||||
|
`verbose_json`, or `vtt`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
## TODO (varun) : Support if set to 0, certain thresholds are met !!
|
||||||
|
temperature: float = Field(default=0.0)
|
||||||
|
"""The sampling temperature, between 0 and 1.
|
||||||
|
|
||||||
|
Higher values like 0.8 will make the output more random, while lower values
|
||||||
|
like 0.2 will make it more focused / deterministic. If set to 0, the model
|
||||||
|
will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
|
||||||
|
to automatically increase the temperature until certain thresholds are hit.
|
||||||
|
"""
|
||||||
|
|
||||||
|
timestamp_granularities: List[Literal["word", "segment"]] = Field(
|
||||||
|
alias="timestamp_granularities[]", default=[])
|
||||||
|
"""The timestamp granularities to populate for this transcription.
|
||||||
|
|
||||||
|
`response_format` must be set `verbose_json` to use timestamp granularities.
|
||||||
|
Either or both of these options are supported: `word`, or `segment`. Note:
|
||||||
|
There is no additional latency for segment timestamps, but generating word
|
||||||
|
timestamps incurs additional latency.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Default sampling parameters for transcription requests.
|
||||||
|
_DEFAULT_SAMPLING_PARAMS: dict = {
|
||||||
|
"temperature": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
def to_sampling_params(
|
||||||
|
self,
|
||||||
|
default_max_tokens: int,
|
||||||
|
default_sampling_params: Optional[dict] = None) -> SamplingParams:
|
||||||
|
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
|
||||||
|
max_tokens = default_max_tokens
|
||||||
|
|
||||||
|
if default_sampling_params is None:
|
||||||
|
default_sampling_params = {}
|
||||||
|
# Default parameters
|
||||||
|
if (temperature := self.temperature) is None:
|
||||||
|
temperature = default_sampling_params.get(
|
||||||
|
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
|
||||||
|
|
||||||
|
return SamplingParams.from_optional(temperature=temperature,
|
||||||
|
max_tokens=max_tokens)
|
||||||
|
|
||||||
|
|
||||||
|
# Transcription response objects
|
||||||
|
class TranscriptionResponse(OpenAIBaseModel):
|
||||||
|
text: str
|
||||||
|
"""The transcribed text."""
|
||||||
|
|
||||||
|
|
||||||
|
class TranscriptionWord(OpenAIBaseModel):
|
||||||
|
end: float
|
||||||
|
"""End time of the word in seconds."""
|
||||||
|
|
||||||
|
start: float
|
||||||
|
"""Start time of the word in seconds."""
|
||||||
|
|
||||||
|
word: str
|
||||||
|
"""The text content of the word."""
|
||||||
|
|
||||||
|
|
||||||
|
class TranscriptionSegment(OpenAIBaseModel):
|
||||||
|
id: int
|
||||||
|
"""Unique identifier of the segment."""
|
||||||
|
|
||||||
|
avg_logprob: float
|
||||||
|
"""Average logprob of the segment.
|
||||||
|
|
||||||
|
If the value is lower than -1, consider the logprobs failed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
compression_ratio: float
|
||||||
|
"""Compression ratio of the segment.
|
||||||
|
|
||||||
|
If the value is greater than 2.4, consider the compression failed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
end: float
|
||||||
|
"""End time of the segment in seconds."""
|
||||||
|
|
||||||
|
no_speech_prob: float
|
||||||
|
"""Probability of no speech in the segment.
|
||||||
|
|
||||||
|
If the value is higher than 1.0 and the `avg_logprob` is below -1, consider
|
||||||
|
this segment silent.
|
||||||
|
"""
|
||||||
|
|
||||||
|
seek: int
|
||||||
|
"""Seek offset of the segment."""
|
||||||
|
|
||||||
|
start: float
|
||||||
|
"""Start time of the segment in seconds."""
|
||||||
|
|
||||||
|
temperature: float
|
||||||
|
"""Temperature parameter used for generating the segment."""
|
||||||
|
|
||||||
|
text: str
|
||||||
|
"""Text content of the segment."""
|
||||||
|
|
||||||
|
tokens: List[int]
|
||||||
|
"""Array of token IDs for the text content."""
|
||||||
|
|
||||||
|
|
||||||
|
class TranscriptionResponseVerbose(OpenAIBaseModel):
|
||||||
|
duration: str
|
||||||
|
"""The duration of the input audio."""
|
||||||
|
|
||||||
|
language: str
|
||||||
|
"""The language of the input audio."""
|
||||||
|
|
||||||
|
text: str
|
||||||
|
"""The transcribed text."""
|
||||||
|
|
||||||
|
segments: Optional[List[TranscriptionSegment]] = None
|
||||||
|
"""Segments of the transcribed text and their corresponding details."""
|
||||||
|
|
||||||
|
words: Optional[List[TranscriptionWord]] = None
|
||||||
|
"""Extracted words and their corresponding timestamps."""
|
||||||
|
|||||||
@ -31,7 +31,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
|||||||
ErrorResponse, RerankRequest,
|
ErrorResponse, RerankRequest,
|
||||||
ScoreRequest,
|
ScoreRequest,
|
||||||
TokenizeChatRequest,
|
TokenizeChatRequest,
|
||||||
TokenizeCompletionRequest)
|
TokenizeCompletionRequest,
|
||||||
|
TranscriptionRequest)
|
||||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||||
from vllm.entrypoints.openai.tool_parsers import ToolParser
|
from vllm.entrypoints.openai.tool_parsers import ToolParser
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
@ -57,7 +58,8 @@ CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest,
|
|||||||
ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest,
|
ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest,
|
||||||
TokenizeChatRequest]
|
TokenizeChatRequest]
|
||||||
|
|
||||||
AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest]
|
AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest,
|
||||||
|
TranscriptionRequest]
|
||||||
|
|
||||||
|
|
||||||
class TextTokensPrompt(TypedDict):
|
class TextTokensPrompt(TypedDict):
|
||||||
|
|||||||
305
vllm/entrypoints/openai/serving_transcription.py
Normal file
305
vllm/entrypoints/openai/serving_transcription.py
Normal file
@ -0,0 +1,305 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
import asyncio
|
||||||
|
import io
|
||||||
|
from typing import AsyncGenerator, Optional, Union, cast
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
from vllm.config import ModelConfig
|
||||||
|
from vllm.engine.protocol import EngineClient
|
||||||
|
from vllm.entrypoints.logger import RequestLogger
|
||||||
|
from vllm.entrypoints.openai.protocol import (ErrorResponse,
|
||||||
|
RequestResponseMetadata,
|
||||||
|
TranscriptionRequest,
|
||||||
|
TranscriptionResponse,
|
||||||
|
TranscriptionResponseVerbose)
|
||||||
|
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||||
|
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||||
|
from vllm.inputs.data import PromptType
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.outputs import RequestOutput
|
||||||
|
from vllm.utils import PlaceholderModule
|
||||||
|
|
||||||
|
try:
|
||||||
|
import librosa
|
||||||
|
except ImportError:
|
||||||
|
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
# From https://platform.openai.com/docs/guides/speech-to-text/supported-languages#supported-languages
|
||||||
|
# TODO these configs should live somewhere with the model so we can support
|
||||||
|
# additional ones
|
||||||
|
|
||||||
|
ISO639_1_SUPPORTED_LANGS = {
|
||||||
|
"af": "Afrikaans",
|
||||||
|
"ar": "Arabic",
|
||||||
|
"hy": "Armenian",
|
||||||
|
"az": "Azerbaijani",
|
||||||
|
"be": "Belarusian",
|
||||||
|
"bs": "Bosnian",
|
||||||
|
"bg": "Bulgarian",
|
||||||
|
"ca": "Catalan",
|
||||||
|
"zh": "Chinese",
|
||||||
|
"hr": "Croatian",
|
||||||
|
"cs": "Czech",
|
||||||
|
"da": "Danish",
|
||||||
|
"nl": "Dutch",
|
||||||
|
"en": "English",
|
||||||
|
"et": "Estonian",
|
||||||
|
"fi": "Finnish",
|
||||||
|
"fr": "French",
|
||||||
|
"gl": "Galician",
|
||||||
|
"de": "German",
|
||||||
|
"el": "Greek",
|
||||||
|
"he": "Hebrew",
|
||||||
|
"hi": "Hindi",
|
||||||
|
"hu": "Hungarian",
|
||||||
|
"is": "Icelandic",
|
||||||
|
"id": "Indonesian",
|
||||||
|
"it": "Italian",
|
||||||
|
"ja": "Japanese",
|
||||||
|
"kn": "Kannada",
|
||||||
|
"kk": "Kazakh",
|
||||||
|
"ko": "Korean",
|
||||||
|
"lv": "Latvian",
|
||||||
|
"lt": "Lithuanian",
|
||||||
|
"mk": "Macedonian",
|
||||||
|
"ms": "Malay",
|
||||||
|
"mr": "Marathi",
|
||||||
|
"mi": "Maori",
|
||||||
|
"ne": "Nepali",
|
||||||
|
"no": "Norwegian",
|
||||||
|
"fa": "Persian",
|
||||||
|
"pl": "Polish",
|
||||||
|
"pt": "Portuguese",
|
||||||
|
"ro": "Romanian",
|
||||||
|
"ru": "Russian",
|
||||||
|
"sr": "Serbian",
|
||||||
|
"sk": "Slovak",
|
||||||
|
"sl": "Slovenian",
|
||||||
|
"es": "Spanish",
|
||||||
|
"sw": "Swahili",
|
||||||
|
"sv": "Swedish",
|
||||||
|
"tl": "Tagalog",
|
||||||
|
"ta": "Tamil",
|
||||||
|
"th": "Thai",
|
||||||
|
"tr": "Turkish",
|
||||||
|
"uk": "Ukrainian",
|
||||||
|
"ur": "Urdu",
|
||||||
|
"vi": "Vietnamese",
|
||||||
|
"cy": "Welsh"
|
||||||
|
}
|
||||||
|
ISO639_1_OTHER_LANGS = {
|
||||||
|
"lo": "Lao",
|
||||||
|
"jw": "Javanese",
|
||||||
|
"tk": "Turkmen",
|
||||||
|
"yi": "Yiddish",
|
||||||
|
"so": "Somali",
|
||||||
|
"bn": "Bengali",
|
||||||
|
"nn": "Norwegian Nynorsk",
|
||||||
|
"si": "Sinhala",
|
||||||
|
"yo": "Yoruba",
|
||||||
|
"sa": "Sanskrit",
|
||||||
|
"mi": "Māori",
|
||||||
|
"fo": "Faroese", # codespell:ignore
|
||||||
|
"mt": "Maltese",
|
||||||
|
"tg": "Tajik",
|
||||||
|
"mg": "Malagasy",
|
||||||
|
"haw": "Hawaiian",
|
||||||
|
"km": "Khmer",
|
||||||
|
"br": "Breton",
|
||||||
|
"ps": "Pashto",
|
||||||
|
"ln": "Lingala",
|
||||||
|
"la": "Latin",
|
||||||
|
"ml": "Malayalam",
|
||||||
|
"sq": "Albanian",
|
||||||
|
"su": "Sundanese",
|
||||||
|
"eu": "Basque",
|
||||||
|
"ka": "Georgian",
|
||||||
|
"uz": "Uzbek",
|
||||||
|
"sn": "Shona",
|
||||||
|
"ht": "Haitian",
|
||||||
|
"as": "Assamese",
|
||||||
|
"mn": "Mongolian",
|
||||||
|
"te": "Telugu",
|
||||||
|
"pa": "Panjabi",
|
||||||
|
"tt": "Tatar",
|
||||||
|
"gu": "Gujarati",
|
||||||
|
"oc": "Occitan",
|
||||||
|
"ha": "Hausa",
|
||||||
|
"ba": "Bashkir",
|
||||||
|
"my": "Burmese",
|
||||||
|
"sd": "Sindhi",
|
||||||
|
"am": "Amharic",
|
||||||
|
"lb": "Luxembourgish",
|
||||||
|
"bo": "Tibetan"
|
||||||
|
}
|
||||||
|
|
||||||
|
# As per https://platform.openai.com/docs/guides/speech-to-text#overview.
|
||||||
|
# TODO configurable
|
||||||
|
MAX_AUDIO_CLIP_FILESIZE_MB = 25
|
||||||
|
# TODO get from processor.feature_extractor.chunk_length
|
||||||
|
MAX_AUDIO_CLIP_DURATION_S = 30
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIServingTranscription(OpenAIServing):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
engine_client: EngineClient,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
models: OpenAIServingModels,
|
||||||
|
*,
|
||||||
|
request_logger: Optional[RequestLogger],
|
||||||
|
return_tokens_as_token_ids: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__(engine_client=engine_client,
|
||||||
|
model_config=model_config,
|
||||||
|
models=models,
|
||||||
|
request_logger=request_logger,
|
||||||
|
return_tokens_as_token_ids=return_tokens_as_token_ids)
|
||||||
|
|
||||||
|
diff_sampling_param = self.model_config.get_diff_sampling_param()
|
||||||
|
if diff_sampling_param:
|
||||||
|
logger.info(
|
||||||
|
"Overwriting default completion sampling param with: %s",
|
||||||
|
diff_sampling_param)
|
||||||
|
|
||||||
|
async def _preprocess_transcription(
|
||||||
|
self,
|
||||||
|
request: TranscriptionRequest,
|
||||||
|
audio_data: bytes,
|
||||||
|
) -> PromptType:
|
||||||
|
# Validate request
|
||||||
|
# TODO language should be optional and can be guessed.
|
||||||
|
# For now we default to en. See
|
||||||
|
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520
|
||||||
|
lang_token = f"<|{request.language}|>" if request.language else "<|en|>"
|
||||||
|
if request.language:
|
||||||
|
if request.language in ISO639_1_SUPPORTED_LANGS:
|
||||||
|
pass
|
||||||
|
elif request.language in ISO639_1_OTHER_LANGS:
|
||||||
|
logger.warning(
|
||||||
|
"The selected language %s has limited accuracy with"
|
||||||
|
" reported WER>=0.5. Results may be less accurate "
|
||||||
|
"for this choice.", request.language)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported language: {request.language}."
|
||||||
|
"Language should be one of:" +
|
||||||
|
f" {list(ISO639_1_SUPPORTED_LANGS.values())}" +
|
||||||
|
f"or {list(ISO639_1_OTHER_LANGS.values())}")
|
||||||
|
|
||||||
|
if len(audio_data) / 1024**2 > MAX_AUDIO_CLIP_FILESIZE_MB:
|
||||||
|
raise ValueError("Maximum file size exceeded.")
|
||||||
|
|
||||||
|
with io.BytesIO(audio_data) as bytes_:
|
||||||
|
y, sr = librosa.load(bytes_)
|
||||||
|
if librosa.get_duration(y=y, sr=sr) > MAX_AUDIO_CLIP_DURATION_S:
|
||||||
|
raise ValueError(
|
||||||
|
f"Maximum clip duration ({MAX_AUDIO_CLIP_DURATION_S}s) "
|
||||||
|
"exceeded.")
|
||||||
|
|
||||||
|
prompt = {
|
||||||
|
"encoder_prompt": {
|
||||||
|
"prompt": "",
|
||||||
|
"multi_modal_data": {
|
||||||
|
"audio": (y, sr),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"decoder_prompt":
|
||||||
|
f"<|startoftranscript|>{lang_token}<|transcribe|><|notimestamps|>{request.prompt}"
|
||||||
|
}
|
||||||
|
return cast(PromptType, prompt)
|
||||||
|
|
||||||
|
# TODO (varun) : Make verbose response work !
|
||||||
|
async def create_transcription(
|
||||||
|
self, audio_data: bytes, request: TranscriptionRequest,
|
||||||
|
raw_request: Request
|
||||||
|
) -> Union[TranscriptionResponse, TranscriptionResponseVerbose,
|
||||||
|
ErrorResponse]:
|
||||||
|
"""Transcription API similar to OpenAI's API.
|
||||||
|
|
||||||
|
See https://platform.openai.com/docs/api-reference/audio/createTranscription
|
||||||
|
for the API specification. This API mimics the OpenAI transcription API.
|
||||||
|
"""
|
||||||
|
error_check_ret = await self._check_model(request)
|
||||||
|
if error_check_ret is not None:
|
||||||
|
return error_check_ret
|
||||||
|
|
||||||
|
# If the engine is dead, raise the engine's DEAD_ERROR.
|
||||||
|
# This is required for the streaming case, where we return a
|
||||||
|
# success status before we actually start generating text :).
|
||||||
|
if self.engine_client.errored:
|
||||||
|
raise self.engine_client.dead_error
|
||||||
|
|
||||||
|
if request.response_format not in ['text', 'json']:
|
||||||
|
return self.create_error_response(
|
||||||
|
"Currently only support response_format `text` or `json`")
|
||||||
|
|
||||||
|
# TODO cmpl->transcription?
|
||||||
|
request_id = f"cmpl-{self._base_request_id(raw_request)}"
|
||||||
|
|
||||||
|
request_metadata = RequestResponseMetadata(request_id=request_id)
|
||||||
|
if raw_request:
|
||||||
|
raw_request.state.request_metadata = request_metadata
|
||||||
|
|
||||||
|
try:
|
||||||
|
(
|
||||||
|
lora_request,
|
||||||
|
prompt_adapter_request,
|
||||||
|
) = self._maybe_get_adapters(request)
|
||||||
|
|
||||||
|
if lora_request:
|
||||||
|
return self.create_error_response(
|
||||||
|
"Currently do not support LoRA for Transcription.")
|
||||||
|
if prompt_adapter_request:
|
||||||
|
return self.create_error_response(
|
||||||
|
"Currently do not support PromptAdapter for Transcription."
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = await self._preprocess_transcription(
|
||||||
|
request=request,
|
||||||
|
audio_data=audio_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
logger.exception("Error in preprocessing prompt inputs")
|
||||||
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
|
result_generator: Optional[AsyncGenerator[RequestOutput, None]] = None
|
||||||
|
try:
|
||||||
|
# TODO(rob): subtract len of tokenized prompt.
|
||||||
|
default_max_tokens = self.model_config.max_model_len
|
||||||
|
default_params = self.model_config.get_diff_sampling_param()
|
||||||
|
sampling_params = request.to_sampling_params(
|
||||||
|
default_max_tokens, default_params)
|
||||||
|
|
||||||
|
self._log_inputs(
|
||||||
|
request_id,
|
||||||
|
prompt['decoder_prompt'], # type: ignore
|
||||||
|
params=sampling_params,
|
||||||
|
lora_request=None,
|
||||||
|
prompt_adapter_request=None)
|
||||||
|
|
||||||
|
result_generator = self.engine_client.generate(
|
||||||
|
prompt,
|
||||||
|
sampling_params,
|
||||||
|
request_id,
|
||||||
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
# TODO: Use a vllm-specific Validation Error
|
||||||
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
|
# TODO(rob): figure out a way to pipe streaming in.
|
||||||
|
# Non-streaming response.
|
||||||
|
try:
|
||||||
|
async for op in result_generator:
|
||||||
|
result = op
|
||||||
|
return TranscriptionResponse(text=result.outputs[0].text)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
return self.create_error_response("Client disconnected")
|
||||||
|
except ValueError as e:
|
||||||
|
# TODO: Use a vllm-specific Validation Error
|
||||||
|
return self.create_error_response(str(e))
|
||||||
@ -441,3 +441,30 @@ def supports_cross_encoding(
|
|||||||
model: Union[Type[object], object],
|
model: Union[Type[object], object],
|
||||||
) -> Union[TypeIs[Type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]:
|
) -> Union[TypeIs[Type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]:
|
||||||
return is_pooling_model(model) and _supports_cross_encoding(model)
|
return is_pooling_model(model) and _supports_cross_encoding(model)
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class SupportsTranscription(Protocol):
|
||||||
|
"""The interface required for all models that support transcription."""
|
||||||
|
|
||||||
|
supports_transcription: ClassVar[Literal[True]] = True
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def supports_transcription(
|
||||||
|
model: Type[object]) -> TypeIs[Type[SupportsTranscription]]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def supports_transcription(model: object) -> TypeIs[SupportsTranscription]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def supports_transcription(
|
||||||
|
model: Union[Type[object], object],
|
||||||
|
) -> Union[TypeIs[Type[SupportsTranscription]], TypeIs[SupportsTranscription]]:
|
||||||
|
if isinstance(model, type):
|
||||||
|
return isinstance(model, SupportsTranscription)
|
||||||
|
|
||||||
|
return isinstance(model, SupportsTranscription)
|
||||||
|
|||||||
@ -22,7 +22,7 @@ from vllm.logger import init_logger
|
|||||||
|
|
||||||
from .interfaces import (has_inner_state, is_attention_free, is_hybrid,
|
from .interfaces import (has_inner_state, is_attention_free, is_hybrid,
|
||||||
supports_cross_encoding, supports_multimodal,
|
supports_cross_encoding, supports_multimodal,
|
||||||
supports_pp)
|
supports_pp, supports_transcription)
|
||||||
from .interfaces_base import is_text_generation_model
|
from .interfaces_base import is_text_generation_model
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -224,6 +224,7 @@ class _ModelInfo:
|
|||||||
has_inner_state: bool
|
has_inner_state: bool
|
||||||
is_attention_free: bool
|
is_attention_free: bool
|
||||||
is_hybrid: bool
|
is_hybrid: bool
|
||||||
|
supports_transcription: bool
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo":
|
def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo":
|
||||||
@ -237,7 +238,7 @@ class _ModelInfo:
|
|||||||
has_inner_state=has_inner_state(model),
|
has_inner_state=has_inner_state(model),
|
||||||
is_attention_free=is_attention_free(model),
|
is_attention_free=is_attention_free(model),
|
||||||
is_hybrid=is_hybrid(model),
|
is_hybrid=is_hybrid(model),
|
||||||
)
|
supports_transcription=supports_transcription(model))
|
||||||
|
|
||||||
|
|
||||||
class _BaseRegisteredModel(ABC):
|
class _BaseRegisteredModel(ABC):
|
||||||
@ -485,6 +486,13 @@ class _ModelRegistry:
|
|||||||
model_cls, _ = self.inspect_model_cls(architectures)
|
model_cls, _ = self.inspect_model_cls(architectures)
|
||||||
return model_cls.is_hybrid
|
return model_cls.is_hybrid
|
||||||
|
|
||||||
|
def is_transcription_model(
|
||||||
|
self,
|
||||||
|
architectures: Union[str, List[str]],
|
||||||
|
) -> bool:
|
||||||
|
model_cls, _ = self.inspect_model_cls(architectures)
|
||||||
|
return model_cls.supports_transcription
|
||||||
|
|
||||||
|
|
||||||
ModelRegistry = _ModelRegistry({
|
ModelRegistry = _ModelRegistry({
|
||||||
model_arch:
|
model_arch:
|
||||||
|
|||||||
@ -31,7 +31,7 @@ from vllm.multimodal.audio import resample_audio
|
|||||||
from vllm.sequence import SequenceData
|
from vllm.sequence import SequenceData
|
||||||
from vllm.transformers_utils.processor import cached_get_processor
|
from vllm.transformers_utils.processor import cached_get_processor
|
||||||
|
|
||||||
from .interfaces import SupportsMultiModal
|
from .interfaces import SupportsMultiModal, SupportsTranscription
|
||||||
from .utils import AutoWeightsLoader, WeightsMapper, make_layers
|
from .utils import AutoWeightsLoader, WeightsMapper, make_layers
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -637,7 +637,8 @@ def input_mapper_for_whisper(
|
|||||||
@MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_whisper)
|
@MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_whisper)
|
||||||
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
|
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
|
||||||
"audio", get_max_whisper_audio_tokens)
|
"audio", get_max_whisper_audio_tokens)
|
||||||
class WhisperForConditionalGeneration(nn.Module, SupportsMultiModal):
|
class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
|
||||||
|
SupportsMultiModal):
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"self_attn.qkv_proj": [
|
"self_attn.qkv_proj": [
|
||||||
"self_attn.q_proj",
|
"self_attn.q_proj",
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user