mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-25 22:27:10 +08:00
[Frontend] add 'verbose_json' and 'timestamp' feature on Whisper Transcription/Translation (#24209)
Signed-off-by: sangbumlikeagod <oironese@naver.com> Signed-off-by: sangbumlikeagod <98077576+sangbumlikeagod@users.noreply.github.com>
This commit is contained in:
parent
5d43f7372e
commit
092bb73b8a
@ -456,6 +456,7 @@ For `verbose_json` response format:
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
Currently “verbose_json” response format doesn’t support avg_logprob, compression_ratio, no_speech_prob.
|
||||||
|
|
||||||
#### Extra Parameters
|
#### Extra Parameters
|
||||||
|
|
||||||
|
|||||||
@ -235,3 +235,16 @@ async def test_audio_prompt(mary_had_lamb, whisper_client):
|
|||||||
)
|
)
|
||||||
out_prompt = json.loads(transcription_wprompt)["text"]
|
out_prompt = json.loads(transcription_wprompt)["text"]
|
||||||
assert prefix in out_prompt
|
assert prefix in out_prompt
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_audio_with_timestamp(mary_had_lamb, whisper_client):
|
||||||
|
transcription = await whisper_client.audio.transcriptions.create(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
file=mary_had_lamb,
|
||||||
|
language="en",
|
||||||
|
response_format="verbose_json",
|
||||||
|
temperature=0.0,
|
||||||
|
)
|
||||||
|
assert transcription.segments is not None
|
||||||
|
assert len(transcription.segments) > 0
|
||||||
|
|||||||
@ -68,9 +68,9 @@ from vllm.entrypoints.openai.protocol import (
|
|||||||
TokenizeRequest,
|
TokenizeRequest,
|
||||||
TokenizeResponse,
|
TokenizeResponse,
|
||||||
TranscriptionRequest,
|
TranscriptionRequest,
|
||||||
TranscriptionResponse,
|
TranscriptionResponseVariant,
|
||||||
TranslationRequest,
|
TranslationRequest,
|
||||||
TranslationResponse,
|
TranslationResponseVariant,
|
||||||
)
|
)
|
||||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||||
@ -809,7 +809,7 @@ async def create_transcriptions(
|
|||||||
content=generator.model_dump(), status_code=generator.error.code
|
content=generator.model_dump(), status_code=generator.error.code
|
||||||
)
|
)
|
||||||
|
|
||||||
elif isinstance(generator, TranscriptionResponse):
|
elif isinstance(generator, TranscriptionResponseVariant):
|
||||||
return JSONResponse(content=generator.model_dump())
|
return JSONResponse(content=generator.model_dump())
|
||||||
|
|
||||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||||
@ -848,7 +848,7 @@ async def create_translations(
|
|||||||
content=generator.model_dump(), status_code=generator.error.code
|
content=generator.model_dump(), status_code=generator.error.code
|
||||||
)
|
)
|
||||||
|
|
||||||
elif isinstance(generator, TranslationResponse):
|
elif isinstance(generator, TranslationResponseVariant):
|
||||||
return JSONResponse(content=generator.model_dump())
|
return JSONResponse(content=generator.model_dump())
|
||||||
|
|
||||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||||
|
|||||||
@ -2126,13 +2126,13 @@ class TranscriptionSegment(OpenAIBaseModel):
|
|||||||
id: int
|
id: int
|
||||||
"""Unique identifier of the segment."""
|
"""Unique identifier of the segment."""
|
||||||
|
|
||||||
avg_logprob: float
|
avg_logprob: float | None = None
|
||||||
"""Average logprob of the segment.
|
"""Average logprob of the segment.
|
||||||
|
|
||||||
If the value is lower than -1, consider the logprobs failed.
|
If the value is lower than -1, consider the logprobs failed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
compression_ratio: float
|
compression_ratio: float | None = None
|
||||||
"""Compression ratio of the segment.
|
"""Compression ratio of the segment.
|
||||||
|
|
||||||
If the value is greater than 2.4, consider the compression failed.
|
If the value is greater than 2.4, consider the compression failed.
|
||||||
@ -2141,7 +2141,7 @@ class TranscriptionSegment(OpenAIBaseModel):
|
|||||||
end: float
|
end: float
|
||||||
"""End time of the segment in seconds."""
|
"""End time of the segment in seconds."""
|
||||||
|
|
||||||
no_speech_prob: float
|
no_speech_prob: float | None = None
|
||||||
"""Probability of no speech in the segment.
|
"""Probability of no speech in the segment.
|
||||||
|
|
||||||
If the value is higher than 1.0 and the `avg_logprob` is below -1, consider
|
If the value is higher than 1.0 and the `avg_logprob` is below -1, consider
|
||||||
@ -2181,6 +2181,11 @@ class TranscriptionResponseVerbose(OpenAIBaseModel):
|
|||||||
"""Extracted words and their corresponding timestamps."""
|
"""Extracted words and their corresponding timestamps."""
|
||||||
|
|
||||||
|
|
||||||
|
TranscriptionResponseVariant: TypeAlias = (
|
||||||
|
TranscriptionResponse | TranscriptionResponseVerbose
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TranslationResponseStreamChoice(OpenAIBaseModel):
|
class TranslationResponseStreamChoice(OpenAIBaseModel):
|
||||||
delta: DeltaMessage
|
delta: DeltaMessage
|
||||||
finish_reason: str | None = None
|
finish_reason: str | None = None
|
||||||
@ -2325,13 +2330,13 @@ class TranslationSegment(OpenAIBaseModel):
|
|||||||
id: int
|
id: int
|
||||||
"""Unique identifier of the segment."""
|
"""Unique identifier of the segment."""
|
||||||
|
|
||||||
avg_logprob: float
|
avg_logprob: float | None = None
|
||||||
"""Average logprob of the segment.
|
"""Average logprob of the segment.
|
||||||
|
|
||||||
If the value is lower than -1, consider the logprobs failed.
|
If the value is lower than -1, consider the logprobs failed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
compression_ratio: float
|
compression_ratio: float | None = None
|
||||||
"""Compression ratio of the segment.
|
"""Compression ratio of the segment.
|
||||||
|
|
||||||
If the value is greater than 2.4, consider the compression failed.
|
If the value is greater than 2.4, consider the compression failed.
|
||||||
@ -2340,7 +2345,7 @@ class TranslationSegment(OpenAIBaseModel):
|
|||||||
end: float
|
end: float
|
||||||
"""End time of the segment in seconds."""
|
"""End time of the segment in seconds."""
|
||||||
|
|
||||||
no_speech_prob: float
|
no_speech_prob: float | None = None
|
||||||
"""Probability of no speech in the segment.
|
"""Probability of no speech in the segment.
|
||||||
|
|
||||||
If the value is higher than 1.0 and the `avg_logprob` is below -1, consider
|
If the value is higher than 1.0 and the `avg_logprob` is below -1, consider
|
||||||
@ -2380,6 +2385,9 @@ class TranslationResponseVerbose(OpenAIBaseModel):
|
|||||||
"""Extracted words and their corresponding timestamps."""
|
"""Extracted words and their corresponding timestamps."""
|
||||||
|
|
||||||
|
|
||||||
|
TranslationResponseVariant: TypeAlias = TranslationResponse | TranslationResponseVerbose
|
||||||
|
|
||||||
|
|
||||||
####### Tokens IN <> Tokens OUT #######
|
####### Tokens IN <> Tokens OUT #######
|
||||||
class GenerateRequest(BaseModel):
|
class GenerateRequest(BaseModel):
|
||||||
request_id: str = Field(
|
request_id: str = Field(
|
||||||
|
|||||||
@ -12,10 +12,12 @@ from vllm.entrypoints.openai.protocol import (
|
|||||||
TranscriptionRequest,
|
TranscriptionRequest,
|
||||||
TranscriptionResponse,
|
TranscriptionResponse,
|
||||||
TranscriptionResponseStreamChoice,
|
TranscriptionResponseStreamChoice,
|
||||||
|
TranscriptionResponseVerbose,
|
||||||
TranscriptionStreamResponse,
|
TranscriptionStreamResponse,
|
||||||
TranslationRequest,
|
TranslationRequest,
|
||||||
TranslationResponse,
|
TranslationResponse,
|
||||||
TranslationResponseStreamChoice,
|
TranslationResponseStreamChoice,
|
||||||
|
TranslationResponseVerbose,
|
||||||
TranslationStreamResponse,
|
TranslationStreamResponse,
|
||||||
)
|
)
|
||||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||||
@ -51,7 +53,12 @@ class OpenAIServingTranscription(OpenAISpeechToText):
|
|||||||
|
|
||||||
async def create_transcription(
|
async def create_transcription(
|
||||||
self, audio_data: bytes, request: TranscriptionRequest, raw_request: Request
|
self, audio_data: bytes, request: TranscriptionRequest, raw_request: Request
|
||||||
) -> TranscriptionResponse | AsyncGenerator[str, None] | ErrorResponse:
|
) -> (
|
||||||
|
TranscriptionResponse
|
||||||
|
| TranscriptionResponseVerbose
|
||||||
|
| AsyncGenerator[str, None]
|
||||||
|
| ErrorResponse
|
||||||
|
):
|
||||||
"""Transcription API similar to OpenAI's API.
|
"""Transcription API similar to OpenAI's API.
|
||||||
|
|
||||||
See https://platform.openai.com/docs/api-reference/audio/createTranscription
|
See https://platform.openai.com/docs/api-reference/audio/createTranscription
|
||||||
@ -61,7 +68,11 @@ class OpenAIServingTranscription(OpenAISpeechToText):
|
|||||||
audio_data=audio_data,
|
audio_data=audio_data,
|
||||||
request=request,
|
request=request,
|
||||||
raw_request=raw_request,
|
raw_request=raw_request,
|
||||||
response_class=TranscriptionResponse,
|
response_class=(
|
||||||
|
TranscriptionResponseVerbose
|
||||||
|
if request.response_format == "verbose_json"
|
||||||
|
else TranscriptionResponse
|
||||||
|
),
|
||||||
stream_generator_method=self.transcription_stream_generator,
|
stream_generator_method=self.transcription_stream_generator,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -112,7 +123,12 @@ class OpenAIServingTranslation(OpenAISpeechToText):
|
|||||||
|
|
||||||
async def create_translation(
|
async def create_translation(
|
||||||
self, audio_data: bytes, request: TranslationRequest, raw_request: Request
|
self, audio_data: bytes, request: TranslationRequest, raw_request: Request
|
||||||
) -> TranslationResponse | AsyncGenerator[str, None] | ErrorResponse:
|
) -> (
|
||||||
|
TranslationResponse
|
||||||
|
| TranslationResponseVerbose
|
||||||
|
| AsyncGenerator[str, None]
|
||||||
|
| ErrorResponse
|
||||||
|
):
|
||||||
"""Translation API similar to OpenAI's API.
|
"""Translation API similar to OpenAI's API.
|
||||||
|
|
||||||
See https://platform.openai.com/docs/api-reference/audio/createTranslation
|
See https://platform.openai.com/docs/api-reference/audio/createTranslation
|
||||||
@ -122,7 +138,11 @@ class OpenAIServingTranslation(OpenAISpeechToText):
|
|||||||
audio_data=audio_data,
|
audio_data=audio_data,
|
||||||
request=request,
|
request=request,
|
||||||
raw_request=raw_request,
|
raw_request=raw_request,
|
||||||
response_class=TranslationResponse,
|
response_class=(
|
||||||
|
TranslationResponseVerbose
|
||||||
|
if request.response_format == "verbose_json"
|
||||||
|
else TranslationResponse
|
||||||
|
),
|
||||||
stream_generator_method=self.translation_stream_generator,
|
stream_generator_method=self.translation_stream_generator,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -10,6 +10,7 @@ from typing import Literal, TypeAlias, TypeVar, cast
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.engine.protocol import EngineClient
|
from vllm.engine.protocol import EngineClient
|
||||||
@ -20,9 +21,13 @@ from vllm.entrypoints.openai.protocol import (
|
|||||||
RequestResponseMetadata,
|
RequestResponseMetadata,
|
||||||
TranscriptionResponse,
|
TranscriptionResponse,
|
||||||
TranscriptionResponseStreamChoice,
|
TranscriptionResponseStreamChoice,
|
||||||
|
TranscriptionResponseVerbose,
|
||||||
|
TranscriptionSegment,
|
||||||
TranscriptionStreamResponse,
|
TranscriptionStreamResponse,
|
||||||
TranslationResponse,
|
TranslationResponse,
|
||||||
TranslationResponseStreamChoice,
|
TranslationResponseStreamChoice,
|
||||||
|
TranslationResponseVerbose,
|
||||||
|
TranslationSegment,
|
||||||
TranslationStreamResponse,
|
TranslationStreamResponse,
|
||||||
UsageInfo,
|
UsageInfo,
|
||||||
)
|
)
|
||||||
@ -32,6 +37,7 @@ from vllm.inputs.data import PromptType
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.models import SupportsTranscription
|
from vllm.model_executor.models import SupportsTranscription
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
from vllm.utils.import_utils import PlaceholderModule
|
from vllm.utils.import_utils import PlaceholderModule
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -40,7 +46,20 @@ except ImportError:
|
|||||||
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
|
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
|
||||||
|
|
||||||
SpeechToTextResponse: TypeAlias = TranscriptionResponse | TranslationResponse
|
SpeechToTextResponse: TypeAlias = TranscriptionResponse | TranslationResponse
|
||||||
|
SpeechToTextResponseVerbose: TypeAlias = (
|
||||||
|
TranscriptionResponseVerbose | TranslationResponseVerbose
|
||||||
|
)
|
||||||
|
SpeechToTextSegment: TypeAlias = TranscriptionSegment | TranslationSegment
|
||||||
T = TypeVar("T", bound=SpeechToTextResponse)
|
T = TypeVar("T", bound=SpeechToTextResponse)
|
||||||
|
V = TypeVar("V", bound=SpeechToTextResponseVerbose)
|
||||||
|
S = TypeVar("S", bound=SpeechToTextSegment)
|
||||||
|
|
||||||
|
ResponseType: TypeAlias = (
|
||||||
|
TranscriptionResponse
|
||||||
|
| TranslationResponse
|
||||||
|
| TranscriptionResponseVerbose
|
||||||
|
| TranslationResponseVerbose
|
||||||
|
)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -78,6 +97,14 @@ class OpenAISpeechToText(OpenAIServing):
|
|||||||
self.enable_force_include_usage = enable_force_include_usage
|
self.enable_force_include_usage = enable_force_include_usage
|
||||||
|
|
||||||
self.max_audio_filesize_mb = envs.VLLM_MAX_AUDIO_CLIP_FILESIZE_MB
|
self.max_audio_filesize_mb = envs.VLLM_MAX_AUDIO_CLIP_FILESIZE_MB
|
||||||
|
if self.model_cls.supports_segment_timestamp:
|
||||||
|
self.tokenizer = cast(
|
||||||
|
PreTrainedTokenizerBase,
|
||||||
|
get_tokenizer(
|
||||||
|
tokenizer_name=self.model_config.tokenizer,
|
||||||
|
tokenizer_mode=self.model_config.tokenizer_mode,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
if self.default_sampling_params:
|
if self.default_sampling_params:
|
||||||
logger.info(
|
logger.info(
|
||||||
@ -133,17 +160,87 @@ class OpenAISpeechToText(OpenAIServing):
|
|||||||
request_prompt=request.prompt,
|
request_prompt=request.prompt,
|
||||||
to_language=to_language,
|
to_language=to_language,
|
||||||
)
|
)
|
||||||
|
if request.response_format == "verbose_json":
|
||||||
|
if not isinstance(prompt, dict):
|
||||||
|
raise ValueError(f"Expected prompt to be a dict,got {type(prompt)}")
|
||||||
|
prompt_dict = cast(dict, prompt)
|
||||||
|
decoder_prompt = prompt.get("decoder_prompt")
|
||||||
|
if not isinstance(decoder_prompt, str):
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected decoder_prompt to bestr, got {type(decoder_prompt)}"
|
||||||
|
)
|
||||||
|
prompt_dict["decoder_prompt"] = decoder_prompt.replace(
|
||||||
|
"<|notimestamps|>", "<|0.00|>"
|
||||||
|
)
|
||||||
prompts.append(prompt)
|
prompts.append(prompt)
|
||||||
return prompts, duration
|
return prompts, duration
|
||||||
|
|
||||||
|
def _get_verbose_segments(
|
||||||
|
self,
|
||||||
|
tokens: tuple,
|
||||||
|
request: SpeechToTextRequest,
|
||||||
|
segment_class: type[SpeechToTextSegment],
|
||||||
|
start_time: float = 0,
|
||||||
|
) -> list[SpeechToTextSegment]:
|
||||||
|
"""
|
||||||
|
Convert tokens to verbose segments.
|
||||||
|
|
||||||
|
This method expects the model to produce
|
||||||
|
timestamps as tokens (similar to Whisper).
|
||||||
|
If the tokens do not include timestamp information,
|
||||||
|
the segments may not be generated correctly.
|
||||||
|
|
||||||
|
Note: Fields like avg_logprob, compression_ratio,
|
||||||
|
and no_speech_prob are not supported
|
||||||
|
in this implementation and will be None. See docs for details.
|
||||||
|
"""
|
||||||
|
BASE_OFFSET = 0.02
|
||||||
|
init_token = self.tokenizer.encode("<|0.00|>", add_special_tokens=False)[0]
|
||||||
|
if tokens[-1] == self.tokenizer.eos_token_id:
|
||||||
|
tokens = tokens[:-1]
|
||||||
|
|
||||||
|
tokens_with_start = (init_token,) + tokens
|
||||||
|
segments: list[SpeechToTextSegment] = []
|
||||||
|
last_timestamp_start = 0
|
||||||
|
|
||||||
|
if tokens_with_start[-2] < init_token and tokens_with_start[-1] >= init_token:
|
||||||
|
tokens_with_start = tokens_with_start + (tokens_with_start[-1],)
|
||||||
|
for idx, token in enumerate(tokens_with_start):
|
||||||
|
# Timestamp tokens (e.g., <|0.00|>) are assumed to be sorted.
|
||||||
|
# If the ordering is violated, this slicing may produce incorrect results.
|
||||||
|
if (
|
||||||
|
token >= init_token
|
||||||
|
and idx != 0
|
||||||
|
and tokens_with_start[idx - 1] >= init_token
|
||||||
|
):
|
||||||
|
sliced_timestamp_tokens = tokens_with_start[last_timestamp_start:idx]
|
||||||
|
start_timestamp = sliced_timestamp_tokens[0] - init_token
|
||||||
|
end_timestamp = sliced_timestamp_tokens[-1] - init_token
|
||||||
|
|
||||||
|
casting_segment = cast(
|
||||||
|
SpeechToTextSegment,
|
||||||
|
segment_class(
|
||||||
|
id=len(segments),
|
||||||
|
seek=start_time,
|
||||||
|
start=start_time + BASE_OFFSET * start_timestamp,
|
||||||
|
end=start_time + BASE_OFFSET * end_timestamp,
|
||||||
|
temperature=request.temperature,
|
||||||
|
text=self.tokenizer.decode(sliced_timestamp_tokens[1:-1]),
|
||||||
|
tokens=sliced_timestamp_tokens[1:-1],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
segments.append(casting_segment)
|
||||||
|
last_timestamp_start = idx
|
||||||
|
return segments
|
||||||
|
|
||||||
async def _create_speech_to_text(
|
async def _create_speech_to_text(
|
||||||
self,
|
self,
|
||||||
audio_data: bytes,
|
audio_data: bytes,
|
||||||
request: SpeechToTextRequest,
|
request: SpeechToTextRequest,
|
||||||
raw_request: Request,
|
raw_request: Request,
|
||||||
response_class: type[T],
|
response_class: type[T | V],
|
||||||
stream_generator_method: Callable[..., AsyncGenerator[str, None]],
|
stream_generator_method: Callable[..., AsyncGenerator[str, None]],
|
||||||
) -> T | AsyncGenerator[str, None] | ErrorResponse:
|
) -> T | V | AsyncGenerator[str, None] | ErrorResponse:
|
||||||
"""Base method for speech-to-text operations like transcription and
|
"""Base method for speech-to-text operations like transcription and
|
||||||
translation."""
|
translation."""
|
||||||
error_check_ret = await self._check_model(request)
|
error_check_ret = await self._check_model(request)
|
||||||
@ -156,11 +253,24 @@ class OpenAISpeechToText(OpenAIServing):
|
|||||||
if self.engine_client.errored:
|
if self.engine_client.errored:
|
||||||
raise self.engine_client.dead_error
|
raise self.engine_client.dead_error
|
||||||
|
|
||||||
if request.response_format not in ["text", "json"]:
|
if request.response_format not in ["text", "json", "verbose_json"]:
|
||||||
return self.create_error_response(
|
return self.create_error_response(
|
||||||
"Currently only support response_format `text` or `json`"
|
("Currently only support response_format")
|
||||||
|
+ ("`text`, `json` or `verbose_json`")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
request.response_format == "verbose_json"
|
||||||
|
and not self.model_cls.supports_segment_timestamp
|
||||||
|
):
|
||||||
|
return self.create_error_response(
|
||||||
|
f"Currently do not support verbose_json for {request.model}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if request.response_format == "verbose_json" and request.stream:
|
||||||
|
return self.create_error_response(
|
||||||
|
"verbose_json format doesn't support streaming case"
|
||||||
|
)
|
||||||
request_id = f"{self.task_type}-{self._base_request_id(raw_request)}"
|
request_id = f"{self.task_type}-{self._base_request_id(raw_request)}"
|
||||||
|
|
||||||
request_metadata = RequestResponseMetadata(request_id=request_id)
|
request_metadata = RequestResponseMetadata(request_id=request_id)
|
||||||
@ -215,25 +325,69 @@ class OpenAISpeechToText(OpenAIServing):
|
|||||||
request, list_result_generator, request_id, request_metadata, duration_s
|
request, list_result_generator, request_id, request_metadata, duration_s
|
||||||
)
|
)
|
||||||
# Non-streaming response.
|
# Non-streaming response.
|
||||||
|
total_segments = []
|
||||||
|
text_parts = []
|
||||||
try:
|
try:
|
||||||
assert list_result_generator is not None
|
assert list_result_generator is not None
|
||||||
|
segments_types: dict[str, type[SpeechToTextSegment]] = {
|
||||||
|
"transcribe": TranscriptionSegment,
|
||||||
|
"translate": TranslationSegment,
|
||||||
|
}
|
||||||
|
segment_class: type[SpeechToTextSegment] = segments_types[self.task_type]
|
||||||
text = ""
|
text = ""
|
||||||
for result_generator in list_result_generator:
|
for idx, result_generator in enumerate(list_result_generator):
|
||||||
async for op in result_generator:
|
async for op in result_generator:
|
||||||
text += op.outputs[0].text
|
if request.response_format == "verbose_json":
|
||||||
|
segments: list[SpeechToTextSegment] = (
|
||||||
|
self._get_verbose_segments(
|
||||||
|
tokens=tuple(op.outputs[0].token_ids),
|
||||||
|
segment_class=segment_class,
|
||||||
|
request=request,
|
||||||
|
start_time=idx * self.asr_config.max_audio_clip_s,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
total_segments.extend(segments)
|
||||||
|
text_parts.extend([seg.text for seg in segments])
|
||||||
|
else:
|
||||||
|
text_parts.append(op.outputs[0].text)
|
||||||
|
text = "".join(text_parts)
|
||||||
if self.task_type == "transcribe":
|
if self.task_type == "transcribe":
|
||||||
|
final_response: ResponseType
|
||||||
# add usage in TranscriptionResponse.
|
# add usage in TranscriptionResponse.
|
||||||
usage = {
|
usage = {
|
||||||
"type": "duration",
|
"type": "duration",
|
||||||
# rounded up as per openAI specs
|
# rounded up as per openAI specs
|
||||||
"seconds": int(math.ceil(duration_s)),
|
"seconds": int(math.ceil(duration_s)),
|
||||||
}
|
}
|
||||||
final_response = cast(T, response_class(text=text, usage=usage))
|
if request.response_format != "verbose_json":
|
||||||
|
final_response = cast(
|
||||||
|
T, TranscriptionResponse(text=text, usage=usage)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
final_response = cast(
|
||||||
|
V,
|
||||||
|
TranscriptionResponseVerbose(
|
||||||
|
text=text,
|
||||||
|
language=request.language,
|
||||||
|
duration=str(duration_s),
|
||||||
|
segments=total_segments,
|
||||||
|
),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# no usage in response for translation task
|
# no usage in response for translation task
|
||||||
final_response = cast(T, response_class(text=text)) # type: ignore[call-arg]
|
if request.response_format != "verbose_json":
|
||||||
|
final_response = cast(T, TranslationResponse(text=text))
|
||||||
|
else:
|
||||||
|
final_response = cast(
|
||||||
|
V,
|
||||||
|
TranslationResponseVerbose(
|
||||||
|
text=text,
|
||||||
|
language=request.language,
|
||||||
|
duration=str(duration_s),
|
||||||
|
segments=total_segments,
|
||||||
|
),
|
||||||
|
)
|
||||||
return final_response
|
return final_response
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
return self.create_error_response("Client disconnected")
|
return self.create_error_response("Client disconnected")
|
||||||
|
|||||||
@ -837,6 +837,10 @@ class SupportsTranscription(Protocol):
|
|||||||
Transcription models can opt out of text generation by setting this to
|
Transcription models can opt out of text generation by setting this to
|
||||||
`True`.
|
`True`.
|
||||||
"""
|
"""
|
||||||
|
supports_segment_timestamp: ClassVar[bool] = False
|
||||||
|
"""
|
||||||
|
Enables the segment timestamp option for supported models by setting this to `True`.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init_subclass__(cls, **kwargs):
|
def __init_subclass__(cls, **kwargs):
|
||||||
super().__init_subclass__(**kwargs)
|
super().__init_subclass__(**kwargs)
|
||||||
|
|||||||
@ -791,6 +791,7 @@ class WhisperForConditionalGeneration(
|
|||||||
|
|
||||||
# Whisper only supports audio-conditioned generation.
|
# Whisper only supports audio-conditioned generation.
|
||||||
supports_transcription_only = True
|
supports_transcription_only = True
|
||||||
|
supports_segment_timestamp = True
|
||||||
supported_languages = ISO639_1_SUPPORTED_LANGS
|
supported_languages = ISO639_1_SUPPORTED_LANGS
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user