mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 08:04:58 +08:00
[Model] Add Granite Speech Support (#16246)
Signed-off-by: Alex-Brooks <Alex.brooks@ibm.com> Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
This commit is contained in:
parent
aec9674dbe
commit
fa93cd9f60
@ -895,6 +895,13 @@ See [this page](#generative-models) for more information on how to use generativ
|
||||
* ✅︎
|
||||
* ✅︎
|
||||
* ✅︎
|
||||
- * `GraniteSpeechForConditionalGeneration`
|
||||
* Granite Speech
|
||||
* T + A
|
||||
* `ibm-granite/granite-speech-3.3-8b`
|
||||
* ✅︎
|
||||
* ✅︎
|
||||
* ✅︎
|
||||
- * `H2OVLChatModel`
|
||||
* H2OVL
|
||||
* T + I<sup>E+</sup>
|
||||
|
||||
@ -38,6 +38,37 @@ class ModelRequestData(NamedTuple):
|
||||
# Unless specified, these settings have been tested to work on a single L4.
|
||||
|
||||
|
||||
# Granite Speech
|
||||
def run_granite_speech(question: str, audio_count: int) -> ModelRequestData:
|
||||
# NOTE - the setting in this example are somehat different than what is
|
||||
# optimal for granite speech, and it is generally recommended to use beam
|
||||
# search. Check the model README for suggested settings.
|
||||
# https://huggingface.co/ibm-granite/granite-speech-3.3-8b
|
||||
model_name = "ibm-granite/granite-speech-3.3-8b"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
max_model_len=2048,
|
||||
max_num_seqs=2,
|
||||
enable_lora=True,
|
||||
max_lora_rank=64,
|
||||
limit_mm_per_prompt={"audio": audio_count},
|
||||
)
|
||||
|
||||
# The model has an audio-specific lora directly in its model dir;
|
||||
# it should be enabled whenever you pass audio inputs to the model.
|
||||
speech_lora_path = model_name
|
||||
audio_placeholder = "<|audio|>" * audio_count
|
||||
prompts = f"<|start_of_role|>system<|end_of_role|>Knowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>{audio_placeholder}{question}<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>" # noqa: E501
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt=prompts,
|
||||
lora_requests=[LoRARequest("speech", 1, speech_lora_path)],
|
||||
)
|
||||
|
||||
|
||||
# MiniCPM-O
|
||||
def run_minicpmo(question: str, audio_count: int) -> ModelRequestData:
|
||||
model_name = "openbmb/MiniCPM-o-2_6"
|
||||
@ -209,6 +240,7 @@ def run_whisper(question: str, audio_count: int) -> ModelRequestData:
|
||||
|
||||
|
||||
model_example_map = {
|
||||
"granite_speech": run_granite_speech,
|
||||
"minicpmo": run_minicpmo,
|
||||
"phi4_mm": run_phi4mm,
|
||||
"qwen2_audio": run_qwen2_audio,
|
||||
|
||||
@ -21,6 +21,7 @@ from transformers.models.auto.auto_factory import _BaseAutoModelClass
|
||||
from tests.models.utils import (TokensTextLogprobs,
|
||||
TokensTextLogprobsPromptLogprobs)
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.assets.audio import AudioAsset
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.assets.video import VideoAsset
|
||||
from vllm.config import TaskOption, _get_and_verify_dtype
|
||||
@ -103,10 +104,25 @@ class _VideoAssets(_VideoAssetsBase):
|
||||
return [prompts["sample_demo_1"]]
|
||||
|
||||
|
||||
class _AudioAssetsBase(UserList[AudioAsset]):
|
||||
pass
|
||||
|
||||
|
||||
class _AudioAssets(_AudioAssetsBase):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__([
|
||||
AudioAsset("mary_had_lamb"),
|
||||
AudioAsset("winning_call"),
|
||||
])
|
||||
|
||||
|
||||
IMAGE_ASSETS = _ImageAssets()
|
||||
"""Singleton instance of :class:`_ImageAssets`."""
|
||||
VIDEO_ASSETS = _VideoAssets()
|
||||
"""Singleton instance of :class:`_VideoAssets`."""
|
||||
AUDIO_ASSETS = _AudioAssets()
|
||||
"""Singleton instance of :class:`_AudioAssets`."""
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
@ -263,6 +279,11 @@ def video_assets() -> _VideoAssets:
|
||||
return VIDEO_ASSETS
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def audio_assets() -> _AudioAssets:
|
||||
return AUDIO_ASSETS
|
||||
|
||||
|
||||
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature, dict)
|
||||
_R = TypeVar("_R")
|
||||
|
||||
@ -390,10 +411,15 @@ class HfRunner:
|
||||
processor_kwargs["images"] = image
|
||||
if videos is not None and (video := videos[i]) is not None:
|
||||
processor_kwargs["videos"] = video
|
||||
if audios is not None and (audio_tuple := audios[i]) is not None:
|
||||
audio, sr = audio_tuple
|
||||
processor_kwargs["audio"] = audio
|
||||
processor_kwargs["sampling_rate"] = sr
|
||||
if audios is not None and (audio_inputs := audios[i]) is not None:
|
||||
# HACK - not all processors take sampling_rate; we should
|
||||
# clean this up in the future.
|
||||
if len(audio_inputs) == 2:
|
||||
audio, sr = audio_inputs
|
||||
processor_kwargs["audio"] = audio
|
||||
processor_kwargs["sampling_rate"] = sr
|
||||
else:
|
||||
processor_kwargs["audio"] = audio_inputs
|
||||
|
||||
inputs = self.processor(**processor_kwargs)
|
||||
if isinstance(inputs, BatchFeature):
|
||||
|
||||
143
tests/models/decoder_only/audio_language/test_granite_speech.py
Normal file
143
tests/models/decoder_only/audio_language/test_granite_speech.py
Normal file
@ -0,0 +1,143 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from transformers import AutoModelForSpeechSeq2Seq
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import SampleLogprobs
|
||||
|
||||
from ....conftest import HfRunner, PromptAudioInput, VllmRunner, _AudioAssets
|
||||
from ...registry import HF_EXAMPLE_MODELS
|
||||
from ...utils import check_logprobs_close
|
||||
|
||||
HF_AUDIO_PROMPT = "<|start_of_role|>system<|end_of_role|>Knowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|><|audio|>can you transcribe the speech into a written format?<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>" # noqa: E501
|
||||
|
||||
|
||||
def vllm_to_hf_output(
|
||||
vllm_output: tuple[list[int], str, Optional[SampleLogprobs]],
|
||||
) -> tuple[list[int], str, Optional[SampleLogprobs]]:
|
||||
"""Sanitize hf output to be comparable with vllm output."""
|
||||
output_ids, output_str, out_logprobs = vllm_output
|
||||
|
||||
hf_output_str = output_str + "<|end_of_text|>"
|
||||
|
||||
return output_ids, hf_output_str, out_logprobs
|
||||
|
||||
|
||||
MODEL_NAME = "ibm-granite/granite-speech-3.3-8b"
|
||||
# Audio lora co-exists directly in the model directory, but
|
||||
# currently still needs to be passed directly to vLLM.
|
||||
audio_lora_path = MODEL_NAME
|
||||
models = [MODEL_NAME]
|
||||
|
||||
|
||||
def run_test(
|
||||
hf_runner: type[HfRunner],
|
||||
vllm_runner: type[VllmRunner],
|
||||
inputs: Sequence[tuple[list[str], PromptAudioInput]],
|
||||
model: str,
|
||||
*,
|
||||
max_model_len: int,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
tensor_parallel_size: int,
|
||||
distributed_executor_backend: Optional[str] = None,
|
||||
):
|
||||
"""Inference result should be the same between hf and vllm.
|
||||
|
||||
All the audio fixtures for the test are from AUDIO_ASSETS.
|
||||
For huggingface runner, we provide the audio as input.
|
||||
For vllm runner, we provide MultiModalDataDict objects
|
||||
and corresponding MultiModalConfig as input.
|
||||
Note, the text input is also adjusted to abide by vllm contract.
|
||||
The text output is sanitized to be able to compare with hf.
|
||||
"""
|
||||
# NOTE: take care of the order. run vLLM first, and then run HF.
|
||||
# vLLM needs a fresh new process without cuda initialization.
|
||||
# if we run HF first, the cuda initialization will be done and it
|
||||
# will hurt multiprocessing backend with fork method (the default method).
|
||||
# max_model_len should be greater than image_feature_size
|
||||
with vllm_runner(
|
||||
model,
|
||||
task="generate",
|
||||
max_model_len=max_model_len,
|
||||
max_num_seqs=1,
|
||||
dtype=dtype,
|
||||
limit_mm_per_prompt={"audio": 1},
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enable_lora=True,
|
||||
max_lora_rank=64,
|
||||
enforce_eager=True,
|
||||
) as vllm_model:
|
||||
lora_request = LoRARequest("audio", 1, audio_lora_path)
|
||||
vllm_outputs_per_case = [
|
||||
vllm_model.generate_greedy_logprobs(prompts,
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
audios=audios,
|
||||
lora_request=lora_request)
|
||||
for prompts, audios in inputs
|
||||
]
|
||||
|
||||
with hf_runner(model, dtype=dtype,
|
||||
auto_cls=AutoModelForSpeechSeq2Seq) as hf_model:
|
||||
|
||||
hf_processor = hf_model.processor
|
||||
eos_token_id = hf_processor.tokenizer.eos_token_id
|
||||
|
||||
hf_outputs_per_case = [
|
||||
hf_model.generate_greedy_logprobs_limit(prompts,
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
audios=[audios],
|
||||
eos_token_id=eos_token_id)
|
||||
for prompts, audios in inputs
|
||||
]
|
||||
|
||||
for hf_outputs, vllm_outputs in zip(hf_outputs_per_case,
|
||||
vllm_outputs_per_case):
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=[
|
||||
vllm_to_hf_output(output) for output in vllm_outputs
|
||||
],
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_model_len", [2048])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [10])
|
||||
def test_models(hf_runner, vllm_runner, model: str, audio_assets: _AudioAssets,
|
||||
dtype: str, max_model_len: int, max_tokens: int,
|
||||
num_logprobs: int) -> None:
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
model_info.check_transformers_version(on_fail="skip")
|
||||
|
||||
audio, sr = audio_assets[0].audio_and_sample_rate
|
||||
# This model expects 16k sample rate, which our test audio
|
||||
# already is; if this changes, it may break this test,
|
||||
# so we check it directly
|
||||
assert sr == 16000
|
||||
run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
[
|
||||
([HF_AUDIO_PROMPT], [audio]),
|
||||
],
|
||||
model,
|
||||
dtype=dtype,
|
||||
max_model_len=max_model_len,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
@ -11,7 +11,7 @@ from transformers import AutoModel, AutoTokenizer
|
||||
from vllm.multimodal.audio import resample_audio_librosa
|
||||
from vllm.sequence import SampleLogprobs
|
||||
|
||||
from ....conftest import HfRunner, VllmRunner
|
||||
from ....conftest import HfRunner, VllmRunner, _AudioAssets
|
||||
from ....utils import RemoteOpenAIServer
|
||||
from ...registry import HF_EXAMPLE_MODELS
|
||||
from ...utils import check_logprobs_close
|
||||
@ -31,12 +31,6 @@ CHUNKED_PREFILL_KWARGS = {
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def audio_assets():
|
||||
from vllm.assets.audio import AudioAsset
|
||||
return [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=("mary_had_lamb", "winning_call"))
|
||||
def audio(request):
|
||||
from vllm.assets.audio import AudioAsset
|
||||
@ -59,7 +53,7 @@ def params_kwargs_to_cli_args(params_kwargs: dict[str, Any]) -> list[str]:
|
||||
pytest.param({}, marks=pytest.mark.cpu_model),
|
||||
pytest.param(CHUNKED_PREFILL_KWARGS),
|
||||
])
|
||||
def server(request, audio_assets):
|
||||
def server(request, audio_assets: _AudioAssets):
|
||||
args = [
|
||||
"--dtype", "bfloat16", "--max-model-len", "4096", "--enforce-eager",
|
||||
"--limit-mm-per-prompt",
|
||||
@ -230,8 +224,9 @@ def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int,
|
||||
pytest.param({}, marks=pytest.mark.cpu_model),
|
||||
pytest.param(CHUNKED_PREFILL_KWARGS),
|
||||
])
|
||||
def test_models_with_multiple_audios(vllm_runner, audio_assets, dtype: str,
|
||||
max_tokens: int, num_logprobs: int,
|
||||
def test_models_with_multiple_audios(vllm_runner, audio_assets: _AudioAssets,
|
||||
dtype: str, max_tokens: int,
|
||||
num_logprobs: int,
|
||||
vllm_kwargs: dict) -> None:
|
||||
|
||||
vllm_prompt = _get_prompt(len(audio_assets),
|
||||
@ -250,7 +245,7 @@ def test_models_with_multiple_audios(vllm_runner, audio_assets, dtype: str,
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_online_serving(client, audio_assets):
|
||||
async def test_online_serving(client, audio_assets: _AudioAssets):
|
||||
"""Exercises online serving with/without chunked prefill enabled."""
|
||||
|
||||
messages = [{
|
||||
|
||||
@ -254,6 +254,7 @@ def _test_processing_correctness_mistral(
|
||||
"adept/fuyu-8b",
|
||||
"google/gemma-3-4b-it",
|
||||
"THUDM/glm-4v-9b",
|
||||
"ibm-granite/granite-speech-3.3-8b",
|
||||
"h2oai/h2ovl-mississippi-800m",
|
||||
"OpenGVLab/InternVL2-1B",
|
||||
"HuggingFaceM4/Idefics3-8B-Llama3",
|
||||
|
||||
@ -298,9 +298,11 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
extras={"fork": "Isotr0py/deepseek-vl2-tiny"}, # noqa: E501
|
||||
max_transformers_version="4.48", # noqa: E501
|
||||
transformers_version_reason="HF model is not compatible.", # noqa: E501
|
||||
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501
|
||||
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501
|
||||
"FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"),
|
||||
"Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it"),
|
||||
"GraniteSpeechForConditionalGeneration": _HfExamplesInfo("ibm-granite/granite-speech-3.3-8b", # noqa: E501
|
||||
min_transformers_version="4.52.0"), # noqa: E501
|
||||
"GLM4VForCausalLM": _HfExamplesInfo("THUDM/glm-4v-9b",
|
||||
trust_remote_code=True,
|
||||
hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501
|
||||
|
||||
@ -517,7 +517,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
|
||||
raise TypeError(f"Unknown {modality} model type: {model_type}")
|
||||
elif modality == "audio":
|
||||
if model_type == "ultravox":
|
||||
if model_type in ("ultravox", "granite_speech"):
|
||||
return "<|audio|>"
|
||||
if model_type == "phi4mm":
|
||||
return f"<|audio_{current_count}|>"
|
||||
|
||||
@ -60,6 +60,7 @@ class Blip2QFormerMultiHeadAttention(nn.Module):
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
cache_config: Optional[CacheConfig],
|
||||
is_cross_attention: bool = False,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@ -139,7 +140,7 @@ class Blip2QFormerMultiHeadAttention(nn.Module):
|
||||
|
||||
class Blip2QFormerSelfOutput(nn.Module):
|
||||
|
||||
def __init__(self, config: Blip2QFormerConfig) -> None:
|
||||
def __init__(self, config: Blip2QFormerConfig, prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
@ -167,6 +168,7 @@ class Blip2QFormerAttention(nn.Module):
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
cache_config: Optional[CacheConfig],
|
||||
is_cross_attention: bool = False,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@ -175,9 +177,10 @@ class Blip2QFormerAttention(nn.Module):
|
||||
quant_config=quant_config,
|
||||
cache_config=cache_config,
|
||||
is_cross_attention=is_cross_attention,
|
||||
prefix=f"{prefix}.attention",
|
||||
)
|
||||
|
||||
self.output = Blip2QFormerSelfOutput(config)
|
||||
self.output = Blip2QFormerSelfOutput(config, prefix=f"{prefix}.output")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -195,7 +198,7 @@ class Blip2QFormerAttention(nn.Module):
|
||||
|
||||
class Blip2QFormerIntermediate(nn.Module):
|
||||
|
||||
def __init__(self, config: Blip2QFormerConfig) -> None:
|
||||
def __init__(self, config: Blip2QFormerConfig, prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
|
||||
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||
@ -209,7 +212,7 @@ class Blip2QFormerIntermediate(nn.Module):
|
||||
|
||||
class Blip2QFormerOutput(nn.Module):
|
||||
|
||||
def __init__(self, config: Blip2QFormerConfig) -> None:
|
||||
def __init__(self, config: Blip2QFormerConfig, prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
|
||||
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||
@ -237,6 +240,7 @@ class Blip2QFormerLayer(nn.Module):
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
cache_config: Optional[CacheConfig],
|
||||
layer_idx: int,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@ -244,7 +248,8 @@ class Blip2QFormerLayer(nn.Module):
|
||||
self.seq_len_dim = 1
|
||||
self.attention = Blip2QFormerAttention(config,
|
||||
quant_config=quant_config,
|
||||
cache_config=cache_config)
|
||||
cache_config=cache_config,
|
||||
prefix=f"{prefix}.attention")
|
||||
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
@ -253,13 +258,16 @@ class Blip2QFormerLayer(nn.Module):
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
cache_config=cache_config,
|
||||
is_cross_attention=True)
|
||||
is_cross_attention=True,
|
||||
prefix=f"{prefix}.crossattention")
|
||||
self.has_cross_attention = True
|
||||
else:
|
||||
self.has_cross_attention = False
|
||||
|
||||
self.intermediate_query = Blip2QFormerIntermediate(config)
|
||||
self.output_query = Blip2QFormerOutput(config)
|
||||
self.intermediate_query = Blip2QFormerIntermediate(
|
||||
config, prefix=f"{prefix}.intermediate_query")
|
||||
self.output_query = Blip2QFormerOutput(config,
|
||||
prefix=f"{prefix}.output_query")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -325,6 +333,7 @@ class Blip2QFormerEncoder(nn.Module):
|
||||
*,
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
cache_config: Optional[CacheConfig],
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@ -334,7 +343,8 @@ class Blip2QFormerEncoder(nn.Module):
|
||||
Blip2QFormerLayer(config,
|
||||
quant_config=quant_config,
|
||||
cache_config=cache_config,
|
||||
layer_idx=layer_idx)
|
||||
layer_idx=layer_idx,
|
||||
prefix=f"{prefix}.layer.{layer_idx}")
|
||||
for layer_idx in range(config.num_hidden_layers)
|
||||
])
|
||||
|
||||
@ -365,6 +375,7 @@ class Blip2QFormerModel(nn.Module):
|
||||
*,
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
cache_config: Optional[CacheConfig],
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@ -376,7 +387,8 @@ class Blip2QFormerModel(nn.Module):
|
||||
|
||||
self.encoder = Blip2QFormerEncoder(config,
|
||||
quant_config=quant_config,
|
||||
cache_config=cache_config)
|
||||
cache_config=cache_config,
|
||||
prefix=f"{prefix}.encoder")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -511,7 +523,8 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
|
||||
self.qformer = Blip2QFormerModel(config.qformer_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qformer")
|
||||
|
||||
self.language_projection = nn.Linear(
|
||||
config.qformer_config.hidden_size,
|
||||
|
||||
777
vllm/model_executor/models/granite_speech.py
Normal file
777
vllm/model_executor/models/granite_speech.py
Normal file
@ -0,0 +1,777 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
||||
# Copyright 2025 The vLLM team.
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only IBM Granite speeech model."""
|
||||
import math
|
||||
from typing import Iterable, Mapping, Optional, Set, Tuple, TypedDict, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from transformers import BatchFeature, PretrainedConfig
|
||||
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import get_sampler
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalKwargs)
|
||||
from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems,
|
||||
MultiModalDataParser)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .blip2 import Blip2QFormerModel
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP)
|
||||
from .utils import (AutoWeightsLoader, embed_multimodal,
|
||||
init_vllm_registered_model, maybe_prefix)
|
||||
|
||||
|
||||
### Audio Input
|
||||
class GraniteSpeechAudioInputs(TypedDict):
|
||||
|
||||
input_features: torch.Tensor
|
||||
"""Shape: `(bsz, num_features, 160)`"""
|
||||
|
||||
input_features_mask: torch.Tensor
|
||||
"""Shape: `(bsz, num_features)`"""
|
||||
|
||||
audio_embed_sizes: list[int]
|
||||
"""List of length `bsz`"""
|
||||
|
||||
|
||||
class GraniteSpeechMultiModalProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"audio": 1}
|
||||
|
||||
# There is no limit to the maximum number of audio tokens that can be
|
||||
# encoded as features; we pick ~5000 as a number that is probably higher
|
||||
# than we would expect to encounter. The sequence of length
|
||||
# get_max_audio_len() produces get_max_audio_tokens().
|
||||
def get_max_audio_tokens(self):
|
||||
return 5001
|
||||
|
||||
def get_max_audio_len(self):
|
||||
return 8000000
|
||||
|
||||
|
||||
### Input Processing & Multimodal utils
|
||||
class GraniteSpeechMultiModalProcessor(
|
||||
BaseMultiModalProcessor[GraniteSpeechMultiModalProcessingInfo]):
|
||||
|
||||
def _get_data_parser(self) -> MultiModalDataParser:
|
||||
feature_extractor = self.info.get_hf_processor().audio_processor
|
||||
sampling_rate = feature_extractor.melspec_kwargs["sample_rate"]
|
||||
return MultiModalDataParser(target_sr=sampling_rate)
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
input_features=MultiModalFieldConfig.batched("audio"),
|
||||
audio_embed_sizes=MultiModalFieldConfig.batched("audio"),
|
||||
)
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptUpdate]:
|
||||
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
feature_extractor = processor.audio_processor
|
||||
vocab = tokenizer.get_vocab()
|
||||
|
||||
# Use getattr with default to be compatible with transformers<4.48
|
||||
audio_token = getattr(processor, "audio_token", "<|audio|>")
|
||||
audio_token_id = vocab[audio_token]
|
||||
|
||||
def get_replacement(item_idx: int):
|
||||
audios = mm_items.get_items("audio", AudioProcessorItems)
|
||||
audio = audios.get(item_idx)
|
||||
audio_length = audio.shape[-1]
|
||||
num_projector_features = feature_extractor._get_num_audio_features(
|
||||
[audio_length])[0]
|
||||
return [audio_token_id] * num_projector_features
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="audio",
|
||||
target=[audio_token_id],
|
||||
replacement=get_replacement,
|
||||
)
|
||||
]
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
mm_data = dict(mm_data)
|
||||
audios = mm_data.pop("audios", [])
|
||||
|
||||
if audios:
|
||||
# GraniteSpeechFeatureExtractor accepts "audio"
|
||||
mm_data["audio"] = audios
|
||||
|
||||
processed_outputs = super()._call_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data=mm_data,
|
||||
mm_kwargs=mm_kwargs,
|
||||
)
|
||||
|
||||
if "audio" in mm_data:
|
||||
# Calculate the number of audio tokens per entry in the batch;
|
||||
# This is used to split the batch back out after padding.
|
||||
audio_token_index = self.info.get_hf_config().audio_token_index
|
||||
processed_outputs["audio_embed_sizes"] = [
|
||||
torch.sum(indices == audio_token_index).item()
|
||||
for indices in processed_outputs["input_ids"]
|
||||
]
|
||||
|
||||
return processed_outputs
|
||||
|
||||
|
||||
class GraniteSpeechDummyInputsBuilder(
|
||||
BaseDummyInputsBuilder[GraniteSpeechMultiModalProcessingInfo]):
|
||||
|
||||
def get_dummy_mm_data(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> MultiModalDataDict:
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
return {
|
||||
"audio":
|
||||
self._get_dummy_audios(
|
||||
length=self.info.get_max_audio_len(),
|
||||
num_audios=num_audios,
|
||||
)
|
||||
}
|
||||
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
hf_processor = self.info.get_hf_processor()
|
||||
audio_token = getattr(hf_processor, "audio_token", "<|audio|>")
|
||||
return audio_token * num_audios
|
||||
|
||||
|
||||
### QFormer Projector
|
||||
class GraniteSpeechEncoderProjector(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config: CacheConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = config.projector_config.hidden_size
|
||||
self.downsample_rate = config.downsample_rate
|
||||
self.window_size = config.window_size
|
||||
self.num_queries = config.window_size // config.downsample_rate
|
||||
|
||||
self.query = nn.Parameter(
|
||||
torch.zeros(1, self.num_queries,
|
||||
config.projector_config.hidden_size))
|
||||
|
||||
# NOTE - this is implemented generically in transformers,
|
||||
# but for now we create the QFormer model directly since
|
||||
# all existing models use this for the projector.
|
||||
self.qformer = Blip2QFormerModel(
|
||||
config.projector_config,
|
||||
quant_config=quant_config,
|
||||
cache_config=cache_config,
|
||||
prefix=f"{prefix}.qformer",
|
||||
)
|
||||
self.linear = nn.Linear(config.projector_config.hidden_size,
|
||||
config.text_config.hidden_size)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, seq_len, dim = hidden_states.size()
|
||||
nblocks = math.ceil(seq_len / self.window_size)
|
||||
pad = nblocks * self.window_size - seq_len
|
||||
hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, pad),
|
||||
"constant", 0)
|
||||
hidden_states = hidden_states.view(batch_size * nblocks,
|
||||
self.window_size, dim)
|
||||
|
||||
last_hidden_state = self.qformer(
|
||||
query_embeds=self.query.data,
|
||||
encoder_hidden_states=hidden_states,
|
||||
)
|
||||
|
||||
query_proj = self.linear(
|
||||
last_hidden_state.view(
|
||||
batch_size,
|
||||
nblocks * self.window_size // self.downsample_rate,
|
||||
-1,
|
||||
))
|
||||
return query_proj
|
||||
|
||||
|
||||
# Encoder - conformer is adapted from: https://github.com/lucidrains/conformer.git
|
||||
# NOTE - it would be nice to see if we can align this with other models using
|
||||
# conformer in vLLM, e.g., phi4mm audio.
|
||||
class GraniteSpeechConformerFeedForward(nn.Module):
|
||||
"""Feedforward module for conformer encoder blocks."""
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
self.pre_norm = nn.LayerNorm(config.hidden_dim)
|
||||
|
||||
self.up_proj = ColumnParallelLinear(
|
||||
input_size=config.hidden_dim,
|
||||
output_size=config.hidden_dim * config.feedforward_mult,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.up_proj",
|
||||
)
|
||||
self.silu = nn.SiLU()
|
||||
|
||||
self.down_proj = RowParallelLinear(
|
||||
input_size=config.hidden_dim * config.feedforward_mult,
|
||||
output_size=config.hidden_dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.pre_norm(hidden_states)
|
||||
hidden_states, _ = self.up_proj(hidden_states)
|
||||
hidden_states = self.silu(hidden_states)
|
||||
hidden_states, _ = self.down_proj(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GraniteSpeechConformerAttention(nn.Module):
|
||||
"""Attention for conformer blocks using Shaw's relative positional
|
||||
embeddings. See the following [paper](https://arxiv.org/pdf/1803.02155)
|
||||
for more details.
|
||||
"""
|
||||
|
||||
def __init__(self, config: PretrainedConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
inner_dim = config.dim_head * config.num_heads
|
||||
self.max_pos_emb = config.max_pos_emb
|
||||
self.context_size = config.context_size
|
||||
self.num_heads = config.num_heads
|
||||
self.dim_head = config.dim_head
|
||||
self.scale = self.dim_head**-0.5
|
||||
self.pre_norm = nn.LayerNorm(config.hidden_dim)
|
||||
self.to_q = nn.Linear(config.hidden_dim, inner_dim, bias=False)
|
||||
self.to_kv = nn.Linear(config.hidden_dim, inner_dim * 2, bias=False)
|
||||
self.to_out = nn.Linear(inner_dim, config.hidden_dim)
|
||||
self.rel_pos_emb = nn.Embedding(2 * self.max_pos_emb + 1,
|
||||
self.dim_head)
|
||||
|
||||
if self.context_size <= 0 or self.context_size > self.max_pos_emb:
|
||||
raise ValueError(
|
||||
"Context size is either less than 0 or exceeds the max_pos_emb"
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor,
|
||||
attention_dists: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.pre_norm(hidden_states)
|
||||
bsz, num_features, _ = hidden_states.shape
|
||||
|
||||
num_blocks = math.ceil(num_features / self.context_size)
|
||||
remainder = num_features % self.context_size
|
||||
if remainder > 0:
|
||||
# right padding to reach block size
|
||||
hidden_states = torch.nn.functional.pad(
|
||||
hidden_states, (0, 0, 0, self.context_size - remainder))
|
||||
|
||||
# NOTE: would be nice to try to use qkvparallellinear
|
||||
# here for this block attention implementation if possible
|
||||
query_states = self.to_q(hidden_states)
|
||||
key_states, value_states = self.to_kv(hidden_states).chunk(2, dim=-1)
|
||||
|
||||
query_states = query_states.reshape(bsz, num_blocks, self.context_size,
|
||||
self.num_heads,
|
||||
-1).transpose(2, 3)
|
||||
key_states = key_states.reshape(bsz, num_blocks, self.context_size,
|
||||
self.num_heads, -1).transpose(2, 3)
|
||||
value_states = value_states.reshape(bsz, num_blocks, self.context_size,
|
||||
self.num_heads,
|
||||
-1).transpose(2, 3)
|
||||
|
||||
# shaw's relative positional embedding
|
||||
dist = attention_dists.to(hidden_states.device)
|
||||
rel_pos_emb = self.rel_pos_emb(dist)
|
||||
rel_pos_emb_expanded = rel_pos_emb.view([1, 1, 1] +
|
||||
list(rel_pos_emb.shape))
|
||||
pos_attn = torch.sum(query_states.unsqueeze(-2) * rel_pos_emb_expanded,
|
||||
dim=-1) * self.scale
|
||||
|
||||
if remainder > 0:
|
||||
# masked attention in the extended block
|
||||
mask = torch.ones(self.context_size,
|
||||
self.context_size,
|
||||
dtype=bool,
|
||||
device=hidden_states.device)
|
||||
mask[:remainder, :remainder] = 0
|
||||
mask_value = -torch.finfo(pos_attn.dtype).max
|
||||
pos_attn[:, -1, :].masked_fill_(mask, mask_value)
|
||||
|
||||
with torch.nn.attention.sdpa_kernel(
|
||||
torch.nn.attention.SDPBackend.MATH):
|
||||
out = F.scaled_dot_product_attention(query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=pos_attn,
|
||||
scale=self.scale)
|
||||
out = out.transpose(2, 3).reshape(bsz, hidden_states.shape[1], -1)
|
||||
return self.to_out(out[:, :num_features, :])
|
||||
|
||||
|
||||
class GraniteSpeechConformerDepthWiseConv1d(nn.Module):
|
||||
"""Wrapper for padded 1D pointwise convolution."""
|
||||
|
||||
def __init__(self,
|
||||
chan_in: int,
|
||||
chan_out: int,
|
||||
kernel_size: int,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
# Padding for the 1D conv is symmetric or close (i.e., offset by one).
|
||||
pad = kernel_size // 2
|
||||
pad_offset = (kernel_size + 1) % 2
|
||||
self.padding = (pad, pad - pad_offset)
|
||||
|
||||
self.conv = nn.Conv1d(chan_in,
|
||||
chan_out,
|
||||
kernel_size,
|
||||
groups=chan_in,
|
||||
bias=False)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = F.pad(hidden_states, self.padding)
|
||||
return self.conv(hidden_states)
|
||||
|
||||
|
||||
class GraniteSpeechConformerConvModule(nn.Module):
|
||||
"""Conformer conv module consisting of several 1D/depthwise 1D
|
||||
convolutional layers.
|
||||
"""
|
||||
|
||||
def __init__(self, config: PretrainedConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
inner_dim = config.hidden_dim * config.conv_expansion_factor
|
||||
|
||||
self.norm = nn.LayerNorm(config.hidden_dim)
|
||||
self.up_conv = nn.Conv1d(config.hidden_dim, inner_dim * 2, 1)
|
||||
self.glu = nn.GLU(dim=1)
|
||||
self.depth_conv = GraniteSpeechConformerDepthWiseConv1d(
|
||||
inner_dim,
|
||||
inner_dim,
|
||||
kernel_size=config.conv_kernel_size,
|
||||
prefix=f"{prefix}.depth_conv",
|
||||
)
|
||||
self.silu = nn.SiLU()
|
||||
self.batch_norm = nn.BatchNorm1d(inner_dim)
|
||||
self.down_conv = nn.Conv1d(inner_dim, config.hidden_dim, 1)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.norm(hidden_states)
|
||||
hidden_states = self.up_conv(hidden_states.permute(0, 2, 1))
|
||||
hidden_states = self.glu(hidden_states)
|
||||
hidden_states = self.depth_conv(hidden_states)
|
||||
hidden_states = self.silu(self.batch_norm(hidden_states))
|
||||
hidden_states = self.down_conv(hidden_states).permute(0, 2, 1)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GraniteSpeechConformerBlock(nn.Module):
|
||||
"""Conformer block, consisting largely of linear layers,
|
||||
attention, and convolutional layers."""
|
||||
|
||||
def __init__(self, config: PretrainedConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
self.ff1 = GraniteSpeechConformerFeedForward(config,
|
||||
prefix=f"{prefix}.ff1")
|
||||
self.attn = GraniteSpeechConformerAttention(config,
|
||||
prefix=f"{prefix}.attn")
|
||||
self.conv = GraniteSpeechConformerConvModule(config,
|
||||
prefix=f"{prefix}.conv")
|
||||
self.ff2 = GraniteSpeechConformerFeedForward(config,
|
||||
prefix=f"{prefix}.ff2")
|
||||
self.post_norm = nn.LayerNorm(config.hidden_dim)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor,
|
||||
attention_dists: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = 0.5 * self.ff1(hidden_states) + hidden_states
|
||||
hidden_states = self.attn(
|
||||
hidden_states, attention_dists=attention_dists) + hidden_states
|
||||
hidden_states = self.conv(hidden_states) + hidden_states
|
||||
hidden_states = 0.5 * self.ff2(hidden_states) + hidden_states
|
||||
hidden_states = self.post_norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GraniteSpeechCTCEncoder(nn.Module):
|
||||
"""CTC Encoder comprising conformer blocks and additional linear layers."""
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
prefix: str,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
# Precompute clamped relative positional encoding distances
|
||||
seq = torch.arange(config.context_size)
|
||||
relpos_dist = seq.view(-1, 1) - seq.view(1, -1)
|
||||
self.attention_dists = torch.clamp(
|
||||
relpos_dist, -config.context_size,
|
||||
config.context_size) + config.max_pos_emb
|
||||
|
||||
self.input_linear = nn.Linear(config.input_dim,
|
||||
config.hidden_dim,
|
||||
bias=True)
|
||||
self.layers = nn.ModuleList([
|
||||
GraniteSpeechConformerBlock(
|
||||
config,
|
||||
prefix=f"{prefix}.layers.{idx}",
|
||||
) for idx in range(config.num_layers)
|
||||
])
|
||||
|
||||
self.out = ColumnParallelLinear(
|
||||
input_size=config.hidden_dim,
|
||||
output_size=config.output_dim,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.out",
|
||||
)
|
||||
|
||||
self.out_mid = RowParallelLinear(
|
||||
input_size=config.output_dim,
|
||||
output_size=config.hidden_dim,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.out_mid",
|
||||
)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
self.num_layers = config.num_layers
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor):
|
||||
hidden_states = self.input_linear(hidden_states)
|
||||
for idx, layer in enumerate(self.layers, start=1):
|
||||
hidden_states = layer(hidden_states,
|
||||
attention_dists=self.attention_dists)
|
||||
|
||||
if idx == self.num_layers // 2:
|
||||
hidden_states_mid = hidden_states.clone()
|
||||
hidden_states_mid, _ = self.out(hidden_states_mid)
|
||||
hidden_states_mid = self.softmax(hidden_states_mid)
|
||||
hidden_states_mid, _ = self.out_mid(hidden_states_mid)
|
||||
hidden_states += hidden_states_mid
|
||||
return hidden_states
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
GraniteSpeechMultiModalProcessor,
|
||||
info=GraniteSpeechMultiModalProcessingInfo,
|
||||
dummy_inputs=GraniteSpeechDummyInputsBuilder)
|
||||
class GraniteSpeechForConditionalGeneration(
|
||||
nn.Module,
|
||||
SupportsMultiModal,
|
||||
SupportsPP,
|
||||
SupportsLoRA,
|
||||
):
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
cache_config = vllm_config.cache_config
|
||||
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.cache_config = cache_config
|
||||
self.sampler = get_sampler()
|
||||
|
||||
# The language model is typically a Granite LLM
|
||||
self.language_model = init_vllm_registered_model(
|
||||
vllm_config=vllm_config,
|
||||
hf_config=config.text_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
|
||||
# Conformer encoder
|
||||
self.encoder = GraniteSpeechCTCEncoder(
|
||||
config=config.encoder_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.encoder",
|
||||
)
|
||||
|
||||
# Blip2 QFormer
|
||||
self.projector = GraniteSpeechEncoderProjector(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
cache_config=cache_config,
|
||||
prefix=f"{prefix}.projector",
|
||||
)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors)
|
||||
|
||||
def _parse_and_validate_audio_input(
|
||||
self,
|
||||
**kwargs: object,
|
||||
) -> Optional[GraniteSpeechAudioInputs]:
|
||||
input_features = kwargs.pop("input_features", None)
|
||||
input_features_mask = kwargs.pop("input_features_mask", None)
|
||||
audio_embed_sizes = kwargs.pop("audio_embed_sizes", None)
|
||||
if input_features is None:
|
||||
return None
|
||||
|
||||
# If we have a batch of variable feature length audio clips, we need
|
||||
# to mask the features; usually we would get an input_features_mask
|
||||
# from the processor, but we handle rebuilding it here since
|
||||
# vLLM generally processes everything independently + batches.
|
||||
if input_features_mask is None:
|
||||
input_features_mask = self._build_input_features_mask(
|
||||
audio_embed_sizes)
|
||||
|
||||
if not isinstance(input_features, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of audio input features. "
|
||||
f"Got type: {type(input_features)}")
|
||||
|
||||
if input_features_mask is not None and not isinstance(
|
||||
input_features_mask, torch.Tensor):
|
||||
raise ValueError("Incorrect type of audio input features mask. "
|
||||
f"Got type: {type(input_features_mask)}")
|
||||
|
||||
if isinstance(input_features, torch.Tensor):
|
||||
# Granite speech currently only allows one audio token per instance
|
||||
# and features are already unsqueezed in the processor, so one
|
||||
# instance will have shape [1, {num_features}, 160]. As such,
|
||||
# input features will usually be of shape
|
||||
# [bsz, 1, num_features, 160], which we squeeze to be 3D here.
|
||||
if len(input_features.shape) == 4:
|
||||
input_features = input_features.squeeze(1)
|
||||
if len(input_features.shape) != 3:
|
||||
raise ValueError(
|
||||
"Squeezed input features should be 3D but are of shape "
|
||||
f"{input_features.shape}")
|
||||
input_features = input_features.to(
|
||||
self.encoder.input_linear.weight.dtype)
|
||||
|
||||
else:
|
||||
# Otherwise we have a list of tensors, which are almost certainly
|
||||
# differing in their respective numbers of audio features;
|
||||
# stack them into a 3D tensor of size [bsz, most_num_features, 160].
|
||||
input_features = self._pad_and_stack_input_features(
|
||||
input_features, ).to(self.encoder.input_linear.weight.dtype)
|
||||
|
||||
return GraniteSpeechAudioInputs(
|
||||
input_features=input_features,
|
||||
input_features_mask=input_features_mask,
|
||||
audio_embed_sizes=audio_embed_sizes.flatten().tolist(),
|
||||
)
|
||||
|
||||
def _build_input_features_mask(
|
||||
self,
|
||||
audio_embed_sizes: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Calculate the input features mask, which will generally be used
|
||||
to mask the the padded features for all entries in the batch except
|
||||
for those with the most audio features.
|
||||
|
||||
Args:
|
||||
audio_embed_sizes: torch.Tensor
|
||||
Tensor of num features in each seq in the batch.
|
||||
Returns:
|
||||
torch.Tensor: Mask of shape (bsz, num_features) to be applied to
|
||||
the audio features prior to splitting the audio embeddings.
|
||||
"""
|
||||
most_audio_features = torch.max(audio_embed_sizes).item()
|
||||
mask_indices = torch.arange(
|
||||
most_audio_features,
|
||||
device=audio_embed_sizes.device,
|
||||
).view(1, -1)
|
||||
input_features_mask = mask_indices < audio_embed_sizes.view(-1, 1)
|
||||
return input_features_mask
|
||||
|
||||
def _pad_and_stack_input_features(
|
||||
self,
|
||||
input_features: list[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""Given a list of input features of varying length, pad them to the
|
||||
same length and stack them into a torch.Tensor.
|
||||
|
||||
NOTE: Usually, padding is done in the input processor/feature extractor
|
||||
and zero padded prior to the computation of the Mel features; the
|
||||
resulting values are only constant within a batch and generally nonzero
|
||||
(i.e., slightly negative nums); we should validate that this is okay
|
||||
since we don't use a feature attention mask, but the more important
|
||||
thing is that we apply the input_features_mask with variable len
|
||||
batches.
|
||||
|
||||
Args:
|
||||
input_features: list[torch.Tensor]
|
||||
Input features to be coerced into a tensor.
|
||||
Returns:
|
||||
torch.Tensor: Tensor of shape [bsz, num_features, 160], where
|
||||
num_features is the max number of features of any entry in the
|
||||
batch.
|
||||
"""
|
||||
# Input features are of shape [bsz, num_features, 160]
|
||||
feat_lens = [feats.shape[1] for feats in input_features]
|
||||
padding = [max(feat_lens) - length for length in feat_lens]
|
||||
# TODO (Alex) - Validate that it's okay to zero pad like this;
|
||||
# in transformers we zero pad prior to calculating the speech features,
|
||||
# so the value is not zero and is dependent on the batched features.
|
||||
padded = [
|
||||
torch.nn.functional.pad(feats, (0, 0, 0, pad, 0, 0))
|
||||
for feats, pad in zip(input_features, padding)
|
||||
]
|
||||
stacked_features = torch.cat(padded, dim=0).to(input_features[0])
|
||||
return stacked_features
|
||||
|
||||
def _process_audio_input(
|
||||
self,
|
||||
audio_input: GraniteSpeechAudioInputs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
"""Compute the audio features to be merged into the LLM embeddings.
|
||||
|
||||
Args:
|
||||
audio_input: GraniteSpeechAudioInputs
|
||||
Audio inputs object containing Mel features, an input features
|
||||
mask, and the (flattened) number of audio tokens per instance.
|
||||
Returns:
|
||||
tuple[torch.Tensor]: List of length bsz.
|
||||
"""
|
||||
# TODO (Alex) - support embedding inputs
|
||||
encoder_embeds = self.encoder(audio_input["input_features"])
|
||||
# [bsz, <max feature size>, 4096]
|
||||
projected_embeds = self.projector(encoder_embeds)
|
||||
# Apply mask on variable length audio features
|
||||
masked_embeds = projected_embeds[audio_input["input_features_mask"]]
|
||||
# Split variable length features into a tuple
|
||||
return torch.split(masked_embeds, audio_input["audio_embed_sizes"])
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self,
|
||||
**kwargs: object,
|
||||
) -> Optional[MultiModalEmbeddings]:
|
||||
"""Compute the audio embeddings if audio inputs are present."""
|
||||
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
||||
if audio_input is None:
|
||||
return None
|
||||
audio_features = self._process_audio_input(audio_input)
|
||||
return audio_features
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Compute the merged LLM / audio embeddings."""
|
||||
if multimodal_embeddings is None:
|
||||
return self.language_model.get_input_embeddings(input_ids)
|
||||
|
||||
inputs_embeds = embed_multimodal(
|
||||
input_ids,
|
||||
self.config.audio_token_index,
|
||||
self.language_model.model.get_input_embeddings,
|
||||
multimodal_embeddings,
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
audio_embeds = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids, audio_embeds)
|
||||
input_ids = None
|
||||
|
||||
model_output = self.language_model(input_ids, positions,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
return model_output
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
return self.language_model.compute_logits(
|
||||
hidden_states,
|
||||
sampling_metadata,
|
||||
)
|
||||
|
||||
def load_weights(
|
||||
self,
|
||||
weights: Iterable[Tuple[str, torch.Tensor]],
|
||||
) -> Set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
def get_mm_mapping(self) -> MultiModelKeys:
|
||||
"""Get the module prefix in multimodal models."""
|
||||
return MultiModelKeys.from_string_field(
|
||||
language_model="language_model",
|
||||
connector="projector",
|
||||
tower_model="encoder",
|
||||
)
|
||||
@ -178,6 +178,7 @@ _MULTIMODAL_MODELS = {
|
||||
"FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
|
||||
"Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"), # noqa: E501
|
||||
"GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
|
||||
"GraniteSpeechForConditionalGeneration": ("granite_speech", "GraniteSpeechForConditionalGeneration"), # noqa: E501
|
||||
"H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
|
||||
"InternVLChatModel": ("internvl", "InternVLChatModel"),
|
||||
"Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user