[Gemma3n] Fix audio batching (#24052)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi 2025-09-02 16:23:35 +02:00 committed by GitHub
parent 8bd5844989
commit 0a74e9d0f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 63 additions and 7 deletions

View File

@ -266,10 +266,52 @@ def run_audio(model: str) -> None:
print("Chat completion output from base64 encoded audio:", result)
def run_multi_audio(model: str) -> None:
from vllm.assets.audio import AudioAsset
# Two different audios to showcase batched inference.
audio_url = AudioAsset("winning_call").url
audio_base64 = encode_base64_content_from_url(audio_url)
audio_url2 = AudioAsset("azacinto_foscolo").url
audio_base64_2 = encode_base64_content_from_url(audio_url2)
# OpenAI-compatible schema (`input_audio`)
chat_completion_from_base64 = client.chat.completions.create(
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "Are these two audios the same?"},
{
"type": "input_audio",
"input_audio": {
"data": audio_base64,
"format": "wav",
},
},
{
"type": "input_audio",
"input_audio": {
"data": audio_base64_2,
"format": "wav",
},
},
],
}
],
model=model,
max_completion_tokens=64,
)
result = chat_completion_from_base64.choices[0].message.content
print("Chat completion output from input audio:", result)
example_function_map = {
"text-only": run_text_only,
"single-image": run_single_image,
"multi-image": run_multi_image,
"multi-audio": run_multi_audio,
"video": run_video,
"audio": run_audio,
}

View File

@ -5,6 +5,7 @@ from typing import Any, Literal, Optional, TypedDict, Union, cast
import numpy as np
import torch
# yapf: disable
from torch import nn
from transformers import AutoModel, BatchFeature
from transformers.models.gemma3n import (Gemma3nAudioConfig,
@ -30,7 +31,6 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems)
from vllm.multimodal.parse import (ImageProcessorItems, MultiModalDataItems,
MultiModalDataParser)
# yapf: disable
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo,
MultiModalPromptUpdates,
@ -62,7 +62,8 @@ class Gemma3nImagePixelInputs(TypedDict):
class Gemma3nAudioInputs(TypedDict):
input_features: torch.Tensor
input_features: Union[torch.Tensor, list[torch.Tensor]]
input_features_padded: torch.Tensor
"""Shape: `(batch_size * num_audio, seq_length, num_features)`"""
input_features_mask: torch.Tensor
"""Shape: `(batch_size * num_audio, seq_length)`"""
@ -188,8 +189,13 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]
mm_kwargs,
tok_kwargs,
)
if 'input_features' in processed_outputs:
# Avoid padding since we need the output of each item to be
# Padding enables audio_tower to run in batched mode
processed_outputs["input_features_padded"] = \
processed_outputs["input_features"]
# Unpad features here since we need the output of each item to be
# independent of other items for the cache to work correctly
unpadded_features = [
f[mask] for f, mask in zip(
@ -206,9 +212,11 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(pixel_values=MultiModalFieldConfig.batched("image"),
input_features=MultiModalFieldConfig.batched("audio"),
input_features_mask=MultiModalFieldConfig.batched("audio"))
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
input_features=MultiModalFieldConfig.batched("audio"),
input_features_padded=MultiModalFieldConfig.batched("audio"),
input_features_mask=MultiModalFieldConfig.batched("audio"))
def _get_prompt_updates(
self,
@ -516,9 +524,14 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
if input_features_mask is None:
return None
input_features_padded = kwargs.pop("input_features_padded", None)
if input_features_padded is None:
return None
return Gemma3nAudioInputs(
input_features=input_features,
input_features_mask=input_features_mask,
input_features_padded=input_features_padded,
)
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
@ -564,7 +577,8 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
audio_input: Gemma3nAudioInputs,
) -> list[torch.Tensor]:
assert self.audio_tower is not None
input_features = audio_input["input_features"].squeeze(1)
# Run on padded features to enable batching
input_features = audio_input["input_features_padded"].squeeze(1)
input_features_mask = audio_input["input_features_mask"].squeeze(1)
audio_outputs, audio_mask = self.audio_tower(input_features,
~input_features_mask)