diff --git a/examples/online_serving/openai_chat_completion_client_for_multimodal.py b/examples/online_serving/openai_chat_completion_client_for_multimodal.py index ac5f79b56e49f..37216a5cfe574 100644 --- a/examples/online_serving/openai_chat_completion_client_for_multimodal.py +++ b/examples/online_serving/openai_chat_completion_client_for_multimodal.py @@ -266,10 +266,52 @@ def run_audio(model: str) -> None: print("Chat completion output from base64 encoded audio:", result) +def run_multi_audio(model: str) -> None: + from vllm.assets.audio import AudioAsset + + # Two different audios to showcase batched inference. + audio_url = AudioAsset("winning_call").url + audio_base64 = encode_base64_content_from_url(audio_url) + audio_url2 = AudioAsset("azacinto_foscolo").url + audio_base64_2 = encode_base64_content_from_url(audio_url2) + + # OpenAI-compatible schema (`input_audio`) + chat_completion_from_base64 = client.chat.completions.create( + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "Are these two audios the same?"}, + { + "type": "input_audio", + "input_audio": { + "data": audio_base64, + "format": "wav", + }, + }, + { + "type": "input_audio", + "input_audio": { + "data": audio_base64_2, + "format": "wav", + }, + }, + ], + } + ], + model=model, + max_completion_tokens=64, + ) + + result = chat_completion_from_base64.choices[0].message.content + print("Chat completion output from input audio:", result) + + example_function_map = { "text-only": run_text_only, "single-image": run_single_image, "multi-image": run_multi_image, + "multi-audio": run_multi_audio, "video": run_video, "audio": run_audio, } diff --git a/vllm/model_executor/models/gemma3n_mm.py b/vllm/model_executor/models/gemma3n_mm.py index c25bbcd420c39..d831e9084db57 100644 --- a/vllm/model_executor/models/gemma3n_mm.py +++ b/vllm/model_executor/models/gemma3n_mm.py @@ -5,6 +5,7 @@ from typing import Any, Literal, Optional, TypedDict, Union, cast import numpy as np import torch +# yapf: disable from torch import nn from transformers import AutoModel, BatchFeature from transformers.models.gemma3n import (Gemma3nAudioConfig, @@ -30,7 +31,6 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems) from vllm.multimodal.parse import (ImageProcessorItems, MultiModalDataItems, MultiModalDataParser) -# yapf: disable from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, MultiModalPromptUpdates, @@ -62,7 +62,8 @@ class Gemma3nImagePixelInputs(TypedDict): class Gemma3nAudioInputs(TypedDict): - input_features: torch.Tensor + input_features: Union[torch.Tensor, list[torch.Tensor]] + input_features_padded: torch.Tensor """Shape: `(batch_size * num_audio, seq_length, num_features)`""" input_features_mask: torch.Tensor """Shape: `(batch_size * num_audio, seq_length)`""" @@ -188,8 +189,13 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo] mm_kwargs, tok_kwargs, ) + if 'input_features' in processed_outputs: - # Avoid padding since we need the output of each item to be + # Padding enables audio_tower to run in batched mode + processed_outputs["input_features_padded"] = \ + processed_outputs["input_features"] + + # Unpad features here since we need the output of each item to be # independent of other items for the cache to work correctly unpadded_features = [ f[mask] for f, mask in zip( @@ -206,9 +212,11 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo] hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - return dict(pixel_values=MultiModalFieldConfig.batched("image"), - input_features=MultiModalFieldConfig.batched("audio"), - input_features_mask=MultiModalFieldConfig.batched("audio")) + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + input_features=MultiModalFieldConfig.batched("audio"), + input_features_padded=MultiModalFieldConfig.batched("audio"), + input_features_mask=MultiModalFieldConfig.batched("audio")) def _get_prompt_updates( self, @@ -516,9 +524,14 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal, if input_features_mask is None: return None + input_features_padded = kwargs.pop("input_features_padded", None) + if input_features_padded is None: + return None + return Gemma3nAudioInputs( input_features=input_features, input_features_mask=input_features_mask, + input_features_padded=input_features_padded, ) def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: @@ -564,7 +577,8 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal, audio_input: Gemma3nAudioInputs, ) -> list[torch.Tensor]: assert self.audio_tower is not None - input_features = audio_input["input_features"].squeeze(1) + # Run on padded features to enable batching + input_features = audio_input["input_features_padded"].squeeze(1) input_features_mask = audio_input["input_features_mask"].squeeze(1) audio_outputs, audio_mask = self.audio_tower(input_features, ~input_features_mask)