[Frontend] Add max-completion-token option to transcription/translation endpoints (#30769)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi 2025-12-16 20:36:49 +01:00 committed by GitHub
parent 10ee1c64cf
commit ca702a14dc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 79 additions and 2 deletions

View File

@ -244,3 +244,35 @@ async def test_audio_with_timestamp(mary_had_lamb, whisper_client):
)
assert transcription.segments is not None
assert len(transcription.segments) > 0
@pytest.mark.asyncio
async def test_audio_with_max_tokens(whisper_client, mary_had_lamb):
transcription = await whisper_client.audio.transcriptions.create(
model=MODEL_NAME,
file=mary_had_lamb,
language="en",
response_format="text",
temperature=0.0,
extra_body={"max_completion_tokens": 1},
)
out = json.loads(transcription)
out_text = out["text"]
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained(MODEL_NAME)
out_tokens = tok(out_text, add_special_tokens=False)["input_ids"]
assert len(out_tokens) == 1
# max_completion_tokens > max_model_len
transcription = await whisper_client.audio.transcriptions.create(
model=MODEL_NAME,
file=mary_had_lamb,
language="en",
response_format="text",
temperature=0.0,
extra_body={"max_completion_tokens": int(1e6)},
)
out = json.loads(transcription)
out_text = out["text"]
out_tokens = tok(out_text, add_special_tokens=False)["input_ids"]
assert len(out_tokens) < 450 # ~Whisper max output len

View File

@ -227,3 +227,36 @@ async def test_long_audio_request(foscolo, client_and_model):
)
out = json.loads(translation)["text"].strip().lower()
assert out.count("greek sea") == 2
@pytest.mark.asyncio
async def test_audio_with_max_tokens(mary_had_lamb, client_and_model):
client, model_name = client_and_model
transcription = await client.audio.translations.create(
model=model_name,
file=mary_had_lamb,
response_format="text",
temperature=0.0,
extra_body={"max_completion_tokens": 1},
)
out = json.loads(transcription)
out_text = out["text"]
print(out_text)
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained(model_name)
out_tokens = tok(out_text, add_special_tokens=False)["input_ids"]
assert len(out_tokens) == 1
# max_completion_tokens > max_model_len
transcription = await client.audio.transcriptions.create(
model=model_name,
file=mary_had_lamb,
response_format="text",
temperature=0.0,
extra_body={"max_completion_tokens": int(1e6)},
)
out = json.loads(transcription)
out_text = out["text"]
print(out_text)
out_tokens = tok(out_text, add_special_tokens=False)["input_ids"]
assert len(out_tokens) < 450 # ~Whisper max output len

View File

@ -2054,6 +2054,9 @@ class TranscriptionRequest(OpenAIBaseModel):
presence_penalty: float | None = 0.0
"""The presence penalty to use for sampling."""
max_completion_tokens: int | None = None
"""The maximum number of tokens to generate."""
# --8<-- [end:transcription-sampling-params]
# Default sampling parameters for transcription requests.
@ -2300,6 +2303,9 @@ class TranslationRequest(OpenAIBaseModel):
# Flattened stream option to simplify form data.
stream_include_usage: bool | None = False
stream_continuous_usage_stats: bool | None = False
max_completion_tokens: int | None = None
"""The maximum number of tokens to generate."""
# --8<-- [end:translation-extra-params]
# Default sampling parameters for translation requests.

View File

@ -293,8 +293,14 @@ class OpenAISpeechToText(OpenAIServing):
try:
# Unlike most decoder-only models, whisper generation length is not
# constrained by the size of the input audio, which is mapped to a
# fixed-size log-mel-spectogram.
default_max_tokens = self.model_config.max_model_len
# fixed-size log-mel-spectogram. Still, allow for fewer tokens to be
# generated by respecting the extra completion tokens arg.
if request.max_completion_tokens is None:
default_max_tokens = self.model_config.max_model_len
else:
default_max_tokens = min(
self.model_config.max_model_len, request.max_completion_tokens
)
sampling_params = request.to_sampling_params(
default_max_tokens, self.default_sampling_params
)