[Frontend] Gemma3n audio transcriptions/translations endpoint (#23735)

Signed-off-by: NickLucche <nlucches@redhat.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Nicolò Lucchesi 2025-09-01 12:07:46 +02:00 committed by GitHub
parent 107284959a
commit d46934b229
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 189 additions and 63 deletions

View File

@ -0,0 +1,27 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm.assets.audio import AudioAsset
@pytest.fixture
def mary_had_lamb():
path = AudioAsset('mary_had_lamb').get_local_path()
with open(str(path), "rb") as f:
yield f
@pytest.fixture
def winning_call():
path = AudioAsset('winning_call').get_local_path()
with open(str(path), "rb") as f:
yield f
@pytest.fixture
def foscolo():
# Test translation it->en
path = AudioAsset('azacinto_foscolo').get_local_path()
with open(str(path), "rb") as f:
yield f

View File

@ -12,8 +12,6 @@ import pytest
import pytest_asyncio
import soundfile as sf
from vllm.assets.audio import AudioAsset
from ...utils import RemoteOpenAIServer
MODEL_NAME = "openai/whisper-large-v3-turbo"
@ -24,20 +22,6 @@ MISTRAL_FORMAT_ARGS = [
]
@pytest.fixture
def mary_had_lamb():
path = AudioAsset('mary_had_lamb').get_local_path()
with open(str(path), "rb") as f:
yield f
@pytest.fixture
def winning_call():
path = AudioAsset('winning_call').get_local_path()
with open(str(path), "rb") as f:
yield f
@pytest.fixture(scope="module")
def server():
with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as remote_server:
@ -76,6 +60,25 @@ async def test_basic_audio(mary_had_lamb, model_name):
assert out_usage["seconds"] == 16, out_usage["seconds"]
@pytest.mark.asyncio
async def test_basic_audio_gemma(foscolo):
# Gemma accuracy on some of the audio samples we use is particularly bad,
# hence we use a different one here. WER is evaluated separately.
model_name = "google/gemma-3n-E2B-it"
server_args = ["--enforce-eager"]
with RemoteOpenAIServer(model_name, server_args) as remote_server:
client = remote_server.get_async_client()
transcription = await client.audio.transcriptions.create(
model=model_name,
file=foscolo,
language="it",
response_format="text",
temperature=0.0)
out = json.loads(transcription)['text']
assert "da cui vergine nacque Venere" in out
@pytest.mark.asyncio
async def test_non_asr_model(winning_call):
# text to text model

View File

@ -12,32 +12,24 @@ import pytest
import pytest_asyncio
import soundfile as sf
from vllm.assets.audio import AudioAsset
from ...utils import RemoteOpenAIServer
MODEL_NAME = "openai/whisper-small"
SERVER_ARGS = ["--enforce-eager"]
@pytest.fixture
def foscolo():
# Test translation it->en
path = AudioAsset('azacinto_foscolo').get_local_path()
with open(str(path), "rb") as f:
yield f
@pytest.fixture(scope="module")
def server():
with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as remote_server:
yield remote_server
@pytest.fixture(scope="module",
params=["openai/whisper-small", "google/gemma-3n-E2B-it"])
def server(request):
# Parametrize over model name
with RemoteOpenAIServer(request.param, SERVER_ARGS) as remote_server:
yield remote_server, request.param
@pytest_asyncio.fixture
async def client(server):
async def client_and_model(server):
server, model_name = server
async with server.get_async_client() as async_client:
yield async_client
yield async_client, model_name
@pytest.mark.asyncio
@ -56,27 +48,29 @@ async def test_non_asr_model(foscolo):
# NOTE: (NickLucche) the large-v3-turbo model was not trained on translation!
@pytest.mark.asyncio
async def test_basic_audio(foscolo, client):
async def test_basic_audio(foscolo, client_and_model):
client, model_name = client_and_model
translation = await client.audio.translations.create(
model=MODEL_NAME,
model=model_name,
file=foscolo,
response_format="text",
# TODO remove once language detection is implemented
extra_body=dict(language="it"),
# TODO remove `language="it"` once language detection is implemented
extra_body=dict(language="it", to_language="en"),
temperature=0.0)
out = json.loads(translation)['text'].strip().lower()
assert "greek sea" in out
@pytest.mark.asyncio
async def test_audio_prompt(foscolo, client):
async def test_audio_prompt(foscolo, client_and_model):
client, model_name = client_and_model
# Condition whisper on starting text
prompt = "Nor have I ever"
transcription = await client.audio.translations.create(
model=MODEL_NAME,
model=model_name,
file=foscolo,
prompt=prompt,
extra_body=dict(language="it"),
extra_body=dict(language="it", to_language="en"),
response_format="text",
temperature=0.0)
out = json.loads(transcription)['text']
@ -85,22 +79,27 @@ async def test_audio_prompt(foscolo, client):
@pytest.mark.asyncio
async def test_streaming_response(foscolo, client, server):
async def test_streaming_response(foscolo, client_and_model, server):
client, model_name = client_and_model
translation = ""
res_no_stream = await client.audio.translations.create(
model=MODEL_NAME,
model=model_name,
file=foscolo,
response_format="json",
extra_body=dict(language="it"),
extra_body=dict(language="it", to_language="en", seed=42),
temperature=0.0)
# Stream via HTTPX since OpenAI translation client doesn't expose streaming
server, model_name = server
url = server.url_for("v1/audio/translations")
headers = {"Authorization": f"Bearer {server.DUMMY_API_KEY}"}
data = {
"model": MODEL_NAME,
"model": model_name,
"language": "it",
"to_language": "en",
"stream": True,
"temperature": 0.0,
"seed": 42,
}
foscolo.seek(0)
async with httpx.AsyncClient() as http_client:
@ -121,16 +120,24 @@ async def test_streaming_response(foscolo, client, server):
text = chunk["choices"][0].get("delta", {}).get("content")
translation += text or ""
assert translation == res_no_stream.text
res_stream = translation.split()
# NOTE There's a small non-deterministic issue here, likely in the attn
# computation, which will cause a few tokens to be different, while still
# being very close semantically.
assert sum([
x == y for x, y in zip(res_stream, res_no_stream.text.split())
]) >= len(res_stream) * 0.9
@pytest.mark.asyncio
async def test_stream_options(foscolo, client, server):
async def test_stream_options(foscolo, server):
server, model_name = server
url = server.url_for("v1/audio/translations")
headers = {"Authorization": f"Bearer {server.DUMMY_API_KEY}"}
data = {
"model": MODEL_NAME,
"model": model_name,
"language": "it",
"to_language": "en",
"stream": True,
"stream_include_usage": True,
"stream_continuous_usage_stats": True,
@ -164,7 +171,10 @@ async def test_stream_options(foscolo, client, server):
@pytest.mark.asyncio
async def test_long_audio_request(foscolo, client):
async def test_long_audio_request(foscolo, client_and_model):
client, model_name = client_and_model
if model_name == "google/gemma-3n-E2B-it":
pytest.skip("Gemma3n does not support long audio requests")
foscolo.seek(0)
audio, sr = librosa.load(foscolo)
repeated_audio = np.tile(audio, 2)
@ -173,9 +183,9 @@ async def test_long_audio_request(foscolo, client):
sf.write(buffer, repeated_audio, sr, format='WAV')
buffer.seek(0)
translation = await client.audio.translations.create(
model=MODEL_NAME,
model=model_name,
file=buffer,
extra_body=dict(language="it"),
extra_body=dict(language="it", to_language="en"),
response_format="text",
temperature=0.0)
out = json.loads(translation)['text'].strip().lower()

View File

@ -2175,6 +2175,13 @@ class TranscriptionRequest(OpenAIBaseModel):
)
# --8<-- [end:transcription-extra-params]
to_language: Optional[str] = None
"""The language of the output audio we transcribe to.
Please note that this is not currently used by supported models at this
time, but it is a placeholder for future use, matching translation api.
"""
# --8<-- [start:transcription-sampling-params]
temperature: float = Field(default=0.0)
"""The sampling temperature, between 0 and 1.
@ -2408,6 +2415,9 @@ class TranslationRequest(OpenAIBaseModel):
# TODO support additional sampling parameters
# --8<-- [start:translation-sampling-params]
seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
"""The seed to use for sampling."""
temperature: float = Field(default=0.0)
"""The sampling temperature, between 0 and 1.
@ -2427,6 +2437,14 @@ class TranslationRequest(OpenAIBaseModel):
will improve accuracy.
"""
to_language: Optional[str] = None
"""The language of the input audio we translate to.
Please note that this is not supported by all models, refer to the specific
model documentation for more details.
For instance, Whisper only supports `to_language=en`.
"""
stream: Optional[bool] = False
"""Custom field not present in the original OpenAI definition. When set,
it will enable output to be streamed in a similar fashion as the Chat
@ -2458,6 +2476,7 @@ class TranslationRequest(OpenAIBaseModel):
return SamplingParams.from_optional(temperature=temperature,
max_tokens=max_tokens,
seed=self.seed,
output_kind=RequestOutputKind.DELTA
if self.stream \
else RequestOutputKind.FINAL_ONLY)

View File

@ -89,6 +89,9 @@ class OpenAISpeechToText(OpenAIServing):
) -> tuple[list[PromptType], float]:
# Validate request
language = self.model_cls.validate_language(request.language)
# Skip to_language validation to avoid extra logging for Whisper.
to_language = self.model_cls.validate_language(request.to_language) \
if request.to_language else None
if len(audio_data) / 1024**2 > self.max_audio_filesize_mb:
raise ValueError("Maximum file size exceeded.")
@ -112,7 +115,9 @@ class OpenAISpeechToText(OpenAIServing):
model_config=self.model_config,
language=language,
task_type=self.task_type,
request_prompt=request.prompt)
request_prompt=request.prompt,
to_language=to_language,
)
prompts.append(prompt)
return prompts, duration

