Add AudioFlamingo3 model support (#30539)

Signed-off-by: Lasha <26011196+lashahub@users.noreply.github.com>
Signed-off-by: Lasha Koroshinadze <26011196+lashahub@users.noreply.github.com>
Co-authored-by: Isotr0py <2037008807@qq.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Lasha Koroshinadze 2025-12-14 05:14:55 -05:00 committed by GitHub
parent 1a55cfafcb
commit 3a20450d31
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 989 additions and 44 deletions

View File

@ -659,6 +659,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) |
|--------------|--------|--------|-------------------|----------------------|---------------------------|
| `AriaForConditionalGeneration` | Aria | T + I<sup>+</sup> | `rhymes-ai/Aria` | | |
| `AudioFlamingo3ForConditionalGeneration` | AudioFlamingo3 | T + A<sup>+</sup> | `nvidia/audio-flamingo-3-hf`, `nvidia/music-flamingo-hf` | ✅︎ | ✅︎ |
| `AyaVisionForConditionalGeneration` | Aya Vision | T + I<sup>+</sup> | `CohereLabs/aya-vision-8b`, `CohereLabs/aya-vision-32b`, etc. | | ✅︎ |
| `BeeForConditionalGeneration` | Bee-8B | T + I<sup>E+</sup> | `Open-Bee/Bee-8B-RL`, `Open-Bee/Bee-8B-SFT` | | ✅︎ |
| `Blip2ForConditionalGeneration` | BLIP-2 | T + I<sup>E</sup> | `Salesforce/blip2-opt-2.7b`, `Salesforce/blip2-opt-6.7b`, etc. | | ✅︎ |

View File

