mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-17 05:37:07 +08:00
[Frontend] add 'verbose_json' and 'timestamp' feature on Whisper Transcription/Translation (#24209)
Signed-off-by: sangbumlikeagod <oironese@naver.com> Signed-off-by: sangbumlikeagod <98077576+sangbumlikeagod@users.noreply.github.com>
This commit is contained in:
parent
5d43f7372e
commit
092bb73b8a
@ -456,6 +456,7 @@ For `verbose_json` response format:
|
||||
]
|
||||
}
|
||||
```
|
||||
Currently “verbose_json” response format doesn’t support avg_logprob, compression_ratio, no_speech_prob.
|
||||
|
||||
#### Extra Parameters
|
||||
|
||||
|
||||
@ -235,3 +235,16 @@ async def test_audio_prompt(mary_had_lamb, whisper_client):
|
||||
)
|
||||
out_prompt = json.loads(transcription_wprompt)["text"]
|
||||
assert prefix in out_prompt
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audio_with_timestamp(mary_had_lamb, whisper_client):
|
||||
transcription = await whisper_client.audio.transcriptions.create(
|
||||
model=MODEL_NAME,
|
||||
file=mary_had_lamb,
|
||||
language="en",
|
||||
response_format="verbose_json",
|
||||
temperature=0.0,
|
||||
)
|
||||
assert transcription.segments is not None
|
||||
assert len(transcription.segments) > 0
|
||||
|
||||
@ -68,9 +68,9 @@ from vllm.entrypoints.openai.protocol import (
|
||||
TokenizeRequest,
|
||||
TokenizeResponse,
|
||||
TranscriptionRequest,
|
||||
TranscriptionResponse,
|
||||
TranscriptionResponseVariant,
|
||||
TranslationRequest,
|
||||
TranslationResponse,
|
||||
TranslationResponseVariant,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||
@ -809,7 +809,7 @@ async def create_transcriptions(
|
||||
content=generator.model_dump(), status_code=generator.error.code
|
||||
)
|
||||
|
||||
elif isinstance(generator, TranscriptionResponse):
|
||||
elif isinstance(generator, TranscriptionResponseVariant):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
@ -848,7 +848,7 @@ async def create_translations(
|
||||
content=generator.model_dump(), status_code=generator.error.code
|
||||
)
|
||||
|
||||
elif isinstance(generator, TranslationResponse):
|
||||
elif isinstance(generator, TranslationResponseVariant):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
|
||||
@ -2126,13 +2126,13 @@ class TranscriptionSegment(OpenAIBaseModel):
|
||||
id: int
|
||||
"""Unique identifier of the segment."""
|
||||
|
||||
avg_logprob: float
|
||||
avg_logprob: float | None = None
|
||||
"""Average logprob of the segment.
|
||||
|
||||
If the value is lower than -1, consider the logprobs failed.
|
||||
"""
|
||||
|
||||
compression_ratio: float
|
||||
compression_ratio: float | None = None
|
||||
"""Compression ratio of the segment.
|
||||
|
||||
If the value is greater than 2.4, consider the compression failed.
|
||||
@ -2141,7 +2141,7 @@ class TranscriptionSegment(OpenAIBaseModel):
|
||||
end: float
|
||||
"""End time of the segment in seconds."""
|
||||
|
||||
no_speech_prob: float
|
||||
no_speech_prob: float | None = None
|
||||
"""Probability of no speech in the segment.
|
||||
|
||||
If the value is higher than 1.0 and the `avg_logprob` is below -1, consider
|
||||
@ -2181,6 +2181,11 @@ class TranscriptionResponseVerbose(OpenAIBaseModel):
|
||||
"""Extracted words and their corresponding timestamps."""
|
||||
|
||||
|
||||
TranscriptionResponseVariant: TypeAlias = (
|
||||
TranscriptionResponse | TranscriptionResponseVerbose
|
||||
)
|
||||
|
||||
|
||||
class TranslationResponseStreamChoice(OpenAIBaseModel):
|
||||
delta: DeltaMessage
|
||||
finish_reason: str | None = None
|
||||
@ -2325,13 +2330,13 @@ class TranslationSegment(OpenAIBaseModel):
|
||||
id: int
|
||||
"""Unique identifier of the segment."""
|
||||
|
||||
avg_logprob: float
|
||||
avg_logprob: float | None = None
|
||||
"""Average logprob of the segment.
|
||||
|
||||
If the value is lower than -1, consider the logprobs failed.
|
||||
"""
|
||||
|
||||
compression_ratio: float
|
||||
compression_ratio: float | None = None
|
||||
"""Compression ratio of the segment.
|
||||
|
||||
If the value is greater than 2.4, consider the compression failed.
|
||||
@ -2340,7 +2345,7 @@ class TranslationSegment(OpenAIBaseModel):
|
||||
end: float
|
||||
"""End time of the segment in seconds."""
|
||||
|
||||
no_speech_prob: float
|
||||
no_speech_prob: float | None = None
|
||||
"""Probability of no speech in the segment.
|
||||
|
||||
If the value is higher than 1.0 and the `avg_logprob` is below -1, consider
|
||||
@ -2380,6 +2385,9 @@ class TranslationResponseVerbose(OpenAIBaseModel):
|
||||
"""Extracted words and their corresponding timestamps."""
|
||||
|
||||
|
||||
TranslationResponseVariant: TypeAlias = TranslationResponse | TranslationResponseVerbose
|
||||
|
||||
|
||||
####### Tokens IN <> Tokens OUT #######
|
||||
class GenerateRequest(BaseModel):
|
||||
request_id: str = Field(
|
||||
|
||||
@ -12,10 +12,12 @@ from vllm.entrypoints.openai.protocol import (
|
||||
TranscriptionRequest,
|
||||
TranscriptionResponse,
|
||||
TranscriptionResponseStreamChoice,
|
||||
TranscriptionResponseVerbose,
|
||||
TranscriptionStreamResponse,
|
||||
TranslationRequest,
|
||||
TranslationResponse,
|
||||
TranslationResponseStreamChoice,
|
||||
TranslationResponseVerbose,
|
||||
TranslationStreamResponse,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
@ -51,7 +53,12 @@ class OpenAIServingTranscription(OpenAISpeechToText):
|
||||
|
||||
async def create_transcription(
|
||||
self, audio_data: bytes, request: TranscriptionRequest, raw_request: Request
|
||||
) -> TranscriptionResponse | AsyncGenerator[str, None] | ErrorResponse:
|
||||
) -> (
|
||||
TranscriptionResponse
|
||||
| TranscriptionResponseVerbose
|
||||
| AsyncGenerator[str, None]
|
||||
| ErrorResponse
|
||||
):
|
||||
"""Transcription API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/audio/createTranscription
|
||||
@ -61,7 +68,11 @@ class OpenAIServingTranscription(OpenAISpeechToText):
|
||||
audio_data=audio_data,
|
||||
request=request,
|
||||
raw_request=raw_request,
|
||||
response_class=TranscriptionResponse,
|
||||
response_class=(
|
||||
TranscriptionResponseVerbose
|
||||
if request.response_format == "verbose_json"
|
||||
else TranscriptionResponse
|
||||
),
|
||||
stream_generator_method=self.transcription_stream_generator,
|
||||
)
|
||||
|
||||
@ -112,7 +123,12 @@ class OpenAIServingTranslation(OpenAISpeechToText):
|
||||
|
||||
async def create_translation(
|
||||
self, audio_data: bytes, request: TranslationRequest, raw_request: Request
|
||||
) -> TranslationResponse | AsyncGenerator[str, None] | ErrorResponse:
|
||||
) -> (
|
||||
TranslationResponse
|
||||
| TranslationResponseVerbose
|
||||
| AsyncGenerator[str, None]
|
||||
| ErrorResponse
|
||||
):
|
||||
"""Translation API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/audio/createTranslation
|
||||
@ -122,7 +138,11 @@ class OpenAIServingTranslation(OpenAISpeechToText):
|
||||
audio_data=audio_data,
|
||||
request=request,
|
||||
raw_request=raw_request,
|
||||
response_class=TranslationResponse,
|
||||
response_class=(
|
||||
TranslationResponseVerbose
|
||||
if request.response_format == "verbose_json"
|
||||
else TranslationResponse
|
||||
),
|
||||
stream_generator_method=self.translation_stream_generator,
|
||||
)
|
||||
|
||||
|
||||
@ -10,6 +10,7 @@ from typing import Literal, TypeAlias, TypeVar, cast
|
||||
|
||||
import numpy as np
|
||||
from fastapi import Request
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.engine.protocol import EngineClient
|
||||
@ -20,9 +21,13 @@ from vllm.entrypoints.openai.protocol import (
|
||||
RequestResponseMetadata,
|
||||
TranscriptionResponse,
|
||||
TranscriptionResponseStreamChoice,
|
||||
TranscriptionResponseVerbose,
|
||||
TranscriptionSegment,
|
||||
TranscriptionStreamResponse,
|
||||
TranslationResponse,
|
||||
TranslationResponseStreamChoice,
|
||||
TranslationResponseVerbose,
|
||||
TranslationSegment,
|
||||
TranslationStreamResponse,
|
||||
UsageInfo,
|
||||
)
|
||||
@ -32,6 +37,7 @@ from vllm.inputs.data import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models import SupportsTranscription
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
|
||||
try:
|
||||
@ -40,7 +46,20 @@ except ImportError:
|
||||
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
|
||||
|
||||
SpeechToTextResponse: TypeAlias = TranscriptionResponse | TranslationResponse
|
||||
SpeechToTextResponseVerbose: TypeAlias = (
|
||||
TranscriptionResponseVerbose | TranslationResponseVerbose
|
||||
)
|
||||
SpeechToTextSegment: TypeAlias = TranscriptionSegment | TranslationSegment
|
||||
T = TypeVar("T", bound=SpeechToTextResponse)
|
||||
V = TypeVar("V", bound=SpeechToTextResponseVerbose)
|
||||
S = TypeVar("S", bound=SpeechToTextSegment)
|
||||
|
||||
ResponseType: TypeAlias = (
|
||||
TranscriptionResponse
|
||||
| TranslationResponse
|
||||
| TranscriptionResponseVerbose
|
||||
| TranslationResponseVerbose
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -78,6 +97,14 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
self.enable_force_include_usage = enable_force_include_usage
|
||||
|
||||
self.max_audio_filesize_mb = envs.VLLM_MAX_AUDIO_CLIP_FILESIZE_MB
|
||||
if self.model_cls.supports_segment_timestamp:
|
||||
self.tokenizer = cast(
|
||||
PreTrainedTokenizerBase,
|
||||
get_tokenizer(
|
||||
tokenizer_name=self.model_config.tokenizer,
|
||||
tokenizer_mode=self.model_config.tokenizer_mode,
|
||||
),
|
||||
)
|
||||
|
||||
if self.default_sampling_params:
|
||||
logger.info(
|
||||
@ -133,17 +160,87 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
request_prompt=request.prompt,
|
||||
to_language=to_language,
|
||||
)
|
||||
if request.response_format == "verbose_json":
|
||||
if not isinstance(prompt, dict):
|
||||
raise ValueError(f"Expected prompt to be a dict,got {type(prompt)}")
|
||||
prompt_dict = cast(dict, prompt)
|
||||
decoder_prompt = prompt.get("decoder_prompt")
|
||||
if not isinstance(decoder_prompt, str):
|
||||
raise ValueError(
|
||||
f"Expected decoder_prompt to bestr, got {type(decoder_prompt)}"
|
||||
)
|
||||
prompt_dict["decoder_prompt"] = decoder_prompt.replace(
|
||||
"<|notimestamps|>", "<|0.00|>"
|
||||
)
|
||||
prompts.append(prompt)
|
||||
return prompts, duration
|
||||
|
||||
def _get_verbose_segments(
|
||||
self,
|
||||
tokens: tuple,
|
||||
request: SpeechToTextRequest,
|
||||
segment_class: type[SpeechToTextSegment],
|
||||
start_time: float = 0,
|
||||
) -> list[SpeechToTextSegment]:
|
||||
"""
|
||||
Convert tokens to verbose segments.
|
||||
|
||||
This method expects the model to produce
|
||||
timestamps as tokens (similar to Whisper).
|
||||
If the tokens do not include timestamp information,
|
||||
the segments may not be generated correctly.
|
||||
|
||||
Note: Fields like avg_logprob, compression_ratio,
|
||||
and no_speech_prob are not supported
|
||||
in this implementation and will be None. See docs for details.
|
||||
"""
|
||||
BASE_OFFSET = 0.02
|
||||
init_token = self.tokenizer.encode("<|0.00|>", add_special_tokens=False)[0]
|
||||
if tokens[-1] == self.tokenizer.eos_token_id:
|
||||
tokens = tokens[:-1]
|
||||
|
||||
tokens_with_start = (init_token,) + tokens
|
||||
segments: list[SpeechToTextSegment] = []
|
||||
last_timestamp_start = 0
|
||||
|
||||
if tokens_with_start[-2] < init_token and tokens_with_start[-1] >= init_token:
|
||||
tokens_with_start = tokens_with_start + (tokens_with_start[-1],)
|
||||
for idx, token in enumerate(tokens_with_start):
|
||||
# Timestamp tokens (e.g., <|0.00|>) are assumed to be sorted.
|
||||
# If the ordering is violated, this slicing may produce incorrect results.
|
||||
if (
|
||||
token >= init_token
|
||||
and idx != 0
|
||||
and tokens_with_start[idx - 1] >= init_token
|
||||
):
|
||||
sliced_timestamp_tokens = tokens_with_start[last_timestamp_start:idx]
|
||||
start_timestamp = sliced_timestamp_tokens[0] - init_token
|
||||
end_timestamp = sliced_timestamp_tokens[-1] - init_token
|
||||
|
||||
casting_segment = cast(
|
||||
SpeechToTextSegment,
|
||||
segment_class(
|
||||
id=len(segments),
|
||||
seek=start_time,
|
||||
start=start_time + BASE_OFFSET * start_timestamp,
|
||||
end=start_time + BASE_OFFSET * end_timestamp,
|
||||
temperature=request.temperature,
|
||||
text=self.tokenizer.decode(sliced_timestamp_tokens[1:-1]),
|
||||
tokens=sliced_timestamp_tokens[1:-1],
|
||||
),
|
||||
)
|
||||
segments.append(casting_segment)
|
||||
last_timestamp_start = idx
|
||||
return segments
|
||||
|
||||
async def _create_speech_to_text(
|
||||
self,
|
||||
audio_data: bytes,
|
||||
request: SpeechToTextRequest,
|
||||
raw_request: Request,
|
||||
response_class: type[T],
|
||||
response_class: type[T | V],
|
||||
stream_generator_method: Callable[..., AsyncGenerator[str, None]],
|
||||
) -> T | AsyncGenerator[str, None] | ErrorResponse:
|
||||
) -> T | V | AsyncGenerator[str, None] | ErrorResponse:
|
||||
"""Base method for speech-to-text operations like transcription and
|
||||
translation."""
|
||||
error_check_ret = await self._check_model(request)
|
||||
@ -156,11 +253,24 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
if self.engine_client.errored:
|
||||
raise self.engine_client.dead_error
|
||||
|
||||
if request.response_format not in ["text", "json"]:
|
||||
if request.response_format not in ["text", "json", "verbose_json"]:
|
||||
return self.create_error_response(
|
||||
"Currently only support response_format `text` or `json`"
|
||||
("Currently only support response_format")
|
||||
+ ("`text`, `json` or `verbose_json`")
|
||||
)
|
||||
|
||||
if (
|
||||
request.response_format == "verbose_json"
|
||||
and not self.model_cls.supports_segment_timestamp
|
||||
):
|
||||
return self.create_error_response(
|
||||
f"Currently do not support verbose_json for {request.model}"
|
||||
)
|
||||
|
||||
if request.response_format == "verbose_json" and request.stream:
|
||||
return self.create_error_response(
|
||||
"verbose_json format doesn't support streaming case"
|
||||
)
|
||||
request_id = f"{self.task_type}-{self._base_request_id(raw_request)}"
|
||||
|
||||
request_metadata = RequestResponseMetadata(request_id=request_id)
|
||||
@ -215,25 +325,69 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
request, list_result_generator, request_id, request_metadata, duration_s
|
||||
)
|
||||
# Non-streaming response.
|
||||
total_segments = []
|
||||
text_parts = []
|
||||
try:
|
||||
assert list_result_generator is not None
|
||||
segments_types: dict[str, type[SpeechToTextSegment]] = {
|
||||
"transcribe": TranscriptionSegment,
|
||||
"translate": TranslationSegment,
|
||||
}
|
||||
segment_class: type[SpeechToTextSegment] = segments_types[self.task_type]
|
||||
text = ""
|
||||
for result_generator in list_result_generator:
|
||||
for idx, result_generator in enumerate(list_result_generator):
|
||||
async for op in result_generator:
|
||||
text += op.outputs[0].text
|
||||
if request.response_format == "verbose_json":
|
||||
segments: list[SpeechToTextSegment] = (
|
||||
self._get_verbose_segments(
|
||||
tokens=tuple(op.outputs[0].token_ids),
|
||||
segment_class=segment_class,
|
||||
request=request,
|
||||
start_time=idx * self.asr_config.max_audio_clip_s,
|
||||
)
|
||||
)
|
||||
|
||||
total_segments.extend(segments)
|
||||
text_parts.extend([seg.text for seg in segments])
|
||||
else:
|
||||
text_parts.append(op.outputs[0].text)
|
||||
text = "".join(text_parts)
|
||||
if self.task_type == "transcribe":
|
||||
final_response: ResponseType
|
||||
# add usage in TranscriptionResponse.
|
||||
usage = {
|
||||
"type": "duration",
|
||||
# rounded up as per openAI specs
|
||||
"seconds": int(math.ceil(duration_s)),
|
||||
}
|
||||
final_response = cast(T, response_class(text=text, usage=usage))
|
||||
if request.response_format != "verbose_json":
|
||||
final_response = cast(
|
||||
T, TranscriptionResponse(text=text, usage=usage)
|
||||
)
|
||||
else:
|
||||
final_response = cast(
|
||||
V,
|
||||
TranscriptionResponseVerbose(
|
||||
text=text,
|
||||
language=request.language,
|
||||
duration=str(duration_s),
|
||||
segments=total_segments,
|
||||
),
|
||||
)
|
||||
else:
|
||||
# no usage in response for translation task
|
||||
final_response = cast(T, response_class(text=text)) # type: ignore[call-arg]
|
||||
|
||||
if request.response_format != "verbose_json":
|
||||
final_response = cast(T, TranslationResponse(text=text))
|
||||
else:
|
||||
final_response = cast(
|
||||
V,
|
||||
TranslationResponseVerbose(
|
||||
text=text,
|
||||
language=request.language,
|
||||
duration=str(duration_s),
|
||||
segments=total_segments,
|
||||
),
|
||||
)
|
||||
return final_response
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
|
||||
@ -837,6 +837,10 @@ class SupportsTranscription(Protocol):
|
||||
Transcription models can opt out of text generation by setting this to
|
||||
`True`.
|
||||
"""
|
||||
supports_segment_timestamp: ClassVar[bool] = False
|
||||
"""
|
||||
Enables the segment timestamp option for supported models by setting this to `True`.
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
@ -791,6 +791,7 @@ class WhisperForConditionalGeneration(
|
||||
|
||||
# Whisper only supports audio-conditioned generation.
|
||||
supports_transcription_only = True
|
||||
supports_segment_timestamp = True
|
||||
supported_languages = ISO639_1_SUPPORTED_LANGS
|
||||
|
||||
@classmethod
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user