View File

@ -1,8 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Optional, TypedDict, Union, cast
from typing import Any, Literal, Optional, TypedDict, Union, cast
import numpy as np
import torch
from torch import nn
from transformers import AutoModel, BatchFeature
@ -13,7 +14,8 @@ from transformers.models.gemma3n import (Gemma3nAudioConfig,
Gemma3nVisionConfig)
from transformers.models.siglip import SiglipImageProcessorFast
from vllm.config import VllmConfig
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import RowParallelLinear
@ -21,6 +23,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.models.gemma3n import Gemma3nForCausalLM
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.whisper import ISO639_1_SUPPORTED_LANGS
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
@ -40,7 +43,8 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from .interfaces import MultiModalEmbeddings, SupportsMultiModal
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal,
SupportsTranscription)
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
@ -410,7 +414,10 @@ class Gemma3nMultimodalEmbedder(nn.Module):
@MULTIMODAL_REGISTRY.register_processor(Gemma3nMultiModalProcessor,
info=Gemma3nProcessingInfo,
dummy_inputs=Gemma3nDummyInputsBuilder)
class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal):
class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsTranscription):
supported_languages = ISO639_1_SUPPORTED_LANGS
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@ -694,3 +701,53 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal):
return "<audio_soft_token>"
else:
raise ValueError(f"Unsupported modality: {modality}")
@classmethod
def get_generation_prompt(cls, audio: np.ndarray,
stt_config: SpeechToTextConfig,
model_config: ModelConfig,
language: Optional[str],
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: Optional[str]) -> PromptType:
"""
Gemma3n supports "free-form" transcription.
We fix its prompt here to standardize transcriptions/translations
requests.
"""
# Transcribe this audio [into <>] | for transcription
# Translate this audio [from <> into <>] | for translation
prompt = "<start_of_turn>user\n"
prompt += "Transcribe" if task_type == "transcribe" else "Translate"
prompt += " this audio"
# We assume the language is a valid ISO 639-1 code.
full_lang_name = cls.supported_languages.get(language, "")
# Translation only for now
full_lang_name_to = cls.supported_languages.get(to_language, "")
if task_type == "transcribe" and full_lang_name:
prompt += f" into {full_lang_name}"
elif task_type == "translate":
if full_lang_name:
prompt += f" from {full_lang_name}"
if full_lang_name_to:
prompt += f" into {full_lang_name_to}"
prompt += ": <audio_soft_token><end_of_turn>\n<start_of_turn>model\n"
audio = (audio, stt_config.sample_rate)
prompts_dict = {"multi_modal_data": {"audio": audio}, "prompt": prompt}
return cast(PromptType, prompts_dict)
@classmethod
def get_speech_to_text_config(cls, model_config: ModelConfig,
task_type: str) -> SpeechToTextConfig:
return SpeechToTextConfig(
# Let's set this to 30 as suggested in the docs for now, although
# the model is only limited by its context length.
max_audio_clip_s=30,
sample_rate=16000,
# TODO enable chunking after more thorough testing.
min_energy_split_window_size=None,
)

