mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 06:35:00 +08:00
[Frontend][Docs] Transcription API streaming (#13301)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
parent
69ff99fdcd
commit
fa82b93853
@ -379,6 +379,10 @@ For chat-like input (i.e. if `messages` is passed), these extra parameters are s
|
|||||||
Our Transcriptions API is compatible with [OpenAI's Transcriptions API](https://platform.openai.com/docs/api-reference/audio/createTranscription);
|
Our Transcriptions API is compatible with [OpenAI's Transcriptions API](https://platform.openai.com/docs/api-reference/audio/createTranscription);
|
||||||
you can use the [official OpenAI Python client](https://github.com/openai/openai-python) to interact with it.
|
you can use the [official OpenAI Python client](https://github.com/openai/openai-python) to interact with it.
|
||||||
|
|
||||||
|
:::{note}
|
||||||
|
To use the Transcriptions API, please install with extra audio dependencies using `pip install vllm[audio]`.
|
||||||
|
:::
|
||||||
|
|
||||||
<!-- TODO: api enforced limits + uploading audios -->
|
<!-- TODO: api enforced limits + uploading audios -->
|
||||||
|
|
||||||
Code example: <gh-file:examples/online_serving/openai_transcription_client.py>
|
Code example: <gh-file:examples/online_serving/openai_transcription_client.py>
|
||||||
|
|||||||
@ -1,4 +1,8 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
|
||||||
|
import httpx
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
from vllm.assets.audio import AudioAsset
|
from vllm.assets.audio import AudioAsset
|
||||||
@ -13,11 +17,50 @@ client = OpenAI(
|
|||||||
api_key=openai_api_key,
|
api_key=openai_api_key,
|
||||||
base_url=openai_api_base,
|
base_url=openai_api_base,
|
||||||
)
|
)
|
||||||
with open(str(mary_had_lamb), "rb") as f:
|
|
||||||
transcription = client.audio.transcriptions.create(
|
|
||||||
file=f,
|
def sync_openai():
|
||||||
model="openai/whisper-large-v3",
|
with open(str(mary_had_lamb), "rb") as f:
|
||||||
language="en",
|
transcription = client.audio.transcriptions.create(
|
||||||
response_format="text",
|
file=f,
|
||||||
temperature=0.0)
|
model="openai/whisper-small",
|
||||||
print("transcription result:", transcription)
|
language="en",
|
||||||
|
response_format="json",
|
||||||
|
temperature=0.0)
|
||||||
|
print("transcription result:", transcription.text)
|
||||||
|
|
||||||
|
|
||||||
|
sync_openai()
|
||||||
|
|
||||||
|
|
||||||
|
# OpenAI Transcription API client does not support streaming.
|
||||||
|
async def stream_openai_response():
|
||||||
|
data = {
|
||||||
|
"language": "en",
|
||||||
|
'stream': True,
|
||||||
|
"model": "openai/whisper-large-v3",
|
||||||
|
}
|
||||||
|
url = openai_api_base + "/audio/transcriptions"
|
||||||
|
print("transcription result:", end=' ')
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
with open(str(winning_call), "rb") as f:
|
||||||
|
async with client.stream('POST', url, files={'file': f},
|
||||||
|
data=data) as response:
|
||||||
|
async for line in response.aiter_lines():
|
||||||
|
# Each line is a JSON object prefixed with 'data: '
|
||||||
|
if line:
|
||||||
|
if line.startswith('data: '):
|
||||||
|
line = line[len('data: '):]
|
||||||
|
# Last chunk, stream ends
|
||||||
|
if line.strip() == '[DONE]':
|
||||||
|
break
|
||||||
|
# Parse the JSON response
|
||||||
|
chunk = json.loads(line)
|
||||||
|
# Extract and print the content
|
||||||
|
content = chunk['choices'][0].get('delta',
|
||||||
|
{}).get('content')
|
||||||
|
print(content, end='')
|
||||||
|
|
||||||
|
|
||||||
|
# Run the asynchronous function
|
||||||
|
asyncio.run(stream_openai_response())
|
||||||
|
|||||||
@ -3,12 +3,14 @@
|
|||||||
# imports for guided decoding tests
|
# imports for guided decoding tests
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import librosa
|
import librosa
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import openai
|
import openai
|
||||||
import pytest
|
import pytest
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
|
from openai._base_client import AsyncAPIClient
|
||||||
|
|
||||||
from vllm.assets.audio import AudioAsset
|
from vllm.assets.audio import AudioAsset
|
||||||
|
|
||||||
@ -120,3 +122,73 @@ async def test_completion_endpoints():
|
|||||||
res = await client.completions.create(model=model_name, prompt="Hello")
|
res = await client.completions.create(model=model_name, prompt="Hello")
|
||||||
assert res.code == 400
|
assert res.code == 400
|
||||||
assert res.message == "The model does not support Completions API"
|
assert res.message == "The model does not support Completions API"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_streaming_response(winning_call):
|
||||||
|
model_name = "openai/whisper-small"
|
||||||
|
server_args = ["--enforce-eager"]
|
||||||
|
transcription = ""
|
||||||
|
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
||||||
|
client = remote_server.get_async_client()
|
||||||
|
res_no_stream = await client.audio.transcriptions.create(
|
||||||
|
model=model_name,
|
||||||
|
file=winning_call,
|
||||||
|
response_format="json",
|
||||||
|
language="en",
|
||||||
|
temperature=0.0)
|
||||||
|
# Unfortunately this only works when the openai client is patched
|
||||||
|
# to use streaming mode, not exposed in the transcription api.
|
||||||
|
original_post = AsyncAPIClient.post
|
||||||
|
|
||||||
|
async def post_with_stream(*args, **kwargs):
|
||||||
|
kwargs['stream'] = True
|
||||||
|
return await original_post(*args, **kwargs)
|
||||||
|
|
||||||
|
with patch.object(AsyncAPIClient, "post", new=post_with_stream):
|
||||||
|
client = remote_server.get_async_client()
|
||||||
|
res = await client.audio.transcriptions.create(
|
||||||
|
model=model_name,
|
||||||
|
file=winning_call,
|
||||||
|
language="en",
|
||||||
|
temperature=0.0,
|
||||||
|
extra_body=dict(stream=True))
|
||||||
|
# Reconstruct from chunks and validate
|
||||||
|
async for chunk in res:
|
||||||
|
# just a chunk
|
||||||
|
text = chunk.choices[0]['delta']['content']
|
||||||
|
transcription += text
|
||||||
|
|
||||||
|
assert transcription == res_no_stream.text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stream_options(winning_call):
|
||||||
|
model_name = "openai/whisper-small"
|
||||||
|
server_args = ["--enforce-eager"]
|
||||||
|
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
||||||
|
original_post = AsyncAPIClient.post
|
||||||
|
|
||||||
|
async def post_with_stream(*args, **kwargs):
|
||||||
|
kwargs['stream'] = True
|
||||||
|
return await original_post(*args, **kwargs)
|
||||||
|
|
||||||
|
with patch.object(AsyncAPIClient, "post", new=post_with_stream):
|
||||||
|
client = remote_server.get_async_client()
|
||||||
|
res = await client.audio.transcriptions.create(
|
||||||
|
model=model_name,
|
||||||
|
file=winning_call,
|
||||||
|
language="en",
|
||||||
|
temperature=0.0,
|
||||||
|
extra_body=dict(stream=True,
|
||||||
|
stream_include_usage=True,
|
||||||
|
stream_continuous_usage_stats=True))
|
||||||
|
final = False
|
||||||
|
continuous = True
|
||||||
|
async for chunk in res:
|
||||||
|
if not len(chunk.choices):
|
||||||
|
# final usage sent
|
||||||
|
final = True
|
||||||
|
else:
|
||||||
|
continuous = continuous and hasattr(chunk, 'usage')
|
||||||
|
assert final and continuous
|
||||||
|
|||||||
@ -1285,6 +1285,21 @@ class ChatCompletionStreamResponse(OpenAIBaseModel):
|
|||||||
usage: Optional[UsageInfo] = Field(default=None)
|
usage: Optional[UsageInfo] = Field(default=None)
|
||||||
|
|
||||||
|
|
||||||
|
class TranscriptionResponseStreamChoice(OpenAIBaseModel):
|
||||||
|
delta: DeltaMessage
|
||||||
|
finish_reason: Optional[str] = None
|
||||||
|
stop_reason: Optional[Union[int, str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class TranscriptionStreamResponse(OpenAIBaseModel):
|
||||||
|
id: str = Field(default_factory=lambda: f"trsc-{random_uuid()}")
|
||||||
|
object: Literal["transcription.chunk"] = "transcription.chunk"
|
||||||
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
|
model: str
|
||||||
|
choices: list[TranscriptionResponseStreamChoice]
|
||||||
|
usage: Optional[UsageInfo] = Field(default=None)
|
||||||
|
|
||||||
|
|
||||||
class BatchRequestInput(OpenAIBaseModel):
|
class BatchRequestInput(OpenAIBaseModel):
|
||||||
"""
|
"""
|
||||||
The per-line object of the batch input file.
|
The per-line object of the batch input file.
|
||||||
@ -1510,6 +1525,15 @@ class TranscriptionRequest(OpenAIBaseModel):
|
|||||||
timestamps incurs additional latency.
|
timestamps incurs additional latency.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
stream: Optional[bool] = False
|
||||||
|
"""Custom field not present in the original OpenAI definition. When set,
|
||||||
|
it will enable output to be streamed in a similar fashion as the Chat
|
||||||
|
Completion endpoint.
|
||||||
|
"""
|
||||||
|
# Flattened stream option to simplify form data.
|
||||||
|
stream_include_usage: Optional[bool] = False
|
||||||
|
stream_continuous_usage_stats: Optional[bool] = False
|
||||||
|
|
||||||
# Default sampling parameters for transcription requests.
|
# Default sampling parameters for transcription requests.
|
||||||
_DEFAULT_SAMPLING_PARAMS: dict = {
|
_DEFAULT_SAMPLING_PARAMS: dict = {
|
||||||
"temperature": 0,
|
"temperature": 0,
|
||||||
@ -1530,7 +1554,21 @@ class TranscriptionRequest(OpenAIBaseModel):
|
|||||||
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
|
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
|
||||||
|
|
||||||
return SamplingParams.from_optional(temperature=temperature,
|
return SamplingParams.from_optional(temperature=temperature,
|
||||||
max_tokens=max_tokens)
|
max_tokens=max_tokens,
|
||||||
|
output_kind=RequestOutputKind.DELTA
|
||||||
|
if self.stream \
|
||||||
|
else RequestOutputKind.FINAL_ONLY)
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_stream_options(cls, data):
|
||||||
|
stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"]
|
||||||
|
stream = data.get("stream", False)
|
||||||
|
if any(bool(data.get(so, False)) for so in stream_opts) and not stream:
|
||||||
|
raise ValueError(
|
||||||
|
"Stream options can only be defined when `stream=True`.")
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
# Transcription response objects
|
# Transcription response objects
|
||||||
|
|||||||
@ -1,24 +1,26 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
import asyncio
|
import asyncio
|
||||||
import io
|
import io
|
||||||
|
import time
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from typing import Optional, Union, cast
|
from math import ceil
|
||||||
|
from typing import Final, Optional, Union, cast
|
||||||
|
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.engine.protocol import EngineClient
|
from vllm.engine.protocol import EngineClient
|
||||||
from vllm.entrypoints.logger import RequestLogger
|
from vllm.entrypoints.logger import RequestLogger
|
||||||
from vllm.entrypoints.openai.protocol import (ErrorResponse,
|
from vllm.entrypoints.openai.protocol import (
|
||||||
RequestResponseMetadata,
|
DeltaMessage, ErrorResponse, RequestResponseMetadata, TranscriptionRequest,
|
||||||
TranscriptionRequest,
|
TranscriptionResponse, TranscriptionResponseStreamChoice,
|
||||||
TranscriptionResponse,
|
TranscriptionStreamResponse, UsageInfo)
|
||||||
TranscriptionResponseVerbose)
|
|
||||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||||
from vllm.inputs.data import PromptType
|
from vllm.inputs.data import PromptType
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
|
from vllm.transformers_utils.processor import cached_get_processor
|
||||||
from vllm.utils import PlaceholderModule
|
from vllm.utils import PlaceholderModule
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -140,8 +142,6 @@ ISO639_1_OTHER_LANGS = {
|
|||||||
# As per https://platform.openai.com/docs/guides/speech-to-text#overview.
|
# As per https://platform.openai.com/docs/guides/speech-to-text#overview.
|
||||||
# TODO configurable
|
# TODO configurable
|
||||||
MAX_AUDIO_CLIP_FILESIZE_MB = 25
|
MAX_AUDIO_CLIP_FILESIZE_MB = 25
|
||||||
# TODO get from processor.feature_extractor.chunk_length
|
|
||||||
MAX_AUDIO_CLIP_DURATION_S = 30
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIServingTranscription(OpenAIServing):
|
class OpenAIServingTranscription(OpenAIServing):
|
||||||
@ -163,6 +163,11 @@ class OpenAIServingTranscription(OpenAIServing):
|
|||||||
|
|
||||||
self.default_sampling_params = (
|
self.default_sampling_params = (
|
||||||
self.model_config.get_diff_sampling_param())
|
self.model_config.get_diff_sampling_param())
|
||||||
|
processor = cached_get_processor(model_config.model)
|
||||||
|
self.max_audio_clip_s = processor.feature_extractor.chunk_length
|
||||||
|
self.model_sr = processor.feature_extractor.sampling_rate
|
||||||
|
self.hop_length = processor.feature_extractor.hop_length
|
||||||
|
|
||||||
if self.default_sampling_params:
|
if self.default_sampling_params:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Overwriting default completion sampling param with: %s",
|
"Overwriting default completion sampling param with: %s",
|
||||||
@ -172,7 +177,7 @@ class OpenAIServingTranscription(OpenAIServing):
|
|||||||
self,
|
self,
|
||||||
request: TranscriptionRequest,
|
request: TranscriptionRequest,
|
||||||
audio_data: bytes,
|
audio_data: bytes,
|
||||||
) -> PromptType:
|
) -> tuple[PromptType, float]:
|
||||||
# Validate request
|
# Validate request
|
||||||
# TODO language should be optional and can be guessed.
|
# TODO language should be optional and can be guessed.
|
||||||
# For now we default to en. See
|
# For now we default to en. See
|
||||||
@ -198,9 +203,11 @@ class OpenAIServingTranscription(OpenAIServing):
|
|||||||
|
|
||||||
with io.BytesIO(audio_data) as bytes_:
|
with io.BytesIO(audio_data) as bytes_:
|
||||||
y, sr = librosa.load(bytes_)
|
y, sr = librosa.load(bytes_)
|
||||||
if librosa.get_duration(y=y, sr=sr) > MAX_AUDIO_CLIP_DURATION_S:
|
|
||||||
|
duration = librosa.get_duration(y=y, sr=sr)
|
||||||
|
if duration > self.max_audio_clip_s:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Maximum clip duration ({MAX_AUDIO_CLIP_DURATION_S}s) "
|
f"Maximum clip duration ({self.max_audio_clip_s}s) "
|
||||||
"exceeded.")
|
"exceeded.")
|
||||||
|
|
||||||
prompt = {
|
prompt = {
|
||||||
@ -213,13 +220,13 @@ class OpenAIServingTranscription(OpenAIServing):
|
|||||||
"decoder_prompt":
|
"decoder_prompt":
|
||||||
f"<|startoftranscript|>{lang_token}<|transcribe|><|notimestamps|>{request.prompt}"
|
f"<|startoftranscript|>{lang_token}<|transcribe|><|notimestamps|>{request.prompt}"
|
||||||
}
|
}
|
||||||
return cast(PromptType, prompt)
|
return cast(PromptType, prompt), duration
|
||||||
|
|
||||||
# TODO (varun) : Make verbose response work !
|
# TODO (varun) : Make verbose response work !
|
||||||
async def create_transcription(
|
async def create_transcription(
|
||||||
self, audio_data: bytes, request: TranscriptionRequest,
|
self, audio_data: bytes, request: TranscriptionRequest,
|
||||||
raw_request: Request
|
raw_request: Request
|
||||||
) -> Union[TranscriptionResponse, TranscriptionResponseVerbose,
|
) -> Union[TranscriptionResponse, AsyncGenerator[str, None],
|
||||||
ErrorResponse]:
|
ErrorResponse]:
|
||||||
"""Transcription API similar to OpenAI's API.
|
"""Transcription API similar to OpenAI's API.
|
||||||
|
|
||||||
@ -240,8 +247,7 @@ class OpenAIServingTranscription(OpenAIServing):
|
|||||||
return self.create_error_response(
|
return self.create_error_response(
|
||||||
"Currently only support response_format `text` or `json`")
|
"Currently only support response_format `text` or `json`")
|
||||||
|
|
||||||
# TODO cmpl->transcription?
|
request_id = f"trsc-{self._base_request_id(raw_request)}"
|
||||||
request_id = f"cmpl-{self._base_request_id(raw_request)}"
|
|
||||||
|
|
||||||
request_metadata = RequestResponseMetadata(request_id=request_id)
|
request_metadata = RequestResponseMetadata(request_id=request_id)
|
||||||
if raw_request:
|
if raw_request:
|
||||||
@ -261,7 +267,7 @@ class OpenAIServingTranscription(OpenAIServing):
|
|||||||
"Currently do not support PromptAdapter for Transcription."
|
"Currently do not support PromptAdapter for Transcription."
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = await self._preprocess_transcription(
|
prompt, duration_s = await self._preprocess_transcription(
|
||||||
request=request,
|
request=request,
|
||||||
audio_data=audio_data,
|
audio_data=audio_data,
|
||||||
)
|
)
|
||||||
@ -293,7 +299,12 @@ class OpenAIServingTranscription(OpenAIServing):
|
|||||||
# TODO: Use a vllm-specific Validation Error
|
# TODO: Use a vllm-specific Validation Error
|
||||||
return self.create_error_response(str(e))
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
# TODO(rob): figure out a way to pipe streaming in.
|
if request.stream:
|
||||||
|
return self.transcription_stream_generator(request,
|
||||||
|
result_generator,
|
||||||
|
request_id,
|
||||||
|
request_metadata,
|
||||||
|
duration_s)
|
||||||
# Non-streaming response.
|
# Non-streaming response.
|
||||||
try:
|
try:
|
||||||
assert result_generator is not None
|
assert result_generator is not None
|
||||||
@ -305,3 +316,106 @@ class OpenAIServingTranscription(OpenAIServing):
|
|||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
# TODO: Use a vllm-specific Validation Error
|
# TODO: Use a vllm-specific Validation Error
|
||||||
return self.create_error_response(str(e))
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
|
async def transcription_stream_generator(
|
||||||
|
self, request: TranscriptionRequest,
|
||||||
|
result_generator: AsyncGenerator[RequestOutput, None],
|
||||||
|
request_id: str, request_metadata: RequestResponseMetadata,
|
||||||
|
audio_duration_s: float) -> AsyncGenerator[str, None]:
|
||||||
|
created_time = int(time.time())
|
||||||
|
model_name = request.model
|
||||||
|
chunk_object_type: Final = "transcription.chunk"
|
||||||
|
|
||||||
|
completion_tokens = 0
|
||||||
|
num_prompt_tokens = 0
|
||||||
|
|
||||||
|
include_usage = request.stream_include_usage \
|
||||||
|
if request.stream_include_usage else False
|
||||||
|
include_continuous_usage = request.stream_continuous_usage_stats\
|
||||||
|
if include_usage and request.stream_continuous_usage_stats\
|
||||||
|
else False
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for res in result_generator:
|
||||||
|
# On first result.
|
||||||
|
if res.prompt_token_ids is not None:
|
||||||
|
# Do not account the 4-tokens `<|startoftranscript|>..`
|
||||||
|
# Could be negative when language token is not specified.
|
||||||
|
num_prompt_tokens = max(len(res.prompt_token_ids) - 4, 0)
|
||||||
|
# NOTE(NickLucche) user can't pass encoder prompts directly
|
||||||
|
# at least not to Whisper. One indicator of the encoder
|
||||||
|
# amount of processing is the log-mel spectogram length.
|
||||||
|
num_prompt_tokens += ceil(audio_duration_s *
|
||||||
|
self.model_sr / self.hop_length)
|
||||||
|
|
||||||
|
# We need to do it here, because if there are exceptions in
|
||||||
|
# the result_generator, it needs to be sent as the FIRST
|
||||||
|
# response (by the try...catch).
|
||||||
|
|
||||||
|
# Just one output (n=1) supported.
|
||||||
|
assert len(res.outputs) == 1
|
||||||
|
output = res.outputs[0]
|
||||||
|
|
||||||
|
delta_message = DeltaMessage(content=output.text)
|
||||||
|
completion_tokens += len(output.token_ids)
|
||||||
|
|
||||||
|
if output.finish_reason is None:
|
||||||
|
# Still generating, send delta update.
|
||||||
|
choice_data = TranscriptionResponseStreamChoice(
|
||||||
|
delta=delta_message)
|
||||||
|
else:
|
||||||
|
# Model is finished generating.
|
||||||
|
choice_data = TranscriptionResponseStreamChoice(
|
||||||
|
delta=delta_message,
|
||||||
|
finish_reason=output.finish_reason,
|
||||||
|
stop_reason=output.stop_reason)
|
||||||
|
|
||||||
|
chunk = TranscriptionStreamResponse(id=request_id,
|
||||||
|
object=chunk_object_type,
|
||||||
|
created=created_time,
|
||||||
|
choices=[choice_data],
|
||||||
|
model=model_name)
|
||||||
|
|
||||||
|
# handle usage stats if requested & if continuous
|
||||||
|
if include_continuous_usage:
|
||||||
|
chunk.usage = UsageInfo(
|
||||||
|
prompt_tokens=num_prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=num_prompt_tokens + completion_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
data = chunk.model_dump_json(exclude_unset=True)
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
|
||||||
|
# Once the final token is handled, if stream_options.include_usage
|
||||||
|
# is sent, send the usage.
|
||||||
|
if include_usage:
|
||||||
|
final_usage = UsageInfo(prompt_tokens=num_prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=num_prompt_tokens +
|
||||||
|
completion_tokens)
|
||||||
|
|
||||||
|
final_usage_chunk = TranscriptionStreamResponse(
|
||||||
|
id=request_id,
|
||||||
|
object=chunk_object_type,
|
||||||
|
created=created_time,
|
||||||
|
choices=[],
|
||||||
|
model=model_name,
|
||||||
|
usage=final_usage)
|
||||||
|
final_usage_data = (final_usage_chunk.model_dump_json(
|
||||||
|
exclude_unset=True, exclude_none=True))
|
||||||
|
yield f"data: {final_usage_data}\n\n"
|
||||||
|
|
||||||
|
# report to FastAPI middleware aggregate usage across all choices
|
||||||
|
request_metadata.final_usage_info = UsageInfo(
|
||||||
|
prompt_tokens=num_prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=num_prompt_tokens + completion_tokens)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# TODO: Use a vllm-specific Validation Error
|
||||||
|
logger.exception("Error in chat completion stream generator.")
|
||||||
|
data = self.create_streaming_error_response(str(e))
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
# Send the final done message after all response.n are finished
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user