[Frontend] Add sampling params to v1/audio/transcriptions endpoint (#16591)

Signed-off-by: Jannis Schönleber <joennlae@gmail.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Co-authored-by: Jannis Schönleber <joennlae@gmail.com>
This commit is contained in:
Nicolò Lucchesi 2025-04-19 09:03:54 +02:00 committed by GitHub
parent 1d4680fad2
commit 2ef0dc53b8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 122 additions and 11 deletions

View File

@ -402,9 +402,26 @@ you can use the [official OpenAI Python client](https://github.com/openai/openai
To use the Transcriptions API, please install with extra audio dependencies using `pip install vllm[audio]`. To use the Transcriptions API, please install with extra audio dependencies using `pip install vllm[audio]`.
::: :::
Code example: <gh-file:examples/online_serving/openai_transcription_client.py>
<!-- TODO: api enforced limits + uploading audios --> <!-- TODO: api enforced limits + uploading audios -->
Code example: <gh-file:examples/online_serving/openai_transcription_client.py> #### Extra Parameters
The following [sampling parameters](#sampling-params) are supported.
:::{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
:language: python
:start-after: begin-transcription-sampling-params
:end-before: end-transcription-sampling-params
:::
The following extra parameters are supported:
:::{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
:language: python
:start-after: begin-transcription-extra-params
:end-before: end-transcription-extra-params
:::
(tokenizer-api)= (tokenizer-api)=

View File

@ -26,7 +26,12 @@ def sync_openai():
model="openai/whisper-large-v3", model="openai/whisper-large-v3",
language="en", language="en",
response_format="json", response_format="json",
temperature=0.0) temperature=0.0,
# Additional sampling params not provided by OpenAI API.
extra_body=dict(
seed=4419,
repetition_penalty=1.3,
))
print("transcription result:", transcription.text) print("transcription result:", transcription.text)

View File

@ -192,3 +192,36 @@ async def test_stream_options(winning_call):
else: else:
continuous = continuous and hasattr(chunk, 'usage') continuous = continuous and hasattr(chunk, 'usage')
assert final and continuous assert final and continuous
@pytest.mark.asyncio
async def test_sampling_params(mary_had_lamb):
"""
Compare sampling with params and greedy sampling to assert results
are different when extreme sampling parameters values are picked.
"""
model_name = "openai/whisper-small"
server_args = ["--enforce-eager"]
with RemoteOpenAIServer(model_name, server_args) as remote_server:
client = remote_server.get_async_client()
transcription = await client.audio.transcriptions.create(
model=model_name,
file=mary_had_lamb,
language="en",
temperature=0.8,
extra_body=dict(seed=42,
repetition_penalty=1.9,
top_k=12,
top_p=0.4,
min_p=0.5,
frequency_penalty=1.8,
presence_penalty=2.0))
greedy_transcription = await client.audio.transcriptions.create(
model=model_name,
file=mary_had_lamb,
language="en",
temperature=0.0,
extra_body=dict(seed=42))
assert greedy_transcription.text != transcription.text

View File

@ -1577,14 +1577,6 @@ class TranscriptionRequest(OpenAIBaseModel):
""" """
## TODO (varun) : Support if set to 0, certain thresholds are met !! ## TODO (varun) : Support if set to 0, certain thresholds are met !!
temperature: float = Field(default=0.0)
"""The sampling temperature, between 0 and 1.
Higher values like 0.8 will make the output more random, while lower values
like 0.2 will make it more focused / deterministic. If set to 0, the model
will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
to automatically increase the temperature until certain thresholds are hit.
"""
timestamp_granularities: list[Literal["word", "segment"]] = Field( timestamp_granularities: list[Literal["word", "segment"]] = Field(
alias="timestamp_granularities[]", default=[]) alias="timestamp_granularities[]", default=[])
@ -1596,6 +1588,7 @@ class TranscriptionRequest(OpenAIBaseModel):
timestamps incurs additional latency. timestamps incurs additional latency.
""" """
# doc: begin-transcription-extra-params
stream: Optional[bool] = False stream: Optional[bool] = False
"""Custom field not present in the original OpenAI definition. When set, """Custom field not present in the original OpenAI definition. When set,
it will enable output to be streamed in a similar fashion as the Chat it will enable output to be streamed in a similar fashion as the Chat
@ -1604,10 +1597,51 @@ class TranscriptionRequest(OpenAIBaseModel):
# Flattened stream option to simplify form data. # Flattened stream option to simplify form data.
stream_include_usage: Optional[bool] = False stream_include_usage: Optional[bool] = False
stream_continuous_usage_stats: Optional[bool] = False stream_continuous_usage_stats: Optional[bool] = False
# doc: end-transcription-extra-params
# doc: begin-transcription-sampling-params
temperature: float = Field(default=0.0)
"""The sampling temperature, between 0 and 1.
Higher values like 0.8 will make the output more random, while lower values
like 0.2 will make it more focused / deterministic. If set to 0, the model
will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
to automatically increase the temperature until certain thresholds are hit.
"""
top_p: Optional[float] = None
"""Enables nucleus (top-p) sampling, where tokens are selected from the
smallest possible set whose cumulative probability exceeds `p`.
"""
top_k: Optional[int] = None
"""Limits sampling to the `k` most probable tokens at each step."""
min_p: Optional[float] = None
"""Filters out tokens with a probability lower than `min_p`, ensuring a
minimum likelihood threshold during sampling.
"""
seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
"""The seed to use for sampling."""
frequency_penalty: Optional[float] = 0.0
"""The frequency penalty to use for sampling."""
repetition_penalty: Optional[float] = None
"""The repetition penalty to use for sampling."""
presence_penalty: Optional[float] = 0.0
"""The presence penalty to use for sampling."""
# doc: end-transcription-sampling-params
# Default sampling parameters for transcription requests. # Default sampling parameters for transcription requests.
_DEFAULT_SAMPLING_PARAMS: dict = { _DEFAULT_SAMPLING_PARAMS: dict = {
"temperature": 0, "repetition_penalty": 1.0,
"temperature": 1.0,
"top_p": 1.0,
"top_k": -1,
"min_p": 0.0,
} }
def to_sampling_params( def to_sampling_params(
@ -1619,13 +1653,35 @@ class TranscriptionRequest(OpenAIBaseModel):
if default_sampling_params is None: if default_sampling_params is None:
default_sampling_params = {} default_sampling_params = {}
# Default parameters # Default parameters
if (temperature := self.temperature) is None: if (temperature := self.temperature) is None:
temperature = default_sampling_params.get( temperature = default_sampling_params.get(
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
if (top_p := self.top_p) is None:
top_p = default_sampling_params.get(
"top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"])
if (top_k := self.top_k) is None:
top_k = default_sampling_params.get(
"top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"])
if (min_p := self.min_p) is None:
min_p = default_sampling_params.get(
"min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"])
if (repetition_penalty := self.repetition_penalty) is None:
repetition_penalty = default_sampling_params.get(
"repetition_penalty",
self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"])
return SamplingParams.from_optional(temperature=temperature, return SamplingParams.from_optional(temperature=temperature,
max_tokens=max_tokens, max_tokens=max_tokens,
seed=self.seed,
top_p=top_p,
top_k=top_k,
min_p=min_p,
frequency_penalty=self.frequency_penalty,
repetition_penalty=repetition_penalty,
presence_penalty=self.presence_penalty,
output_kind=RequestOutputKind.DELTA output_kind=RequestOutputKind.DELTA
if self.stream \ if self.stream \
else RequestOutputKind.FINAL_ONLY) else RequestOutputKind.FINAL_ONLY)