View File

@ -700,8 +700,10 @@ class SupportsTranscription(Protocol):
def get_generation_prompt(cls, audio: np.ndarray,
stt_config: SpeechToTextConfig,
model_config: ModelConfig,
language: Optional[str], task_type: str,
request_prompt: str) -> PromptType:
language: Optional[str],
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: Optional[str]) -> PromptType:
"""Get the prompt for the ASR model.
The model has control over the construction, as long as it
returns a valid PromptType."""

View File

@ -5,7 +5,7 @@ import math
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from math import ceil
from typing import Optional, Union, cast
from typing import Literal, Optional, Union, cast
import numpy as np
import regex as re
@ -455,8 +455,10 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
def get_generation_prompt(cls, audio: np.ndarray,
model_config: ModelConfig,
stt_config: SpeechToTextConfig,
language: Optional[str], task_type: str,
request_prompt: str) -> PromptType:
language: Optional[str],
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: Optional[str]) -> PromptType:
tokenizer = cached_tokenizer_from_config(model_config)
audio = Audio(audio, int(stt_config.sample_rate),
format="wav") # lossless

View File

@ -4,7 +4,7 @@
import math
from collections.abc import Iterable, Mapping, Sequence
from contextlib import nullcontext
from typing import Optional, TypedDict, Union, cast
from typing import Literal, Optional, TypedDict, Union, cast
import numpy as np
import torch
@ -783,8 +783,9 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
model_config: ModelConfig, # not needed here
stt_config: SpeechToTextConfig,
language: Optional[str],
task_type: str,
request_prompt: str) -> PromptType:
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: Optional[str]) -> PromptType:
if language is None:
raise ValueError(
"Language must be specified when creating the Whisper prompt")