mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 06:35:00 +08:00
[Frontend][Docs] Transcription API streaming (#13301)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
parent
69ff99fdcd
commit
fa82b93853
@ -379,6 +379,10 @@ For chat-like input (i.e. if `messages` is passed), these extra parameters are s
|
||||
Our Transcriptions API is compatible with [OpenAI's Transcriptions API](https://platform.openai.com/docs/api-reference/audio/createTranscription);
|
||||
you can use the [official OpenAI Python client](https://github.com/openai/openai-python) to interact with it.
|
||||
|
||||
:::{note}
|
||||
To use the Transcriptions API, please install with extra audio dependencies using `pip install vllm[audio]`.
|
||||
:::
|
||||
|
||||
<!-- TODO: api enforced limits + uploading audios -->
|
||||
|
||||
Code example: <gh-file:examples/online_serving/openai_transcription_client.py>
|
||||
|
||||
@ -1,4 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
import httpx
|
||||
from openai import OpenAI
|
||||
|
||||
from vllm.assets.audio import AudioAsset
|
||||
@ -13,11 +17,50 @@ client = OpenAI(
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_api_base,
|
||||
)
|
||||
with open(str(mary_had_lamb), "rb") as f:
|
||||
|
||||
|
||||
def sync_openai():
|
||||
with open(str(mary_had_lamb), "rb") as f:
|
||||
transcription = client.audio.transcriptions.create(
|
||||
file=f,
|
||||
model="openai/whisper-large-v3",
|
||||
model="openai/whisper-small",
|
||||
language="en",
|
||||
response_format="text",
|
||||
response_format="json",
|
||||
temperature=0.0)
|
||||
print("transcription result:", transcription)
|
||||
print("transcription result:", transcription.text)
|
||||
|
||||
|
||||
sync_openai()
|
||||
|
||||
|
||||
# OpenAI Transcription API client does not support streaming.
|
||||
async def stream_openai_response():
|
||||
data = {
|
||||
"language": "en",
|
||||
'stream': True,
|
||||
"model": "openai/whisper-large-v3",
|
||||
}
|
||||
url = openai_api_base + "/audio/transcriptions"
|
||||
print("transcription result:", end=' ')
|
||||
async with httpx.AsyncClient() as client:
|
||||
with open(str(winning_call), "rb") as f:
|
||||
async with client.stream('POST', url, files={'file': f},
|
||||
data=data) as response:
|
||||
async for line in response.aiter_lines():
|
||||
# Each line is a JSON object prefixed with 'data: '
|
||||
if line:
|
||||
if line.startswith('data: '):
|
||||
line = line[len('data: '):]
|
||||
# Last chunk, stream ends
|
||||
if line.strip() == '[DONE]':
|
||||
break
|
||||
# Parse the JSON response
|
||||
chunk = json.loads(line)
|
||||
# Extract and print the content
|
||||
content = chunk['choices'][0].get('delta',
|
||||
{}).get('content')
|
||||
print(content, end='')
|
||||
|
||||
|
||||
# Run the asynchronous function
|
||||
asyncio.run(stream_openai_response())
|
||||
|
||||
@ -3,12 +3,14 @@
|
||||
# imports for guided decoding tests
|
||||
import io
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import openai
|
||||
import pytest
|
||||
import soundfile as sf
|
||||
from openai._base_client import AsyncAPIClient
|
||||
|
||||
from vllm.assets.audio import AudioAsset
|
||||
|
||||
@ -120,3 +122,73 @@ async def test_completion_endpoints():
|
||||
res = await client.completions.create(model=model_name, prompt="Hello")
|
||||
assert res.code == 400
|
||||
assert res.message == "The model does not support Completions API"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_response(winning_call):
|
||||
model_name = "openai/whisper-small"
|
||||
server_args = ["--enforce-eager"]
|
||||
transcription = ""
|
||||
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
||||
client = remote_server.get_async_client()
|
||||
res_no_stream = await client.audio.transcriptions.create(
|
||||
model=model_name,
|
||||
file=winning_call,
|
||||
response_format="json",
|
||||
language="en",
|
||||
temperature=0.0)
|
||||
# Unfortunately this only works when the openai client is patched
|
||||
# to use streaming mode, not exposed in the transcription api.
|
||||
original_post = AsyncAPIClient.post
|
||||
|
||||
async def post_with_stream(*args, **kwargs):
|
||||
kwargs['stream'] = True
|
||||
return await original_post(*args, **kwargs)
|
||||
|
||||
with patch.object(AsyncAPIClient, "post", new=post_with_stream):
|
||||
client = remote_server.get_async_client()
|
||||
res = await client.audio.transcriptions.create(
|
||||
model=model_name,
|
||||
file=winning_call,
|
||||
language="en",
|
||||
temperature=0.0,
|
||||
extra_body=dict(stream=True))
|
||||
# Reconstruct from chunks and validate
|
||||
async for chunk in res:
|
||||
# just a chunk
|
||||
text = chunk.choices[0]['delta']['content']
|
||||
transcription += text
|
||||
|
||||
assert transcription == res_no_stream.text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_options(winning_call):
|
||||
model_name = "openai/whisper-small"
|
||||
server_args = ["--enforce-eager"]
|
||||
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
||||
original_post = AsyncAPIClient.post
|
||||
|
||||
async def post_with_stream(*args, **kwargs):
|
||||
kwargs['stream'] = True
|
||||
return await original_post(*args, **kwargs)
|
||||
|
||||
with patch.object(AsyncAPIClient, "post", new=post_with_stream):
|
||||
client = remote_server.get_async_client()
|
||||
res = await client.audio.transcriptions.create(
|
||||
model=model_name,
|
||||
file=winning_call,
|
||||
language="en",
|
||||
temperature=0.0,
|
||||
extra_body=dict(stream=True,
|
||||
stream_include_usage=True,
|
||||
stream_continuous_usage_stats=True))
|
||||
final = False
|
||||
continuous = True
|
||||
async for chunk in res:
|
||||
if not len(chunk.choices):
|
||||
# final usage sent
|
||||
final = True
|
||||
else:
|
||||
continuous = continuous and hasattr(chunk, 'usage')
|
||||
assert final and continuous
|
||||
|
||||
@ -1285,6 +1285,21 @@ class ChatCompletionStreamResponse(OpenAIBaseModel):
|
||||
usage: Optional[UsageInfo] = Field(default=None)
|
||||
|
||||
|
||||
class TranscriptionResponseStreamChoice(OpenAIBaseModel):
|
||||
delta: DeltaMessage
|
||||
finish_reason: Optional[str] = None
|
||||
stop_reason: Optional[Union[int, str]] = None
|
||||
|
||||
|
||||
class TranscriptionStreamResponse(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"trsc-{random_uuid()}")
|
||||
object: Literal["transcription.chunk"] = "transcription.chunk"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: list[TranscriptionResponseStreamChoice]
|
||||
usage: Optional[UsageInfo] = Field(default=None)
|
||||
|
||||
|
||||
class BatchRequestInput(OpenAIBaseModel):
|
||||
"""
|
||||
The per-line object of the batch input file.
|
||||
@ -1510,6 +1525,15 @@ class TranscriptionRequest(OpenAIBaseModel):
|
||||
timestamps incurs additional latency.
|
||||
"""
|
||||
|
||||
stream: Optional[bool] = False
|
||||
"""Custom field not present in the original OpenAI definition. When set,
|
||||
it will enable output to be streamed in a similar fashion as the Chat
|
||||
Completion endpoint.
|
||||
"""
|
||||
# Flattened stream option to simplify form data.
|
||||
stream_include_usage: Optional[bool] = False
|
||||
stream_continuous_usage_stats: Optional[bool] = False
|
||||
|
||||
# Default sampling parameters for transcription requests.
|
||||
_DEFAULT_SAMPLING_PARAMS: dict = {
|
||||
"temperature": 0,
|
||||
@ -1530,7 +1554,21 @@ class TranscriptionRequest(OpenAIBaseModel):
|
||||
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
|
||||
|
||||
return SamplingParams.from_optional(temperature=temperature,
|
||||
max_tokens=max_tokens)
|
||||
max_tokens=max_tokens,
|
||||
output_kind=RequestOutputKind.DELTA
|
||||
if self.stream \
|
||||
else RequestOutputKind.FINAL_ONLY)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_stream_options(cls, data):
|
||||
stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"]
|
||||
stream = data.get("stream", False)
|
||||
if any(bool(data.get(so, False)) for so in stream_opts) and not stream:
|
||||
raise ValueError(
|
||||
"Stream options can only be defined when `stream=True`.")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
# Transcription response objects
|
||||
|
||||
@ -1,24 +1,26 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import asyncio
|
||||
import io
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Optional, Union, cast
|
||||
from math import ceil
|
||||
from typing import Final, Optional, Union, cast
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (ErrorResponse,
|
||||
RequestResponseMetadata,
|
||||
TranscriptionRequest,
|
||||
TranscriptionResponse,
|
||||
TranscriptionResponseVerbose)
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
DeltaMessage, ErrorResponse, RequestResponseMetadata, TranscriptionRequest,
|
||||
TranscriptionResponse, TranscriptionResponseStreamChoice,
|
||||
TranscriptionStreamResponse, UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.transformers_utils.processor import cached_get_processor
|
||||
from vllm.utils import PlaceholderModule
|
||||
|
||||
try:
|
||||
@ -140,8 +142,6 @@ ISO639_1_OTHER_LANGS = {
|
||||
# As per https://platform.openai.com/docs/guides/speech-to-text#overview.
|
||||
# TODO configurable
|
||||
MAX_AUDIO_CLIP_FILESIZE_MB = 25
|
||||
# TODO get from processor.feature_extractor.chunk_length
|
||||
MAX_AUDIO_CLIP_DURATION_S = 30
|
||||
|
||||
|
||||
class OpenAIServingTranscription(OpenAIServing):
|
||||
@ -163,6 +163,11 @@ class OpenAIServingTranscription(OpenAIServing):
|
||||
|
||||
self.default_sampling_params = (
|
||||
self.model_config.get_diff_sampling_param())
|
||||
processor = cached_get_processor(model_config.model)
|
||||
self.max_audio_clip_s = processor.feature_extractor.chunk_length
|
||||
self.model_sr = processor.feature_extractor.sampling_rate
|
||||
self.hop_length = processor.feature_extractor.hop_length
|
||||
|
||||
if self.default_sampling_params:
|
||||
logger.info(
|
||||
"Overwriting default completion sampling param with: %s",
|
||||
@ -172,7 +177,7 @@ class OpenAIServingTranscription(OpenAIServing):
|
||||
self,
|
||||
request: TranscriptionRequest,
|
||||
audio_data: bytes,
|
||||
) -> PromptType:
|
||||
) -> tuple[PromptType, float]:
|
||||
# Validate request
|
||||
# TODO language should be optional and can be guessed.
|
||||
# For now we default to en. See
|
||||
@ -198,9 +203,11 @@ class OpenAIServingTranscription(OpenAIServing):
|
||||
|
||||
with io.BytesIO(audio_data) as bytes_:
|
||||
y, sr = librosa.load(bytes_)
|
||||
if librosa.get_duration(y=y, sr=sr) > MAX_AUDIO_CLIP_DURATION_S:
|
||||
|
||||
duration = librosa.get_duration(y=y, sr=sr)
|
||||
if duration > self.max_audio_clip_s:
|
||||
raise ValueError(
|
||||
f"Maximum clip duration ({MAX_AUDIO_CLIP_DURATION_S}s) "
|
||||
f"Maximum clip duration ({self.max_audio_clip_s}s) "
|
||||
"exceeded.")
|
||||
|
||||
prompt = {
|
||||
@ -213,13 +220,13 @@ class OpenAIServingTranscription(OpenAIServing):
|
||||
"decoder_prompt":
|
||||
f"<|startoftranscript|>{lang_token}<|transcribe|><|notimestamps|>{request.prompt}"
|
||||
}
|
||||
return cast(PromptType, prompt)
|
||||
return cast(PromptType, prompt), duration
|
||||
|
||||
# TODO (varun) : Make verbose response work !
|
||||
async def create_transcription(
|
||||
self, audio_data: bytes, request: TranscriptionRequest,
|
||||
raw_request: Request
|
||||
) -> Union[TranscriptionResponse, TranscriptionResponseVerbose,
|
||||
) -> Union[TranscriptionResponse, AsyncGenerator[str, None],
|
||||
ErrorResponse]:
|
||||
"""Transcription API similar to OpenAI's API.
|
||||
|
||||
@ -240,8 +247,7 @@ class OpenAIServingTranscription(OpenAIServing):
|
||||
return self.create_error_response(
|
||||
"Currently only support response_format `text` or `json`")
|
||||
|
||||
# TODO cmpl->transcription?
|
||||
request_id = f"cmpl-{self._base_request_id(raw_request)}"
|
||||
request_id = f"trsc-{self._base_request_id(raw_request)}"
|
||||
|
||||
request_metadata = RequestResponseMetadata(request_id=request_id)
|
||||
if raw_request:
|
||||
@ -261,7 +267,7 @@ class OpenAIServingTranscription(OpenAIServing):
|
||||
"Currently do not support PromptAdapter for Transcription."
|
||||
)
|
||||
|
||||
prompt = await self._preprocess_transcription(
|
||||
prompt, duration_s = await self._preprocess_transcription(
|
||||
request=request,
|
||||
audio_data=audio_data,
|
||||
)
|
||||
@ -293,7 +299,12 @@ class OpenAIServingTranscription(OpenAIServing):
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
# TODO(rob): figure out a way to pipe streaming in.
|
||||
if request.stream:
|
||||
return self.transcription_stream_generator(request,
|
||||
result_generator,
|
||||
request_id,
|
||||
request_metadata,
|
||||
duration_s)
|
||||
# Non-streaming response.
|
||||
try:
|
||||
assert result_generator is not None
|
||||
@ -305,3 +316,106 @@ class OpenAIServingTranscription(OpenAIServing):
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
async def transcription_stream_generator(
|
||||
self, request: TranscriptionRequest,
|
||||
result_generator: AsyncGenerator[RequestOutput, None],
|
||||
request_id: str, request_metadata: RequestResponseMetadata,
|
||||
audio_duration_s: float) -> AsyncGenerator[str, None]:
|
||||
created_time = int(time.time())
|
||||
model_name = request.model
|
||||
chunk_object_type: Final = "transcription.chunk"
|
||||
|
||||
completion_tokens = 0
|
||||
num_prompt_tokens = 0
|
||||
|
||||
include_usage = request.stream_include_usage \
|
||||
if request.stream_include_usage else False
|
||||
include_continuous_usage = request.stream_continuous_usage_stats\
|
||||
if include_usage and request.stream_continuous_usage_stats\
|
||||
else False
|
||||
|
||||
try:
|
||||
async for res in result_generator:
|
||||
# On first result.
|
||||
if res.prompt_token_ids is not None:
|
||||
# Do not account the 4-tokens `<|startoftranscript|>..`
|
||||
# Could be negative when language token is not specified.
|
||||
num_prompt_tokens = max(len(res.prompt_token_ids) - 4, 0)
|
||||
# NOTE(NickLucche) user can't pass encoder prompts directly
|
||||
# at least not to Whisper. One indicator of the encoder
|
||||
# amount of processing is the log-mel spectogram length.
|
||||
num_prompt_tokens += ceil(audio_duration_s *
|
||||
self.model_sr / self.hop_length)
|
||||
|
||||
# We need to do it here, because if there are exceptions in
|
||||
# the result_generator, it needs to be sent as the FIRST
|
||||
# response (by the try...catch).
|
||||
|
||||
# Just one output (n=1) supported.
|
||||
assert len(res.outputs) == 1
|
||||
output = res.outputs[0]
|
||||
|
||||
delta_message = DeltaMessage(content=output.text)
|
||||
completion_tokens += len(output.token_ids)
|
||||
|
||||
if output.finish_reason is None:
|
||||
# Still generating, send delta update.
|
||||
choice_data = TranscriptionResponseStreamChoice(
|
||||
delta=delta_message)
|
||||
else:
|
||||
# Model is finished generating.
|
||||
choice_data = TranscriptionResponseStreamChoice(
|
||||
delta=delta_message,
|
||||
finish_reason=output.finish_reason,
|
||||
stop_reason=output.stop_reason)
|
||||
|
||||
chunk = TranscriptionStreamResponse(id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[choice_data],
|
||||
model=model_name)
|
||||
|
||||
# handle usage stats if requested & if continuous
|
||||
if include_continuous_usage:
|
||||
chunk.usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=num_prompt_tokens + completion_tokens,
|
||||
)
|
||||
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
# Once the final token is handled, if stream_options.include_usage
|
||||
# is sent, send the usage.
|
||||
if include_usage:
|
||||
final_usage = UsageInfo(prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=num_prompt_tokens +
|
||||
completion_tokens)
|
||||
|
||||
final_usage_chunk = TranscriptionStreamResponse(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[],
|
||||
model=model_name,
|
||||
usage=final_usage)
|
||||
final_usage_data = (final_usage_chunk.model_dump_json(
|
||||
exclude_unset=True, exclude_none=True))
|
||||
yield f"data: {final_usage_data}\n\n"
|
||||
|
||||
# report to FastAPI middleware aggregate usage across all choices
|
||||
request_metadata.final_usage_info = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=num_prompt_tokens + completion_tokens)
|
||||
|
||||
except Exception as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
logger.exception("Error in chat completion stream generator.")
|
||||
data = self.create_streaming_error_response(str(e))
|
||||
yield f"data: {data}\n\n"
|
||||
# Send the final done message after all response.n are finished
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user