diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md
index b13ec92957a1..98b7d76313de 100644
--- a/docs/source/models/supported_models.md
+++ b/docs/source/models/supported_models.md
@@ -895,6 +895,13 @@ See [this page](#generative-models) for more information on how to use generativ
* ✅︎
* ✅︎
* ✅︎
+- * `GraniteSpeechForConditionalGeneration`
+ * Granite Speech
+ * T + A
+ * `ibm-granite/granite-speech-3.3-8b`
+ * ✅︎
+ * ✅︎
+ * ✅︎
- * `H2OVLChatModel`
* H2OVL
* T + IE+
diff --git a/examples/offline_inference/audio_language.py b/examples/offline_inference/audio_language.py
index e3c75d5cb6a9..bab41c915c32 100644
--- a/examples/offline_inference/audio_language.py
+++ b/examples/offline_inference/audio_language.py
@@ -38,6 +38,37 @@ class ModelRequestData(NamedTuple):
# Unless specified, these settings have been tested to work on a single L4.
+# Granite Speech
+def run_granite_speech(question: str, audio_count: int) -> ModelRequestData:
+ # NOTE - the setting in this example are somehat different than what is
+ # optimal for granite speech, and it is generally recommended to use beam
+ # search. Check the model README for suggested settings.
+ # https://huggingface.co/ibm-granite/granite-speech-3.3-8b
+ model_name = "ibm-granite/granite-speech-3.3-8b"
+
+ engine_args = EngineArgs(
+ model=model_name,
+ trust_remote_code=True,
+ max_model_len=2048,
+ max_num_seqs=2,
+ enable_lora=True,
+ max_lora_rank=64,
+ limit_mm_per_prompt={"audio": audio_count},
+ )
+
+ # The model has an audio-specific lora directly in its model dir;
+ # it should be enabled whenever you pass audio inputs to the model.
+ speech_lora_path = model_name
+ audio_placeholder = "<|audio|>" * audio_count
+ prompts = f"<|start_of_role|>system<|end_of_role|>Knowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>{audio_placeholder}{question}<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>" # noqa: E501
+
+ return ModelRequestData(
+ engine_args=engine_args,
+ prompt=prompts,
+ lora_requests=[LoRARequest("speech", 1, speech_lora_path)],
+ )
+
+
# MiniCPM-O
def run_minicpmo(question: str, audio_count: int) -> ModelRequestData:
model_name = "openbmb/MiniCPM-o-2_6"
@@ -209,6 +240,7 @@ def run_whisper(question: str, audio_count: int) -> ModelRequestData:
model_example_map = {
+ "granite_speech": run_granite_speech,
"minicpmo": run_minicpmo,
"phi4_mm": run_phi4mm,
"qwen2_audio": run_qwen2_audio,
diff --git a/tests/conftest.py b/tests/conftest.py
index 1d8919470f80..e62b56cb5825 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -21,6 +21,7 @@ from transformers.models.auto.auto_factory import _BaseAutoModelClass
from tests.models.utils import (TokensTextLogprobs,
TokensTextLogprobsPromptLogprobs)
from vllm import LLM, SamplingParams
+from vllm.assets.audio import AudioAsset
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
from vllm.config import TaskOption, _get_and_verify_dtype
@@ -103,10 +104,25 @@ class _VideoAssets(_VideoAssetsBase):
return [prompts["sample_demo_1"]]
+class _AudioAssetsBase(UserList[AudioAsset]):
+ pass
+
+
+class _AudioAssets(_AudioAssetsBase):
+
+ def __init__(self) -> None:
+ super().__init__([
+ AudioAsset("mary_had_lamb"),
+ AudioAsset("winning_call"),
+ ])
+
+
IMAGE_ASSETS = _ImageAssets()
"""Singleton instance of :class:`_ImageAssets`."""
VIDEO_ASSETS = _VideoAssets()
"""Singleton instance of :class:`_VideoAssets`."""
+AUDIO_ASSETS = _AudioAssets()
+"""Singleton instance of :class:`_AudioAssets`."""
@pytest.fixture(scope="function", autouse=True)
@@ -263,6 +279,11 @@ def video_assets() -> _VideoAssets:
return VIDEO_ASSETS
+@pytest.fixture(scope="session")
+def audio_assets() -> _AudioAssets:
+ return AUDIO_ASSETS
+
+
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature, dict)
_R = TypeVar("_R")
@@ -390,10 +411,15 @@ class HfRunner:
processor_kwargs["images"] = image
if videos is not None and (video := videos[i]) is not None:
processor_kwargs["videos"] = video
- if audios is not None and (audio_tuple := audios[i]) is not None:
- audio, sr = audio_tuple
- processor_kwargs["audio"] = audio
- processor_kwargs["sampling_rate"] = sr
+ if audios is not None and (audio_inputs := audios[i]) is not None:
+ # HACK - not all processors take sampling_rate; we should
+ # clean this up in the future.
+ if len(audio_inputs) == 2:
+ audio, sr = audio_inputs
+ processor_kwargs["audio"] = audio
+ processor_kwargs["sampling_rate"] = sr
+ else:
+ processor_kwargs["audio"] = audio_inputs
inputs = self.processor(**processor_kwargs)
if isinstance(inputs, BatchFeature):
diff --git a/tests/models/decoder_only/audio_language/test_granite_speech.py b/tests/models/decoder_only/audio_language/test_granite_speech.py
new file mode 100644
index 000000000000..7c14845ec54d
--- /dev/null
+++ b/tests/models/decoder_only/audio_language/test_granite_speech.py
@@ -0,0 +1,143 @@
+# SPDX-License-Identifier: Apache-2.0
+
+from collections.abc import Sequence
+from typing import Optional
+
+import pytest
+from transformers import AutoModelForSpeechSeq2Seq
+
+from vllm.lora.request import LoRARequest
+from vllm.sequence import SampleLogprobs
+
+from ....conftest import HfRunner, PromptAudioInput, VllmRunner, _AudioAssets
+from ...registry import HF_EXAMPLE_MODELS
+from ...utils import check_logprobs_close
+
+HF_AUDIO_PROMPT = "<|start_of_role|>system<|end_of_role|>Knowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|><|audio|>can you transcribe the speech into a written format?<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>" # noqa: E501
+
+
+def vllm_to_hf_output(
+ vllm_output: tuple[list[int], str, Optional[SampleLogprobs]],
+) -> tuple[list[int], str, Optional[SampleLogprobs]]:
+ """Sanitize hf output to be comparable with vllm output."""
+ output_ids, output_str, out_logprobs = vllm_output
+
+ hf_output_str = output_str + "<|end_of_text|>"
+
+ return output_ids, hf_output_str, out_logprobs
+
+
+MODEL_NAME = "ibm-granite/granite-speech-3.3-8b"
+# Audio lora co-exists directly in the model directory, but
+# currently still needs to be passed directly to vLLM.
+audio_lora_path = MODEL_NAME
+models = [MODEL_NAME]
+
+
+def run_test(
+ hf_runner: type[HfRunner],
+ vllm_runner: type[VllmRunner],
+ inputs: Sequence[tuple[list[str], PromptAudioInput]],
+ model: str,
+ *,
+ max_model_len: int,
+ dtype: str,
+ max_tokens: int,
+ num_logprobs: int,
+ tensor_parallel_size: int,
+ distributed_executor_backend: Optional[str] = None,
+):
+ """Inference result should be the same between hf and vllm.
+
+ All the audio fixtures for the test are from AUDIO_ASSETS.
+ For huggingface runner, we provide the audio as input.
+ For vllm runner, we provide MultiModalDataDict objects
+ and corresponding MultiModalConfig as input.
+ Note, the text input is also adjusted to abide by vllm contract.
+ The text output is sanitized to be able to compare with hf.
+ """
+ # NOTE: take care of the order. run vLLM first, and then run HF.
+ # vLLM needs a fresh new process without cuda initialization.
+ # if we run HF first, the cuda initialization will be done and it
+ # will hurt multiprocessing backend with fork method (the default method).
+ # max_model_len should be greater than image_feature_size
+ with vllm_runner(
+ model,
+ task="generate",
+ max_model_len=max_model_len,
+ max_num_seqs=1,
+ dtype=dtype,
+ limit_mm_per_prompt={"audio": 1},
+ tensor_parallel_size=tensor_parallel_size,
+ distributed_executor_backend=distributed_executor_backend,
+ enable_lora=True,
+ max_lora_rank=64,
+ enforce_eager=True,
+ ) as vllm_model:
+ lora_request = LoRARequest("audio", 1, audio_lora_path)
+ vllm_outputs_per_case = [
+ vllm_model.generate_greedy_logprobs(prompts,
+ max_tokens,
+ num_logprobs=num_logprobs,
+ audios=audios,
+ lora_request=lora_request)
+ for prompts, audios in inputs
+ ]
+
+ with hf_runner(model, dtype=dtype,
+ auto_cls=AutoModelForSpeechSeq2Seq) as hf_model:
+
+ hf_processor = hf_model.processor
+ eos_token_id = hf_processor.tokenizer.eos_token_id
+
+ hf_outputs_per_case = [
+ hf_model.generate_greedy_logprobs_limit(prompts,
+ max_tokens,
+ num_logprobs=num_logprobs,
+ audios=[audios],
+ eos_token_id=eos_token_id)
+ for prompts, audios in inputs
+ ]
+
+ for hf_outputs, vllm_outputs in zip(hf_outputs_per_case,
+ vllm_outputs_per_case):
+ check_logprobs_close(
+ outputs_0_lst=hf_outputs,
+ outputs_1_lst=[
+ vllm_to_hf_output(output) for output in vllm_outputs
+ ],
+ name_0="hf",
+ name_1="vllm",
+ )
+
+
+@pytest.mark.parametrize("model", models)
+@pytest.mark.parametrize("dtype", ["bfloat16"])
+@pytest.mark.parametrize("max_model_len", [2048])
+@pytest.mark.parametrize("max_tokens", [128])
+@pytest.mark.parametrize("num_logprobs", [10])
+def test_models(hf_runner, vllm_runner, model: str, audio_assets: _AudioAssets,
+ dtype: str, max_model_len: int, max_tokens: int,
+ num_logprobs: int) -> None:
+ model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
+ model_info.check_available_online(on_fail="skip")
+ model_info.check_transformers_version(on_fail="skip")
+
+ audio, sr = audio_assets[0].audio_and_sample_rate
+ # This model expects 16k sample rate, which our test audio
+ # already is; if this changes, it may break this test,
+ # so we check it directly
+ assert sr == 16000
+ run_test(
+ hf_runner,
+ vllm_runner,
+ [
+ ([HF_AUDIO_PROMPT], [audio]),
+ ],
+ model,
+ dtype=dtype,
+ max_model_len=max_model_len,
+ max_tokens=max_tokens,
+ num_logprobs=num_logprobs,
+ tensor_parallel_size=1,
+ )
diff --git a/tests/models/decoder_only/audio_language/test_ultravox.py b/tests/models/decoder_only/audio_language/test_ultravox.py
index e9dcba8ec089..1d7de946a3f8 100644
--- a/tests/models/decoder_only/audio_language/test_ultravox.py
+++ b/tests/models/decoder_only/audio_language/test_ultravox.py
@@ -11,7 +11,7 @@ from transformers import AutoModel, AutoTokenizer
from vllm.multimodal.audio import resample_audio_librosa
from vllm.sequence import SampleLogprobs
-from ....conftest import HfRunner, VllmRunner
+from ....conftest import HfRunner, VllmRunner, _AudioAssets
from ....utils import RemoteOpenAIServer
from ...registry import HF_EXAMPLE_MODELS
from ...utils import check_logprobs_close
@@ -31,12 +31,6 @@ CHUNKED_PREFILL_KWARGS = {
}
-@pytest.fixture(scope="session")
-def audio_assets():
- from vllm.assets.audio import AudioAsset
- return [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]
-
-
@pytest.fixture(scope="module", params=("mary_had_lamb", "winning_call"))
def audio(request):
from vllm.assets.audio import AudioAsset
@@ -59,7 +53,7 @@ def params_kwargs_to_cli_args(params_kwargs: dict[str, Any]) -> list[str]:
pytest.param({}, marks=pytest.mark.cpu_model),
pytest.param(CHUNKED_PREFILL_KWARGS),
])
-def server(request, audio_assets):
+def server(request, audio_assets: _AudioAssets):
args = [
"--dtype", "bfloat16", "--max-model-len", "4096", "--enforce-eager",
"--limit-mm-per-prompt",
@@ -230,8 +224,9 @@ def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int,
pytest.param({}, marks=pytest.mark.cpu_model),
pytest.param(CHUNKED_PREFILL_KWARGS),
])
-def test_models_with_multiple_audios(vllm_runner, audio_assets, dtype: str,
- max_tokens: int, num_logprobs: int,
+def test_models_with_multiple_audios(vllm_runner, audio_assets: _AudioAssets,
+ dtype: str, max_tokens: int,
+ num_logprobs: int,
vllm_kwargs: dict) -> None:
vllm_prompt = _get_prompt(len(audio_assets),
@@ -250,7 +245,7 @@ def test_models_with_multiple_audios(vllm_runner, audio_assets, dtype: str,
@pytest.mark.asyncio
-async def test_online_serving(client, audio_assets):
+async def test_online_serving(client, audio_assets: _AudioAssets):
"""Exercises online serving with/without chunked prefill enabled."""
messages = [{
diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py
index d56638f051f2..b3c56e18b243 100644
--- a/tests/models/multimodal/processing/test_common.py
+++ b/tests/models/multimodal/processing/test_common.py
@@ -254,6 +254,7 @@ def _test_processing_correctness_mistral(
"adept/fuyu-8b",
"google/gemma-3-4b-it",
"THUDM/glm-4v-9b",
+ "ibm-granite/granite-speech-3.3-8b",
"h2oai/h2ovl-mississippi-800m",
"OpenGVLab/InternVL2-1B",
"HuggingFaceM4/Idefics3-8B-Llama3",
diff --git a/tests/models/registry.py b/tests/models/registry.py
index d29c8ce633cf..a08924639b17 100644
--- a/tests/models/registry.py
+++ b/tests/models/registry.py
@@ -298,9 +298,11 @@ _MULTIMODAL_EXAMPLE_MODELS = {
extras={"fork": "Isotr0py/deepseek-vl2-tiny"}, # noqa: E501
max_transformers_version="4.48", # noqa: E501
transformers_version_reason="HF model is not compatible.", # noqa: E501
- hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501
+ hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501
"FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"),
"Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it"),
+ "GraniteSpeechForConditionalGeneration": _HfExamplesInfo("ibm-granite/granite-speech-3.3-8b", # noqa: E501
+ min_transformers_version="4.52.0"), # noqa: E501
"GLM4VForCausalLM": _HfExamplesInfo("THUDM/glm-4v-9b",
trust_remote_code=True,
hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501
diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py
index 4901d8f98dac..fcaa24eec8c8 100644
--- a/vllm/entrypoints/chat_utils.py
+++ b/vllm/entrypoints/chat_utils.py
@@ -517,7 +517,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
raise TypeError(f"Unknown {modality} model type: {model_type}")
elif modality == "audio":
- if model_type == "ultravox":
+ if model_type in ("ultravox", "granite_speech"):
return "<|audio|>"
if model_type == "phi4mm":
return f"<|audio_{current_count}|>"
diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py
index 17c857ce096a..eed49e74ac9f 100644
--- a/vllm/model_executor/models/blip2.py
+++ b/vllm/model_executor/models/blip2.py
@@ -60,6 +60,7 @@ class Blip2QFormerMultiHeadAttention(nn.Module):
quant_config: Optional[QuantizationConfig],
cache_config: Optional[CacheConfig],
is_cross_attention: bool = False,
+ prefix: str = "",
) -> None:
super().__init__()
@@ -139,7 +140,7 @@ class Blip2QFormerMultiHeadAttention(nn.Module):
class Blip2QFormerSelfOutput(nn.Module):
- def __init__(self, config: Blip2QFormerConfig) -> None:
+ def __init__(self, config: Blip2QFormerConfig, prefix: str = "") -> None:
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
@@ -167,6 +168,7 @@ class Blip2QFormerAttention(nn.Module):
quant_config: Optional[QuantizationConfig],
cache_config: Optional[CacheConfig],
is_cross_attention: bool = False,
+ prefix: str = "",
) -> None:
super().__init__()
@@ -175,9 +177,10 @@ class Blip2QFormerAttention(nn.Module):
quant_config=quant_config,
cache_config=cache_config,
is_cross_attention=is_cross_attention,
+ prefix=f"{prefix}.attention",
)
- self.output = Blip2QFormerSelfOutput(config)
+ self.output = Blip2QFormerSelfOutput(config, prefix=f"{prefix}.output")
def forward(
self,
@@ -195,7 +198,7 @@ class Blip2QFormerAttention(nn.Module):
class Blip2QFormerIntermediate(nn.Module):
- def __init__(self, config: Blip2QFormerConfig) -> None:
+ def __init__(self, config: Blip2QFormerConfig, prefix: str = "") -> None:
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
@@ -209,7 +212,7 @@ class Blip2QFormerIntermediate(nn.Module):
class Blip2QFormerOutput(nn.Module):
- def __init__(self, config: Blip2QFormerConfig) -> None:
+ def __init__(self, config: Blip2QFormerConfig, prefix: str = "") -> None:
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
@@ -237,6 +240,7 @@ class Blip2QFormerLayer(nn.Module):
quant_config: Optional[QuantizationConfig],
cache_config: Optional[CacheConfig],
layer_idx: int,
+ prefix: str = "",
) -> None:
super().__init__()
@@ -244,7 +248,8 @@ class Blip2QFormerLayer(nn.Module):
self.seq_len_dim = 1
self.attention = Blip2QFormerAttention(config,
quant_config=quant_config,
- cache_config=cache_config)
+ cache_config=cache_config,
+ prefix=f"{prefix}.attention")
self.layer_idx = layer_idx
@@ -253,13 +258,16 @@ class Blip2QFormerLayer(nn.Module):
config,
quant_config=quant_config,
cache_config=cache_config,
- is_cross_attention=True)
+ is_cross_attention=True,
+ prefix=f"{prefix}.crossattention")
self.has_cross_attention = True
else:
self.has_cross_attention = False
- self.intermediate_query = Blip2QFormerIntermediate(config)
- self.output_query = Blip2QFormerOutput(config)
+ self.intermediate_query = Blip2QFormerIntermediate(
+ config, prefix=f"{prefix}.intermediate_query")
+ self.output_query = Blip2QFormerOutput(config,
+ prefix=f"{prefix}.output_query")
def forward(
self,
@@ -325,6 +333,7 @@ class Blip2QFormerEncoder(nn.Module):
*,
quant_config: Optional[QuantizationConfig],
cache_config: Optional[CacheConfig],
+ prefix: str = "",
) -> None:
super().__init__()
@@ -334,7 +343,8 @@ class Blip2QFormerEncoder(nn.Module):
Blip2QFormerLayer(config,
quant_config=quant_config,
cache_config=cache_config,
- layer_idx=layer_idx)
+ layer_idx=layer_idx,
+ prefix=f"{prefix}.layer.{layer_idx}")
for layer_idx in range(config.num_hidden_layers)
])
@@ -365,6 +375,7 @@ class Blip2QFormerModel(nn.Module):
*,
quant_config: Optional[QuantizationConfig],
cache_config: Optional[CacheConfig],
+ prefix: str = "",
) -> None:
super().__init__()
@@ -376,7 +387,8 @@ class Blip2QFormerModel(nn.Module):
self.encoder = Blip2QFormerEncoder(config,
quant_config=quant_config,
- cache_config=cache_config)
+ cache_config=cache_config,
+ prefix=f"{prefix}.encoder")
def forward(
self,
@@ -511,7 +523,8 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
self.qformer = Blip2QFormerModel(config.qformer_config,
cache_config=cache_config,
- quant_config=quant_config)
+ quant_config=quant_config,
+ prefix=f"{prefix}.qformer")
self.language_projection = nn.Linear(
config.qformer_config.hidden_size,
diff --git a/vllm/model_executor/models/granite_speech.py b/vllm/model_executor/models/granite_speech.py
new file mode 100644
index 000000000000..b43b59da6d11
--- /dev/null
+++ b/vllm/model_executor/models/granite_speech.py
@@ -0,0 +1,777 @@
+# SPDX-License-Identifier: Apache-2.0
+
+# Adapted from
+# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
+# Copyright 2025 The vLLM team.
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Inference-only IBM Granite speeech model."""
+import math
+from typing import Iterable, Mapping, Optional, Set, Tuple, TypedDict, Union
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+from transformers import BatchFeature, PretrainedConfig
+
+from vllm.config import CacheConfig, VllmConfig
+from vllm.model_executor.layers.linear import (ColumnParallelLinear,
+ RowParallelLinear)
+from vllm.model_executor.layers.quantization import QuantizationConfig
+from vllm.model_executor.layers.sampler import get_sampler
+from vllm.model_executor.models.module_mapping import MultiModelKeys
+from vllm.model_executor.sampling_metadata import SamplingMetadata
+from vllm.multimodal import MULTIMODAL_REGISTRY
+from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
+ MultiModalKwargs)
+from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems,
+ MultiModalDataParser)
+from vllm.multimodal.processing import (BaseMultiModalProcessor,
+ BaseProcessingInfo, PromptReplacement,
+ PromptUpdate)
+from vllm.multimodal.profiling import BaseDummyInputsBuilder
+from vllm.sequence import IntermediateTensors
+
+from .blip2 import Blip2QFormerModel
+from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
+ SupportsMultiModal, SupportsPP)
+from .utils import (AutoWeightsLoader, embed_multimodal,
+ init_vllm_registered_model, maybe_prefix)
+
+
+### Audio Input
+class GraniteSpeechAudioInputs(TypedDict):
+
+ input_features: torch.Tensor
+ """Shape: `(bsz, num_features, 160)`"""
+
+ input_features_mask: torch.Tensor
+ """Shape: `(bsz, num_features)`"""
+
+ audio_embed_sizes: list[int]
+ """List of length `bsz`"""
+
+
+class GraniteSpeechMultiModalProcessingInfo(BaseProcessingInfo):
+
+ def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
+ return {"audio": 1}
+
+ # There is no limit to the maximum number of audio tokens that can be
+ # encoded as features; we pick ~5000 as a number that is probably higher
+ # than we would expect to encounter. The sequence of length
+ # get_max_audio_len() produces get_max_audio_tokens().
+ def get_max_audio_tokens(self):
+ return 5001
+
+ def get_max_audio_len(self):
+ return 8000000
+
+
+### Input Processing & Multimodal utils
+class GraniteSpeechMultiModalProcessor(
+ BaseMultiModalProcessor[GraniteSpeechMultiModalProcessingInfo]):
+
+ def _get_data_parser(self) -> MultiModalDataParser:
+ feature_extractor = self.info.get_hf_processor().audio_processor
+ sampling_rate = feature_extractor.melspec_kwargs["sample_rate"]
+ return MultiModalDataParser(target_sr=sampling_rate)
+
+ def _get_mm_fields_config(
+ self,
+ hf_inputs: BatchFeature,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ ) -> Mapping[str, MultiModalFieldConfig]:
+ return dict(
+ input_features=MultiModalFieldConfig.batched("audio"),
+ audio_embed_sizes=MultiModalFieldConfig.batched("audio"),
+ )
+
+ def _get_prompt_updates(
+ self,
+ mm_items: MultiModalDataItems,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ out_mm_kwargs: MultiModalKwargs,
+ ) -> list[PromptUpdate]:
+ processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
+ tokenizer = self.info.get_tokenizer()
+ feature_extractor = processor.audio_processor
+ vocab = tokenizer.get_vocab()
+
+ # Use getattr with default to be compatible with transformers<4.48
+ audio_token = getattr(processor, "audio_token", "<|audio|>")
+ audio_token_id = vocab[audio_token]
+
+ def get_replacement(item_idx: int):
+ audios = mm_items.get_items("audio", AudioProcessorItems)
+ audio = audios.get(item_idx)
+ audio_length = audio.shape[-1]
+ num_projector_features = feature_extractor._get_num_audio_features(
+ [audio_length])[0]
+ return [audio_token_id] * num_projector_features
+
+ return [
+ PromptReplacement(
+ modality="audio",
+ target=[audio_token_id],
+ replacement=get_replacement,
+ )
+ ]
+
+ def _call_hf_processor(
+ self,
+ prompt: str,
+ mm_data: Mapping[str, object],
+ mm_kwargs: Mapping[str, object],
+ ) -> BatchFeature:
+ mm_data = dict(mm_data)
+ audios = mm_data.pop("audios", [])
+
+ if audios:
+ # GraniteSpeechFeatureExtractor accepts "audio"
+ mm_data["audio"] = audios
+
+ processed_outputs = super()._call_hf_processor(
+ prompt=prompt,
+ mm_data=mm_data,
+ mm_kwargs=mm_kwargs,
+ )
+
+ if "audio" in mm_data:
+ # Calculate the number of audio tokens per entry in the batch;
+ # This is used to split the batch back out after padding.
+ audio_token_index = self.info.get_hf_config().audio_token_index
+ processed_outputs["audio_embed_sizes"] = [
+ torch.sum(indices == audio_token_index).item()
+ for indices in processed_outputs["input_ids"]
+ ]
+
+ return processed_outputs
+
+
+class GraniteSpeechDummyInputsBuilder(
+ BaseDummyInputsBuilder[GraniteSpeechMultiModalProcessingInfo]):
+
+ def get_dummy_mm_data(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> MultiModalDataDict:
+ num_audios = mm_counts.get("audio", 0)
+ return {
+ "audio":
+ self._get_dummy_audios(
+ length=self.info.get_max_audio_len(),
+ num_audios=num_audios,
+ )
+ }
+
+ def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
+ num_audios = mm_counts.get("audio", 0)
+ hf_processor = self.info.get_hf_processor()
+ audio_token = getattr(hf_processor, "audio_token", "<|audio|>")
+ return audio_token * num_audios
+
+
+### QFormer Projector
+class GraniteSpeechEncoderProjector(nn.Module):
+
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ cache_config: CacheConfig,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ):
+ super().__init__()
+ self.hidden_size = config.projector_config.hidden_size
+ self.downsample_rate = config.downsample_rate
+ self.window_size = config.window_size
+ self.num_queries = config.window_size // config.downsample_rate
+
+ self.query = nn.Parameter(
+ torch.zeros(1, self.num_queries,
+ config.projector_config.hidden_size))
+
+ # NOTE - this is implemented generically in transformers,
+ # but for now we create the QFormer model directly since
+ # all existing models use this for the projector.
+ self.qformer = Blip2QFormerModel(
+ config.projector_config,
+ quant_config=quant_config,
+ cache_config=cache_config,
+ prefix=f"{prefix}.qformer",
+ )
+ self.linear = nn.Linear(config.projector_config.hidden_size,
+ config.text_config.hidden_size)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ batch_size, seq_len, dim = hidden_states.size()
+ nblocks = math.ceil(seq_len / self.window_size)
+ pad = nblocks * self.window_size - seq_len
+ hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, pad),
+ "constant", 0)
+ hidden_states = hidden_states.view(batch_size * nblocks,
+ self.window_size, dim)
+
+ last_hidden_state = self.qformer(
+ query_embeds=self.query.data,
+ encoder_hidden_states=hidden_states,
+ )
+
+ query_proj = self.linear(
+ last_hidden_state.view(
+ batch_size,
+ nblocks * self.window_size // self.downsample_rate,
+ -1,
+ ))
+ return query_proj
+
+
+# Encoder - conformer is adapted from: https://github.com/lucidrains/conformer.git
+# NOTE - it would be nice to see if we can align this with other models using
+# conformer in vLLM, e.g., phi4mm audio.
+class GraniteSpeechConformerFeedForward(nn.Module):
+ """Feedforward module for conformer encoder blocks."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = ""):
+ super().__init__()
+ self.pre_norm = nn.LayerNorm(config.hidden_dim)
+
+ self.up_proj = ColumnParallelLinear(
+ input_size=config.hidden_dim,
+ output_size=config.hidden_dim * config.feedforward_mult,
+ quant_config=quant_config,
+ prefix=f"{prefix}.up_proj",
+ )
+ self.silu = nn.SiLU()
+
+ self.down_proj = RowParallelLinear(
+ input_size=config.hidden_dim * config.feedforward_mult,
+ output_size=config.hidden_dim,
+ quant_config=quant_config,
+ prefix=f"{prefix}.down_proj",
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.pre_norm(hidden_states)
+ hidden_states, _ = self.up_proj(hidden_states)
+ hidden_states = self.silu(hidden_states)
+ hidden_states, _ = self.down_proj(hidden_states)
+ return hidden_states
+
+
+class GraniteSpeechConformerAttention(nn.Module):
+ """Attention for conformer blocks using Shaw's relative positional
+ embeddings. See the following [paper](https://arxiv.org/pdf/1803.02155)
+ for more details.
+ """
+
+ def __init__(self, config: PretrainedConfig, prefix: str = ""):
+ super().__init__()
+
+ inner_dim = config.dim_head * config.num_heads
+ self.max_pos_emb = config.max_pos_emb
+ self.context_size = config.context_size
+ self.num_heads = config.num_heads
+ self.dim_head = config.dim_head
+ self.scale = self.dim_head**-0.5
+ self.pre_norm = nn.LayerNorm(config.hidden_dim)
+ self.to_q = nn.Linear(config.hidden_dim, inner_dim, bias=False)
+ self.to_kv = nn.Linear(config.hidden_dim, inner_dim * 2, bias=False)
+ self.to_out = nn.Linear(inner_dim, config.hidden_dim)
+ self.rel_pos_emb = nn.Embedding(2 * self.max_pos_emb + 1,
+ self.dim_head)
+
+ if self.context_size <= 0 or self.context_size > self.max_pos_emb:
+ raise ValueError(
+ "Context size is either less than 0 or exceeds the max_pos_emb"
+ )
+
+ def forward(self, hidden_states: torch.Tensor,
+ attention_dists: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.pre_norm(hidden_states)
+ bsz, num_features, _ = hidden_states.shape
+
+ num_blocks = math.ceil(num_features / self.context_size)
+ remainder = num_features % self.context_size
+ if remainder > 0:
+ # right padding to reach block size
+ hidden_states = torch.nn.functional.pad(
+ hidden_states, (0, 0, 0, self.context_size - remainder))
+
+ # NOTE: would be nice to try to use qkvparallellinear
+ # here for this block attention implementation if possible
+ query_states = self.to_q(hidden_states)
+ key_states, value_states = self.to_kv(hidden_states).chunk(2, dim=-1)
+
+ query_states = query_states.reshape(bsz, num_blocks, self.context_size,
+ self.num_heads,
+ -1).transpose(2, 3)
+ key_states = key_states.reshape(bsz, num_blocks, self.context_size,
+ self.num_heads, -1).transpose(2, 3)
+ value_states = value_states.reshape(bsz, num_blocks, self.context_size,
+ self.num_heads,
+ -1).transpose(2, 3)
+
+ # shaw's relative positional embedding
+ dist = attention_dists.to(hidden_states.device)
+ rel_pos_emb = self.rel_pos_emb(dist)
+ rel_pos_emb_expanded = rel_pos_emb.view([1, 1, 1] +
+ list(rel_pos_emb.shape))
+ pos_attn = torch.sum(query_states.unsqueeze(-2) * rel_pos_emb_expanded,
+ dim=-1) * self.scale
+
+ if remainder > 0:
+ # masked attention in the extended block
+ mask = torch.ones(self.context_size,
+ self.context_size,
+ dtype=bool,
+ device=hidden_states.device)
+ mask[:remainder, :remainder] = 0
+ mask_value = -torch.finfo(pos_attn.dtype).max
+ pos_attn[:, -1, :].masked_fill_(mask, mask_value)
+
+ with torch.nn.attention.sdpa_kernel(
+ torch.nn.attention.SDPBackend.MATH):
+ out = F.scaled_dot_product_attention(query_states,
+ key_states,
+ value_states,
+ attn_mask=pos_attn,
+ scale=self.scale)
+ out = out.transpose(2, 3).reshape(bsz, hidden_states.shape[1], -1)
+ return self.to_out(out[:, :num_features, :])
+
+
+class GraniteSpeechConformerDepthWiseConv1d(nn.Module):
+ """Wrapper for padded 1D pointwise convolution."""
+
+ def __init__(self,
+ chan_in: int,
+ chan_out: int,
+ kernel_size: int,
+ prefix: str = ""):
+ super().__init__()
+ # Padding for the 1D conv is symmetric or close (i.e., offset by one).
+ pad = kernel_size // 2
+ pad_offset = (kernel_size + 1) % 2
+ self.padding = (pad, pad - pad_offset)
+
+ self.conv = nn.Conv1d(chan_in,
+ chan_out,
+ kernel_size,
+ groups=chan_in,
+ bias=False)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = F.pad(hidden_states, self.padding)
+ return self.conv(hidden_states)
+
+
+class GraniteSpeechConformerConvModule(nn.Module):
+ """Conformer conv module consisting of several 1D/depthwise 1D
+ convolutional layers.
+ """
+
+ def __init__(self, config: PretrainedConfig, prefix: str = ""):
+ super().__init__()
+ inner_dim = config.hidden_dim * config.conv_expansion_factor
+
+ self.norm = nn.LayerNorm(config.hidden_dim)
+ self.up_conv = nn.Conv1d(config.hidden_dim, inner_dim * 2, 1)
+ self.glu = nn.GLU(dim=1)
+ self.depth_conv = GraniteSpeechConformerDepthWiseConv1d(
+ inner_dim,
+ inner_dim,
+ kernel_size=config.conv_kernel_size,
+ prefix=f"{prefix}.depth_conv",
+ )
+ self.silu = nn.SiLU()
+ self.batch_norm = nn.BatchNorm1d(inner_dim)
+ self.down_conv = nn.Conv1d(inner_dim, config.hidden_dim, 1)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.norm(hidden_states)
+ hidden_states = self.up_conv(hidden_states.permute(0, 2, 1))
+ hidden_states = self.glu(hidden_states)
+ hidden_states = self.depth_conv(hidden_states)
+ hidden_states = self.silu(self.batch_norm(hidden_states))
+ hidden_states = self.down_conv(hidden_states).permute(0, 2, 1)
+ return hidden_states
+
+
+class GraniteSpeechConformerBlock(nn.Module):
+ """Conformer block, consisting largely of linear layers,
+ attention, and convolutional layers."""
+
+ def __init__(self, config: PretrainedConfig, prefix: str = ""):
+ super().__init__()
+ self.ff1 = GraniteSpeechConformerFeedForward(config,
+ prefix=f"{prefix}.ff1")
+ self.attn = GraniteSpeechConformerAttention(config,
+ prefix=f"{prefix}.attn")
+ self.conv = GraniteSpeechConformerConvModule(config,
+ prefix=f"{prefix}.conv")
+ self.ff2 = GraniteSpeechConformerFeedForward(config,
+ prefix=f"{prefix}.ff2")
+ self.post_norm = nn.LayerNorm(config.hidden_dim)
+
+ def forward(self, hidden_states: torch.Tensor,
+ attention_dists: torch.Tensor) -> torch.Tensor:
+ hidden_states = 0.5 * self.ff1(hidden_states) + hidden_states
+ hidden_states = self.attn(
+ hidden_states, attention_dists=attention_dists) + hidden_states
+ hidden_states = self.conv(hidden_states) + hidden_states
+ hidden_states = 0.5 * self.ff2(hidden_states) + hidden_states
+ hidden_states = self.post_norm(hidden_states)
+ return hidden_states
+
+
+class GraniteSpeechCTCEncoder(nn.Module):
+ """CTC Encoder comprising conformer blocks and additional linear layers."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ prefix: str,
+ quant_config: Optional[QuantizationConfig] = None):
+ super().__init__()
+ self.config = config
+
+ # Precompute clamped relative positional encoding distances
+ seq = torch.arange(config.context_size)
+ relpos_dist = seq.view(-1, 1) - seq.view(1, -1)
+ self.attention_dists = torch.clamp(
+ relpos_dist, -config.context_size,
+ config.context_size) + config.max_pos_emb
+
+ self.input_linear = nn.Linear(config.input_dim,
+ config.hidden_dim,
+ bias=True)
+ self.layers = nn.ModuleList([
+ GraniteSpeechConformerBlock(
+ config,
+ prefix=f"{prefix}.layers.{idx}",
+ ) for idx in range(config.num_layers)
+ ])
+
+ self.out = ColumnParallelLinear(
+ input_size=config.hidden_dim,
+ output_size=config.output_dim,
+ bias=True,
+ quant_config=quant_config,
+ prefix=f"{prefix}.out",
+ )
+
+ self.out_mid = RowParallelLinear(
+ input_size=config.output_dim,
+ output_size=config.hidden_dim,
+ bias=True,
+ quant_config=quant_config,
+ prefix=f"{prefix}.out_mid",
+ )
+ self.softmax = nn.Softmax(dim=-1)
+ self.num_layers = config.num_layers
+
+ def forward(self, hidden_states: torch.Tensor):
+ hidden_states = self.input_linear(hidden_states)
+ for idx, layer in enumerate(self.layers, start=1):
+ hidden_states = layer(hidden_states,
+ attention_dists=self.attention_dists)
+
+ if idx == self.num_layers // 2:
+ hidden_states_mid = hidden_states.clone()
+ hidden_states_mid, _ = self.out(hidden_states_mid)
+ hidden_states_mid = self.softmax(hidden_states_mid)
+ hidden_states_mid, _ = self.out_mid(hidden_states_mid)
+ hidden_states += hidden_states_mid
+ return hidden_states
+
+
+@MULTIMODAL_REGISTRY.register_processor(
+ GraniteSpeechMultiModalProcessor,
+ info=GraniteSpeechMultiModalProcessingInfo,
+ dummy_inputs=GraniteSpeechDummyInputsBuilder)
+class GraniteSpeechForConditionalGeneration(
+ nn.Module,
+ SupportsMultiModal,
+ SupportsPP,
+ SupportsLoRA,
+):
+
+ packed_modules_mapping = {
+ "qkv_proj": [
+ "q_proj",
+ "k_proj",
+ "v_proj",
+ ],
+ "gate_up_proj": [
+ "gate_proj",
+ "up_proj",
+ ],
+ }
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str):
+ super().__init__()
+ config = vllm_config.model_config.hf_config
+ quant_config = vllm_config.quant_config
+ cache_config = vllm_config.cache_config
+
+ self.config = config
+ self.quant_config = quant_config
+ self.cache_config = cache_config
+ self.sampler = get_sampler()
+
+ # The language model is typically a Granite LLM
+ self.language_model = init_vllm_registered_model(
+ vllm_config=vllm_config,
+ hf_config=config.text_config,
+ prefix=maybe_prefix(prefix, "language_model"),
+ )
+
+ # Conformer encoder
+ self.encoder = GraniteSpeechCTCEncoder(
+ config=config.encoder_config,
+ quant_config=quant_config,
+ prefix=f"{prefix}.encoder",
+ )
+
+ # Blip2 QFormer
+ self.projector = GraniteSpeechEncoderProjector(
+ config=config,
+ quant_config=quant_config,
+ cache_config=cache_config,
+ prefix=f"{prefix}.projector",
+ )
+
+ self.make_empty_intermediate_tensors = (
+ self.language_model.make_empty_intermediate_tensors)
+
+ def _parse_and_validate_audio_input(
+ self,
+ **kwargs: object,
+ ) -> Optional[GraniteSpeechAudioInputs]:
+ input_features = kwargs.pop("input_features", None)
+ input_features_mask = kwargs.pop("input_features_mask", None)
+ audio_embed_sizes = kwargs.pop("audio_embed_sizes", None)
+ if input_features is None:
+ return None
+
+ # If we have a batch of variable feature length audio clips, we need
+ # to mask the features; usually we would get an input_features_mask
+ # from the processor, but we handle rebuilding it here since
+ # vLLM generally processes everything independently + batches.
+ if input_features_mask is None:
+ input_features_mask = self._build_input_features_mask(
+ audio_embed_sizes)
+
+ if not isinstance(input_features, (torch.Tensor, list)):
+ raise ValueError("Incorrect type of audio input features. "
+ f"Got type: {type(input_features)}")
+
+ if input_features_mask is not None and not isinstance(
+ input_features_mask, torch.Tensor):
+ raise ValueError("Incorrect type of audio input features mask. "
+ f"Got type: {type(input_features_mask)}")
+
+ if isinstance(input_features, torch.Tensor):
+ # Granite speech currently only allows one audio token per instance
+ # and features are already unsqueezed in the processor, so one
+ # instance will have shape [1, {num_features}, 160]. As such,
+ # input features will usually be of shape
+ # [bsz, 1, num_features, 160], which we squeeze to be 3D here.
+ if len(input_features.shape) == 4:
+ input_features = input_features.squeeze(1)
+ if len(input_features.shape) != 3:
+ raise ValueError(
+ "Squeezed input features should be 3D but are of shape "
+ f"{input_features.shape}")
+ input_features = input_features.to(
+ self.encoder.input_linear.weight.dtype)
+
+ else:
+ # Otherwise we have a list of tensors, which are almost certainly
+ # differing in their respective numbers of audio features;
+ # stack them into a 3D tensor of size [bsz, most_num_features, 160].
+ input_features = self._pad_and_stack_input_features(
+ input_features, ).to(self.encoder.input_linear.weight.dtype)
+
+ return GraniteSpeechAudioInputs(
+ input_features=input_features,
+ input_features_mask=input_features_mask,
+ audio_embed_sizes=audio_embed_sizes.flatten().tolist(),
+ )
+
+ def _build_input_features_mask(
+ self,
+ audio_embed_sizes: torch.Tensor,
+ ) -> torch.Tensor:
+ """Calculate the input features mask, which will generally be used
+ to mask the the padded features for all entries in the batch except
+ for those with the most audio features.
+
+ Args:
+ audio_embed_sizes: torch.Tensor
+ Tensor of num features in each seq in the batch.
+ Returns:
+ torch.Tensor: Mask of shape (bsz, num_features) to be applied to
+ the audio features prior to splitting the audio embeddings.
+ """
+ most_audio_features = torch.max(audio_embed_sizes).item()
+ mask_indices = torch.arange(
+ most_audio_features,
+ device=audio_embed_sizes.device,
+ ).view(1, -1)
+ input_features_mask = mask_indices < audio_embed_sizes.view(-1, 1)
+ return input_features_mask
+
+ def _pad_and_stack_input_features(
+ self,
+ input_features: list[torch.Tensor],
+ ) -> torch.Tensor:
+ """Given a list of input features of varying length, pad them to the
+ same length and stack them into a torch.Tensor.
+
+ NOTE: Usually, padding is done in the input processor/feature extractor
+ and zero padded prior to the computation of the Mel features; the
+ resulting values are only constant within a batch and generally nonzero
+ (i.e., slightly negative nums); we should validate that this is okay
+ since we don't use a feature attention mask, but the more important
+ thing is that we apply the input_features_mask with variable len
+ batches.
+
+ Args:
+ input_features: list[torch.Tensor]
+ Input features to be coerced into a tensor.
+ Returns:
+ torch.Tensor: Tensor of shape [bsz, num_features, 160], where
+ num_features is the max number of features of any entry in the
+ batch.
+ """
+ # Input features are of shape [bsz, num_features, 160]
+ feat_lens = [feats.shape[1] for feats in input_features]
+ padding = [max(feat_lens) - length for length in feat_lens]
+ # TODO (Alex) - Validate that it's okay to zero pad like this;
+ # in transformers we zero pad prior to calculating the speech features,
+ # so the value is not zero and is dependent on the batched features.
+ padded = [
+ torch.nn.functional.pad(feats, (0, 0, 0, pad, 0, 0))
+ for feats, pad in zip(input_features, padding)
+ ]
+ stacked_features = torch.cat(padded, dim=0).to(input_features[0])
+ return stacked_features
+
+ def _process_audio_input(
+ self,
+ audio_input: GraniteSpeechAudioInputs,
+ ) -> tuple[torch.Tensor]:
+ """Compute the audio features to be merged into the LLM embeddings.
+
+ Args:
+ audio_input: GraniteSpeechAudioInputs
+ Audio inputs object containing Mel features, an input features
+ mask, and the (flattened) number of audio tokens per instance.
+ Returns:
+ tuple[torch.Tensor]: List of length bsz.
+ """
+ # TODO (Alex) - support embedding inputs
+ encoder_embeds = self.encoder(audio_input["input_features"])
+ # [bsz, , 4096]
+ projected_embeds = self.projector(encoder_embeds)
+ # Apply mask on variable length audio features
+ masked_embeds = projected_embeds[audio_input["input_features_mask"]]
+ # Split variable length features into a tuple
+ return torch.split(masked_embeds, audio_input["audio_embed_sizes"])
+
+ def get_multimodal_embeddings(
+ self,
+ **kwargs: object,
+ ) -> Optional[MultiModalEmbeddings]:
+ """Compute the audio embeddings if audio inputs are present."""
+ audio_input = self._parse_and_validate_audio_input(**kwargs)
+ if audio_input is None:
+ return None
+ audio_features = self._process_audio_input(audio_input)
+ return audio_features
+
+ def get_input_embeddings(
+ self,
+ input_ids: torch.Tensor,
+ multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
+ ) -> torch.Tensor:
+ """Compute the merged LLM / audio embeddings."""
+ if multimodal_embeddings is None:
+ return self.language_model.get_input_embeddings(input_ids)
+
+ inputs_embeds = embed_multimodal(
+ input_ids,
+ self.config.audio_token_index,
+ self.language_model.model.get_input_embeddings,
+ multimodal_embeddings,
+ )
+ return inputs_embeds
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ **kwargs: object,
+ ) -> Union[torch.Tensor, IntermediateTensors]:
+ if intermediate_tensors is not None:
+ inputs_embeds = None
+
+ # NOTE: In v1, inputs_embeds is always generated at model runner, this
+ # condition is for v0 compatibility.
+ elif inputs_embeds is None:
+ audio_embeds = self.get_multimodal_embeddings(**kwargs)
+ inputs_embeds = self.get_input_embeddings(input_ids, audio_embeds)
+ input_ids = None
+
+ model_output = self.language_model(input_ids, positions,
+ intermediate_tensors, inputs_embeds)
+ return model_output
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ sampling_metadata: SamplingMetadata,
+ ) -> Optional[torch.Tensor]:
+ return self.language_model.compute_logits(
+ hidden_states,
+ sampling_metadata,
+ )
+
+ def load_weights(
+ self,
+ weights: Iterable[Tuple[str, torch.Tensor]],
+ ) -> Set[str]:
+ loader = AutoWeightsLoader(self)
+ return loader.load_weights(weights)
+
+ def get_mm_mapping(self) -> MultiModelKeys:
+ """Get the module prefix in multimodal models."""
+ return MultiModelKeys.from_string_field(
+ language_model="language_model",
+ connector="projector",
+ tower_model="encoder",
+ )
diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py
index 11e663e32d45..33877829fdb4 100644
--- a/vllm/model_executor/models/registry.py
+++ b/vllm/model_executor/models/registry.py
@@ -178,6 +178,7 @@ _MULTIMODAL_MODELS = {
"FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
"Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"), # noqa: E501
"GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
+ "GraniteSpeechForConditionalGeneration": ("granite_speech", "GraniteSpeechForConditionalGeneration"), # noqa: E501
"H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
"InternVLChatModel": ("internvl", "InternVLChatModel"),
"Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),