mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-21 04:45:01 +08:00
819 lines
30 KiB
Python
819 lines
30 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
# Adapted from
|
|
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
|
# Copyright 2025 The vLLM team.
|
|
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
|
# and OPT implementations in this library. It has been modified from its
|
|
# original forms to accommodate minor architectural differences compared
|
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Inference-only IBM Granite speech model."""
|
|
|
|
import math
|
|
from collections.abc import Iterable, Mapping
|
|
from typing import Annotated
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch import nn
|
|
from transformers import BatchFeature, PretrainedConfig
|
|
|
|
from vllm.config import CacheConfig, VllmConfig
|
|
from vllm.config.multimodal import BaseDummyOptions
|
|
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
|
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
|
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
|
from vllm.multimodal.inputs import (
|
|
MultiModalDataDict,
|
|
MultiModalFieldConfig,
|
|
MultiModalKwargsItems,
|
|
)
|
|
from vllm.multimodal.parse import (
|
|
AudioProcessorItems,
|
|
MultiModalDataItems,
|
|
MultiModalDataParser,
|
|
)
|
|
from vllm.multimodal.processing import (
|
|
BaseMultiModalProcessor,
|
|
BaseProcessingInfo,
|
|
PromptReplacement,
|
|
PromptUpdate,
|
|
)
|
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
|
from vllm.sequence import IntermediateTensors
|
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
|
|
|
from .blip2 import Blip2QFormerModel
|
|
from .interfaces import (
|
|
MultiModalEmbeddings,
|
|
SupportsLoRA,
|
|
SupportsMultiModal,
|
|
SupportsPP,
|
|
)
|
|
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
|
|
|
|
|
|
### Audio Input
|
|
class GraniteSpeechAudioInputs(TensorSchema):
|
|
"""
|
|
Audio input features for Granite Speech model.
|
|
|
|
Dimensions:
|
|
- b: Batch size
|
|
- fi: Number of input features from the Mel spectrogram.
|
|
- fo: Number of output features, i.e. the embedding size.
|
|
- 160: Fixed feature dimension for Mel spectrogram features
|
|
"""
|
|
|
|
input_features: Annotated[torch.Tensor, TensorShape("b", "fi", 160)]
|
|
"""Audio input features."""
|
|
|
|
input_features_mask: Annotated[torch.Tensor, TensorShape("b", "fo")]
|
|
"""Mask for variable length audio features."""
|
|
|
|
audio_embed_sizes: Annotated[list[int], TensorShape("b")]
|
|
"""List of audio embedding sizes for each item in batch."""
|
|
|
|
|
|
class GraniteSpeechMultiModalProcessingInfo(BaseProcessingInfo):
|
|
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
|
return {"audio": 1}
|
|
|
|
# There is no limit to the maximum number of audio tokens that can be
|
|
# encoded as features; we pick ~5000 as a number that is probably higher
|
|
# than we would expect to encounter. The sequence of length
|
|
# get_max_audio_len() produces get_max_audio_tokens().
|
|
def get_max_audio_tokens(self):
|
|
return 5001
|
|
|
|
def get_max_audio_len(self):
|
|
return 8000000
|
|
|
|
|
|
### Input Processing & Multimodal utils
|
|
class GraniteSpeechMultiModalProcessor(
|
|
BaseMultiModalProcessor[GraniteSpeechMultiModalProcessingInfo]
|
|
):
|
|
def _get_data_parser(self) -> MultiModalDataParser:
|
|
feature_extractor = self.info.get_hf_processor().audio_processor
|
|
sampling_rate = feature_extractor.melspec_kwargs["sample_rate"]
|
|
return MultiModalDataParser(target_sr=sampling_rate)
|
|
|
|
def _get_mm_fields_config(
|
|
self,
|
|
hf_inputs: BatchFeature,
|
|
hf_processor_mm_kwargs: Mapping[str, object],
|
|
) -> Mapping[str, MultiModalFieldConfig]:
|
|
return dict(
|
|
input_features=MultiModalFieldConfig.batched("audio"),
|
|
audio_embed_sizes=MultiModalFieldConfig.batched("audio"),
|
|
)
|
|
|
|
def _get_prompt_updates(
|
|
self,
|
|
mm_items: MultiModalDataItems,
|
|
hf_processor_mm_kwargs: Mapping[str, object],
|
|
out_mm_kwargs: MultiModalKwargsItems,
|
|
) -> list[PromptUpdate]:
|
|
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
|
tokenizer = self.info.get_tokenizer()
|
|
feature_extractor = processor.audio_processor
|
|
vocab = tokenizer.get_vocab()
|
|
|
|
# Use getattr with default to be compatible with transformers<4.48
|
|
audio_token = getattr(processor, "audio_token", "<|audio|>")
|
|
audio_token_id = vocab[audio_token]
|
|
|
|
def get_replacement(item_idx: int):
|
|
audios = mm_items.get_items("audio", AudioProcessorItems)
|
|
audio = audios.get(item_idx)
|
|
audio_length = audio.shape[-1]
|
|
num_projector_features = feature_extractor._get_num_audio_features(
|
|
[audio_length]
|
|
)[0]
|
|
return [audio_token_id] * num_projector_features
|
|
|
|
return [
|
|
PromptReplacement(
|
|
modality="audio",
|
|
target=[audio_token_id],
|
|
replacement=get_replacement,
|
|
)
|
|
]
|
|
|
|
def _call_hf_processor(
|
|
self,
|
|
prompt: str,
|
|
mm_data: Mapping[str, object],
|
|
mm_kwargs: Mapping[str, object],
|
|
tok_kwargs: Mapping[str, object],
|
|
) -> BatchFeature:
|
|
mm_data = dict(mm_data)
|
|
audios = mm_data.pop("audios", [])
|
|
|
|
if audios:
|
|
# GraniteSpeechFeatureExtractor accepts "audio"
|
|
mm_data["audio"] = audios
|
|
|
|
processed_outputs = super()._call_hf_processor(
|
|
prompt=prompt,
|
|
mm_data=mm_data,
|
|
mm_kwargs=mm_kwargs,
|
|
tok_kwargs=tok_kwargs,
|
|
)
|
|
|
|
if "audio" in mm_data:
|
|
# Calculate the number of audio tokens per entry in the batch;
|
|
# This is used to split the batch back out after padding.
|
|
audio_token_index = self.info.get_hf_config().audio_token_index
|
|
processed_outputs["audio_embed_sizes"] = (
|
|
processed_outputs["input_ids"] == audio_token_index
|
|
).sum(-1)
|
|
|
|
return processed_outputs
|
|
|
|
|
|
class GraniteSpeechDummyInputsBuilder(
|
|
BaseDummyInputsBuilder[GraniteSpeechMultiModalProcessingInfo]
|
|
):
|
|
def get_dummy_mm_data(
|
|
self,
|
|
seq_len: int,
|
|
mm_counts: Mapping[str, int],
|
|
mm_options: Mapping[str, BaseDummyOptions] | None = None,
|
|
) -> MultiModalDataDict:
|
|
num_audios = mm_counts.get("audio", 0)
|
|
audio_overrides = mm_options.get("audio") if mm_options else None
|
|
|
|
return {
|
|
"audio": self._get_dummy_audios(
|
|
length=self.info.get_max_audio_len(),
|
|
num_audios=num_audios,
|
|
overrides=audio_overrides,
|
|
)
|
|
}
|
|
|
|
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
|
num_audios = mm_counts.get("audio", 0)
|
|
hf_processor = self.info.get_hf_processor()
|
|
audio_token = getattr(hf_processor, "audio_token", "<|audio|>")
|
|
return audio_token * num_audios
|
|
|
|
|
|
### QFormer Projector
|
|
class GraniteSpeechEncoderProjector(nn.Module):
|
|
def __init__(
|
|
self,
|
|
config: PretrainedConfig,
|
|
cache_config: CacheConfig,
|
|
quant_config: QuantizationConfig | None = None,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
self.hidden_size = config.projector_config.hidden_size
|
|
self.downsample_rate = config.downsample_rate
|
|
self.window_size = config.window_size
|
|
self.num_queries = config.window_size // config.downsample_rate
|
|
|
|
self.query = nn.Parameter(
|
|
torch.zeros(1, self.num_queries, config.projector_config.hidden_size)
|
|
)
|
|
|
|
# NOTE - this is implemented generically in transformers,
|
|
# but for now we create the QFormer model directly since
|
|
# all existing models use this for the projector.
|
|
self.qformer = Blip2QFormerModel(
|
|
config.projector_config,
|
|
quant_config=quant_config,
|
|
cache_config=cache_config,
|
|
prefix=f"{prefix}.qformer",
|
|
)
|
|
self.linear = nn.Linear(
|
|
config.projector_config.hidden_size, config.text_config.hidden_size
|
|
)
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
batch_size, seq_len, dim = hidden_states.size()
|
|
nblocks = math.ceil(seq_len / self.window_size)
|
|
pad = nblocks * self.window_size - seq_len
|
|
hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, pad), "constant", 0)
|
|
hidden_states = hidden_states.view(batch_size * nblocks, self.window_size, dim)
|
|
|
|
last_hidden_state = self.qformer(
|
|
query_embeds=self.query.data,
|
|
encoder_hidden_states=hidden_states,
|
|
)
|
|
|
|
query_proj = self.linear(
|
|
last_hidden_state.view(
|
|
batch_size,
|
|
nblocks * self.window_size // self.downsample_rate,
|
|
-1,
|
|
)
|
|
)
|
|
return query_proj
|
|
|
|
|
|
# Encoder - conformer is adapted from: https://github.com/lucidrains/conformer.git
|
|
# NOTE - it would be nice to see if we can align this with other models using
|
|
# conformer in vLLM, e.g., phi4mm audio.
|
|
class GraniteSpeechConformerFeedForward(nn.Module):
|
|
"""Feedforward module for conformer encoder blocks."""
|
|
|
|
def __init__(
|
|
self,
|
|
config: PretrainedConfig,
|
|
quant_config: QuantizationConfig | None = None,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
self.pre_norm = nn.LayerNorm(config.hidden_dim)
|
|
|
|
self.up_proj = ColumnParallelLinear(
|
|
input_size=config.hidden_dim,
|
|
output_size=config.hidden_dim * config.feedforward_mult,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.up_proj",
|
|
)
|
|
self.silu = nn.SiLU()
|
|
|
|
self.down_proj = RowParallelLinear(
|
|
input_size=config.hidden_dim * config.feedforward_mult,
|
|
output_size=config.hidden_dim,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.down_proj",
|
|
)
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
hidden_states = self.pre_norm(hidden_states)
|
|
hidden_states, _ = self.up_proj(hidden_states)
|
|
hidden_states = self.silu(hidden_states)
|
|
hidden_states, _ = self.down_proj(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class GraniteSpeechConformerAttention(nn.Module):
|
|
"""Attention for conformer blocks using Shaw's relative positional
|
|
embeddings. See the following [paper](https://arxiv.org/pdf/1803.02155)
|
|
for more details.
|
|
"""
|
|
|
|
def __init__(self, config: PretrainedConfig, prefix: str = ""):
|
|
super().__init__()
|
|
|
|
inner_dim = config.dim_head * config.num_heads
|
|
self.max_pos_emb = config.max_pos_emb
|
|
self.context_size = config.context_size
|
|
self.num_heads = config.num_heads
|
|
self.dim_head = config.dim_head
|
|
self.scale = self.dim_head**-0.5
|
|
self.pre_norm = nn.LayerNorm(config.hidden_dim)
|
|
self.to_q = nn.Linear(config.hidden_dim, inner_dim, bias=False)
|
|
self.to_kv = nn.Linear(config.hidden_dim, inner_dim * 2, bias=False)
|
|
self.to_out = nn.Linear(inner_dim, config.hidden_dim)
|
|
self.rel_pos_emb = nn.Embedding(2 * self.max_pos_emb + 1, self.dim_head)
|
|
|
|
if self.context_size <= 0 or self.context_size > self.max_pos_emb:
|
|
raise ValueError(
|
|
"Context size is either less than 0 or exceeds the max_pos_emb"
|
|
)
|
|
|
|
def forward(
|
|
self, hidden_states: torch.Tensor, attention_dists: torch.Tensor
|
|
) -> torch.Tensor:
|
|
hidden_states = self.pre_norm(hidden_states)
|
|
bsz, num_features, _ = hidden_states.shape
|
|
|
|
num_blocks = math.ceil(num_features / self.context_size)
|
|
remainder = num_features % self.context_size
|
|
if remainder > 0:
|
|
# right padding to reach block size
|
|
hidden_states = torch.nn.functional.pad(
|
|
hidden_states, (0, 0, 0, self.context_size - remainder)
|
|
)
|
|
|
|
# NOTE: would be nice to try to use qkvparallellinear
|
|
# here for this block attention implementation if possible
|
|
query_states = self.to_q(hidden_states)
|
|
key_states, value_states = self.to_kv(hidden_states).chunk(2, dim=-1)
|
|
|
|
query_states = query_states.reshape(
|
|
bsz, num_blocks, self.context_size, self.num_heads, -1
|
|
).transpose(2, 3)
|
|
key_states = key_states.reshape(
|
|
bsz, num_blocks, self.context_size, self.num_heads, -1
|
|
).transpose(2, 3)
|
|
value_states = value_states.reshape(
|
|
bsz, num_blocks, self.context_size, self.num_heads, -1
|
|
).transpose(2, 3)
|
|
|
|
# shaw's relative positional embedding
|
|
dist = attention_dists.to(hidden_states.device)
|
|
rel_pos_emb = self.rel_pos_emb(dist)
|
|
rel_pos_emb_expanded = rel_pos_emb.view([1, 1, 1] + list(rel_pos_emb.shape))
|
|
pos_attn = (
|
|
torch.sum(query_states.unsqueeze(-2) * rel_pos_emb_expanded, dim=-1)
|
|
* self.scale
|
|
)
|
|
|
|
if remainder > 0:
|
|
# masked attention in the extended block
|
|
mask = torch.ones(
|
|
self.context_size,
|
|
self.context_size,
|
|
dtype=bool,
|
|
device=hidden_states.device,
|
|
)
|
|
mask[:remainder, :remainder] = 0
|
|
mask_value = -torch.finfo(pos_attn.dtype).max
|
|
pos_attn[:, -1, :].masked_fill_(mask, mask_value)
|
|
|
|
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH):
|
|
out = F.scaled_dot_product_attention(
|
|
query_states,
|
|
key_states,
|
|
value_states,
|
|
attn_mask=pos_attn,
|
|
scale=self.scale,
|
|
)
|
|
out = out.transpose(2, 3).reshape(bsz, hidden_states.shape[1], -1)
|
|
return self.to_out(out[:, :num_features, :])
|
|
|
|
|
|
class GraniteSpeechConformerDepthWiseConv1d(nn.Module):
|
|
"""Wrapper for padded 1D pointwise convolution."""
|
|
|
|
def __init__(self, chan_in: int, chan_out: int, kernel_size: int, prefix: str = ""):
|
|
super().__init__()
|
|
# Padding for the 1D conv is symmetric or close (i.e., offset by one).
|
|
pad = kernel_size // 2
|
|
pad_offset = (kernel_size + 1) % 2
|
|
self.padding = (pad, pad - pad_offset)
|
|
|
|
self.conv = nn.Conv1d(
|
|
chan_in, chan_out, kernel_size, groups=chan_in, bias=False
|
|
)
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
hidden_states = F.pad(hidden_states, self.padding)
|
|
return self.conv(hidden_states)
|
|
|
|
|
|
class GraniteSpeechConformerConvModule(nn.Module):
|
|
"""Conformer conv module consisting of several 1D/depthwise 1D
|
|
convolutional layers.
|
|
"""
|
|
|
|
def __init__(self, config: PretrainedConfig, prefix: str = ""):
|
|
super().__init__()
|
|
inner_dim = config.hidden_dim * config.conv_expansion_factor
|
|
|
|
self.norm = nn.LayerNorm(config.hidden_dim)
|
|
self.up_conv = nn.Conv1d(config.hidden_dim, inner_dim * 2, 1)
|
|
self.glu = nn.GLU(dim=1)
|
|
self.depth_conv = GraniteSpeechConformerDepthWiseConv1d(
|
|
inner_dim,
|
|
inner_dim,
|
|
kernel_size=config.conv_kernel_size,
|
|
prefix=f"{prefix}.depth_conv",
|
|
)
|
|
self.silu = nn.SiLU()
|
|
self.batch_norm = nn.BatchNorm1d(inner_dim)
|
|
self.down_conv = nn.Conv1d(inner_dim, config.hidden_dim, 1)
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
hidden_states = self.norm(hidden_states)
|
|
hidden_states = self.up_conv(hidden_states.permute(0, 2, 1))
|
|
hidden_states = self.glu(hidden_states)
|
|
hidden_states = self.depth_conv(hidden_states)
|
|
hidden_states = self.silu(self.batch_norm(hidden_states))
|
|
hidden_states = self.down_conv(hidden_states).permute(0, 2, 1)
|
|
return hidden_states
|
|
|
|
|
|
class GraniteSpeechConformerBlock(nn.Module):
|
|
"""Conformer block, consisting largely of linear layers,
|
|
attention, and convolutional layers."""
|
|
|
|
def __init__(self, config: PretrainedConfig, prefix: str = ""):
|
|
super().__init__()
|
|
self.ff1 = GraniteSpeechConformerFeedForward(config, prefix=f"{prefix}.ff1")
|
|
self.attn = GraniteSpeechConformerAttention(config, prefix=f"{prefix}.attn")
|
|
self.conv = GraniteSpeechConformerConvModule(config, prefix=f"{prefix}.conv")
|
|
self.ff2 = GraniteSpeechConformerFeedForward(config, prefix=f"{prefix}.ff2")
|
|
self.post_norm = nn.LayerNorm(config.hidden_dim)
|
|
|
|
def forward(
|
|
self, hidden_states: torch.Tensor, attention_dists: torch.Tensor
|
|
) -> torch.Tensor:
|
|
hidden_states = 0.5 * self.ff1(hidden_states) + hidden_states
|
|
hidden_states = (
|
|
self.attn(hidden_states, attention_dists=attention_dists) + hidden_states
|
|
)
|
|
hidden_states = self.conv(hidden_states) + hidden_states
|
|
hidden_states = 0.5 * self.ff2(hidden_states) + hidden_states
|
|
hidden_states = self.post_norm(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class GraniteSpeechCTCEncoder(nn.Module):
|
|
"""CTC Encoder comprising conformer blocks and additional linear layers."""
|
|
|
|
def __init__(
|
|
self,
|
|
config: PretrainedConfig,
|
|
prefix: str,
|
|
quant_config: QuantizationConfig | None = None,
|
|
):
|
|
super().__init__()
|
|
self.config = config
|
|
|
|
# Precompute clamped relative positional encoding distances
|
|
seq = torch.arange(config.context_size)
|
|
relpos_dist = seq.view(-1, 1) - seq.view(1, -1)
|
|
self.attention_dists = (
|
|
torch.clamp(relpos_dist, -config.context_size, config.context_size)
|
|
+ config.max_pos_emb
|
|
)
|
|
|
|
self.input_linear = nn.Linear(config.input_dim, config.hidden_dim, bias=True)
|
|
self.layers = nn.ModuleList(
|
|
[
|
|
GraniteSpeechConformerBlock(
|
|
config,
|
|
prefix=f"{prefix}.layers.{idx}",
|
|
)
|
|
for idx in range(config.num_layers)
|
|
]
|
|
)
|
|
|
|
self.out = ColumnParallelLinear(
|
|
input_size=config.hidden_dim,
|
|
output_size=config.output_dim,
|
|
bias=True,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.out",
|
|
)
|
|
|
|
self.out_mid = RowParallelLinear(
|
|
input_size=config.output_dim,
|
|
output_size=config.hidden_dim,
|
|
bias=True,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.out_mid",
|
|
)
|
|
self.softmax = nn.Softmax(dim=-1)
|
|
self.num_layers = config.num_layers
|
|
|
|
def forward(self, hidden_states: torch.Tensor):
|
|
hidden_states = self.input_linear(hidden_states)
|
|
for idx, layer in enumerate(self.layers, start=1):
|
|
hidden_states = layer(hidden_states, attention_dists=self.attention_dists)
|
|
|
|
if idx == self.num_layers // 2:
|
|
hidden_states_mid = hidden_states.clone()
|
|
hidden_states_mid, _ = self.out(hidden_states_mid)
|
|
hidden_states_mid = self.softmax(hidden_states_mid)
|
|
hidden_states_mid, _ = self.out_mid(hidden_states_mid)
|
|
hidden_states += hidden_states_mid
|
|
return hidden_states
|
|
|
|
|
|
@MULTIMODAL_REGISTRY.register_processor(
|
|
GraniteSpeechMultiModalProcessor,
|
|
info=GraniteSpeechMultiModalProcessingInfo,
|
|
dummy_inputs=GraniteSpeechDummyInputsBuilder,
|
|
)
|
|
class GraniteSpeechForConditionalGeneration(
|
|
nn.Module,
|
|
SupportsMultiModal,
|
|
SupportsPP,
|
|
SupportsLoRA,
|
|
):
|
|
merge_by_field_config = True
|
|
|
|
packed_modules_mapping = {
|
|
"qkv_proj": [
|
|
"q_proj",
|
|
"k_proj",
|
|
"v_proj",
|
|
],
|
|
"gate_up_proj": [
|
|
"gate_proj",
|
|
"up_proj",
|
|
],
|
|
}
|
|
|
|
@classmethod
|
|
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
|
|
if modality.startswith("audio"):
|
|
return "<|audio|>"
|
|
|
|
raise ValueError("Only audio modality is supported")
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
config = vllm_config.model_config.hf_config
|
|
quant_config = vllm_config.quant_config
|
|
cache_config = vllm_config.cache_config
|
|
|
|
self.config = config
|
|
self.quant_config = quant_config
|
|
self.cache_config = cache_config
|
|
|
|
# The language model is typically a Granite LLM
|
|
self.language_model = init_vllm_registered_model(
|
|
vllm_config=vllm_config,
|
|
hf_config=config.text_config,
|
|
prefix=maybe_prefix(prefix, "language_model"),
|
|
)
|
|
|
|
# Conformer encoder
|
|
self.encoder = GraniteSpeechCTCEncoder(
|
|
config=config.encoder_config,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.encoder",
|
|
)
|
|
|
|
# Blip2 QFormer
|
|
self.projector = GraniteSpeechEncoderProjector(
|
|
config=config,
|
|
quant_config=quant_config,
|
|
cache_config=cache_config,
|
|
prefix=f"{prefix}.projector",
|
|
)
|
|
|
|
self.make_empty_intermediate_tensors = (
|
|
self.language_model.make_empty_intermediate_tensors
|
|
)
|
|
|
|
def _parse_and_validate_audio_input(
|
|
self,
|
|
**kwargs: object,
|
|
) -> GraniteSpeechAudioInputs | None:
|
|
input_features = kwargs.pop("input_features", None)
|
|
input_features_mask = kwargs.pop("input_features_mask", None)
|
|
audio_embed_sizes = kwargs.pop("audio_embed_sizes", None)
|
|
|
|
if input_features is None:
|
|
return None
|
|
|
|
# If we have a batch of variable feature length audio clips, we need
|
|
# to mask the features; usually we would get an input_features_mask
|
|
# from the processor, but we handle rebuilding it here since
|
|
# vLLM generally processes everything independently + batches.
|
|
if input_features_mask is None:
|
|
input_features_mask = self._build_input_features_mask(audio_embed_sizes)
|
|
|
|
if not isinstance(input_features, (torch.Tensor, list)):
|
|
raise ValueError(
|
|
"Incorrect type of audio input features. "
|
|
f"Got type: {type(input_features)}"
|
|
)
|
|
|
|
if input_features_mask is not None and not isinstance(
|
|
input_features_mask, torch.Tensor
|
|
):
|
|
raise ValueError(
|
|
"Incorrect type of audio input features mask. "
|
|
f"Got type: {type(input_features_mask)}"
|
|
)
|
|
|
|
if isinstance(input_features, torch.Tensor):
|
|
# Granite speech currently only allows one audio token per instance
|
|
# and features are already unsqueezed in the processor, so one
|
|
# instance will have shape [1, {num_features}, 160]. As such,
|
|
# input features will usually be of shape
|
|
# [bsz, 1, num_features, 160], which we squeeze to be 3D here.
|
|
if len(input_features.shape) == 4:
|
|
input_features = input_features.squeeze(1)
|
|
if len(input_features.shape) != 3:
|
|
raise ValueError(
|
|
"Squeezed input features should be 3D but are of shape "
|
|
f"{input_features.shape}"
|
|
)
|
|
input_features = input_features.to(self.encoder.input_linear.weight.dtype)
|
|
|
|
else:
|
|
# Otherwise we have a list of tensors, which are almost certainly
|
|
# differing in their respective numbers of audio features;
|
|
# stack them into a 3D tensor of size [bsz, most_num_features, 160].
|
|
input_features = self._pad_and_stack_input_features(
|
|
input_features,
|
|
).to(self.encoder.input_linear.weight.dtype)
|
|
|
|
return GraniteSpeechAudioInputs(
|
|
input_features=input_features,
|
|
input_features_mask=input_features_mask,
|
|
audio_embed_sizes=audio_embed_sizes.flatten().tolist(),
|
|
)
|
|
|
|
def _build_input_features_mask(
|
|
self,
|
|
audio_embed_sizes: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
"""Calculate the input features mask, which will generally be used
|
|
to mask the padded features for all entries in the batch except
|
|
for those with the most audio features.
|
|
|
|
Args:
|
|
audio_embed_sizes: torch.Tensor
|
|
Tensor of num features in each seq in the batch.
|
|
Returns:
|
|
torch.Tensor: Mask of shape (bsz, num_features) to be applied to
|
|
the audio features prior to splitting the audio embeddings.
|
|
"""
|
|
most_audio_features = torch.max(audio_embed_sizes).item()
|
|
mask_indices = torch.arange(
|
|
most_audio_features,
|
|
device=audio_embed_sizes.device,
|
|
).view(1, -1)
|
|
input_features_mask = mask_indices < audio_embed_sizes.view(-1, 1)
|
|
return input_features_mask
|
|
|
|
def _pad_and_stack_input_features(
|
|
self,
|
|
input_features: list[torch.Tensor],
|
|
) -> torch.Tensor:
|
|
"""Given a list of input features of varying length, pad them to the
|
|
same length and stack them into a torch.Tensor.
|
|
|
|
NOTE: Usually, padding is done in the input processor/feature extractor
|
|
and zero padded prior to the computation of the Mel features; the
|
|
resulting values are only constant within a batch and generally nonzero
|
|
(i.e., slightly negative nums); we should validate that this is okay
|
|
since we don't use a feature attention mask, but the more important
|
|
thing is that we apply the input_features_mask with variable len
|
|
batches.
|
|
|
|
Args:
|
|
input_features: list[torch.Tensor]
|
|
Input features to be coerced into a tensor.
|
|
Returns:
|
|
torch.Tensor: Tensor of shape [bsz, num_features, 160], where
|
|
num_features is the max number of features of any entry in the
|
|
batch.
|
|
"""
|
|
# Input features are of shape [bsz, num_features, 160]
|
|
feat_lens = [feats.shape[1] for feats in input_features]
|
|
padding = [max(feat_lens) - length for length in feat_lens]
|
|
# TODO (Alex) - Validate that it's okay to zero pad like this;
|
|
# in transformers we zero pad prior to calculating the speech features,
|
|
# so the value is not zero and is dependent on the batched features.
|
|
padded = [
|
|
torch.nn.functional.pad(feats, (0, 0, 0, pad, 0, 0))
|
|
for feats, pad in zip(input_features, padding)
|
|
]
|
|
stacked_features = torch.cat(padded, dim=0).to(input_features[0])
|
|
return stacked_features
|
|
|
|
def _process_audio_input(
|
|
self,
|
|
audio_input: GraniteSpeechAudioInputs,
|
|
) -> tuple[torch.Tensor]:
|
|
"""Compute the audio features to be merged into the LLM embeddings.
|
|
|
|
Args:
|
|
audio_input: GraniteSpeechAudioInputs
|
|
Audio inputs object containing Mel features, an input features
|
|
mask, and the (flattened) number of audio tokens per instance.
|
|
Returns:
|
|
tuple[torch.Tensor]: List of length bsz.
|
|
"""
|
|
# TODO (Alex) - support embedding inputs
|
|
encoder_embeds = self.encoder(audio_input["input_features"])
|
|
# [bsz, <max feature size>, 4096]
|
|
projected_embeds = self.projector(encoder_embeds)
|
|
# Apply mask on variable length audio features
|
|
masked_embeds = projected_embeds[audio_input["input_features_mask"]]
|
|
# Split variable length features into a tuple
|
|
return torch.split(masked_embeds, audio_input["audio_embed_sizes"])
|
|
|
|
def get_language_model(self) -> torch.nn.Module:
|
|
return self.language_model
|
|
|
|
def get_multimodal_embeddings(
|
|
self,
|
|
**kwargs: object,
|
|
) -> MultiModalEmbeddings:
|
|
"""Compute the audio embeddings if audio inputs are present."""
|
|
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
|
if audio_input is None:
|
|
return []
|
|
|
|
audio_features = self._process_audio_input(audio_input)
|
|
return audio_features
|
|
|
|
def get_input_embeddings(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
multimodal_embeddings: MultiModalEmbeddings | None = None,
|
|
*,
|
|
is_multimodal: torch.Tensor | None = None,
|
|
# Multi-modal token ID may exceed vocab size
|
|
handle_oov_mm_token: bool = True,
|
|
) -> torch.Tensor:
|
|
# This is to satisfy the type checker for each overload
|
|
if multimodal_embeddings is None or is_multimodal is None:
|
|
return super().get_input_embeddings(input_ids)
|
|
|
|
return super().get_input_embeddings(
|
|
input_ids,
|
|
multimodal_embeddings=multimodal_embeddings,
|
|
is_multimodal=is_multimodal,
|
|
handle_oov_mm_token=handle_oov_mm_token,
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
intermediate_tensors: IntermediateTensors | None = None,
|
|
inputs_embeds: torch.Tensor | None = None,
|
|
**kwargs: object,
|
|
) -> torch.Tensor | IntermediateTensors:
|
|
if intermediate_tensors is not None:
|
|
inputs_embeds = None
|
|
|
|
model_output = self.language_model(
|
|
input_ids, positions, intermediate_tensors, inputs_embeds
|
|
)
|
|
return model_output
|
|
|
|
def compute_logits(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
) -> torch.Tensor | None:
|
|
return self.language_model.compute_logits(hidden_states)
|
|
|
|
def load_weights(
|
|
self,
|
|
weights: Iterable[tuple[str, torch.Tensor]],
|
|
) -> set[str]:
|
|
loader = AutoWeightsLoader(self)
|
|
return loader.load_weights(weights)
|
|
|
|
def get_mm_mapping(self) -> MultiModelKeys:
|
|
"""Get the module prefix in multimodal models."""
|
|
return MultiModelKeys.from_string_field(
|
|
language_model="language_model",
|
|
connector="projector",
|
|
tower_model="encoder",
|
|
)
|