mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-11 05:59:12 +08:00
[Frontend] Add max-completion-token option to transcription/translation endpoints (#30769)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
parent
10ee1c64cf
commit
ca702a14dc
@ -244,3 +244,35 @@ async def test_audio_with_timestamp(mary_had_lamb, whisper_client):
|
|||||||
)
|
)
|
||||||
assert transcription.segments is not None
|
assert transcription.segments is not None
|
||||||
assert len(transcription.segments) > 0
|
assert len(transcription.segments) > 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_audio_with_max_tokens(whisper_client, mary_had_lamb):
|
||||||
|
transcription = await whisper_client.audio.transcriptions.create(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
file=mary_had_lamb,
|
||||||
|
language="en",
|
||||||
|
response_format="text",
|
||||||
|
temperature=0.0,
|
||||||
|
extra_body={"max_completion_tokens": 1},
|
||||||
|
)
|
||||||
|
out = json.loads(transcription)
|
||||||
|
out_text = out["text"]
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
tok = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||||
|
out_tokens = tok(out_text, add_special_tokens=False)["input_ids"]
|
||||||
|
assert len(out_tokens) == 1
|
||||||
|
# max_completion_tokens > max_model_len
|
||||||
|
transcription = await whisper_client.audio.transcriptions.create(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
file=mary_had_lamb,
|
||||||
|
language="en",
|
||||||
|
response_format="text",
|
||||||
|
temperature=0.0,
|
||||||
|
extra_body={"max_completion_tokens": int(1e6)},
|
||||||
|
)
|
||||||
|
out = json.loads(transcription)
|
||||||
|
out_text = out["text"]
|
||||||
|
out_tokens = tok(out_text, add_special_tokens=False)["input_ids"]
|
||||||
|
assert len(out_tokens) < 450 # ~Whisper max output len
|
||||||
|
|||||||
@ -227,3 +227,36 @@ async def test_long_audio_request(foscolo, client_and_model):
|
|||||||
)
|
)
|
||||||
out = json.loads(translation)["text"].strip().lower()
|
out = json.loads(translation)["text"].strip().lower()
|
||||||
assert out.count("greek sea") == 2
|
assert out.count("greek sea") == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_audio_with_max_tokens(mary_had_lamb, client_and_model):
|
||||||
|
client, model_name = client_and_model
|
||||||
|
transcription = await client.audio.translations.create(
|
||||||
|
model=model_name,
|
||||||
|
file=mary_had_lamb,
|
||||||
|
response_format="text",
|
||||||
|
temperature=0.0,
|
||||||
|
extra_body={"max_completion_tokens": 1},
|
||||||
|
)
|
||||||
|
out = json.loads(transcription)
|
||||||
|
out_text = out["text"]
|
||||||
|
print(out_text)
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
tok = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
out_tokens = tok(out_text, add_special_tokens=False)["input_ids"]
|
||||||
|
assert len(out_tokens) == 1
|
||||||
|
# max_completion_tokens > max_model_len
|
||||||
|
transcription = await client.audio.transcriptions.create(
|
||||||
|
model=model_name,
|
||||||
|
file=mary_had_lamb,
|
||||||
|
response_format="text",
|
||||||
|
temperature=0.0,
|
||||||
|
extra_body={"max_completion_tokens": int(1e6)},
|
||||||
|
)
|
||||||
|
out = json.loads(transcription)
|
||||||
|
out_text = out["text"]
|
||||||
|
print(out_text)
|
||||||
|
out_tokens = tok(out_text, add_special_tokens=False)["input_ids"]
|
||||||
|
assert len(out_tokens) < 450 # ~Whisper max output len
|
||||||
|
|||||||
@ -2054,6 +2054,9 @@ class TranscriptionRequest(OpenAIBaseModel):
|
|||||||
|
|
||||||
presence_penalty: float | None = 0.0
|
presence_penalty: float | None = 0.0
|
||||||
"""The presence penalty to use for sampling."""
|
"""The presence penalty to use for sampling."""
|
||||||
|
|
||||||
|
max_completion_tokens: int | None = None
|
||||||
|
"""The maximum number of tokens to generate."""
|
||||||
# --8<-- [end:transcription-sampling-params]
|
# --8<-- [end:transcription-sampling-params]
|
||||||
|
|
||||||
# Default sampling parameters for transcription requests.
|
# Default sampling parameters for transcription requests.
|
||||||
@ -2300,6 +2303,9 @@ class TranslationRequest(OpenAIBaseModel):
|
|||||||
# Flattened stream option to simplify form data.
|
# Flattened stream option to simplify form data.
|
||||||
stream_include_usage: bool | None = False
|
stream_include_usage: bool | None = False
|
||||||
stream_continuous_usage_stats: bool | None = False
|
stream_continuous_usage_stats: bool | None = False
|
||||||
|
|
||||||
|
max_completion_tokens: int | None = None
|
||||||
|
"""The maximum number of tokens to generate."""
|
||||||
# --8<-- [end:translation-extra-params]
|
# --8<-- [end:translation-extra-params]
|
||||||
|
|
||||||
# Default sampling parameters for translation requests.
|
# Default sampling parameters for translation requests.
|
||||||
|
|||||||
@ -293,8 +293,14 @@ class OpenAISpeechToText(OpenAIServing):
|
|||||||
try:
|
try:
|
||||||
# Unlike most decoder-only models, whisper generation length is not
|
# Unlike most decoder-only models, whisper generation length is not
|
||||||
# constrained by the size of the input audio, which is mapped to a
|
# constrained by the size of the input audio, which is mapped to a
|
||||||
# fixed-size log-mel-spectogram.
|
# fixed-size log-mel-spectogram. Still, allow for fewer tokens to be
|
||||||
default_max_tokens = self.model_config.max_model_len
|
# generated by respecting the extra completion tokens arg.
|
||||||
|
if request.max_completion_tokens is None:
|
||||||
|
default_max_tokens = self.model_config.max_model_len
|
||||||
|
else:
|
||||||
|
default_max_tokens = min(
|
||||||
|
self.model_config.max_model_len, request.max_completion_tokens
|
||||||
|
)
|
||||||
sampling_params = request.to_sampling_params(
|
sampling_params = request.to_sampling_params(
|
||||||
default_max_tokens, self.default_sampling_params
|
default_max_tokens, self.default_sampling_params
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user