@ -42,60 +42,31 @@ class ModelRequestData(NamedTuple):
# Unless specified, these settings have been tested to work on a single L4.
# Voxtral
# Make sure to install mistral-common[audio].
def run_voxtral(question: str, audio_count: int) -> ModelRequestData:
from mistral_common.audio import Audio
from mistral_common.protocol.instruct.chunk import (
AudioChunk,
RawAudio,
TextChunk,
)
from mistral_common.protocol.instruct.messages import (
UserMessage,
)
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
model_name = "mistralai/Voxtral-Mini-3B-2507"
tokenizer = MistralTokenizer.from_hf_hub(model_name)
# AudioFlamingo3
def run_audioflamingo3(question: str, audio_count: int) -> ModelRequestData:
model_name = "nvidia/audio-flamingo-3-hf"
engine_args = EngineArgs(
model=model_name,
max_model_len=8192,
max_model_len=4096,
max_num_seqs=2,
limit_mm_per_prompt={"audio": audio_count},
config_format="mistral",
load_format="mistral",
tokenizer_mode="mistral",
enforce_eager=True,
enable_chunked_prefill=False,
)
text_chunk = TextChunk(text=question)
audios = [
Audio.from_file(str(audio_assets[i].get_local_path()), strict=False)
for i in range(audio_count)
]
audio_chunks = [
AudioChunk(input_audio=RawAudio.from_audio(audio)) for audio in audios
]
# AudioFlamingo3 uses <sound> token for audio
audio_placeholder = "<sound>" * audio_count
messages = [UserMessage(content=[*audio_chunks, text_chunk])]
req = ChatCompletionRequest(messages=messages, model=model_name)
tokens = tokenizer.encode_chat_completion(req)
prompt_ids, audios = tokens.tokens, tokens.audios
audios_and_sr = [(au.audio_array, au.sampling_rate) for au in audios]
multi_modal_data = {"audio": audios_and_sr}
prompt = (
"<|im_start|>system\n"
"You are a helpful assistant.<|im_end|>\n"
"<|im_start|>user\n"
f"{audio_placeholder}{question}<|im_end|>\n"
"<|im_start|>assistant\n"
)
return ModelRequestData(
engine_args=engine_args,
prompt_token_ids=prompt_ids,
multi_modal_data=multi_modal_data,
prompt=prompt,
)
@ -361,6 +332,63 @@ def run_ultravox(question: str, audio_count: int) -> ModelRequestData:
)
# Voxtral
# Make sure to install mistral-common[audio].
def run_voxtral(question: str, audio_count: int) -> ModelRequestData:
from mistral_common.audio import Audio
from mistral_common.protocol.instruct.chunk import (
AudioChunk,
RawAudio,
TextChunk,
)
from mistral_common.protocol.instruct.messages import (
UserMessage,
)
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
model_name = "mistralai/Voxtral-Mini-3B-2507"
tokenizer = MistralTokenizer.from_hf_hub(model_name)
engine_args = EngineArgs(
model=model_name,
max_model_len=8192,
max_num_seqs=2,
limit_mm_per_prompt={"audio": audio_count},
config_format="mistral",
load_format="mistral",
tokenizer_mode="mistral",
enforce_eager=True,
enable_chunked_prefill=False,
)
text_chunk = TextChunk(text=question)
audios = [
Audio.from_file(str(audio_assets[i].get_local_path()), strict=False)
for i in range(audio_count)
]
audio_chunks = [
AudioChunk(input_audio=RawAudio.from_audio(audio)) for audio in audios
]
messages = [UserMessage(content=[*audio_chunks, text_chunk])]
req = ChatCompletionRequest(messages=messages, model=model_name)
tokens = tokenizer.encode_chat_completion(req)
prompt_ids, audios = tokens.tokens, tokens.audios
audios_and_sr = [(au.audio_array, au.sampling_rate) for au in audios]
multi_modal_data = {"audio": audios_and_sr}
return ModelRequestData(
engine_args=engine_args,
prompt_token_ids=prompt_ids,
multi_modal_data=multi_modal_data,
)
# Whisper
def run_whisper(question: str, audio_count: int) -> ModelRequestData:
assert audio_count == 1, "Whisper only support single audio input per prompt"
@ -382,7 +410,7 @@ def run_whisper(question: str, audio_count: int) -> ModelRequestData:
model_example_map = {
"voxtral": run_voxtral,
"audioflamingo3": run_audioflamingo3,
"gemma3n": run_gemma3n,
"granite_speech": run_granite_speech,
"midashenglm": run_midashenglm,
@ -392,6 +420,7 @@ model_example_map = {
"qwen2_audio": run_qwen2_audio,
"qwen2_5_omni": run_qwen2_5_omni,
"ultravox": run_ultravox,
"voxtral": run_voxtral,
"whisper": run_whisper,
}

View File

@ -0,0 +1 @@
{"transcriptions": ["There is no clear relationship between the barking and the music, as they seem to be independent of each other.", "(B) To indicate that language cannot express clearly, satirizing the inversion of black and white in the world"], "token_ids": [[3862, 374, 902, 2797, 5025, 1948, 279, 293, 33452, 323, 279, 4627, 11, 438, 807, 2803, 311, 387, 9489, 315, 1817, 1008, 13, 151645], [5349, 8, 2014, 13216, 429, 4128, 4157, 3158, 9355, 11, 7578, 404, 4849, 279, 46488, 315, 3691, 323, 4158, 304, 279, 1879, 151645, 151671]]}

View File

@ -0,0 +1 @@
{"transcriptions": ["The content of the input audio is 'you can ask why over and over and over again forever even if one day we explain every physical interaction and scientific law and hope and dream and regret with a single elegant equation'."], "token_ids": [[785, 2213, 315, 279, 1946, 7699, 374, 364, 9330, 646, 2548, 3170, 916, 323, 916, 323, 916, 1549, 15683, 1496, 421, 825, 1899, 582, 10339, 1449, 6961, 16230, 323, 12344, 2329, 323, 3900, 323, 7904, 323, 22231, 448, 264, 3175, 25777, 23606, 4427, 151645]]}

View File

@ -0,0 +1,142 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 The vLLM team.
# Copyright 2025 NVIDIA CORPORATION and the HuggingFace Inc. team. All rights
# reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import pytest
from tests.models.registry import HF_EXAMPLE_MODELS
from vllm import LLM, SamplingParams
MODEL_NAME = "nvidia/audio-flamingo-3-hf"
def get_fixture_path(filename):
return os.path.join(
os.path.dirname(__file__), "../../fixtures/audioflamingo3", filename
)
@pytest.fixture(scope="module")
def llm():
# Check if the model is supported by the current transformers version
model_info = HF_EXAMPLE_MODELS.get_hf_info("AudioFlamingo3ForConditionalGeneration")
model_info.check_transformers_version(on_fail="skip")
try:
llm = LLM(
model=MODEL_NAME,
trust_remote_code=True,
dtype="bfloat16",
enforce_eager=True,
limit_mm_per_prompt={"audio": 1},
)
return llm
except Exception as e:
pytest.skip(f"Failed to load model {MODEL_NAME}: {e}")
def test_single_generation(llm):
fixture_path = get_fixture_path("expected_results_single.json")
if not os.path.exists(fixture_path):
pytest.skip(f"Fixture not found: {fixture_path}")
with open(fixture_path) as f:
expected = json.load(f)
audio_url = "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/Why_do_we_ask_questions_converted.wav"
messages = [
{
"role": "user",
"content": [
{"type": "audio_url", "audio_url": {"url": audio_url}},
{"type": "text", "text": "Transcribe the input speech."},
],
}
]
sampling_params = SamplingParams(temperature=0.0, max_tokens=128)
outputs = llm.chat(
messages=messages,
sampling_params=sampling_params,
)
generated_text = outputs[0].outputs[0].text.strip()
expected_text = expected["transcriptions"][0]
assert expected_text in generated_text or generated_text in expected_text
def test_batched_generation(llm):
fixture_path = get_fixture_path("expected_results_batched.json")
if not os.path.exists(fixture_path):
pytest.skip(f"Fixture not found: {fixture_path}")
with open(fixture_path) as f:
expected = json.load(f)
items = [
{
"audio_url": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/dogs_barking_in_sync_with_the_music.wav",
"question": "What is surprising about the relationship "
"between the barking and the music?",
"expected_idx": 0,
},
{
"audio_url": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/Ch6Ae9DT6Ko_00-04-03_00-04-31.wav",
"question": (
"Why is the philosopher's name mentioned in the lyrics? "
"(A) To express a sense of nostalgia "
"(B) To indicate that language cannot express clearly, "
"satirizing the inversion of black and white in the world "
"(C) To add depth and complexity to the lyrics "
"(D) To showcase the wisdom and influence of the philosopher"
),
"expected_idx": 1,
},
]
conversations = []
for item in items:
messages = [
{
"role": "user",
"content": [
{"type": "audio_url", "audio_url": {"url": item["audio_url"]}},
{"type": "text", "text": item["question"]},
],
}
]
conversations.append(messages)
sampling_params = SamplingParams(temperature=0.0, max_tokens=128)
outputs = llm.chat(
messages=conversations,
sampling_params=sampling_params,
)
for i, output in enumerate(outputs):
generated_text = output.outputs[0].text.strip()
expected_text = expected["transcriptions"][i]
assert expected_text in generated_text or generated_text in expected_text

View File

@ -0,0 +1,125 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 The vLLM team.
# Copyright 2025 NVIDIA CORPORATION and the HuggingFace Inc. team. All rights
# reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import MagicMock
import numpy as np
import pytest
import torch
from transformers import PretrainedConfig
from tests.models.registry import HF_EXAMPLE_MODELS
class MockAudioFlamingo3Config(PretrainedConfig):
model_type = "audioflamingo3"
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.audio_config = PretrainedConfig()
self.text_config = PretrainedConfig()
class MockAudioFlamingo3Processor:
def __init__(self):
self.audio_token = "<sound>"
self.audio_token_id = 12345
self.feature_extractor = MockFeatureExtractor()
def __call__(self, text=None, audios=None, **kwargs):
return {"input_ids": [1, 2, 3], "input_features": [np.zeros((3000, 80))]}
class MockFeatureExtractor:
def __init__(self):
self.sampling_rate = 16000
self.chunk_length = 30
@pytest.fixture
def mock_ctx():
config = MockAudioFlamingo3Config()
ctx = MagicMock()
ctx.get_hf_config.return_value = config
ctx.get_hf_processor.return_value = MockAudioFlamingo3Processor()
ctx.model_config.hf_config = config
return ctx
@pytest.fixture(autouse=True)
def check_transformers_version():
# Check if the model is supported by the current transformers version
model_info = HF_EXAMPLE_MODELS.get_hf_info("AudioFlamingo3ForConditionalGeneration")
model_info.check_transformers_version(on_fail="skip")
def test_audio_chunk_counting(mock_ctx):
from vllm.model_executor.models.audioflamingo3 import (
AudioFlamingo3DummyInputsBuilder,
AudioFlamingo3MultiModalProcessor,
AudioFlamingo3ProcessingInfo,
)
info = AudioFlamingo3ProcessingInfo(mock_ctx)
processor = AudioFlamingo3MultiModalProcessor(
info, AudioFlamingo3DummyInputsBuilder(info)
)
sr = 16000
audio_1 = np.zeros(30 * sr)
audio_2 = np.zeros(45 * sr)
mm_data = {"audio": [audio_1, audio_2]}
prompt = "<|user|>Listen.<|end|>"
from vllm.multimodal.processing import BaseMultiModalProcessor
def mock_base_call(self, prompt, mm_data, mm_kwargs, tok_kwargs):
return {"input_ids": [1, 2, 3], "input_features": torch.randn(1, 80, 3000)}
with pytest.MonkeyPatch.context() as mp:
mp.setattr(BaseMultiModalProcessor, "_call_hf_processor", mock_base_call)
processed = processor._call_hf_processor(prompt, mm_data, {}, {})
chunk_counts = processed["chunk_counts"]
assert chunk_counts[0].item() == 1
assert chunk_counts[1].item() == 2
assert len(chunk_counts) == 2
def test_dummy_data_generation(mock_ctx):
from vllm.model_executor.models.audioflamingo3 import (
AudioFlamingo3DummyInputsBuilder,
AudioFlamingo3ProcessingInfo,
)
info = AudioFlamingo3ProcessingInfo(mock_ctx)
builder = AudioFlamingo3DummyInputsBuilder(info)
mm_counts = {"audio": 2}
dummy_data = builder.get_dummy_mm_data(100, mm_counts, None)
assert "audio" in dummy_data
assert len(dummy_data["audio"]) == 2
expected_len = 600 * 16000
assert len(dummy_data["audio"][0]) == expected_len

View File

@ -578,6 +578,9 @@ _AUTOMATIC_CONVERTED_MODELS = {
_MULTIMODAL_EXAMPLE_MODELS = {
# [Decoder-only]
"AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria"),
"AudioFlamingo3ForConditionalGeneration": _HfExamplesInfo(
"nvidia/audio-flamingo-3-hf", min_transformers_version="5.0.0.dev"
),
"AyaVisionForConditionalGeneration": _HfExamplesInfo("CohereLabs/aya-vision-8b"),
"BeeForConditionalGeneration": _HfExamplesInfo(
"Open-Bee/Bee-8B-RL",

View File

@ -0,0 +1,639 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 The vLLM team.
# Copyright 2025 NVIDIA CORPORATION and the HuggingFace Inc. team. All rights
# reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Any, Literal, TypeAlias
import torch
import torch.nn as nn
from transformers import BatchFeature, PretrainedConfig
from transformers.models.audioflamingo3 import (
AudioFlamingo3Config,
AudioFlamingo3Processor,
)
from transformers.models.qwen2_audio import Qwen2AudioEncoder
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalKwargsItems,
)
from vllm.multimodal.parse import (
DictEmbeddingItems,
ModalityData,
ModalityDataItems,
MultiModalDataItems,
MultiModalDataParser,
)
from vllm.multimodal.processing import (
BaseMultiModalProcessor,
BaseProcessingInfo,
PromptReplacement,
PromptUpdate,
PromptUpdateDetails,
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (
MultiModalEmbeddings,
SupportsLoRA,
SupportsMultiModal,
SupportsPP,
)
from .utils import (
AutoWeightsLoader,
init_vllm_registered_model,
maybe_prefix,
)
MAX_AUDIO_LEN = 10 * 60
# === Audio Inputs === #
class AudioFlamingo3FeatureInputs(TensorSchema):
"""
Dimensions:
- num_chunks: Number of audio chunks (flattened)
- nmb: Number of mel bins
- num_audios: Number of original audio files
"""
type: Literal["audio_features"]
input_features: Annotated[
torch.Tensor | list[torch.Tensor],
TensorShape("num_chunks", "nmb", 3000),
]
feature_attention_mask: Annotated[
torch.Tensor,
TensorShape("num_chunks", 3000),
]
chunk_counts: Annotated[
torch.Tensor,
TensorShape("num_audios"),
]
class AudioFlamingo3EmbeddingInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size
- naf: Number of audio features
- hs: Hidden size (must match the hidden size of language model
backbone)
"""
type: Literal["audio_embeds"] = "audio_embeds"
audio_embeds: Annotated[
list[torch.Tensor],
TensorShape("bn", "naf", "hs"),
]
AudioFlamingo3Inputs: TypeAlias = (
AudioFlamingo3FeatureInputs | AudioFlamingo3EmbeddingInputs
)
class AudioFlamingo3Encoder(Qwen2AudioEncoder):
def __init__(
self,
config: PretrainedConfig,
):
super().__init__(config)
self.avg_pooler = nn.AvgPool1d(kernel_size=2, stride=2)
# self.layer_norm is already initialized in super().__init__
def forward(
self,
input_features: torch.Tensor | list[torch.Tensor],
attention_mask: torch.Tensor = None,
):
# input_features: (batch, num_mel_bins, seq_len)
if isinstance(input_features, list):
input_features = torch.stack(input_features)
hidden_states = nn.functional.gelu(self.conv1(input_features))
hidden_states = nn.functional.gelu(self.conv2(hidden_states))
hidden_states = hidden_states.transpose(-1, -2)
hidden_states = (
hidden_states + self.embed_positions.weight[: hidden_states.size(-2), :]
).to(hidden_states.dtype)
for layer in self.layers:
layer_outputs = layer(hidden_states, attention_mask)
hidden_states = layer_outputs[0]
# AvgPool (time/2) + LayerNorm
# hidden_states: (batch, seq_len, hidden_size)
hidden_states = hidden_states.permute(0, 2, 1) # (batch, hidden_size, seq_len)
hidden_states = self.avg_pooler(hidden_states)
hidden_states = hidden_states.permute(
0, 2, 1
) # (batch, seq_len/2, hidden_size)
hidden_states = self.layer_norm(hidden_states)
return hidden_states
def _get_feat_extract_output_lengths(self, input_lengths: torch.Tensor):
"""
Computes the output length of the convolutional layers and the output length
of the audio encoder
"""
input_lengths = (input_lengths - 1) // 2 + 1
output_lengths = (input_lengths - 2) // 2 + 1
return input_lengths, output_lengths
class AudioFlamingo3MultiModalProjector(nn.Module):
def __init__(self, config: PretrainedConfig):
super().__init__()
self.linear_1 = nn.Linear(
config.audio_config.hidden_size,
config.text_config.hidden_size,
bias=config.projector_bias,
)
self.act = get_act_fn(config.projector_hidden_act)
self.linear_2 = nn.Linear(
config.text_config.hidden_size,
config.text_config.hidden_size,
bias=config.projector_bias,
)
def forward(self, audio_features):
hidden_states = self.linear_1(audio_features)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
class AudioFlamingo3ProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config(AudioFlamingo3Config)
def get_hf_processor(self, **kwargs: object):
return self.ctx.get_hf_processor(AudioFlamingo3Processor, **kwargs)
def get_feature_extractor(self, **kwargs: object):
hf_processor = self.get_hf_processor(**kwargs)
feature_extractor = hf_processor.feature_extractor
return feature_extractor
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"audio": None}
class AudioFlamingo3DummyInputsBuilder(
BaseDummyInputsBuilder[AudioFlamingo3ProcessingInfo]
):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_audios = mm_counts.get("audio", 0)
hf_processor = self.info.get_hf_processor()
audio_token = hf_processor.audio_token
return audio_token * num_audios
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> MultiModalDataDict:
feature_extractor = self.info.get_feature_extractor()
sampling_rate = feature_extractor.sampling_rate
audio_len = MAX_AUDIO_LEN * sampling_rate
num_audios = mm_counts.get("audio", 0)
audio_overrides = mm_options.get("audio") if mm_options else None
return {
"audio": self._get_dummy_audios(
length=audio_len,
num_audios=num_audios,
overrides=audio_overrides,
)
}
def _audioflamingo3_field_config(hf_inputs: Mapping[str, torch.Tensor]):
chunk_counts = hf_inputs.get("chunk_counts")
if chunk_counts is not None:
return dict(
audio_embeds=MultiModalFieldConfig.batched("audio"),
input_features=MultiModalFieldConfig.flat_from_sizes(
"audio", chunk_counts, dim=0
),
feature_attention_mask=MultiModalFieldConfig.flat_from_sizes(
"audio", chunk_counts, dim=0
),
chunk_counts=MultiModalFieldConfig.batched("audio"),
)
return dict(
audio_embeds=MultiModalFieldConfig.batched("audio"),
input_features=MultiModalFieldConfig.batched("audio"),
feature_attention_mask=MultiModalFieldConfig.batched("audio"),
chunk_counts=MultiModalFieldConfig.batched("audio"),
)
class AudioFlamingo3MultiModalDataParser(MultiModalDataParser):
def _parse_audio_data(
self,
data: dict[str, torch.Tensor] | ModalityData[Any],
) -> ModalityDataItems[Any, Any] | None:
if isinstance(data, dict):
return DictEmbeddingItems(
data,
modality="audio",
required_fields={"audio_embeds"},
fields_factory=_audioflamingo3_field_config,
)
return super()._parse_audio_data(data)
class AudioFlamingo3MultiModalProcessor(
BaseMultiModalProcessor[AudioFlamingo3ProcessingInfo]
):
def _get_data_parser(self) -> MultiModalDataParser:
feature_extractor = self.info.get_feature_extractor()
return AudioFlamingo3MultiModalDataParser(
target_sr=feature_extractor.sampling_rate
)
def _call_hf_processor(
self,
prompt: str,
mm_data: dict[str, object],
mm_kwargs: Mapping[str, Any],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
audios = mm_data.pop("audios", [])
if audios:
mm_data["audio"] = audios
if not mm_data.get("audio", []):
prompt_ids = self.info.get_tokenizer().encode(prompt)
prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
mm_kwargs = dict(
**mm_kwargs,
sampling_rate=feature_extractor.sampling_rate,
)
# Calculate chunk counts
audio_list = mm_data.get("audio")
if not isinstance(audio_list, list):
audio_list = [audio_list]
chunk_counts = []
sampling_rate = feature_extractor.sampling_rate
chunk_length = feature_extractor.chunk_length
window_size = int(sampling_rate * chunk_length)
# MAX_AUDIO_LEN is 10 * 60 in HF processor.
max_windows = int(MAX_AUDIO_LEN // chunk_length)
for audio in audio_list:
# audio is numpy array or list
n_samples = len(audio) if isinstance(audio, list) else audio.shape[0]
n_win = max(1, (n_samples + window_size - 1) // window_size)
if n_win > max_windows:
n_win = max_windows
chunk_counts.append(n_win)
outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
)
if "input_features_mask" in outputs:
outputs["feature_attention_mask"] = outputs.pop("input_features_mask")
outputs["chunk_counts"] = torch.tensor(chunk_counts, dtype=torch.long)
return outputs
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return _audioflamingo3_field_config(hf_inputs)
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab()
audio_token = getattr(processor, "audio_token", "<sound>")
audio_token_id = vocab.get(audio_token)
if audio_token_id is None:
# Fallback if not found, though it should be there
audio_token_id = processor.audio_token_id
out_mm_data = out_mm_kwargs.get_data()
feature_attention_mask = out_mm_data.get("feature_attention_mask")
chunk_counts = out_mm_data.get("chunk_counts")
def get_replacement_audioflamingo3(item_idx: int):
if feature_attention_mask is not None:
if chunk_counts is not None:
counts = (
chunk_counts.tolist()
if isinstance(chunk_counts, torch.Tensor)
else chunk_counts
)
start_idx = sum(counts[:item_idx])
count = counts[item_idx]
end_idx = start_idx + count
if isinstance(feature_attention_mask, list):
mask_list = feature_attention_mask[start_idx:end_idx]
if len(mask_list) > 0 and isinstance(
mask_list[0], torch.Tensor
):
mask = torch.stack(mask_list)
else:
mask = torch.tensor(mask_list)
else:
mask = feature_attention_mask[start_idx:end_idx]
else:
# feature_attention_mask is list[Tensor] or Tensor
if isinstance(feature_attention_mask, list):
mask = feature_attention_mask[item_idx]
else:
mask = feature_attention_mask[item_idx].unsqueeze(0)
# mask shape: (num_chunks, 3000)
input_lengths = mask.sum(-1)
conv_lengths = (input_lengths - 1) // 2 + 1
audio_output_lengths = (conv_lengths - 2) // 2 + 1
num_features = audio_output_lengths.sum().item()
else:
audio_embeds = out_mm_data["audio_embeds"][item_idx]
num_features = audio_embeds.shape[0]
if num_features == 0:
raise ValueError("Audio is too short")
audio_tokens = [audio_token_id] * int(num_features)
return PromptUpdateDetails.select_token_id(
audio_tokens,
embed_token_id=audio_token_id,
)
return [
PromptReplacement(
modality="audio",
target=audio_token,
replacement=get_replacement_audioflamingo3,
)
]
@MULTIMODAL_REGISTRY.register_processor(
AudioFlamingo3MultiModalProcessor,
info=AudioFlamingo3ProcessingInfo,
dummy_inputs=AudioFlamingo3DummyInputsBuilder,
)
class AudioFlamingo3ForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA
):
"""
AudioFlamingo3 model for conditional generation.
This model integrates a Whisper-based audio encoder with a Qwen2 language model.
It supports multi-chunk audio processing.
"""
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
}
def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(
language_model="language_model.",
connector="multi_modal_projector.",
tower_model="audio_tower.",
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
self.audio_tower = AudioFlamingo3Encoder(
config.audio_config,
)
self.multi_modal_projector = AudioFlamingo3MultiModalProjector(config)
self.quant_config = quant_config
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
architectures=["Qwen2ForCausalLM"],
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
)
def _parse_and_validate_audio_input(
self, **kwargs: object
) -> AudioFlamingo3Inputs | None:
input_features = kwargs.pop("input_features", None)
audio_embeds = kwargs.pop("audio_embeds", None)
feature_attention_mask = kwargs.pop("feature_attention_mask", None)
chunk_counts = kwargs.pop("chunk_counts", None)
if input_features is None and audio_embeds is None:
return None
if audio_embeds is not None:
return AudioFlamingo3EmbeddingInputs(
type="audio_embeds", audio_embeds=audio_embeds
)
if input_features is not None:
return AudioFlamingo3FeatureInputs(
type="audio_features",
input_features=input_features,
feature_attention_mask=feature_attention_mask,
chunk_counts=chunk_counts,
)
raise AssertionError("This line should be unreachable.")
def _process_audio_input(
self, audio_input: AudioFlamingo3Inputs
) -> torch.Tensor | tuple[torch.Tensor, ...]:
if audio_input["type"] == "audio_embeds":
audio_embeds = audio_input["audio_embeds"]
return tuple(audio_embeds)
input_features = audio_input["input_features"]
feature_attention_mask = audio_input["feature_attention_mask"]
chunk_counts = audio_input.get("chunk_counts")
if isinstance(input_features, list):
input_features = torch.cat(input_features, dim=0)
feature_attention_mask = torch.cat(feature_attention_mask, dim=0)
if chunk_counts is None:
chunk_counts = [1] * input_features.shape[0]
elif isinstance(chunk_counts, torch.Tensor):
chunk_counts = chunk_counts.tolist()
elif (
isinstance(chunk_counts, list)
and chunk_counts
and isinstance(chunk_counts[0], torch.Tensor)
):
chunk_counts = [c.item() for c in chunk_counts]
# Calculate output lengths
input_lengths = feature_attention_mask.sum(-1)
# Conv downsampling
conv_lengths = (input_lengths - 1) // 2 + 1
# AvgPool downsampling
audio_output_lengths = (conv_lengths - 2) // 2 + 1
batch_size, _, max_mel_seq_len = input_features.shape
# Calculate max_seq_len after convs (before pooling) for attention mask
max_seq_len = (max_mel_seq_len - 1) // 2 + 1
# Create a sequence tensor of shape (batch_size, max_seq_len)
seq_range = (
torch.arange(
0,
max_seq_len,
dtype=conv_lengths.dtype,
device=conv_lengths.device,
)
.unsqueeze(0)
.expand(batch_size, max_seq_len)
)
lengths_expand = conv_lengths.unsqueeze(-1).expand(batch_size, max_seq_len)
# Create mask
padding_mask = seq_range >= lengths_expand
audio_attention_mask_ = padding_mask.view(batch_size, 1, 1, max_seq_len).expand(
batch_size, 1, max_seq_len, max_seq_len
)
audio_attention_mask = audio_attention_mask_.to(
dtype=self.audio_tower.conv1.weight.dtype,
device=self.audio_tower.conv1.weight.device,
)
audio_attention_mask[audio_attention_mask_] = float("-inf")
# Forward pass
audio_features = self.audio_tower(
input_features, attention_mask=audio_attention_mask
)
# Project
audio_features = self.multi_modal_projector(audio_features)
# Masking after pooling
num_audios, max_audio_tokens, embed_dim = audio_features.shape
audio_output_lengths = audio_output_lengths.unsqueeze(1)
audio_features_mask = (
torch.arange(max_audio_tokens)
.expand(num_audios, max_audio_tokens)
.to(audio_output_lengths.device)
< audio_output_lengths
)
masked_audio_features = audio_features[audio_features_mask].view(-1, embed_dim)
# Split to tuple of embeddings for individual audio input.
chunk_embeddings = torch.split(
masked_audio_features, audio_output_lengths.flatten().tolist()
)
grouped_embeddings = []
current_idx = 0
for count in chunk_counts:
audio_chunks = chunk_embeddings[current_idx : current_idx + count]
grouped_embeddings.append(torch.cat(audio_chunks, dim=0))
current_idx += count
return tuple(grouped_embeddings)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
audio_input = self._parse_and_validate_audio_input(**kwargs)
if audio_input is None:
return []
masked_audio_features = self._process_audio_input(audio_input)
return masked_audio_features
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
) -> torch.Tensor | IntermediateTensors:
if intermediate_tensors is not None:
inputs_embeds = None
hidden_states = self.language_model.model(
input_ids,
positions,
intermediate_tensors,
inputs_embeds=inputs_embeds,
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)

View File

@ -264,6 +264,10 @@ _CROSS_ENCODER_MODELS = {
_MULTIMODAL_MODELS = {
# [Decoder-only]
"AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
"AudioFlamingo3ForConditionalGeneration": (
"audioflamingo3",
"AudioFlamingo3ForConditionalGeneration",
),
"AyaVisionForConditionalGeneration": (
"aya_vision",
"AyaVisionForConditionalGeneration",