From ca702a14dc2d4c5c077dbb8098e66ca244cea185 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicol=C3=B2=20Lucchesi?= Date: Tue, 16 Dec 2025 20:36:49 +0100 Subject: [PATCH] [Frontend] Add `max-completion-token` option to transcription/translation endpoints (#30769) Signed-off-by: NickLucche --- .../test_transcription_validation_whisper.py | 32 ++++++++++++++++++ .../openai/test_translation_validation.py | 33 +++++++++++++++++++ vllm/entrypoints/openai/protocol.py | 6 ++++ vllm/entrypoints/openai/speech_to_text.py | 10 ++++-- 4 files changed, 79 insertions(+), 2 deletions(-) diff --git a/tests/entrypoints/openai/test_transcription_validation_whisper.py b/tests/entrypoints/openai/test_transcription_validation_whisper.py index 3c507ee0a3fa7..8bf729c517f7a 100644 --- a/tests/entrypoints/openai/test_transcription_validation_whisper.py +++ b/tests/entrypoints/openai/test_transcription_validation_whisper.py @@ -244,3 +244,35 @@ async def test_audio_with_timestamp(mary_had_lamb, whisper_client): ) assert transcription.segments is not None assert len(transcription.segments) > 0 + + +@pytest.mark.asyncio +async def test_audio_with_max_tokens(whisper_client, mary_had_lamb): + transcription = await whisper_client.audio.transcriptions.create( + model=MODEL_NAME, + file=mary_had_lamb, + language="en", + response_format="text", + temperature=0.0, + extra_body={"max_completion_tokens": 1}, + ) + out = json.loads(transcription) + out_text = out["text"] + from transformers import AutoTokenizer + + tok = AutoTokenizer.from_pretrained(MODEL_NAME) + out_tokens = tok(out_text, add_special_tokens=False)["input_ids"] + assert len(out_tokens) == 1 + # max_completion_tokens > max_model_len + transcription = await whisper_client.audio.transcriptions.create( + model=MODEL_NAME, + file=mary_had_lamb, + language="en", + response_format="text", + temperature=0.0, + extra_body={"max_completion_tokens": int(1e6)}, + ) + out = json.loads(transcription) + out_text = out["text"] + out_tokens = tok(out_text, add_special_tokens=False)["input_ids"] + assert len(out_tokens) < 450 # ~Whisper max output len diff --git a/tests/entrypoints/openai/test_translation_validation.py b/tests/entrypoints/openai/test_translation_validation.py index d7d407484f16d..2c577237691ab 100644 --- a/tests/entrypoints/openai/test_translation_validation.py +++ b/tests/entrypoints/openai/test_translation_validation.py @@ -227,3 +227,36 @@ async def test_long_audio_request(foscolo, client_and_model): ) out = json.loads(translation)["text"].strip().lower() assert out.count("greek sea") == 2 + + +@pytest.mark.asyncio +async def test_audio_with_max_tokens(mary_had_lamb, client_and_model): + client, model_name = client_and_model + transcription = await client.audio.translations.create( + model=model_name, + file=mary_had_lamb, + response_format="text", + temperature=0.0, + extra_body={"max_completion_tokens": 1}, + ) + out = json.loads(transcription) + out_text = out["text"] + print(out_text) + from transformers import AutoTokenizer + + tok = AutoTokenizer.from_pretrained(model_name) + out_tokens = tok(out_text, add_special_tokens=False)["input_ids"] + assert len(out_tokens) == 1 + # max_completion_tokens > max_model_len + transcription = await client.audio.transcriptions.create( + model=model_name, + file=mary_had_lamb, + response_format="text", + temperature=0.0, + extra_body={"max_completion_tokens": int(1e6)}, + ) + out = json.loads(transcription) + out_text = out["text"] + print(out_text) + out_tokens = tok(out_text, add_special_tokens=False)["input_ids"] + assert len(out_tokens) < 450 # ~Whisper max output len diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index a7c4980cd3674..94dde4564ea0c 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -2054,6 +2054,9 @@ class TranscriptionRequest(OpenAIBaseModel): presence_penalty: float | None = 0.0 """The presence penalty to use for sampling.""" + + max_completion_tokens: int | None = None + """The maximum number of tokens to generate.""" # --8<-- [end:transcription-sampling-params] # Default sampling parameters for transcription requests. @@ -2300,6 +2303,9 @@ class TranslationRequest(OpenAIBaseModel): # Flattened stream option to simplify form data. stream_include_usage: bool | None = False stream_continuous_usage_stats: bool | None = False + + max_completion_tokens: int | None = None + """The maximum number of tokens to generate.""" # --8<-- [end:translation-extra-params] # Default sampling parameters for translation requests. diff --git a/vllm/entrypoints/openai/speech_to_text.py b/vllm/entrypoints/openai/speech_to_text.py index cea9924ebbaca..df9c06adb105a 100644 --- a/vllm/entrypoints/openai/speech_to_text.py +++ b/vllm/entrypoints/openai/speech_to_text.py @@ -293,8 +293,14 @@ class OpenAISpeechToText(OpenAIServing): try: # Unlike most decoder-only models, whisper generation length is not # constrained by the size of the input audio, which is mapped to a - # fixed-size log-mel-spectogram. - default_max_tokens = self.model_config.max_model_len + # fixed-size log-mel-spectogram. Still, allow for fewer tokens to be + # generated by respecting the extra completion tokens arg. + if request.max_completion_tokens is None: + default_max_tokens = self.model_config.max_model_len + else: + default_max_tokens = min( + self.model_config.max_model_len, request.max_completion_tokens + ) sampling_params = request.to_sampling_params( default_max_tokens, self.default_sampling_params )