diff --git a/docs/serving/openai_compatible_server.md b/docs/serving/openai_compatible_server.md index ac98efb7b88a6..672663dc50b1e 100644 --- a/docs/serving/openai_compatible_server.md +++ b/docs/serving/openai_compatible_server.md @@ -456,6 +456,7 @@ For `verbose_json` response format: ] } ``` +Currently “verbose_json” response format doesn’t support avg_logprob, compression_ratio, no_speech_prob. #### Extra Parameters diff --git a/tests/entrypoints/openai/test_transcription_validation_whisper.py b/tests/entrypoints/openai/test_transcription_validation_whisper.py index 82c50e58a0168..47cd7b1f12d00 100644 --- a/tests/entrypoints/openai/test_transcription_validation_whisper.py +++ b/tests/entrypoints/openai/test_transcription_validation_whisper.py @@ -235,3 +235,16 @@ async def test_audio_prompt(mary_had_lamb, whisper_client): ) out_prompt = json.loads(transcription_wprompt)["text"] assert prefix in out_prompt + + +@pytest.mark.asyncio +async def test_audio_with_timestamp(mary_had_lamb, whisper_client): + transcription = await whisper_client.audio.transcriptions.create( + model=MODEL_NAME, + file=mary_had_lamb, + language="en", + response_format="verbose_json", + temperature=0.0, + ) + assert transcription.segments is not None + assert len(transcription.segments) > 0 diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 6a648822d9b2b..92161f67f1cf0 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -68,9 +68,9 @@ from vllm.entrypoints.openai.protocol import ( TokenizeRequest, TokenizeResponse, TranscriptionRequest, - TranscriptionResponse, + TranscriptionResponseVariant, TranslationRequest, - TranslationResponse, + TranslationResponseVariant, ) from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion @@ -809,7 +809,7 @@ async def create_transcriptions( content=generator.model_dump(), status_code=generator.error.code ) - elif isinstance(generator, TranscriptionResponse): + elif isinstance(generator, TranscriptionResponseVariant): return JSONResponse(content=generator.model_dump()) return StreamingResponse(content=generator, media_type="text/event-stream") @@ -848,7 +848,7 @@ async def create_translations( content=generator.model_dump(), status_code=generator.error.code ) - elif isinstance(generator, TranslationResponse): + elif isinstance(generator, TranslationResponseVariant): return JSONResponse(content=generator.model_dump()) return StreamingResponse(content=generator, media_type="text/event-stream") diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index fb73416f45b24..0f4b2b4d7aad0 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -2126,13 +2126,13 @@ class TranscriptionSegment(OpenAIBaseModel): id: int """Unique identifier of the segment.""" - avg_logprob: float + avg_logprob: float | None = None """Average logprob of the segment. If the value is lower than -1, consider the logprobs failed. """ - compression_ratio: float + compression_ratio: float | None = None """Compression ratio of the segment. If the value is greater than 2.4, consider the compression failed. @@ -2141,7 +2141,7 @@ class TranscriptionSegment(OpenAIBaseModel): end: float """End time of the segment in seconds.""" - no_speech_prob: float + no_speech_prob: float | None = None """Probability of no speech in the segment. If the value is higher than 1.0 and the `avg_logprob` is below -1, consider @@ -2181,6 +2181,11 @@ class TranscriptionResponseVerbose(OpenAIBaseModel): """Extracted words and their corresponding timestamps.""" +TranscriptionResponseVariant: TypeAlias = ( + TranscriptionResponse | TranscriptionResponseVerbose +) + + class TranslationResponseStreamChoice(OpenAIBaseModel): delta: DeltaMessage finish_reason: str | None = None @@ -2325,13 +2330,13 @@ class TranslationSegment(OpenAIBaseModel): id: int """Unique identifier of the segment.""" - avg_logprob: float + avg_logprob: float | None = None """Average logprob of the segment. If the value is lower than -1, consider the logprobs failed. """ - compression_ratio: float + compression_ratio: float | None = None """Compression ratio of the segment. If the value is greater than 2.4, consider the compression failed. @@ -2340,7 +2345,7 @@ class TranslationSegment(OpenAIBaseModel): end: float """End time of the segment in seconds.""" - no_speech_prob: float + no_speech_prob: float | None = None """Probability of no speech in the segment. If the value is higher than 1.0 and the `avg_logprob` is below -1, consider @@ -2380,6 +2385,9 @@ class TranslationResponseVerbose(OpenAIBaseModel): """Extracted words and their corresponding timestamps.""" +TranslationResponseVariant: TypeAlias = TranslationResponse | TranslationResponseVerbose + + ####### Tokens IN <> Tokens OUT ####### class GenerateRequest(BaseModel): request_id: str = Field( diff --git a/vllm/entrypoints/openai/serving_transcription.py b/vllm/entrypoints/openai/serving_transcription.py index 33da7034afabc..189b532810b43 100644 --- a/vllm/entrypoints/openai/serving_transcription.py +++ b/vllm/entrypoints/openai/serving_transcription.py @@ -12,10 +12,12 @@ from vllm.entrypoints.openai.protocol import ( TranscriptionRequest, TranscriptionResponse, TranscriptionResponseStreamChoice, + TranscriptionResponseVerbose, TranscriptionStreamResponse, TranslationRequest, TranslationResponse, TranslationResponseStreamChoice, + TranslationResponseVerbose, TranslationStreamResponse, ) from vllm.entrypoints.openai.serving_models import OpenAIServingModels @@ -51,7 +53,12 @@ class OpenAIServingTranscription(OpenAISpeechToText): async def create_transcription( self, audio_data: bytes, request: TranscriptionRequest, raw_request: Request - ) -> TranscriptionResponse | AsyncGenerator[str, None] | ErrorResponse: + ) -> ( + TranscriptionResponse + | TranscriptionResponseVerbose + | AsyncGenerator[str, None] + | ErrorResponse + ): """Transcription API similar to OpenAI's API. See https://platform.openai.com/docs/api-reference/audio/createTranscription @@ -61,7 +68,11 @@ class OpenAIServingTranscription(OpenAISpeechToText): audio_data=audio_data, request=request, raw_request=raw_request, - response_class=TranscriptionResponse, + response_class=( + TranscriptionResponseVerbose + if request.response_format == "verbose_json" + else TranscriptionResponse + ), stream_generator_method=self.transcription_stream_generator, ) @@ -112,7 +123,12 @@ class OpenAIServingTranslation(OpenAISpeechToText): async def create_translation( self, audio_data: bytes, request: TranslationRequest, raw_request: Request - ) -> TranslationResponse | AsyncGenerator[str, None] | ErrorResponse: + ) -> ( + TranslationResponse + | TranslationResponseVerbose + | AsyncGenerator[str, None] + | ErrorResponse + ): """Translation API similar to OpenAI's API. See https://platform.openai.com/docs/api-reference/audio/createTranslation @@ -122,7 +138,11 @@ class OpenAIServingTranslation(OpenAISpeechToText): audio_data=audio_data, request=request, raw_request=raw_request, - response_class=TranslationResponse, + response_class=( + TranslationResponseVerbose + if request.response_format == "verbose_json" + else TranslationResponse + ), stream_generator_method=self.translation_stream_generator, ) diff --git a/vllm/entrypoints/openai/speech_to_text.py b/vllm/entrypoints/openai/speech_to_text.py index 3dece07748cc4..b34446d3230b1 100644 --- a/vllm/entrypoints/openai/speech_to_text.py +++ b/vllm/entrypoints/openai/speech_to_text.py @@ -10,6 +10,7 @@ from typing import Literal, TypeAlias, TypeVar, cast import numpy as np from fastapi import Request +from transformers import PreTrainedTokenizerBase import vllm.envs as envs from vllm.engine.protocol import EngineClient @@ -20,9 +21,13 @@ from vllm.entrypoints.openai.protocol import ( RequestResponseMetadata, TranscriptionResponse, TranscriptionResponseStreamChoice, + TranscriptionResponseVerbose, + TranscriptionSegment, TranscriptionStreamResponse, TranslationResponse, TranslationResponseStreamChoice, + TranslationResponseVerbose, + TranslationSegment, TranslationStreamResponse, UsageInfo, ) @@ -32,6 +37,7 @@ from vllm.inputs.data import PromptType from vllm.logger import init_logger from vllm.model_executor.models import SupportsTranscription from vllm.outputs import RequestOutput +from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.utils.import_utils import PlaceholderModule try: @@ -40,7 +46,20 @@ except ImportError: librosa = PlaceholderModule("librosa") # type: ignore[assignment] SpeechToTextResponse: TypeAlias = TranscriptionResponse | TranslationResponse +SpeechToTextResponseVerbose: TypeAlias = ( + TranscriptionResponseVerbose | TranslationResponseVerbose +) +SpeechToTextSegment: TypeAlias = TranscriptionSegment | TranslationSegment T = TypeVar("T", bound=SpeechToTextResponse) +V = TypeVar("V", bound=SpeechToTextResponseVerbose) +S = TypeVar("S", bound=SpeechToTextSegment) + +ResponseType: TypeAlias = ( + TranscriptionResponse + | TranslationResponse + | TranscriptionResponseVerbose + | TranslationResponseVerbose +) logger = init_logger(__name__) @@ -78,6 +97,14 @@ class OpenAISpeechToText(OpenAIServing): self.enable_force_include_usage = enable_force_include_usage self.max_audio_filesize_mb = envs.VLLM_MAX_AUDIO_CLIP_FILESIZE_MB + if self.model_cls.supports_segment_timestamp: + self.tokenizer = cast( + PreTrainedTokenizerBase, + get_tokenizer( + tokenizer_name=self.model_config.tokenizer, + tokenizer_mode=self.model_config.tokenizer_mode, + ), + ) if self.default_sampling_params: logger.info( @@ -133,17 +160,87 @@ class OpenAISpeechToText(OpenAIServing): request_prompt=request.prompt, to_language=to_language, ) + if request.response_format == "verbose_json": + if not isinstance(prompt, dict): + raise ValueError(f"Expected prompt to be a dict,got {type(prompt)}") + prompt_dict = cast(dict, prompt) + decoder_prompt = prompt.get("decoder_prompt") + if not isinstance(decoder_prompt, str): + raise ValueError( + f"Expected decoder_prompt to bestr, got {type(decoder_prompt)}" + ) + prompt_dict["decoder_prompt"] = decoder_prompt.replace( + "<|notimestamps|>", "<|0.00|>" + ) prompts.append(prompt) return prompts, duration + def _get_verbose_segments( + self, + tokens: tuple, + request: SpeechToTextRequest, + segment_class: type[SpeechToTextSegment], + start_time: float = 0, + ) -> list[SpeechToTextSegment]: + """ + Convert tokens to verbose segments. + + This method expects the model to produce + timestamps as tokens (similar to Whisper). + If the tokens do not include timestamp information, + the segments may not be generated correctly. + + Note: Fields like avg_logprob, compression_ratio, + and no_speech_prob are not supported + in this implementation and will be None. See docs for details. + """ + BASE_OFFSET = 0.02 + init_token = self.tokenizer.encode("<|0.00|>", add_special_tokens=False)[0] + if tokens[-1] == self.tokenizer.eos_token_id: + tokens = tokens[:-1] + + tokens_with_start = (init_token,) + tokens + segments: list[SpeechToTextSegment] = [] + last_timestamp_start = 0 + + if tokens_with_start[-2] < init_token and tokens_with_start[-1] >= init_token: + tokens_with_start = tokens_with_start + (tokens_with_start[-1],) + for idx, token in enumerate(tokens_with_start): + # Timestamp tokens (e.g., <|0.00|>) are assumed to be sorted. + # If the ordering is violated, this slicing may produce incorrect results. + if ( + token >= init_token + and idx != 0 + and tokens_with_start[idx - 1] >= init_token + ): + sliced_timestamp_tokens = tokens_with_start[last_timestamp_start:idx] + start_timestamp = sliced_timestamp_tokens[0] - init_token + end_timestamp = sliced_timestamp_tokens[-1] - init_token + + casting_segment = cast( + SpeechToTextSegment, + segment_class( + id=len(segments), + seek=start_time, + start=start_time + BASE_OFFSET * start_timestamp, + end=start_time + BASE_OFFSET * end_timestamp, + temperature=request.temperature, + text=self.tokenizer.decode(sliced_timestamp_tokens[1:-1]), + tokens=sliced_timestamp_tokens[1:-1], + ), + ) + segments.append(casting_segment) + last_timestamp_start = idx + return segments + async def _create_speech_to_text( self, audio_data: bytes, request: SpeechToTextRequest, raw_request: Request, - response_class: type[T], + response_class: type[T | V], stream_generator_method: Callable[..., AsyncGenerator[str, None]], - ) -> T | AsyncGenerator[str, None] | ErrorResponse: + ) -> T | V | AsyncGenerator[str, None] | ErrorResponse: """Base method for speech-to-text operations like transcription and translation.""" error_check_ret = await self._check_model(request) @@ -156,11 +253,24 @@ class OpenAISpeechToText(OpenAIServing): if self.engine_client.errored: raise self.engine_client.dead_error - if request.response_format not in ["text", "json"]: + if request.response_format not in ["text", "json", "verbose_json"]: return self.create_error_response( - "Currently only support response_format `text` or `json`" + ("Currently only support response_format") + + ("`text`, `json` or `verbose_json`") ) + if ( + request.response_format == "verbose_json" + and not self.model_cls.supports_segment_timestamp + ): + return self.create_error_response( + f"Currently do not support verbose_json for {request.model}" + ) + + if request.response_format == "verbose_json" and request.stream: + return self.create_error_response( + "verbose_json format doesn't support streaming case" + ) request_id = f"{self.task_type}-{self._base_request_id(raw_request)}" request_metadata = RequestResponseMetadata(request_id=request_id) @@ -215,25 +325,69 @@ class OpenAISpeechToText(OpenAIServing): request, list_result_generator, request_id, request_metadata, duration_s ) # Non-streaming response. + total_segments = [] + text_parts = [] try: assert list_result_generator is not None + segments_types: dict[str, type[SpeechToTextSegment]] = { + "transcribe": TranscriptionSegment, + "translate": TranslationSegment, + } + segment_class: type[SpeechToTextSegment] = segments_types[self.task_type] text = "" - for result_generator in list_result_generator: + for idx, result_generator in enumerate(list_result_generator): async for op in result_generator: - text += op.outputs[0].text + if request.response_format == "verbose_json": + segments: list[SpeechToTextSegment] = ( + self._get_verbose_segments( + tokens=tuple(op.outputs[0].token_ids), + segment_class=segment_class, + request=request, + start_time=idx * self.asr_config.max_audio_clip_s, + ) + ) + total_segments.extend(segments) + text_parts.extend([seg.text for seg in segments]) + else: + text_parts.append(op.outputs[0].text) + text = "".join(text_parts) if self.task_type == "transcribe": + final_response: ResponseType # add usage in TranscriptionResponse. usage = { "type": "duration", # rounded up as per openAI specs "seconds": int(math.ceil(duration_s)), } - final_response = cast(T, response_class(text=text, usage=usage)) + if request.response_format != "verbose_json": + final_response = cast( + T, TranscriptionResponse(text=text, usage=usage) + ) + else: + final_response = cast( + V, + TranscriptionResponseVerbose( + text=text, + language=request.language, + duration=str(duration_s), + segments=total_segments, + ), + ) else: # no usage in response for translation task - final_response = cast(T, response_class(text=text)) # type: ignore[call-arg] - + if request.response_format != "verbose_json": + final_response = cast(T, TranslationResponse(text=text)) + else: + final_response = cast( + V, + TranslationResponseVerbose( + text=text, + language=request.language, + duration=str(duration_s), + segments=total_segments, + ), + ) return final_response except asyncio.CancelledError: return self.create_error_response("Client disconnected") diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index ccd5be42e65a9..0f65683cf7c57 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -837,6 +837,10 @@ class SupportsTranscription(Protocol): Transcription models can opt out of text generation by setting this to `True`. """ + supports_segment_timestamp: ClassVar[bool] = False + """ + Enables the segment timestamp option for supported models by setting this to `True`. + """ def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index c72b5e1c091f2..1ed6ae4366d0c 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -791,6 +791,7 @@ class WhisperForConditionalGeneration( # Whisper only supports audio-conditioned generation. supports_transcription_only = True + supports_segment_timestamp = True supported_languages = ISO639_1_SUPPORTED_LANGS @classmethod