mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-11 03:37:02 +08:00
[Gemma3n] Fix audio batching (#24052)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
parent
8bd5844989
commit
0a74e9d0f2
@ -266,10 +266,52 @@ def run_audio(model: str) -> None:
|
||||
print("Chat completion output from base64 encoded audio:", result)
|
||||
|
||||
|
||||
def run_multi_audio(model: str) -> None:
|
||||
from vllm.assets.audio import AudioAsset
|
||||
|
||||
# Two different audios to showcase batched inference.
|
||||
audio_url = AudioAsset("winning_call").url
|
||||
audio_base64 = encode_base64_content_from_url(audio_url)
|
||||
audio_url2 = AudioAsset("azacinto_foscolo").url
|
||||
audio_base64_2 = encode_base64_content_from_url(audio_url2)
|
||||
|
||||
# OpenAI-compatible schema (`input_audio`)
|
||||
chat_completion_from_base64 = client.chat.completions.create(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Are these two audios the same?"},
|
||||
{
|
||||
"type": "input_audio",
|
||||
"input_audio": {
|
||||
"data": audio_base64,
|
||||
"format": "wav",
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "input_audio",
|
||||
"input_audio": {
|
||||
"data": audio_base64_2,
|
||||
"format": "wav",
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
model=model,
|
||||
max_completion_tokens=64,
|
||||
)
|
||||
|
||||
result = chat_completion_from_base64.choices[0].message.content
|
||||
print("Chat completion output from input audio:", result)
|
||||
|
||||
|
||||
example_function_map = {
|
||||
"text-only": run_text_only,
|
||||
"single-image": run_single_image,
|
||||
"multi-image": run_multi_image,
|
||||
"multi-audio": run_multi_audio,
|
||||
"video": run_video,
|
||||
"audio": run_audio,
|
||||
}
|
||||
|
||||
@ -5,6 +5,7 @@ from typing import Any, Literal, Optional, TypedDict, Union, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
# yapf: disable
|
||||
from torch import nn
|
||||
from transformers import AutoModel, BatchFeature
|
||||
from transformers.models.gemma3n import (Gemma3nAudioConfig,
|
||||
@ -30,7 +31,6 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalKwargsItems)
|
||||
from vllm.multimodal.parse import (ImageProcessorItems, MultiModalDataItems,
|
||||
MultiModalDataParser)
|
||||
# yapf: disable
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
MultiModalPromptUpdates,
|
||||
@ -62,7 +62,8 @@ class Gemma3nImagePixelInputs(TypedDict):
|
||||
|
||||
|
||||
class Gemma3nAudioInputs(TypedDict):
|
||||
input_features: torch.Tensor
|
||||
input_features: Union[torch.Tensor, list[torch.Tensor]]
|
||||
input_features_padded: torch.Tensor
|
||||
"""Shape: `(batch_size * num_audio, seq_length, num_features)`"""
|
||||
input_features_mask: torch.Tensor
|
||||
"""Shape: `(batch_size * num_audio, seq_length)`"""
|
||||
@ -188,8 +189,13 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]
|
||||
mm_kwargs,
|
||||
tok_kwargs,
|
||||
)
|
||||
|
||||
if 'input_features' in processed_outputs:
|
||||
# Avoid padding since we need the output of each item to be
|
||||
# Padding enables audio_tower to run in batched mode
|
||||
processed_outputs["input_features_padded"] = \
|
||||
processed_outputs["input_features"]
|
||||
|
||||
# Unpad features here since we need the output of each item to be
|
||||
# independent of other items for the cache to work correctly
|
||||
unpadded_features = [
|
||||
f[mask] for f, mask in zip(
|
||||
@ -206,9 +212,11 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
|
||||
return dict(pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
input_features=MultiModalFieldConfig.batched("audio"),
|
||||
input_features_mask=MultiModalFieldConfig.batched("audio"))
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
input_features=MultiModalFieldConfig.batched("audio"),
|
||||
input_features_padded=MultiModalFieldConfig.batched("audio"),
|
||||
input_features_mask=MultiModalFieldConfig.batched("audio"))
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
@ -516,9 +524,14 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
if input_features_mask is None:
|
||||
return None
|
||||
|
||||
input_features_padded = kwargs.pop("input_features_padded", None)
|
||||
if input_features_padded is None:
|
||||
return None
|
||||
|
||||
return Gemma3nAudioInputs(
|
||||
input_features=input_features,
|
||||
input_features_mask=input_features_mask,
|
||||
input_features_padded=input_features_padded,
|
||||
)
|
||||
|
||||
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
||||
@ -564,7 +577,8 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
audio_input: Gemma3nAudioInputs,
|
||||
) -> list[torch.Tensor]:
|
||||
assert self.audio_tower is not None
|
||||
input_features = audio_input["input_features"].squeeze(1)
|
||||
# Run on padded features to enable batching
|
||||
input_features = audio_input["input_features_padded"].squeeze(1)
|
||||
input_features_mask = audio_input["input_features_mask"].squeeze(1)
|
||||
audio_outputs, audio_mask = self.audio_tower(input_features,
|
||||
~input_features_mask)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user