mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 15:35:01 +08:00
Add AudioFlamingo3 model support (#30539)
Signed-off-by: Lasha <26011196+lashahub@users.noreply.github.com> Signed-off-by: Lasha Koroshinadze <26011196+lashahub@users.noreply.github.com> Co-authored-by: Isotr0py <2037008807@qq.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
parent
1a55cfafcb
commit
3a20450d31
@ -659,6 +659,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
|
||||
| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) |
|
||||
|--------------|--------|--------|-------------------|----------------------|---------------------------|
|
||||
| `AriaForConditionalGeneration` | Aria | T + I<sup>+</sup> | `rhymes-ai/Aria` | | |
|
||||
| `AudioFlamingo3ForConditionalGeneration` | AudioFlamingo3 | T + A<sup>+</sup> | `nvidia/audio-flamingo-3-hf`, `nvidia/music-flamingo-hf` | ✅︎ | ✅︎ |
|
||||
| `AyaVisionForConditionalGeneration` | Aya Vision | T + I<sup>+</sup> | `CohereLabs/aya-vision-8b`, `CohereLabs/aya-vision-32b`, etc. | | ✅︎ |
|
||||
| `BeeForConditionalGeneration` | Bee-8B | T + I<sup>E+</sup> | `Open-Bee/Bee-8B-RL`, `Open-Bee/Bee-8B-SFT` | | ✅︎ |
|
||||
| `Blip2ForConditionalGeneration` | BLIP-2 | T + I<sup>E</sup> | `Salesforce/blip2-opt-2.7b`, `Salesforce/blip2-opt-6.7b`, etc. | | ✅︎ |
|
||||
|
||||
@ -42,60 +42,31 @@ class ModelRequestData(NamedTuple):
|
||||
# Unless specified, these settings have been tested to work on a single L4.
|
||||
|
||||
|
||||
# Voxtral
|
||||
# Make sure to install mistral-common[audio].
|
||||
def run_voxtral(question: str, audio_count: int) -> ModelRequestData:
|
||||
from mistral_common.audio import Audio
|
||||
from mistral_common.protocol.instruct.chunk import (
|
||||
AudioChunk,
|
||||
RawAudio,
|
||||
TextChunk,
|
||||
)
|
||||
from mistral_common.protocol.instruct.messages import (
|
||||
UserMessage,
|
||||
)
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
model_name = "mistralai/Voxtral-Mini-3B-2507"
|
||||
tokenizer = MistralTokenizer.from_hf_hub(model_name)
|
||||
|
||||
# AudioFlamingo3
|
||||
def run_audioflamingo3(question: str, audio_count: int) -> ModelRequestData:
|
||||
model_name = "nvidia/audio-flamingo-3-hf"
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=8192,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
limit_mm_per_prompt={"audio": audio_count},
|
||||
config_format="mistral",
|
||||
load_format="mistral",
|
||||
tokenizer_mode="mistral",
|
||||
enforce_eager=True,
|
||||
enable_chunked_prefill=False,
|
||||
)
|
||||
|
||||
text_chunk = TextChunk(text=question)
|
||||
audios = [
|
||||
Audio.from_file(str(audio_assets[i].get_local_path()), strict=False)
|
||||
for i in range(audio_count)
|
||||
]
|
||||
audio_chunks = [
|
||||
AudioChunk(input_audio=RawAudio.from_audio(audio)) for audio in audios
|
||||
]
|
||||
# AudioFlamingo3 uses <sound> token for audio
|
||||
audio_placeholder = "<sound>" * audio_count
|
||||
|
||||
messages = [UserMessage(content=[*audio_chunks, text_chunk])]
|
||||
|
||||
req = ChatCompletionRequest(messages=messages, model=model_name)
|
||||
|
||||
tokens = tokenizer.encode_chat_completion(req)
|
||||
prompt_ids, audios = tokens.tokens, tokens.audios
|
||||
|
||||
audios_and_sr = [(au.audio_array, au.sampling_rate) for au in audios]
|
||||
|
||||
multi_modal_data = {"audio": audios_and_sr}
|
||||
prompt = (
|
||||
"<|im_start|>system\n"
|
||||
"You are a helpful assistant.<|im_end|>\n"
|
||||
"<|im_start|>user\n"
|
||||
f"{audio_placeholder}{question}<|im_end|>\n"
|
||||
"<|im_start|>assistant\n"
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt_token_ids=prompt_ids,
|
||||
multi_modal_data=multi_modal_data,
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
|
||||
@ -361,6 +332,63 @@ def run_ultravox(question: str, audio_count: int) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
# Voxtral
|
||||
# Make sure to install mistral-common[audio].
|
||||
def run_voxtral(question: str, audio_count: int) -> ModelRequestData:
|
||||
from mistral_common.audio import Audio
|
||||
from mistral_common.protocol.instruct.chunk import (
|
||||
AudioChunk,
|
||||
RawAudio,
|
||||
TextChunk,
|
||||
)
|
||||
from mistral_common.protocol.instruct.messages import (
|
||||
UserMessage,
|
||||
)
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
model_name = "mistralai/Voxtral-Mini-3B-2507"
|
||||
tokenizer = MistralTokenizer.from_hf_hub(model_name)
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=8192,
|
||||
max_num_seqs=2,
|
||||
limit_mm_per_prompt={"audio": audio_count},
|
||||
config_format="mistral",
|
||||
load_format="mistral",
|
||||
tokenizer_mode="mistral",
|
||||
enforce_eager=True,
|
||||
enable_chunked_prefill=False,
|
||||
)
|
||||
|
||||
text_chunk = TextChunk(text=question)
|
||||
audios = [
|
||||
Audio.from_file(str(audio_assets[i].get_local_path()), strict=False)
|
||||
for i in range(audio_count)
|
||||
]
|
||||
audio_chunks = [
|
||||
AudioChunk(input_audio=RawAudio.from_audio(audio)) for audio in audios
|
||||
]
|
||||
|
||||
messages = [UserMessage(content=[*audio_chunks, text_chunk])]
|
||||
|
||||
req = ChatCompletionRequest(messages=messages, model=model_name)
|
||||
|
||||
tokens = tokenizer.encode_chat_completion(req)
|
||||
prompt_ids, audios = tokens.tokens, tokens.audios
|
||||
|
||||
audios_and_sr = [(au.audio_array, au.sampling_rate) for au in audios]
|
||||
|
||||
multi_modal_data = {"audio": audios_and_sr}
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt_token_ids=prompt_ids,
|
||||
multi_modal_data=multi_modal_data,
|
||||
)
|
||||
|
||||
|
||||
# Whisper
|
||||
def run_whisper(question: str, audio_count: int) -> ModelRequestData:
|
||||
assert audio_count == 1, "Whisper only support single audio input per prompt"
|
||||
@ -382,7 +410,7 @@ def run_whisper(question: str, audio_count: int) -> ModelRequestData:
|
||||
|
||||
|
||||
model_example_map = {
|
||||
"voxtral": run_voxtral,
|
||||
"audioflamingo3": run_audioflamingo3,
|
||||
"gemma3n": run_gemma3n,
|
||||
"granite_speech": run_granite_speech,
|
||||
"midashenglm": run_midashenglm,
|
||||
@ -392,6 +420,7 @@ model_example_map = {
|
||||
"qwen2_audio": run_qwen2_audio,
|
||||
"qwen2_5_omni": run_qwen2_5_omni,
|
||||
"ultravox": run_ultravox,
|
||||
"voxtral": run_voxtral,
|
||||
"whisper": run_whisper,
|
||||
}
|
||||
|
||||
|
||||
@ -0,0 +1 @@
|
||||
{"transcriptions": ["There is no clear relationship between the barking and the music, as they seem to be independent of each other.", "(B) To indicate that language cannot express clearly, satirizing the inversion of black and white in the world"], "token_ids": [[3862, 374, 902, 2797, 5025, 1948, 279, 293, 33452, 323, 279, 4627, 11, 438, 807, 2803, 311, 387, 9489, 315, 1817, 1008, 13, 151645], [5349, 8, 2014, 13216, 429, 4128, 4157, 3158, 9355, 11, 7578, 404, 4849, 279, 46488, 315, 3691, 323, 4158, 304, 279, 1879, 151645, 151671]]}
|
||||
@ -0,0 +1 @@
|
||||
{"transcriptions": ["The content of the input audio is 'you can ask why over and over and over again forever even if one day we explain every physical interaction and scientific law and hope and dream and regret with a single elegant equation'."], "token_ids": [[785, 2213, 315, 279, 1946, 7699, 374, 364, 9330, 646, 2548, 3170, 916, 323, 916, 323, 916, 1549, 15683, 1496, 421, 825, 1899, 582, 10339, 1449, 6961, 16230, 323, 12344, 2329, 323, 3900, 323, 7904, 323, 22231, 448, 264, 3175, 25777, 23606, 4427, 151645]]}
|
||||
142
tests/models/multimodal/generation/test_audioflamingo3.py
Normal file
142
tests/models/multimodal/generation/test_audioflamingo3.py
Normal file
@ -0,0 +1,142 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Copyright 2025 The vLLM team.
|
||||
# Copyright 2025 NVIDIA CORPORATION and the HuggingFace Inc. team. All rights
|
||||
# reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.models.registry import HF_EXAMPLE_MODELS
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
MODEL_NAME = "nvidia/audio-flamingo-3-hf"
|
||||
|
||||
|
||||
def get_fixture_path(filename):
|
||||
return os.path.join(
|
||||
os.path.dirname(__file__), "../../fixtures/audioflamingo3", filename
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def llm():
|
||||
# Check if the model is supported by the current transformers version
|
||||
model_info = HF_EXAMPLE_MODELS.get_hf_info("AudioFlamingo3ForConditionalGeneration")
|
||||
model_info.check_transformers_version(on_fail="skip")
|
||||
|
||||
try:
|
||||
llm = LLM(
|
||||
model=MODEL_NAME,
|
||||
trust_remote_code=True,
|
||||
dtype="bfloat16",
|
||||
enforce_eager=True,
|
||||
limit_mm_per_prompt={"audio": 1},
|
||||
)
|
||||
return llm
|
||||
except Exception as e:
|
||||
pytest.skip(f"Failed to load model {MODEL_NAME}: {e}")
|
||||
|
||||
|
||||
def test_single_generation(llm):
|
||||
fixture_path = get_fixture_path("expected_results_single.json")
|
||||
if not os.path.exists(fixture_path):
|
||||
pytest.skip(f"Fixture not found: {fixture_path}")
|
||||
|
||||
with open(fixture_path) as f:
|
||||
expected = json.load(f)
|
||||
|
||||
audio_url = "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/Why_do_we_ask_questions_converted.wav"
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "audio_url", "audio_url": {"url": audio_url}},
|
||||
{"type": "text", "text": "Transcribe the input speech."},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.0, max_tokens=128)
|
||||
|
||||
outputs = llm.chat(
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
generated_text = outputs[0].outputs[0].text.strip()
|
||||
|
||||
expected_text = expected["transcriptions"][0]
|
||||
|
||||
assert expected_text in generated_text or generated_text in expected_text
|
||||
|
||||
|
||||
def test_batched_generation(llm):
|
||||
fixture_path = get_fixture_path("expected_results_batched.json")
|
||||
if not os.path.exists(fixture_path):
|
||||
pytest.skip(f"Fixture not found: {fixture_path}")
|
||||
|
||||
with open(fixture_path) as f:
|
||||
expected = json.load(f)
|
||||
|
||||
items = [
|
||||
{
|
||||
"audio_url": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/dogs_barking_in_sync_with_the_music.wav",
|
||||
"question": "What is surprising about the relationship "
|
||||
"between the barking and the music?",
|
||||
"expected_idx": 0,
|
||||
},
|
||||
{
|
||||
"audio_url": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/Ch6Ae9DT6Ko_00-04-03_00-04-31.wav",
|
||||
"question": (
|
||||
"Why is the philosopher's name mentioned in the lyrics? "
|
||||
"(A) To express a sense of nostalgia "
|
||||
"(B) To indicate that language cannot express clearly, "
|
||||
"satirizing the inversion of black and white in the world "
|
||||
"(C) To add depth and complexity to the lyrics "
|
||||
"(D) To showcase the wisdom and influence of the philosopher"
|
||||
),
|
||||
"expected_idx": 1,
|
||||
},
|
||||
]
|
||||
|
||||
conversations = []
|
||||
for item in items:
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "audio_url", "audio_url": {"url": item["audio_url"]}},
|
||||
{"type": "text", "text": item["question"]},
|
||||
],
|
||||
}
|
||||
]
|
||||
conversations.append(messages)
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.0, max_tokens=128)
|
||||
|
||||
outputs = llm.chat(
|
||||
messages=conversations,
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
|
||||
for i, output in enumerate(outputs):
|
||||
generated_text = output.outputs[0].text.strip()
|
||||
expected_text = expected["transcriptions"][i]
|
||||
|
||||
assert expected_text in generated_text or generated_text in expected_text
|
||||
125
tests/models/multimodal/processing/test_audioflamingo3.py
Normal file
125
tests/models/multimodal/processing/test_audioflamingo3.py
Normal file
@ -0,0 +1,125 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Copyright 2025 The vLLM team.
|
||||
# Copyright 2025 NVIDIA CORPORATION and the HuggingFace Inc. team. All rights
|
||||
# reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from tests.models.registry import HF_EXAMPLE_MODELS
|
||||
|
||||
|
||||
class MockAudioFlamingo3Config(PretrainedConfig):
|
||||
model_type = "audioflamingo3"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.audio_config = PretrainedConfig()
|
||||
self.text_config = PretrainedConfig()
|
||||
|
||||
|
||||
class MockAudioFlamingo3Processor:
|
||||
def __init__(self):
|
||||
self.audio_token = "<sound>"
|
||||
self.audio_token_id = 12345
|
||||
self.feature_extractor = MockFeatureExtractor()
|
||||
|
||||
def __call__(self, text=None, audios=None, **kwargs):
|
||||
return {"input_ids": [1, 2, 3], "input_features": [np.zeros((3000, 80))]}
|
||||
|
||||
|
||||
class MockFeatureExtractor:
|
||||
def __init__(self):
|
||||
self.sampling_rate = 16000
|
||||
self.chunk_length = 30
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ctx():
|
||||
config = MockAudioFlamingo3Config()
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.get_hf_config.return_value = config
|
||||
ctx.get_hf_processor.return_value = MockAudioFlamingo3Processor()
|
||||
ctx.model_config.hf_config = config
|
||||
return ctx
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def check_transformers_version():
|
||||
# Check if the model is supported by the current transformers version
|
||||
model_info = HF_EXAMPLE_MODELS.get_hf_info("AudioFlamingo3ForConditionalGeneration")
|
||||
model_info.check_transformers_version(on_fail="skip")
|
||||
|
||||
|
||||
def test_audio_chunk_counting(mock_ctx):
|
||||
from vllm.model_executor.models.audioflamingo3 import (
|
||||
AudioFlamingo3DummyInputsBuilder,
|
||||
AudioFlamingo3MultiModalProcessor,
|
||||
AudioFlamingo3ProcessingInfo,
|
||||
)
|
||||
|
||||
info = AudioFlamingo3ProcessingInfo(mock_ctx)
|
||||
processor = AudioFlamingo3MultiModalProcessor(
|
||||
info, AudioFlamingo3DummyInputsBuilder(info)
|
||||
)
|
||||
|
||||
sr = 16000
|
||||
audio_1 = np.zeros(30 * sr)
|
||||
audio_2 = np.zeros(45 * sr)
|
||||
|
||||
mm_data = {"audio": [audio_1, audio_2]}
|
||||
prompt = "<|user|>Listen.<|end|>"
|
||||
|
||||
from vllm.multimodal.processing import BaseMultiModalProcessor
|
||||
|
||||
def mock_base_call(self, prompt, mm_data, mm_kwargs, tok_kwargs):
|
||||
return {"input_ids": [1, 2, 3], "input_features": torch.randn(1, 80, 3000)}
|
||||
|
||||
with pytest.MonkeyPatch.context() as mp:
|
||||
mp.setattr(BaseMultiModalProcessor, "_call_hf_processor", mock_base_call)
|
||||
|
||||
processed = processor._call_hf_processor(prompt, mm_data, {}, {})
|
||||
|
||||
chunk_counts = processed["chunk_counts"]
|
||||
|
||||
assert chunk_counts[0].item() == 1
|
||||
assert chunk_counts[1].item() == 2
|
||||
assert len(chunk_counts) == 2
|
||||
|
||||
|
||||
def test_dummy_data_generation(mock_ctx):
|
||||
from vllm.model_executor.models.audioflamingo3 import (
|
||||
AudioFlamingo3DummyInputsBuilder,
|
||||
AudioFlamingo3ProcessingInfo,
|
||||
)
|
||||
|
||||
info = AudioFlamingo3ProcessingInfo(mock_ctx)
|
||||
builder = AudioFlamingo3DummyInputsBuilder(info)
|
||||
|
||||
mm_counts = {"audio": 2}
|
||||
dummy_data = builder.get_dummy_mm_data(100, mm_counts, None)
|
||||
|
||||
assert "audio" in dummy_data
|
||||
assert len(dummy_data["audio"]) == 2
|
||||
|
||||
expected_len = 600 * 16000
|
||||
assert len(dummy_data["audio"][0]) == expected_len
|
||||
@ -578,6 +578,9 @@ _AUTOMATIC_CONVERTED_MODELS = {
|
||||
_MULTIMODAL_EXAMPLE_MODELS = {
|
||||
# [Decoder-only]
|
||||
"AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria"),
|
||||
"AudioFlamingo3ForConditionalGeneration": _HfExamplesInfo(
|
||||
"nvidia/audio-flamingo-3-hf", min_transformers_version="5.0.0.dev"
|
||||
),
|
||||
"AyaVisionForConditionalGeneration": _HfExamplesInfo("CohereLabs/aya-vision-8b"),
|
||||
"BeeForConditionalGeneration": _HfExamplesInfo(
|
||||
"Open-Bee/Bee-8B-RL",
|
||||
|
||||
639
vllm/model_executor/models/audioflamingo3.py
Normal file
639
vllm/model_executor/models/audioflamingo3.py
Normal file
@ -0,0 +1,639 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Copyright 2025 The vLLM team.
|
||||
# Copyright 2025 NVIDIA CORPORATION and the HuggingFace Inc. team. All rights
|
||||
# reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Annotated, Any, Literal, TypeAlias
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import BatchFeature, PretrainedConfig
|
||||
from transformers.models.audioflamingo3 import (
|
||||
AudioFlamingo3Config,
|
||||
AudioFlamingo3Processor,
|
||||
)
|
||||
from transformers.models.qwen2_audio import Qwen2AudioEncoder
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalDataDict,
|
||||
MultiModalFieldConfig,
|
||||
MultiModalKwargsItems,
|
||||
)
|
||||
from vllm.multimodal.parse import (
|
||||
DictEmbeddingItems,
|
||||
ModalityData,
|
||||
ModalityDataItems,
|
||||
MultiModalDataItems,
|
||||
MultiModalDataParser,
|
||||
)
|
||||
from vllm.multimodal.processing import (
|
||||
BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
PromptReplacement,
|
||||
PromptUpdate,
|
||||
PromptUpdateDetails,
|
||||
)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import (
|
||||
MultiModalEmbeddings,
|
||||
SupportsLoRA,
|
||||
SupportsMultiModal,
|
||||
SupportsPP,
|
||||
)
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
init_vllm_registered_model,
|
||||
maybe_prefix,
|
||||
)
|
||||
|
||||
MAX_AUDIO_LEN = 10 * 60
|
||||
|
||||
|
||||
# === Audio Inputs === #
|
||||
class AudioFlamingo3FeatureInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- num_chunks: Number of audio chunks (flattened)
|
||||
- nmb: Number of mel bins
|
||||
- num_audios: Number of original audio files
|
||||
"""
|
||||
|
||||
type: Literal["audio_features"]
|
||||
input_features: Annotated[
|
||||
torch.Tensor | list[torch.Tensor],
|
||||
TensorShape("num_chunks", "nmb", 3000),
|
||||
]
|
||||
|
||||
feature_attention_mask: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("num_chunks", 3000),
|
||||
]
|
||||
|
||||
chunk_counts: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("num_audios"),
|
||||
]
|
||||
|
||||
|
||||
class AudioFlamingo3EmbeddingInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- bn: Batch size
|
||||
- naf: Number of audio features
|
||||
- hs: Hidden size (must match the hidden size of language model
|
||||
backbone)
|
||||
"""
|
||||
|
||||
type: Literal["audio_embeds"] = "audio_embeds"
|
||||
|
||||
audio_embeds: Annotated[
|
||||
list[torch.Tensor],
|
||||
TensorShape("bn", "naf", "hs"),
|
||||
]
|
||||
|
||||
|
||||
AudioFlamingo3Inputs: TypeAlias = (
|
||||
AudioFlamingo3FeatureInputs | AudioFlamingo3EmbeddingInputs
|
||||
)
|
||||
|
||||
|
||||
class AudioFlamingo3Encoder(Qwen2AudioEncoder):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
):
|
||||
super().__init__(config)
|
||||
self.avg_pooler = nn.AvgPool1d(kernel_size=2, stride=2)
|
||||
# self.layer_norm is already initialized in super().__init__
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_features: torch.Tensor | list[torch.Tensor],
|
||||
attention_mask: torch.Tensor = None,
|
||||
):
|
||||
# input_features: (batch, num_mel_bins, seq_len)
|
||||
if isinstance(input_features, list):
|
||||
input_features = torch.stack(input_features)
|
||||
|
||||
hidden_states = nn.functional.gelu(self.conv1(input_features))
|
||||
hidden_states = nn.functional.gelu(self.conv2(hidden_states))
|
||||
hidden_states = hidden_states.transpose(-1, -2)
|
||||
hidden_states = (
|
||||
hidden_states + self.embed_positions.weight[: hidden_states.size(-2), :]
|
||||
).to(hidden_states.dtype)
|
||||
|
||||
for layer in self.layers:
|
||||
layer_outputs = layer(hidden_states, attention_mask)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
# AvgPool (time/2) + LayerNorm
|
||||
# hidden_states: (batch, seq_len, hidden_size)
|
||||
hidden_states = hidden_states.permute(0, 2, 1) # (batch, hidden_size, seq_len)
|
||||
hidden_states = self.avg_pooler(hidden_states)
|
||||
hidden_states = hidden_states.permute(
|
||||
0, 2, 1
|
||||
) # (batch, seq_len/2, hidden_size)
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def _get_feat_extract_output_lengths(self, input_lengths: torch.Tensor):
|
||||
"""
|
||||
Computes the output length of the convolutional layers and the output length
|
||||
of the audio encoder
|
||||
"""
|
||||
input_lengths = (input_lengths - 1) // 2 + 1
|
||||
output_lengths = (input_lengths - 2) // 2 + 1
|
||||
return input_lengths, output_lengths
|
||||
|
||||
|
||||
class AudioFlamingo3MultiModalProjector(nn.Module):
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
super().__init__()
|
||||
self.linear_1 = nn.Linear(
|
||||
config.audio_config.hidden_size,
|
||||
config.text_config.hidden_size,
|
||||
bias=config.projector_bias,
|
||||
)
|
||||
self.act = get_act_fn(config.projector_hidden_act)
|
||||
self.linear_2 = nn.Linear(
|
||||
config.text_config.hidden_size,
|
||||
config.text_config.hidden_size,
|
||||
bias=config.projector_bias,
|
||||
)
|
||||
|
||||
def forward(self, audio_features):
|
||||
hidden_states = self.linear_1(audio_features)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AudioFlamingo3ProcessingInfo(BaseProcessingInfo):
|
||||
def get_hf_config(self):
|
||||
return self.ctx.get_hf_config(AudioFlamingo3Config)
|
||||
|
||||
def get_hf_processor(self, **kwargs: object):
|
||||
return self.ctx.get_hf_processor(AudioFlamingo3Processor, **kwargs)
|
||||
|
||||
def get_feature_extractor(self, **kwargs: object):
|
||||
hf_processor = self.get_hf_processor(**kwargs)
|
||||
feature_extractor = hf_processor.feature_extractor
|
||||
return feature_extractor
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
||||
return {"audio": None}
|
||||
|
||||
|
||||
class AudioFlamingo3DummyInputsBuilder(
|
||||
BaseDummyInputsBuilder[AudioFlamingo3ProcessingInfo]
|
||||
):
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
hf_processor = self.info.get_hf_processor()
|
||||
audio_token = hf_processor.audio_token
|
||||
return audio_token * num_audios
|
||||
|
||||
def get_dummy_mm_data(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Mapping[str, BaseDummyOptions] | None = None,
|
||||
) -> MultiModalDataDict:
|
||||
feature_extractor = self.info.get_feature_extractor()
|
||||
sampling_rate = feature_extractor.sampling_rate
|
||||
audio_len = MAX_AUDIO_LEN * sampling_rate
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
audio_overrides = mm_options.get("audio") if mm_options else None
|
||||
|
||||
return {
|
||||
"audio": self._get_dummy_audios(
|
||||
length=audio_len,
|
||||
num_audios=num_audios,
|
||||
overrides=audio_overrides,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
def _audioflamingo3_field_config(hf_inputs: Mapping[str, torch.Tensor]):
|
||||
chunk_counts = hf_inputs.get("chunk_counts")
|
||||
if chunk_counts is not None:
|
||||
return dict(
|
||||
audio_embeds=MultiModalFieldConfig.batched("audio"),
|
||||
input_features=MultiModalFieldConfig.flat_from_sizes(
|
||||
"audio", chunk_counts, dim=0
|
||||
),
|
||||
feature_attention_mask=MultiModalFieldConfig.flat_from_sizes(
|
||||
"audio", chunk_counts, dim=0
|
||||
),
|
||||
chunk_counts=MultiModalFieldConfig.batched("audio"),
|
||||
)
|
||||
return dict(
|
||||
audio_embeds=MultiModalFieldConfig.batched("audio"),
|
||||
input_features=MultiModalFieldConfig.batched("audio"),
|
||||
feature_attention_mask=MultiModalFieldConfig.batched("audio"),
|
||||
chunk_counts=MultiModalFieldConfig.batched("audio"),
|
||||
)
|
||||
|
||||
|
||||
class AudioFlamingo3MultiModalDataParser(MultiModalDataParser):
|
||||
def _parse_audio_data(
|
||||
self,
|
||||
data: dict[str, torch.Tensor] | ModalityData[Any],
|
||||
) -> ModalityDataItems[Any, Any] | None:
|
||||
if isinstance(data, dict):
|
||||
return DictEmbeddingItems(
|
||||
data,
|
||||
modality="audio",
|
||||
required_fields={"audio_embeds"},
|
||||
fields_factory=_audioflamingo3_field_config,
|
||||
)
|
||||
return super()._parse_audio_data(data)
|
||||
|
||||
|
||||
class AudioFlamingo3MultiModalProcessor(
|
||||
BaseMultiModalProcessor[AudioFlamingo3ProcessingInfo]
|
||||
):
|
||||
def _get_data_parser(self) -> MultiModalDataParser:
|
||||
feature_extractor = self.info.get_feature_extractor()
|
||||
return AudioFlamingo3MultiModalDataParser(
|
||||
target_sr=feature_extractor.sampling_rate
|
||||
)
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_data: dict[str, object],
|
||||
mm_kwargs: Mapping[str, Any],
|
||||
tok_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
audios = mm_data.pop("audios", [])
|
||||
if audios:
|
||||
mm_data["audio"] = audios
|
||||
|
||||
if not mm_data.get("audio", []):
|
||||
prompt_ids = self.info.get_tokenizer().encode(prompt)
|
||||
prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
|
||||
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
|
||||
|
||||
feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
|
||||
mm_kwargs = dict(
|
||||
**mm_kwargs,
|
||||
sampling_rate=feature_extractor.sampling_rate,
|
||||
)
|
||||
|
||||
# Calculate chunk counts
|
||||
audio_list = mm_data.get("audio")
|
||||
if not isinstance(audio_list, list):
|
||||
audio_list = [audio_list]
|
||||
|
||||
chunk_counts = []
|
||||
sampling_rate = feature_extractor.sampling_rate
|
||||
chunk_length = feature_extractor.chunk_length
|
||||
window_size = int(sampling_rate * chunk_length)
|
||||
# MAX_AUDIO_LEN is 10 * 60 in HF processor.
|
||||
max_windows = int(MAX_AUDIO_LEN // chunk_length)
|
||||
|
||||
for audio in audio_list:
|
||||
# audio is numpy array or list
|
||||
n_samples = len(audio) if isinstance(audio, list) else audio.shape[0]
|
||||
|
||||
n_win = max(1, (n_samples + window_size - 1) // window_size)
|
||||
if n_win > max_windows:
|
||||
n_win = max_windows
|
||||
chunk_counts.append(n_win)
|
||||
|
||||
outputs = super()._call_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data=mm_data,
|
||||
mm_kwargs=mm_kwargs,
|
||||
tok_kwargs=tok_kwargs,
|
||||
)
|
||||
|
||||
if "input_features_mask" in outputs:
|
||||
outputs["feature_attention_mask"] = outputs.pop("input_features_mask")
|
||||
|
||||
outputs["chunk_counts"] = torch.tensor(chunk_counts, dtype=torch.long)
|
||||
|
||||
return outputs
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return _audioflamingo3_field_config(hf_inputs)
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargsItems,
|
||||
) -> Sequence[PromptUpdate]:
|
||||
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
vocab = tokenizer.get_vocab()
|
||||
|
||||
audio_token = getattr(processor, "audio_token", "<sound>")
|
||||
audio_token_id = vocab.get(audio_token)
|
||||
if audio_token_id is None:
|
||||
# Fallback if not found, though it should be there
|
||||
audio_token_id = processor.audio_token_id
|
||||
|
||||
out_mm_data = out_mm_kwargs.get_data()
|
||||
feature_attention_mask = out_mm_data.get("feature_attention_mask")
|
||||
chunk_counts = out_mm_data.get("chunk_counts")
|
||||
|
||||
def get_replacement_audioflamingo3(item_idx: int):
|
||||
if feature_attention_mask is not None:
|
||||
if chunk_counts is not None:
|
||||
counts = (
|
||||
chunk_counts.tolist()
|
||||
if isinstance(chunk_counts, torch.Tensor)
|
||||
else chunk_counts
|
||||
)
|
||||
start_idx = sum(counts[:item_idx])
|
||||
count = counts[item_idx]
|
||||
end_idx = start_idx + count
|
||||
|
||||
if isinstance(feature_attention_mask, list):
|
||||
mask_list = feature_attention_mask[start_idx:end_idx]
|
||||
if len(mask_list) > 0 and isinstance(
|
||||
mask_list[0], torch.Tensor
|
||||
):
|
||||
mask = torch.stack(mask_list)
|
||||
else:
|
||||
mask = torch.tensor(mask_list)
|
||||
else:
|
||||
mask = feature_attention_mask[start_idx:end_idx]
|
||||
else:
|
||||
# feature_attention_mask is list[Tensor] or Tensor
|
||||
if isinstance(feature_attention_mask, list):
|
||||
mask = feature_attention_mask[item_idx]
|
||||
else:
|
||||
mask = feature_attention_mask[item_idx].unsqueeze(0)
|
||||
|
||||
# mask shape: (num_chunks, 3000)
|
||||
input_lengths = mask.sum(-1)
|
||||
conv_lengths = (input_lengths - 1) // 2 + 1
|
||||
audio_output_lengths = (conv_lengths - 2) // 2 + 1
|
||||
num_features = audio_output_lengths.sum().item()
|
||||
else:
|
||||
audio_embeds = out_mm_data["audio_embeds"][item_idx]
|
||||
num_features = audio_embeds.shape[0]
|
||||
|
||||
if num_features == 0:
|
||||
raise ValueError("Audio is too short")
|
||||
|
||||
audio_tokens = [audio_token_id] * int(num_features)
|
||||
return PromptUpdateDetails.select_token_id(
|
||||
audio_tokens,
|
||||
embed_token_id=audio_token_id,
|
||||
)
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="audio",
|
||||
target=audio_token,
|
||||
replacement=get_replacement_audioflamingo3,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
AudioFlamingo3MultiModalProcessor,
|
||||
info=AudioFlamingo3ProcessingInfo,
|
||||
dummy_inputs=AudioFlamingo3DummyInputsBuilder,
|
||||
)
|
||||
class AudioFlamingo3ForConditionalGeneration(
|
||||
nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA
|
||||
):
|
||||
"""
|
||||
AudioFlamingo3 model for conditional generation.
|
||||
|
||||
This model integrates a Whisper-based audio encoder with a Qwen2 language model.
|
||||
It supports multi-chunk audio processing.
|
||||
"""
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
}
|
||||
|
||||
def get_mm_mapping(self) -> MultiModelKeys:
|
||||
"""
|
||||
Get the module prefix in multimodal models
|
||||
"""
|
||||
return MultiModelKeys.from_string_field(
|
||||
language_model="language_model.",
|
||||
connector="multi_modal_projector.",
|
||||
tower_model="audio_tower.",
|
||||
)
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
self.audio_tower = AudioFlamingo3Encoder(
|
||||
config.audio_config,
|
||||
)
|
||||
self.multi_modal_projector = AudioFlamingo3MultiModalProjector(config)
|
||||
|
||||
self.quant_config = quant_config
|
||||
|
||||
self.language_model = init_vllm_registered_model(
|
||||
vllm_config=vllm_config,
|
||||
hf_config=config.text_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
architectures=["Qwen2ForCausalLM"],
|
||||
)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def _parse_and_validate_audio_input(
|
||||
self, **kwargs: object
|
||||
) -> AudioFlamingo3Inputs | None:
|
||||
input_features = kwargs.pop("input_features", None)
|
||||
audio_embeds = kwargs.pop("audio_embeds", None)
|
||||
feature_attention_mask = kwargs.pop("feature_attention_mask", None)
|
||||
chunk_counts = kwargs.pop("chunk_counts", None)
|
||||
|
||||
if input_features is None and audio_embeds is None:
|
||||
return None
|
||||
|
||||
if audio_embeds is not None:
|
||||
return AudioFlamingo3EmbeddingInputs(
|
||||
type="audio_embeds", audio_embeds=audio_embeds
|
||||
)
|
||||
|
||||
if input_features is not None:
|
||||
return AudioFlamingo3FeatureInputs(
|
||||
type="audio_features",
|
||||
input_features=input_features,
|
||||
feature_attention_mask=feature_attention_mask,
|
||||
chunk_counts=chunk_counts,
|
||||
)
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
def _process_audio_input(
|
||||
self, audio_input: AudioFlamingo3Inputs
|
||||
) -> torch.Tensor | tuple[torch.Tensor, ...]:
|
||||
if audio_input["type"] == "audio_embeds":
|
||||
audio_embeds = audio_input["audio_embeds"]
|
||||
return tuple(audio_embeds)
|
||||
|
||||
input_features = audio_input["input_features"]
|
||||
feature_attention_mask = audio_input["feature_attention_mask"]
|
||||
chunk_counts = audio_input.get("chunk_counts")
|
||||
|
||||
if isinstance(input_features, list):
|
||||
input_features = torch.cat(input_features, dim=0)
|
||||
feature_attention_mask = torch.cat(feature_attention_mask, dim=0)
|
||||
|
||||
if chunk_counts is None:
|
||||
chunk_counts = [1] * input_features.shape[0]
|
||||
elif isinstance(chunk_counts, torch.Tensor):
|
||||
chunk_counts = chunk_counts.tolist()
|
||||
elif (
|
||||
isinstance(chunk_counts, list)
|
||||
and chunk_counts
|
||||
and isinstance(chunk_counts[0], torch.Tensor)
|
||||
):
|
||||
chunk_counts = [c.item() for c in chunk_counts]
|
||||
|
||||
# Calculate output lengths
|
||||
input_lengths = feature_attention_mask.sum(-1)
|
||||
# Conv downsampling
|
||||
conv_lengths = (input_lengths - 1) // 2 + 1
|
||||
# AvgPool downsampling
|
||||
audio_output_lengths = (conv_lengths - 2) // 2 + 1
|
||||
|
||||
batch_size, _, max_mel_seq_len = input_features.shape
|
||||
|
||||
# Calculate max_seq_len after convs (before pooling) for attention mask
|
||||
max_seq_len = (max_mel_seq_len - 1) // 2 + 1
|
||||
|
||||
# Create a sequence tensor of shape (batch_size, max_seq_len)
|
||||
seq_range = (
|
||||
torch.arange(
|
||||
0,
|
||||
max_seq_len,
|
||||
dtype=conv_lengths.dtype,
|
||||
device=conv_lengths.device,
|
||||
)
|
||||
.unsqueeze(0)
|
||||
.expand(batch_size, max_seq_len)
|
||||
)
|
||||
lengths_expand = conv_lengths.unsqueeze(-1).expand(batch_size, max_seq_len)
|
||||
# Create mask
|
||||
padding_mask = seq_range >= lengths_expand
|
||||
|
||||
audio_attention_mask_ = padding_mask.view(batch_size, 1, 1, max_seq_len).expand(
|
||||
batch_size, 1, max_seq_len, max_seq_len
|
||||
)
|
||||
audio_attention_mask = audio_attention_mask_.to(
|
||||
dtype=self.audio_tower.conv1.weight.dtype,
|
||||
device=self.audio_tower.conv1.weight.device,
|
||||
)
|
||||
audio_attention_mask[audio_attention_mask_] = float("-inf")
|
||||
|
||||
# Forward pass
|
||||
audio_features = self.audio_tower(
|
||||
input_features, attention_mask=audio_attention_mask
|
||||
)
|
||||
|
||||
# Project
|
||||
audio_features = self.multi_modal_projector(audio_features)
|
||||
|
||||
# Masking after pooling
|
||||
num_audios, max_audio_tokens, embed_dim = audio_features.shape
|
||||
audio_output_lengths = audio_output_lengths.unsqueeze(1)
|
||||
audio_features_mask = (
|
||||
torch.arange(max_audio_tokens)
|
||||
.expand(num_audios, max_audio_tokens)
|
||||
.to(audio_output_lengths.device)
|
||||
< audio_output_lengths
|
||||
)
|
||||
masked_audio_features = audio_features[audio_features_mask].view(-1, embed_dim)
|
||||
|
||||
# Split to tuple of embeddings for individual audio input.
|
||||
chunk_embeddings = torch.split(
|
||||
masked_audio_features, audio_output_lengths.flatten().tolist()
|
||||
)
|
||||
|
||||
grouped_embeddings = []
|
||||
current_idx = 0
|
||||
for count in chunk_counts:
|
||||
audio_chunks = chunk_embeddings[current_idx : current_idx + count]
|
||||
grouped_embeddings.append(torch.cat(audio_chunks, dim=0))
|
||||
current_idx += count
|
||||
return tuple(grouped_embeddings)
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
||||
if audio_input is None:
|
||||
return []
|
||||
masked_audio_features = self._process_audio_input(audio_input)
|
||||
return masked_audio_features
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: IntermediateTensors | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
**kwargs: object,
|
||||
) -> torch.Tensor | IntermediateTensors:
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
hidden_states = self.language_model.model(
|
||||
input_ids,
|
||||
positions,
|
||||
intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor | None:
|
||||
return self.language_model.compute_logits(hidden_states)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
@ -264,6 +264,10 @@ _CROSS_ENCODER_MODELS = {
|
||||
_MULTIMODAL_MODELS = {
|
||||
# [Decoder-only]
|
||||
"AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
|
||||
"AudioFlamingo3ForConditionalGeneration": (
|
||||
"audioflamingo3",
|
||||
"AudioFlamingo3ForConditionalGeneration",
|
||||
),
|
||||
"AyaVisionForConditionalGeneration": (
|
||||
"aya_vision",
|
||||
"AyaVisionForConditionalGeneration",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user