From 2ef0dc53b88ded24930f10665a0575f39aef0cac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicol=C3=B2=20Lucchesi?= Date: Sat, 19 Apr 2025 09:03:54 +0200 Subject: [PATCH] [Frontend] Add sampling params to `v1/audio/transcriptions` endpoint (#16591) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Jannis Schönleber Signed-off-by: NickLucche Co-authored-by: Jannis Schönleber --- .../serving/openai_compatible_server.md | 19 ++++- .../openai_transcription_client.py | 7 +- .../openai/test_transcription_validation.py | 33 +++++++++ vllm/entrypoints/openai/protocol.py | 74 ++++++++++++++++--- 4 files changed, 122 insertions(+), 11 deletions(-) diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index a62d4a79e2aa..34382c87a484 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -402,9 +402,26 @@ you can use the [official OpenAI Python client](https://github.com/openai/openai To use the Transcriptions API, please install with extra audio dependencies using `pip install vllm[audio]`. ::: +Code example: -Code example: +#### Extra Parameters + +The following [sampling parameters](#sampling-params) are supported. + +:::{literalinclude} ../../../vllm/entrypoints/openai/protocol.py +:language: python +:start-after: begin-transcription-sampling-params +:end-before: end-transcription-sampling-params +::: + +The following extra parameters are supported: + +:::{literalinclude} ../../../vllm/entrypoints/openai/protocol.py +:language: python +:start-after: begin-transcription-extra-params +:end-before: end-transcription-extra-params +::: (tokenizer-api)= diff --git a/examples/online_serving/openai_transcription_client.py b/examples/online_serving/openai_transcription_client.py index 062868dd8adf..5fcb7c526416 100644 --- a/examples/online_serving/openai_transcription_client.py +++ b/examples/online_serving/openai_transcription_client.py @@ -26,7 +26,12 @@ def sync_openai(): model="openai/whisper-large-v3", language="en", response_format="json", - temperature=0.0) + temperature=0.0, + # Additional sampling params not provided by OpenAI API. + extra_body=dict( + seed=4419, + repetition_penalty=1.3, + )) print("transcription result:", transcription.text) diff --git a/tests/entrypoints/openai/test_transcription_validation.py b/tests/entrypoints/openai/test_transcription_validation.py index 29571bcd7649..5c48df3cebbc 100644 --- a/tests/entrypoints/openai/test_transcription_validation.py +++ b/tests/entrypoints/openai/test_transcription_validation.py @@ -192,3 +192,36 @@ async def test_stream_options(winning_call): else: continuous = continuous and hasattr(chunk, 'usage') assert final and continuous + + +@pytest.mark.asyncio +async def test_sampling_params(mary_had_lamb): + """ + Compare sampling with params and greedy sampling to assert results + are different when extreme sampling parameters values are picked. + """ + model_name = "openai/whisper-small" + server_args = ["--enforce-eager"] + with RemoteOpenAIServer(model_name, server_args) as remote_server: + client = remote_server.get_async_client() + transcription = await client.audio.transcriptions.create( + model=model_name, + file=mary_had_lamb, + language="en", + temperature=0.8, + extra_body=dict(seed=42, + repetition_penalty=1.9, + top_k=12, + top_p=0.4, + min_p=0.5, + frequency_penalty=1.8, + presence_penalty=2.0)) + + greedy_transcription = await client.audio.transcriptions.create( + model=model_name, + file=mary_had_lamb, + language="en", + temperature=0.0, + extra_body=dict(seed=42)) + + assert greedy_transcription.text != transcription.text diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 4639b4cea06b..8d2ab29d221e 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1577,14 +1577,6 @@ class TranscriptionRequest(OpenAIBaseModel): """ ## TODO (varun) : Support if set to 0, certain thresholds are met !! - temperature: float = Field(default=0.0) - """The sampling temperature, between 0 and 1. - - Higher values like 0.8 will make the output more random, while lower values - like 0.2 will make it more focused / deterministic. If set to 0, the model - will use [log probability](https://en.wikipedia.org/wiki/Log_probability) - to automatically increase the temperature until certain thresholds are hit. - """ timestamp_granularities: list[Literal["word", "segment"]] = Field( alias="timestamp_granularities[]", default=[]) @@ -1596,6 +1588,7 @@ class TranscriptionRequest(OpenAIBaseModel): timestamps incurs additional latency. """ + # doc: begin-transcription-extra-params stream: Optional[bool] = False """Custom field not present in the original OpenAI definition. When set, it will enable output to be streamed in a similar fashion as the Chat @@ -1604,10 +1597,51 @@ class TranscriptionRequest(OpenAIBaseModel): # Flattened stream option to simplify form data. stream_include_usage: Optional[bool] = False stream_continuous_usage_stats: Optional[bool] = False + # doc: end-transcription-extra-params + + # doc: begin-transcription-sampling-params + temperature: float = Field(default=0.0) + """The sampling temperature, between 0 and 1. + + Higher values like 0.8 will make the output more random, while lower values + like 0.2 will make it more focused / deterministic. If set to 0, the model + will use [log probability](https://en.wikipedia.org/wiki/Log_probability) + to automatically increase the temperature until certain thresholds are hit. + """ + + top_p: Optional[float] = None + """Enables nucleus (top-p) sampling, where tokens are selected from the + smallest possible set whose cumulative probability exceeds `p`. + """ + + top_k: Optional[int] = None + """Limits sampling to the `k` most probable tokens at each step.""" + + min_p: Optional[float] = None + """Filters out tokens with a probability lower than `min_p`, ensuring a + minimum likelihood threshold during sampling. + """ + + seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) + """The seed to use for sampling.""" + + frequency_penalty: Optional[float] = 0.0 + """The frequency penalty to use for sampling.""" + + repetition_penalty: Optional[float] = None + """The repetition penalty to use for sampling.""" + + presence_penalty: Optional[float] = 0.0 + """The presence penalty to use for sampling.""" + # doc: end-transcription-sampling-params # Default sampling parameters for transcription requests. _DEFAULT_SAMPLING_PARAMS: dict = { - "temperature": 0, + "repetition_penalty": 1.0, + "temperature": 1.0, + "top_p": 1.0, + "top_k": -1, + "min_p": 0.0, } def to_sampling_params( @@ -1619,13 +1653,35 @@ class TranscriptionRequest(OpenAIBaseModel): if default_sampling_params is None: default_sampling_params = {} + # Default parameters if (temperature := self.temperature) is None: temperature = default_sampling_params.get( "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) + if (top_p := self.top_p) is None: + top_p = default_sampling_params.get( + "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]) + if (top_k := self.top_k) is None: + top_k = default_sampling_params.get( + "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]) + if (min_p := self.min_p) is None: + min_p = default_sampling_params.get( + "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]) + + if (repetition_penalty := self.repetition_penalty) is None: + repetition_penalty = default_sampling_params.get( + "repetition_penalty", + self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"]) return SamplingParams.from_optional(temperature=temperature, max_tokens=max_tokens, + seed=self.seed, + top_p=top_p, + top_k=top_k, + min_p=min_p, + frequency_penalty=self.frequency_penalty, + repetition_penalty=repetition_penalty, + presence_penalty=self.presence_penalty, output_kind=RequestOutputKind.DELTA if self.stream \ else RequestOutputKind.FINAL_ONLY)