mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:05:01 +08:00
[Frontend] Gemma3n audio transcriptions/translations endpoint (#23735)
Signed-off-by: NickLucche <nlucches@redhat.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
parent
107284959a
commit
d46934b229
27
tests/entrypoints/openai/conftest.py
Normal file
27
tests/entrypoints/openai/conftest.py
Normal file
@ -0,0 +1,27 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
|
||||
from vllm.assets.audio import AudioAsset
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mary_had_lamb():
|
||||
path = AudioAsset('mary_had_lamb').get_local_path()
|
||||
with open(str(path), "rb") as f:
|
||||
yield f
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def winning_call():
|
||||
path = AudioAsset('winning_call').get_local_path()
|
||||
with open(str(path), "rb") as f:
|
||||
yield f
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def foscolo():
|
||||
# Test translation it->en
|
||||
path = AudioAsset('azacinto_foscolo').get_local_path()
|
||||
with open(str(path), "rb") as f:
|
||||
yield f
|
||||
@ -12,8 +12,6 @@ import pytest
|
||||
import pytest_asyncio
|
||||
import soundfile as sf
|
||||
|
||||
from vllm.assets.audio import AudioAsset
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "openai/whisper-large-v3-turbo"
|
||||
@ -24,20 +22,6 @@ MISTRAL_FORMAT_ARGS = [
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mary_had_lamb():
|
||||
path = AudioAsset('mary_had_lamb').get_local_path()
|
||||
with open(str(path), "rb") as f:
|
||||
yield f
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def winning_call():
|
||||
path = AudioAsset('winning_call').get_local_path()
|
||||
with open(str(path), "rb") as f:
|
||||
yield f
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as remote_server:
|
||||
@ -76,6 +60,25 @@ async def test_basic_audio(mary_had_lamb, model_name):
|
||||
assert out_usage["seconds"] == 16, out_usage["seconds"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_audio_gemma(foscolo):
|
||||
# Gemma accuracy on some of the audio samples we use is particularly bad,
|
||||
# hence we use a different one here. WER is evaluated separately.
|
||||
model_name = "google/gemma-3n-E2B-it"
|
||||
server_args = ["--enforce-eager"]
|
||||
|
||||
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
||||
client = remote_server.get_async_client()
|
||||
transcription = await client.audio.transcriptions.create(
|
||||
model=model_name,
|
||||
file=foscolo,
|
||||
language="it",
|
||||
response_format="text",
|
||||
temperature=0.0)
|
||||
out = json.loads(transcription)['text']
|
||||
assert "da cui vergine nacque Venere" in out
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_asr_model(winning_call):
|
||||
# text to text model
|
||||
|
||||
@ -12,32 +12,24 @@ import pytest
|
||||
import pytest_asyncio
|
||||
import soundfile as sf
|
||||
|
||||
from vllm.assets.audio import AudioAsset
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "openai/whisper-small"
|
||||
SERVER_ARGS = ["--enforce-eager"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def foscolo():
|
||||
# Test translation it->en
|
||||
path = AudioAsset('azacinto_foscolo').get_local_path()
|
||||
with open(str(path), "rb") as f:
|
||||
yield f
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as remote_server:
|
||||
yield remote_server
|
||||
@pytest.fixture(scope="module",
|
||||
params=["openai/whisper-small", "google/gemma-3n-E2B-it"])
|
||||
def server(request):
|
||||
# Parametrize over model name
|
||||
with RemoteOpenAIServer(request.param, SERVER_ARGS) as remote_server:
|
||||
yield remote_server, request.param
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server):
|
||||
async def client_and_model(server):
|
||||
server, model_name = server
|
||||
async with server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
yield async_client, model_name
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -56,27 +48,29 @@ async def test_non_asr_model(foscolo):
|
||||
|
||||
# NOTE: (NickLucche) the large-v3-turbo model was not trained on translation!
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_audio(foscolo, client):
|
||||
async def test_basic_audio(foscolo, client_and_model):
|
||||
client, model_name = client_and_model
|
||||
translation = await client.audio.translations.create(
|
||||
model=MODEL_NAME,
|
||||
model=model_name,
|
||||
file=foscolo,
|
||||
response_format="text",
|
||||
# TODO remove once language detection is implemented
|
||||
extra_body=dict(language="it"),
|
||||
# TODO remove `language="it"` once language detection is implemented
|
||||
extra_body=dict(language="it", to_language="en"),
|
||||
temperature=0.0)
|
||||
out = json.loads(translation)['text'].strip().lower()
|
||||
assert "greek sea" in out
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audio_prompt(foscolo, client):
|
||||
async def test_audio_prompt(foscolo, client_and_model):
|
||||
client, model_name = client_and_model
|
||||
# Condition whisper on starting text
|
||||
prompt = "Nor have I ever"
|
||||
transcription = await client.audio.translations.create(
|
||||
model=MODEL_NAME,
|
||||
model=model_name,
|
||||
file=foscolo,
|
||||
prompt=prompt,
|
||||
extra_body=dict(language="it"),
|
||||
extra_body=dict(language="it", to_language="en"),
|
||||
response_format="text",
|
||||
temperature=0.0)
|
||||
out = json.loads(transcription)['text']
|
||||
@ -85,22 +79,27 @@ async def test_audio_prompt(foscolo, client):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_response(foscolo, client, server):
|
||||
async def test_streaming_response(foscolo, client_and_model, server):
|
||||
client, model_name = client_and_model
|
||||
translation = ""
|
||||
res_no_stream = await client.audio.translations.create(
|
||||
model=MODEL_NAME,
|
||||
model=model_name,
|
||||
file=foscolo,
|
||||
response_format="json",
|
||||
extra_body=dict(language="it"),
|
||||
extra_body=dict(language="it", to_language="en", seed=42),
|
||||
temperature=0.0)
|
||||
|
||||
# Stream via HTTPX since OpenAI translation client doesn't expose streaming
|
||||
server, model_name = server
|
||||
url = server.url_for("v1/audio/translations")
|
||||
headers = {"Authorization": f"Bearer {server.DUMMY_API_KEY}"}
|
||||
data = {
|
||||
"model": MODEL_NAME,
|
||||
"model": model_name,
|
||||
"language": "it",
|
||||
"to_language": "en",
|
||||
"stream": True,
|
||||
"temperature": 0.0,
|
||||
"seed": 42,
|
||||
}
|
||||
foscolo.seek(0)
|
||||
async with httpx.AsyncClient() as http_client:
|
||||
@ -121,16 +120,24 @@ async def test_streaming_response(foscolo, client, server):
|
||||
text = chunk["choices"][0].get("delta", {}).get("content")
|
||||
translation += text or ""
|
||||
|
||||
assert translation == res_no_stream.text
|
||||
res_stream = translation.split()
|
||||
# NOTE There's a small non-deterministic issue here, likely in the attn
|
||||
# computation, which will cause a few tokens to be different, while still
|
||||
# being very close semantically.
|
||||
assert sum([
|
||||
x == y for x, y in zip(res_stream, res_no_stream.text.split())
|
||||
]) >= len(res_stream) * 0.9
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_options(foscolo, client, server):
|
||||
async def test_stream_options(foscolo, server):
|
||||
server, model_name = server
|
||||
url = server.url_for("v1/audio/translations")
|
||||
headers = {"Authorization": f"Bearer {server.DUMMY_API_KEY}"}
|
||||
data = {
|
||||
"model": MODEL_NAME,
|
||||
"model": model_name,
|
||||
"language": "it",
|
||||
"to_language": "en",
|
||||
"stream": True,
|
||||
"stream_include_usage": True,
|
||||
"stream_continuous_usage_stats": True,
|
||||
@ -164,7 +171,10 @@ async def test_stream_options(foscolo, client, server):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_long_audio_request(foscolo, client):
|
||||
async def test_long_audio_request(foscolo, client_and_model):
|
||||
client, model_name = client_and_model
|
||||
if model_name == "google/gemma-3n-E2B-it":
|
||||
pytest.skip("Gemma3n does not support long audio requests")
|
||||
foscolo.seek(0)
|
||||
audio, sr = librosa.load(foscolo)
|
||||
repeated_audio = np.tile(audio, 2)
|
||||
@ -173,9 +183,9 @@ async def test_long_audio_request(foscolo, client):
|
||||
sf.write(buffer, repeated_audio, sr, format='WAV')
|
||||
buffer.seek(0)
|
||||
translation = await client.audio.translations.create(
|
||||
model=MODEL_NAME,
|
||||
model=model_name,
|
||||
file=buffer,
|
||||
extra_body=dict(language="it"),
|
||||
extra_body=dict(language="it", to_language="en"),
|
||||
response_format="text",
|
||||
temperature=0.0)
|
||||
out = json.loads(translation)['text'].strip().lower()
|
||||
|
||||
@ -2175,6 +2175,13 @@ class TranscriptionRequest(OpenAIBaseModel):
|
||||
)
|
||||
# --8<-- [end:transcription-extra-params]
|
||||
|
||||
to_language: Optional[str] = None
|
||||
"""The language of the output audio we transcribe to.
|
||||
|
||||
Please note that this is not currently used by supported models at this
|
||||
time, but it is a placeholder for future use, matching translation api.
|
||||
"""
|
||||
|
||||
# --8<-- [start:transcription-sampling-params]
|
||||
temperature: float = Field(default=0.0)
|
||||
"""The sampling temperature, between 0 and 1.
|
||||
@ -2408,6 +2415,9 @@ class TranslationRequest(OpenAIBaseModel):
|
||||
|
||||
# TODO support additional sampling parameters
|
||||
# --8<-- [start:translation-sampling-params]
|
||||
seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
|
||||
"""The seed to use for sampling."""
|
||||
|
||||
temperature: float = Field(default=0.0)
|
||||
"""The sampling temperature, between 0 and 1.
|
||||
|
||||
@ -2427,6 +2437,14 @@ class TranslationRequest(OpenAIBaseModel):
|
||||
will improve accuracy.
|
||||
"""
|
||||
|
||||
to_language: Optional[str] = None
|
||||
"""The language of the input audio we translate to.
|
||||
|
||||
Please note that this is not supported by all models, refer to the specific
|
||||
model documentation for more details.
|
||||
For instance, Whisper only supports `to_language=en`.
|
||||
"""
|
||||
|
||||
stream: Optional[bool] = False
|
||||
"""Custom field not present in the original OpenAI definition. When set,
|
||||
it will enable output to be streamed in a similar fashion as the Chat
|
||||
@ -2458,6 +2476,7 @@ class TranslationRequest(OpenAIBaseModel):
|
||||
|
||||
return SamplingParams.from_optional(temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
seed=self.seed,
|
||||
output_kind=RequestOutputKind.DELTA
|
||||
if self.stream \
|
||||
else RequestOutputKind.FINAL_ONLY)
|
||||
|
||||
@ -89,6 +89,9 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
) -> tuple[list[PromptType], float]:
|
||||
# Validate request
|
||||
language = self.model_cls.validate_language(request.language)
|
||||
# Skip to_language validation to avoid extra logging for Whisper.
|
||||
to_language = self.model_cls.validate_language(request.to_language) \
|
||||
if request.to_language else None
|
||||
|
||||
if len(audio_data) / 1024**2 > self.max_audio_filesize_mb:
|
||||
raise ValueError("Maximum file size exceeded.")
|
||||
@ -112,7 +115,9 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
model_config=self.model_config,
|
||||
language=language,
|
||||
task_type=self.task_type,
|
||||
request_prompt=request.prompt)
|
||||
request_prompt=request.prompt,
|
||||
to_language=to_language,
|
||||
)
|
||||
prompts.append(prompt)
|
||||
return prompts, duration
|
||||
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Any, Optional, TypedDict, Union, cast
|
||||
from typing import Any, Literal, Optional, TypedDict, Union, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import AutoModel, BatchFeature
|
||||
@ -13,7 +14,8 @@ from transformers.models.gemma3n import (Gemma3nAudioConfig,
|
||||
Gemma3nVisionConfig)
|
||||
from transformers.models.siglip import SiglipImageProcessorFast
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import RowParallelLinear
|
||||
@ -21,6 +23,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.models.gemma3n import Gemma3nForCausalLM
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.models.whisper import ISO639_1_SUPPORTED_LANGS
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
@ -40,7 +43,8 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal,
|
||||
SupportsTranscription)
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
@ -410,7 +414,10 @@ class Gemma3nMultimodalEmbedder(nn.Module):
|
||||
@MULTIMODAL_REGISTRY.register_processor(Gemma3nMultiModalProcessor,
|
||||
info=Gemma3nProcessingInfo,
|
||||
dummy_inputs=Gemma3nDummyInputsBuilder)
|
||||
class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsTranscription):
|
||||
supported_languages = ISO639_1_SUPPORTED_LANGS
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
@ -694,3 +701,53 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
return "<audio_soft_token>"
|
||||
else:
|
||||
raise ValueError(f"Unsupported modality: {modality}")
|
||||
|
||||
@classmethod
|
||||
def get_generation_prompt(cls, audio: np.ndarray,
|
||||
stt_config: SpeechToTextConfig,
|
||||
model_config: ModelConfig,
|
||||
language: Optional[str],
|
||||
task_type: Literal["transcribe", "translate"],
|
||||
request_prompt: str,
|
||||
to_language: Optional[str]) -> PromptType:
|
||||
"""
|
||||
Gemma3n supports "free-form" transcription.
|
||||
We fix its prompt here to standardize transcriptions/translations
|
||||
requests.
|
||||
"""
|
||||
# Transcribe this audio [into <>] | for transcription
|
||||
# Translate this audio [from <> into <>] | for translation
|
||||
prompt = "<start_of_turn>user\n"
|
||||
prompt += "Transcribe" if task_type == "transcribe" else "Translate"
|
||||
prompt += " this audio"
|
||||
|
||||
# We assume the language is a valid ISO 639-1 code.
|
||||
full_lang_name = cls.supported_languages.get(language, "")
|
||||
# Translation only for now
|
||||
full_lang_name_to = cls.supported_languages.get(to_language, "")
|
||||
|
||||
if task_type == "transcribe" and full_lang_name:
|
||||
prompt += f" into {full_lang_name}"
|
||||
elif task_type == "translate":
|
||||
if full_lang_name:
|
||||
prompt += f" from {full_lang_name}"
|
||||
if full_lang_name_to:
|
||||
prompt += f" into {full_lang_name_to}"
|
||||
|
||||
prompt += ": <audio_soft_token><end_of_turn>\n<start_of_turn>model\n"
|
||||
|
||||
audio = (audio, stt_config.sample_rate)
|
||||
prompts_dict = {"multi_modal_data": {"audio": audio}, "prompt": prompt}
|
||||
return cast(PromptType, prompts_dict)
|
||||
|
||||
@classmethod
|
||||
def get_speech_to_text_config(cls, model_config: ModelConfig,
|
||||
task_type: str) -> SpeechToTextConfig:
|
||||
return SpeechToTextConfig(
|
||||
# Let's set this to 30 as suggested in the docs for now, although
|
||||
# the model is only limited by its context length.
|
||||
max_audio_clip_s=30,
|
||||
sample_rate=16000,
|
||||
# TODO enable chunking after more thorough testing.
|
||||
min_energy_split_window_size=None,
|
||||
)
|
||||
|
||||
@ -700,8 +700,10 @@ class SupportsTranscription(Protocol):
|
||||
def get_generation_prompt(cls, audio: np.ndarray,
|
||||
stt_config: SpeechToTextConfig,
|
||||
model_config: ModelConfig,
|
||||
language: Optional[str], task_type: str,
|
||||
request_prompt: str) -> PromptType:
|
||||
language: Optional[str],
|
||||
task_type: Literal["transcribe", "translate"],
|
||||
request_prompt: str,
|
||||
to_language: Optional[str]) -> PromptType:
|
||||
"""Get the prompt for the ASR model.
|
||||
The model has control over the construction, as long as it
|
||||
returns a valid PromptType."""
|
||||
|
||||
@ -5,7 +5,7 @@ import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from functools import cached_property
|
||||
from math import ceil
|
||||
from typing import Optional, Union, cast
|
||||
from typing import Literal, Optional, Union, cast
|
||||
|
||||
import numpy as np
|
||||
import regex as re
|
||||
@ -455,8 +455,10 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
def get_generation_prompt(cls, audio: np.ndarray,
|
||||
model_config: ModelConfig,
|
||||
stt_config: SpeechToTextConfig,
|
||||
language: Optional[str], task_type: str,
|
||||
request_prompt: str) -> PromptType:
|
||||
language: Optional[str],
|
||||
task_type: Literal["transcribe", "translate"],
|
||||
request_prompt: str,
|
||||
to_language: Optional[str]) -> PromptType:
|
||||
tokenizer = cached_tokenizer_from_config(model_config)
|
||||
audio = Audio(audio, int(stt_config.sample_rate),
|
||||
format="wav") # lossless
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from contextlib import nullcontext
|
||||
from typing import Optional, TypedDict, Union, cast
|
||||
from typing import Literal, Optional, TypedDict, Union, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -783,8 +783,9 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
|
||||
model_config: ModelConfig, # not needed here
|
||||
stt_config: SpeechToTextConfig,
|
||||
language: Optional[str],
|
||||
task_type: str,
|
||||
request_prompt: str) -> PromptType:
|
||||
task_type: Literal["transcribe", "translate"],
|
||||
request_prompt: str,
|
||||
to_language: Optional[str]) -> PromptType:
|
||||
if language is None:
|
||||
raise ValueError(
|
||||
"Language must be specified when creating the Whisper prompt")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user