[Frontend][Docs] Transcription API streaming (#13301)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi 2025-03-06 11:39:35 +01:00 committed by GitHub
parent 69ff99fdcd
commit fa82b93853
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 297 additions and 26 deletions

View File

@ -379,6 +379,10 @@ For chat-like input (i.e. if `messages` is passed), these extra parameters are s
Our Transcriptions API is compatible with [OpenAI's Transcriptions API](https://platform.openai.com/docs/api-reference/audio/createTranscription);
you can use the [official OpenAI Python client](https://github.com/openai/openai-python) to interact with it.
:::{note}
To use the Transcriptions API, please install with extra audio dependencies using `pip install vllm[audio]`.
:::
<!-- TODO: api enforced limits + uploading audios -->
Code example: <gh-file:examples/online_serving/openai_transcription_client.py>

View File

@ -1,4 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
import json
import httpx
from openai import OpenAI
from vllm.assets.audio import AudioAsset
@ -13,11 +17,50 @@ client = OpenAI(
api_key=openai_api_key,
base_url=openai_api_base,
)
with open(str(mary_had_lamb), "rb") as f:
def sync_openai():
with open(str(mary_had_lamb), "rb") as f:
transcription = client.audio.transcriptions.create(
file=f,
model="openai/whisper-large-v3",
model="openai/whisper-small",
language="en",
response_format="text",
response_format="json",
temperature=0.0)
print("transcription result:", transcription)
print("transcription result:", transcription.text)
sync_openai()
# OpenAI Transcription API client does not support streaming.
async def stream_openai_response():
data = {
"language": "en",
'stream': True,
"model": "openai/whisper-large-v3",
}
url = openai_api_base + "/audio/transcriptions"
print("transcription result:", end=' ')
async with httpx.AsyncClient() as client:
with open(str(winning_call), "rb") as f:
async with client.stream('POST', url, files={'file': f},
data=data) as response:
async for line in response.aiter_lines():
# Each line is a JSON object prefixed with 'data: '
if line:
if line.startswith('data: '):
line = line[len('data: '):]
# Last chunk, stream ends
if line.strip() == '[DONE]':
break
# Parse the JSON response
chunk = json.loads(line)
# Extract and print the content
content = chunk['choices'][0].get('delta',
{}).get('content')
print(content, end='')
# Run the asynchronous function
asyncio.run(stream_openai_response())

View File

@ -3,12 +3,14 @@
# imports for guided decoding tests
import io
import json
from unittest.mock import patch
import librosa
import numpy as np
import openai
import pytest
import soundfile as sf
from openai._base_client import AsyncAPIClient
from vllm.assets.audio import AudioAsset
@ -120,3 +122,73 @@ async def test_completion_endpoints():
res = await client.completions.create(model=model_name, prompt="Hello")
assert res.code == 400
assert res.message == "The model does not support Completions API"
@pytest.mark.asyncio
async def test_streaming_response(winning_call):
model_name = "openai/whisper-small"
server_args = ["--enforce-eager"]
transcription = ""
with RemoteOpenAIServer(model_name, server_args) as remote_server:
client = remote_server.get_async_client()
res_no_stream = await client.audio.transcriptions.create(
model=model_name,
file=winning_call,
response_format="json",
language="en",
temperature=0.0)
# Unfortunately this only works when the openai client is patched
# to use streaming mode, not exposed in the transcription api.
original_post = AsyncAPIClient.post
async def post_with_stream(*args, **kwargs):
kwargs['stream'] = True
return await original_post(*args, **kwargs)
with patch.object(AsyncAPIClient, "post", new=post_with_stream):
client = remote_server.get_async_client()
res = await client.audio.transcriptions.create(
model=model_name,
file=winning_call,
language="en",
temperature=0.0,
extra_body=dict(stream=True))
# Reconstruct from chunks and validate
async for chunk in res:
# just a chunk
text = chunk.choices[0]['delta']['content']
transcription += text
assert transcription == res_no_stream.text
@pytest.mark.asyncio
async def test_stream_options(winning_call):
model_name = "openai/whisper-small"
server_args = ["--enforce-eager"]
with RemoteOpenAIServer(model_name, server_args) as remote_server:
original_post = AsyncAPIClient.post
async def post_with_stream(*args, **kwargs):
kwargs['stream'] = True
return await original_post(*args, **kwargs)
with patch.object(AsyncAPIClient, "post", new=post_with_stream):
client = remote_server.get_async_client()
res = await client.audio.transcriptions.create(
model=model_name,
file=winning_call,
language="en",
temperature=0.0,
extra_body=dict(stream=True,
stream_include_usage=True,
stream_continuous_usage_stats=True))
final = False
continuous = True
async for chunk in res:
if not len(chunk.choices):
# final usage sent
final = True
else:
continuous = continuous and hasattr(chunk, 'usage')
assert final and continuous

View File

@ -1285,6 +1285,21 @@ class ChatCompletionStreamResponse(OpenAIBaseModel):
usage: Optional[UsageInfo] = Field(default=None)
class TranscriptionResponseStreamChoice(OpenAIBaseModel):
delta: DeltaMessage
finish_reason: Optional[str] = None
stop_reason: Optional[Union[int, str]] = None
class TranscriptionStreamResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"trsc-{random_uuid()}")
object: Literal["transcription.chunk"] = "transcription.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: list[TranscriptionResponseStreamChoice]
usage: Optional[UsageInfo] = Field(default=None)
class BatchRequestInput(OpenAIBaseModel):
"""
The per-line object of the batch input file.
@ -1510,6 +1525,15 @@ class TranscriptionRequest(OpenAIBaseModel):
timestamps incurs additional latency.
"""
stream: Optional[bool] = False
"""Custom field not present in the original OpenAI definition. When set,
it will enable output to be streamed in a similar fashion as the Chat
Completion endpoint.
"""
# Flattened stream option to simplify form data.
stream_include_usage: Optional[bool] = False
stream_continuous_usage_stats: Optional[bool] = False
# Default sampling parameters for transcription requests.
_DEFAULT_SAMPLING_PARAMS: dict = {
"temperature": 0,
@ -1530,7 +1554,21 @@ class TranscriptionRequest(OpenAIBaseModel):
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
return SamplingParams.from_optional(temperature=temperature,
max_tokens=max_tokens)
max_tokens=max_tokens,
output_kind=RequestOutputKind.DELTA
if self.stream \
else RequestOutputKind.FINAL_ONLY)
@model_validator(mode="before")
@classmethod
def validate_stream_options(cls, data):
stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"]
stream = data.get("stream", False)
if any(bool(data.get(so, False)) for so in stream_opts) and not stream:
raise ValueError(
"Stream options can only be defined when `stream=True`.")
return data
# Transcription response objects

View File

@ -1,24 +1,26 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
import io
import time
from collections.abc import AsyncGenerator
from typing import Optional, Union, cast
from math import ceil
from typing import Final, Optional, Union, cast
from fastapi import Request
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (ErrorResponse,
RequestResponseMetadata,
TranscriptionRequest,
TranscriptionResponse,
TranscriptionResponseVerbose)
from vllm.entrypoints.openai.protocol import (
DeltaMessage, ErrorResponse, RequestResponseMetadata, TranscriptionRequest,
TranscriptionResponse, TranscriptionResponseStreamChoice,
TranscriptionStreamResponse, UsageInfo)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.transformers_utils.processor import cached_get_processor
from vllm.utils import PlaceholderModule
try:
@ -140,8 +142,6 @@ ISO639_1_OTHER_LANGS = {
# As per https://platform.openai.com/docs/guides/speech-to-text#overview.
# TODO configurable
MAX_AUDIO_CLIP_FILESIZE_MB = 25
# TODO get from processor.feature_extractor.chunk_length
MAX_AUDIO_CLIP_DURATION_S = 30
class OpenAIServingTranscription(OpenAIServing):
@ -163,6 +163,11 @@ class OpenAIServingTranscription(OpenAIServing):
self.default_sampling_params = (
self.model_config.get_diff_sampling_param())
processor = cached_get_processor(model_config.model)
self.max_audio_clip_s = processor.feature_extractor.chunk_length
self.model_sr = processor.feature_extractor.sampling_rate
self.hop_length = processor.feature_extractor.hop_length
if self.default_sampling_params:
logger.info(
"Overwriting default completion sampling param with: %s",
@ -172,7 +177,7 @@ class OpenAIServingTranscription(OpenAIServing):
self,
request: TranscriptionRequest,
audio_data: bytes,
) -> PromptType:
) -> tuple[PromptType, float]:
# Validate request
# TODO language should be optional and can be guessed.
# For now we default to en. See
@ -198,9 +203,11 @@ class OpenAIServingTranscription(OpenAIServing):
with io.BytesIO(audio_data) as bytes_:
y, sr = librosa.load(bytes_)
if librosa.get_duration(y=y, sr=sr) > MAX_AUDIO_CLIP_DURATION_S:
duration = librosa.get_duration(y=y, sr=sr)
if duration > self.max_audio_clip_s:
raise ValueError(
f"Maximum clip duration ({MAX_AUDIO_CLIP_DURATION_S}s) "
f"Maximum clip duration ({self.max_audio_clip_s}s) "
"exceeded.")
prompt = {
@ -213,13 +220,13 @@ class OpenAIServingTranscription(OpenAIServing):
"decoder_prompt":
f"<|startoftranscript|>{lang_token}<|transcribe|><|notimestamps|>{request.prompt}"
}
return cast(PromptType, prompt)
return cast(PromptType, prompt), duration
# TODO (varun) : Make verbose response work !
async def create_transcription(
self, audio_data: bytes, request: TranscriptionRequest,
raw_request: Request
) -> Union[TranscriptionResponse, TranscriptionResponseVerbose,
) -> Union[TranscriptionResponse, AsyncGenerator[str, None],
ErrorResponse]:
"""Transcription API similar to OpenAI's API.
@ -240,8 +247,7 @@ class OpenAIServingTranscription(OpenAIServing):
return self.create_error_response(
"Currently only support response_format `text` or `json`")
# TODO cmpl->transcription?
request_id = f"cmpl-{self._base_request_id(raw_request)}"
request_id = f"trsc-{self._base_request_id(raw_request)}"
request_metadata = RequestResponseMetadata(request_id=request_id)
if raw_request:
@ -261,7 +267,7 @@ class OpenAIServingTranscription(OpenAIServing):
"Currently do not support PromptAdapter for Transcription."
)
prompt = await self._preprocess_transcription(
prompt, duration_s = await self._preprocess_transcription(
request=request,
audio_data=audio_data,
)
@ -293,7 +299,12 @@ class OpenAIServingTranscription(OpenAIServing):
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
# TODO(rob): figure out a way to pipe streaming in.
if request.stream:
return self.transcription_stream_generator(request,
result_generator,
request_id,
request_metadata,
duration_s)
# Non-streaming response.
try:
assert result_generator is not None
@ -305,3 +316,106 @@ class OpenAIServingTranscription(OpenAIServing):
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
async def transcription_stream_generator(
self, request: TranscriptionRequest,
result_generator: AsyncGenerator[RequestOutput, None],
request_id: str, request_metadata: RequestResponseMetadata,
audio_duration_s: float) -> AsyncGenerator[str, None]:
created_time = int(time.time())
model_name = request.model
chunk_object_type: Final = "transcription.chunk"
completion_tokens = 0
num_prompt_tokens = 0
include_usage = request.stream_include_usage \
if request.stream_include_usage else False
include_continuous_usage = request.stream_continuous_usage_stats\
if include_usage and request.stream_continuous_usage_stats\
else False
try:
async for res in result_generator:
# On first result.
if res.prompt_token_ids is not None:
# Do not account the 4-tokens `<|startoftranscript|>..`
# Could be negative when language token is not specified.
num_prompt_tokens = max(len(res.prompt_token_ids) - 4, 0)
# NOTE(NickLucche) user can't pass encoder prompts directly
# at least not to Whisper. One indicator of the encoder
# amount of processing is the log-mel spectogram length.
num_prompt_tokens += ceil(audio_duration_s *
self.model_sr / self.hop_length)
# We need to do it here, because if there are exceptions in
# the result_generator, it needs to be sent as the FIRST
# response (by the try...catch).
# Just one output (n=1) supported.
assert len(res.outputs) == 1
output = res.outputs[0]
delta_message = DeltaMessage(content=output.text)
completion_tokens += len(output.token_ids)
if output.finish_reason is None:
# Still generating, send delta update.
choice_data = TranscriptionResponseStreamChoice(
delta=delta_message)
else:
# Model is finished generating.
choice_data = TranscriptionResponseStreamChoice(
delta=delta_message,
finish_reason=output.finish_reason,
stop_reason=output.stop_reason)
chunk = TranscriptionStreamResponse(id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)
# handle usage stats if requested & if continuous
if include_continuous_usage:
chunk.usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=num_prompt_tokens + completion_tokens,
)
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
# Once the final token is handled, if stream_options.include_usage
# is sent, send the usage.
if include_usage:
final_usage = UsageInfo(prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=num_prompt_tokens +
completion_tokens)
final_usage_chunk = TranscriptionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[],
model=model_name,
usage=final_usage)
final_usage_data = (final_usage_chunk.model_dump_json(
exclude_unset=True, exclude_none=True))
yield f"data: {final_usage_data}\n\n"
# report to FastAPI middleware aggregate usage across all choices
request_metadata.final_usage_info = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=num_prompt_tokens + completion_tokens)
except Exception as e:
# TODO: Use a vllm-specific Validation Error
logger.exception("Error in chat completion stream generator.")
data = self.create_streaming_error_response(str(e))
yield f"data: {data}\n\n"
# Send the final done message after all response.n are finished
yield "data: [DONE]\n\n"