mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 13:36:12 +08:00
[Frontend] Add sampling params to v1/audio/transcriptions endpoint (#16591)
Signed-off-by: Jannis Schönleber <joennlae@gmail.com> Signed-off-by: NickLucche <nlucches@redhat.com> Co-authored-by: Jannis Schönleber <joennlae@gmail.com>
This commit is contained in:
parent
1d4680fad2
commit
2ef0dc53b8
@ -402,9 +402,26 @@ you can use the [official OpenAI Python client](https://github.com/openai/openai
|
|||||||
To use the Transcriptions API, please install with extra audio dependencies using `pip install vllm[audio]`.
|
To use the Transcriptions API, please install with extra audio dependencies using `pip install vllm[audio]`.
|
||||||
:::
|
:::
|
||||||
|
|
||||||
|
Code example: <gh-file:examples/online_serving/openai_transcription_client.py>
|
||||||
<!-- TODO: api enforced limits + uploading audios -->
|
<!-- TODO: api enforced limits + uploading audios -->
|
||||||
|
|
||||||
Code example: <gh-file:examples/online_serving/openai_transcription_client.py>
|
#### Extra Parameters
|
||||||
|
|
||||||
|
The following [sampling parameters](#sampling-params) are supported.
|
||||||
|
|
||||||
|
:::{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
|
||||||
|
:language: python
|
||||||
|
:start-after: begin-transcription-sampling-params
|
||||||
|
:end-before: end-transcription-sampling-params
|
||||||
|
:::
|
||||||
|
|
||||||
|
The following extra parameters are supported:
|
||||||
|
|
||||||
|
:::{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
|
||||||
|
:language: python
|
||||||
|
:start-after: begin-transcription-extra-params
|
||||||
|
:end-before: end-transcription-extra-params
|
||||||
|
:::
|
||||||
|
|
||||||
(tokenizer-api)=
|
(tokenizer-api)=
|
||||||
|
|
||||||
|
|||||||
@ -26,7 +26,12 @@ def sync_openai():
|
|||||||
model="openai/whisper-large-v3",
|
model="openai/whisper-large-v3",
|
||||||
language="en",
|
language="en",
|
||||||
response_format="json",
|
response_format="json",
|
||||||
temperature=0.0)
|
temperature=0.0,
|
||||||
|
# Additional sampling params not provided by OpenAI API.
|
||||||
|
extra_body=dict(
|
||||||
|
seed=4419,
|
||||||
|
repetition_penalty=1.3,
|
||||||
|
))
|
||||||
print("transcription result:", transcription.text)
|
print("transcription result:", transcription.text)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -192,3 +192,36 @@ async def test_stream_options(winning_call):
|
|||||||
else:
|
else:
|
||||||
continuous = continuous and hasattr(chunk, 'usage')
|
continuous = continuous and hasattr(chunk, 'usage')
|
||||||
assert final and continuous
|
assert final and continuous
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sampling_params(mary_had_lamb):
|
||||||
|
"""
|
||||||
|
Compare sampling with params and greedy sampling to assert results
|
||||||
|
are different when extreme sampling parameters values are picked.
|
||||||
|
"""
|
||||||
|
model_name = "openai/whisper-small"
|
||||||
|
server_args = ["--enforce-eager"]
|
||||||
|
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
||||||
|
client = remote_server.get_async_client()
|
||||||
|
transcription = await client.audio.transcriptions.create(
|
||||||
|
model=model_name,
|
||||||
|
file=mary_had_lamb,
|
||||||
|
language="en",
|
||||||
|
temperature=0.8,
|
||||||
|
extra_body=dict(seed=42,
|
||||||
|
repetition_penalty=1.9,
|
||||||
|
top_k=12,
|
||||||
|
top_p=0.4,
|
||||||
|
min_p=0.5,
|
||||||
|
frequency_penalty=1.8,
|
||||||
|
presence_penalty=2.0))
|
||||||
|
|
||||||
|
greedy_transcription = await client.audio.transcriptions.create(
|
||||||
|
model=model_name,
|
||||||
|
file=mary_had_lamb,
|
||||||
|
language="en",
|
||||||
|
temperature=0.0,
|
||||||
|
extra_body=dict(seed=42))
|
||||||
|
|
||||||
|
assert greedy_transcription.text != transcription.text
|
||||||
|
|||||||
@ -1577,14 +1577,6 @@ class TranscriptionRequest(OpenAIBaseModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
## TODO (varun) : Support if set to 0, certain thresholds are met !!
|
## TODO (varun) : Support if set to 0, certain thresholds are met !!
|
||||||
temperature: float = Field(default=0.0)
|
|
||||||
"""The sampling temperature, between 0 and 1.
|
|
||||||
|
|
||||||
Higher values like 0.8 will make the output more random, while lower values
|
|
||||||
like 0.2 will make it more focused / deterministic. If set to 0, the model
|
|
||||||
will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
|
|
||||||
to automatically increase the temperature until certain thresholds are hit.
|
|
||||||
"""
|
|
||||||
|
|
||||||
timestamp_granularities: list[Literal["word", "segment"]] = Field(
|
timestamp_granularities: list[Literal["word", "segment"]] = Field(
|
||||||
alias="timestamp_granularities[]", default=[])
|
alias="timestamp_granularities[]", default=[])
|
||||||
@ -1596,6 +1588,7 @@ class TranscriptionRequest(OpenAIBaseModel):
|
|||||||
timestamps incurs additional latency.
|
timestamps incurs additional latency.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# doc: begin-transcription-extra-params
|
||||||
stream: Optional[bool] = False
|
stream: Optional[bool] = False
|
||||||
"""Custom field not present in the original OpenAI definition. When set,
|
"""Custom field not present in the original OpenAI definition. When set,
|
||||||
it will enable output to be streamed in a similar fashion as the Chat
|
it will enable output to be streamed in a similar fashion as the Chat
|
||||||
@ -1604,10 +1597,51 @@ class TranscriptionRequest(OpenAIBaseModel):
|
|||||||
# Flattened stream option to simplify form data.
|
# Flattened stream option to simplify form data.
|
||||||
stream_include_usage: Optional[bool] = False
|
stream_include_usage: Optional[bool] = False
|
||||||
stream_continuous_usage_stats: Optional[bool] = False
|
stream_continuous_usage_stats: Optional[bool] = False
|
||||||
|
# doc: end-transcription-extra-params
|
||||||
|
|
||||||
|
# doc: begin-transcription-sampling-params
|
||||||
|
temperature: float = Field(default=0.0)
|
||||||
|
"""The sampling temperature, between 0 and 1.
|
||||||
|
|
||||||
|
Higher values like 0.8 will make the output more random, while lower values
|
||||||
|
like 0.2 will make it more focused / deterministic. If set to 0, the model
|
||||||
|
will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
|
||||||
|
to automatically increase the temperature until certain thresholds are hit.
|
||||||
|
"""
|
||||||
|
|
||||||
|
top_p: Optional[float] = None
|
||||||
|
"""Enables nucleus (top-p) sampling, where tokens are selected from the
|
||||||
|
smallest possible set whose cumulative probability exceeds `p`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
top_k: Optional[int] = None
|
||||||
|
"""Limits sampling to the `k` most probable tokens at each step."""
|
||||||
|
|
||||||
|
min_p: Optional[float] = None
|
||||||
|
"""Filters out tokens with a probability lower than `min_p`, ensuring a
|
||||||
|
minimum likelihood threshold during sampling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
|
||||||
|
"""The seed to use for sampling."""
|
||||||
|
|
||||||
|
frequency_penalty: Optional[float] = 0.0
|
||||||
|
"""The frequency penalty to use for sampling."""
|
||||||
|
|
||||||
|
repetition_penalty: Optional[float] = None
|
||||||
|
"""The repetition penalty to use for sampling."""
|
||||||
|
|
||||||
|
presence_penalty: Optional[float] = 0.0
|
||||||
|
"""The presence penalty to use for sampling."""
|
||||||
|
# doc: end-transcription-sampling-params
|
||||||
|
|
||||||
# Default sampling parameters for transcription requests.
|
# Default sampling parameters for transcription requests.
|
||||||
_DEFAULT_SAMPLING_PARAMS: dict = {
|
_DEFAULT_SAMPLING_PARAMS: dict = {
|
||||||
"temperature": 0,
|
"repetition_penalty": 1.0,
|
||||||
|
"temperature": 1.0,
|
||||||
|
"top_p": 1.0,
|
||||||
|
"top_k": -1,
|
||||||
|
"min_p": 0.0,
|
||||||
}
|
}
|
||||||
|
|
||||||
def to_sampling_params(
|
def to_sampling_params(
|
||||||
@ -1619,13 +1653,35 @@ class TranscriptionRequest(OpenAIBaseModel):
|
|||||||
|
|
||||||
if default_sampling_params is None:
|
if default_sampling_params is None:
|
||||||
default_sampling_params = {}
|
default_sampling_params = {}
|
||||||
|
|
||||||
# Default parameters
|
# Default parameters
|
||||||
if (temperature := self.temperature) is None:
|
if (temperature := self.temperature) is None:
|
||||||
temperature = default_sampling_params.get(
|
temperature = default_sampling_params.get(
|
||||||
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
|
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
|
||||||
|
if (top_p := self.top_p) is None:
|
||||||
|
top_p = default_sampling_params.get(
|
||||||
|
"top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"])
|
||||||
|
if (top_k := self.top_k) is None:
|
||||||
|
top_k = default_sampling_params.get(
|
||||||
|
"top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"])
|
||||||
|
if (min_p := self.min_p) is None:
|
||||||
|
min_p = default_sampling_params.get(
|
||||||
|
"min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"])
|
||||||
|
|
||||||
|
if (repetition_penalty := self.repetition_penalty) is None:
|
||||||
|
repetition_penalty = default_sampling_params.get(
|
||||||
|
"repetition_penalty",
|
||||||
|
self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"])
|
||||||
|
|
||||||
return SamplingParams.from_optional(temperature=temperature,
|
return SamplingParams.from_optional(temperature=temperature,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
|
seed=self.seed,
|
||||||
|
top_p=top_p,
|
||||||
|
top_k=top_k,
|
||||||
|
min_p=min_p,
|
||||||
|
frequency_penalty=self.frequency_penalty,
|
||||||
|
repetition_penalty=repetition_penalty,
|
||||||
|
presence_penalty=self.presence_penalty,
|
||||||
output_kind=RequestOutputKind.DELTA
|
output_kind=RequestOutputKind.DELTA
|
||||||
if self.stream \
|
if self.stream \
|
||||||
else RequestOutputKind.FINAL_ONLY)
|
else RequestOutputKind.FINAL_ONLY)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user