vllm/vllm/model_executor/models/phi4mm_audio.py
Harry Mellor 8fcaaf6a16
Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-10-12 09:51:31 -07:00

1297 lines
49 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Code copied from Microsoft/MoE by Jacob Platin (jacobplatin@microsoft.com)
# but implemented by the Phi-Speech team
#!/usr/bin/env python3
import abc
import math
from typing import Any, Literal
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
CheckpointWrapper,
)
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
from transformers import PretrainedConfig
from vllm.model_executor.models.phi4mm_utils import (
AbsolutePositionalEncoding,
ConvModule,
FeedForward,
MeanVarianceNormLayer,
MultiHeadedAttention,
MultiSequential,
NemoConvSubsampling,
T5RelativeAttentionLogitBias,
adaptive_enc_mask,
get_offset,
unfold_tensor,
)
_AUDIO_PLACEHOLDER_TOKEN_ID = 200011 # <|endoftext11|>
class ConformerEncoderLayer(nn.Module):
"""ConformerEncoder Layer module.
for more details see conformer paper:
https://arxiv.org/abs/2005.08100
This module implement the Conformer block layer.
Args:
d_model: int
attention dim.
ext_pw_out_channel: int
if > 0, ext_pw_out_channel is a dim channel size
for the last pointwise conv after swish activation.
depthwise_seperable_out_channel: int
if set different to 0, the number of
depthwise_seperable_out_channel will be used as a
channel_out of the second conv1d layer.
otherwise, it equals to 0, the second conv1d layer is skipped.
depthwise_multiplier: int
number of input_dim channels duplication. this value
will be used to compute the hidden channels of the Conv1D.
n_head: int
the number of heads for multihead attention module.
d_ffn: int
output size of the feed_forward blocks.
ext_pw_kernel_size: int
kernel size of the conv pointwise of the conformer.
kernel_size: int
kernel size.
dropout_rate: float
dropout rate.
causal: bool, optional
if set to True, convolution have no access
to future frames. default False.
batch_norm: bool, optional
if set to True, apply batchnorm before activation
in ConvModule layer of the conformer.
default False
activation: str, optional
activation function name,
one of ["relu", "swish", "sigmoid"],
sigmoid activation is only used with "glu_in_fnn=True",
default "relu".
chunk_se: int, optional
0 for offline SE.
1 for streaming SE, where mean is computed
by accumulated history until current chunk_se.
2 for streaming SE, where mean is computed
by only the current chunk.
default 0.
chunk_size: int, optional
chunk_size for cnn. default 18
conv_activation: str, optional
activation function used in ConvModule part
of the conformer, default "relu".
conv_glu_type: str, optional
activation function used for the glu inside
the ConvModule part of the conformer.
default: "sigmoid".
bias_in_glu: bool, optional
if set to True, use additive bias in the weight module
before GLU.
linear_glu_in_convm: bool, optional
if set to True, use GLULinear module,
otherwise, used GLUPointWiseConv module.
default to False.
attention_inner_dim: int, optional
if equal to -1, attention dim for linears k/q/v is
equal to d_model. otherwise attention_inner_dim is used.
default -1.
attention_glu_type: str, optional
activation function for glu used in the multihead attention,
default "swish".
activation_checkpointing: str, optional
a dictionary of {"module","interval","offload"}, where
"module": str
accept ["transformer", "attention"] to select
which module should do activation checkpointing.
"interval": int, default 1,
interval of applying activation checkpointing,
interval = 1 means that we apply checkpointing
on every layer (if activation), otherwise,
we apply it every x interval.
"offload": bool, default False,
if set to True, we offload activation to cpu and
reload it during backward, otherwise,
we recalculate activation in backward.
default "".
export: bool, optional
if set to True, it removes the padding from convolutional layers
and allow the onnx conversion for inference.
default False.
use_pt_scaled_dot_product_attention: bool, optional
if set to True, use pytorch's scaled dot product attention
implementation in training.
attn_group_sizes: int, optional
the number of groups to use for attention, default 1
(Multi-Head Attention),
1 = typical Multi-Head Attention,
1 < attn_group_sizes < attention_heads = Grouped-Query Attention
attn_group_sizes = attention_heads = Multi-Query Attention
"""
def __init__(
self,
d_model: int = 512,
ext_pw_out_channel: int = 0,
depthwise_seperable_out_channel: int = 256,
depthwise_multiplier: int = 1,
n_head: int = 4,
d_ffn: int = 2048,
ext_pw_kernel_size: int = 1,
kernel_size: int = 3,
dropout_rate: float = 0.1,
causal: bool = False,
batch_norm: bool = False,
activation: str = "relu",
chunk_se: int = 0,
chunk_size: int = 18,
conv_activation: str = "relu",
conv_glu_type: str = "sigmoid",
bias_in_glu: bool = True,
linear_glu_in_convm: bool = False,
attention_inner_dim: int = -1,
attention_glu_type: str = "swish",
activation_checkpointing: str = "",
export: bool = False,
use_pt_scaled_dot_product_attention: bool = False,
attn_group_sizes: int = 1,
) -> None:
super().__init__()
self.feed_forward_in = FeedForward(
d_model=d_model,
d_inner=d_ffn,
dropout_rate=dropout_rate,
activation=activation,
bias_in_glu=bias_in_glu,
)
self.self_attn = MultiHeadedAttention(
n_head,
d_model,
dropout_rate,
attention_inner_dim,
attention_glu_type,
bias_in_glu,
use_pt_scaled_dot_product_attention=use_pt_scaled_dot_product_attention,
group_size=attn_group_sizes,
)
self.conv = ConvModule(
d_model,
ext_pw_out_channel,
depthwise_seperable_out_channel,
ext_pw_kernel_size,
kernel_size,
depthwise_multiplier,
dropout_rate,
causal,
batch_norm,
chunk_se,
chunk_size,
conv_activation,
conv_glu_type,
bias_in_glu,
linear_glu_in_convm,
export=export,
)
self.feed_forward_out = FeedForward(
d_model=d_model,
d_inner=d_ffn,
dropout_rate=dropout_rate,
activation=activation,
bias_in_glu=bias_in_glu,
)
self.layer_norm_att = nn.LayerNorm(d_model)
self.layer_norm = nn.LayerNorm(d_model)
def forward(
self,
x: torch.Tensor,
pos_k: torch.Tensor,
pos_v: torch.Tensor,
mask: torch.Tensor,
relative_attention_bias: Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""ConformerEncoder forward.
Args:
x: input feature of shape (batch, max_time_in, size)
pos_k: positional key embedding.
pos_v: positional value embedding.
mask: mask for x (batch, max_time_in)
relative_attention_bias: bias added to attention logits w.r.t.
relative positions (1, n_head, time1, time2)
"""
x = x + 0.5 * self.feed_forward_in(x)
norm_x = self.layer_norm_att(x)
x = x + self.self_attn(
norm_x,
norm_x,
norm_x,
pos_k,
pos_v,
mask,
relative_attention_bias=relative_attention_bias,
)
x = x + self.conv(x)
x = x + 0.5 * self.feed_forward_out(x)
out = self.layer_norm(x)
return out, pos_k, pos_v, mask
class TransformerEncoderBase(abc.ABC, nn.Module):
"""The Base class for Transformer based encoders
Please set causal = True in streaming model
Args:
input_size: int
input feature dimension.
chunk_size: int, list(int)
Number of frames for each chunk
This variable can take 2 forms:
int: Used for inference, or single chunk size training
list(int) : Used only for variable chunk size training
Some examples for the 2 cases:
chunk_size = 12
chunk_size = [6, 8, 12, 24]
left_chunk: int, list(int)
Number of chunks used for masking in streaming mode.
This variable can take 2 forms:
int: Used for inference, or single chunk size training
list(int) : Used only for variable chunk size training. When
chunk_size is a list, left_chunk must be a list with same length.
Some examples for the 2 cases:
left_chunk = 6
left_chunk = [12, 9, 6, 3]
attention_dim: int, optional
attention dimension. default 256.
attention_heads: int, optional
the number of heads. default 4
input_layer: str, optional
input layer type before Conformer,
one of ["linear", "conv2d", "custom", "vgg2l", "embed"],
default "conv2d"
cnn_out: int, optional
the number of CNN channels before Conformer.
default -1.
cnn_layer_norm: bool, optional
layer norm between Conformer and the first CNN.
default False.
time_reduction: int, optional
time reduction factor
default 4
dropout_rate: float, optional
dropout rate. default 0.1
padding_idx: int, optional
padding index for input_layer=embed
default -1
relative_attention_bias_args: dict, optional
use more efficient scalar bias-based relative multihead attention
(Q*K^T + B) implemented in cmb.basics.embedding.
[T5/ALiBi]RelativeAttentionLogitBias
usage: relative_attention_bias_args={"type": t5/alibi}
additional method-specific arguments can be provided (see
transformer_base.py)
positional_dropout_rate: float, optional
dropout rate after positional encoding. default 0.0
nemo_conv_settings: dict, optional
A dictionary of settings for NeMo Subsampling.
default None
conv2d_extra_padding: str, optional
Add extra padding in conv2d subsampling layers. Choices are
(feat, feat_time, none, True).
if True or feat_time, the extra padding is added into non full
supraframe utts in batch.
Default: none
attention_group_size: int, optional
the number of groups to use for attention, default 1
(Multi-Head Attention),
1 = typical Multi-Head Attention,
1 < attention_group_size < attention_heads = Grouped-Query
Attention
attention_group_size = attention_heads = Multi-Query Attention
"""
def __init__(
self,
input_size: int,
chunk_size: int | list[int],
left_chunk: int | list[int],
attention_dim: int = 256,
attention_heads: int = 4,
input_layer: str = "nemo_conv",
cnn_out: int = -1,
cnn_layer_norm: bool = False,
time_reduction: int = 4,
dropout_rate: float = 0.0,
padding_idx: int = -1,
relative_attention_bias_args: dict[str, Any] | None = None,
positional_dropout_rate: float = 0.0,
nemo_conv_settings: dict[str, Any] | None = None,
conv2d_extra_padding: Literal["feat", "feat_time", "none", True] = "none",
attention_group_size: int = 1,
encoder_embedding_config: dict[str, Any] | None = None,
) -> None:
super().__init__()
self.input_size = input_size
self.input_layer = input_layer
self.chunk_size = chunk_size
self.left_chunk = left_chunk
self.attention_dim = attention_dim
self.num_heads = attention_heads
self.attention_group_size = attention_group_size
self.time_reduction = time_reduction
self.nemo_conv_settings = nemo_conv_settings
self.encoder_embedding_config = encoder_embedding_config
if self.input_layer == "nemo_conv":
default_nemo_conv_settings = {
"subsampling": "dw_striding",
"subsampling_factor": self.time_reduction,
"feat_in": input_size,
"feat_out": attention_dim,
"conv_channels": 256,
"subsampling_conv_chunking_factor": 1,
"activation": nn.ReLU(),
"is_causal": False,
}
# Override any of the defaults with the incoming, user settings
if nemo_conv_settings:
default_nemo_conv_settings.update(nemo_conv_settings)
for i in ["subsampling_factor", "feat_in", "feat_out"]:
assert i not in nemo_conv_settings, (
"{i} should be specified outside of the NeMo dictionary"
)
self.embed = NemoConvSubsampling(
**default_nemo_conv_settings,
)
else:
raise ValueError("unknown input_layer: " + input_layer)
self.pos_emb = AbsolutePositionalEncoding(
attention_dim, positional_dropout_rate
)
self.relative_attention_bias_type = (
relative_attention_bias_args.get("type")
if relative_attention_bias_args
else None
)
if self.relative_attention_bias_type == "t5":
assert self.num_heads % self.attention_group_size == 0, (
"attention_group_size must divide n_head"
)
self.relative_attention_bias_layer = T5RelativeAttentionLogitBias(
self.num_heads // self.attention_group_size,
max_distance=relative_attention_bias_args.get(
"t5_bias_max_distance", 1000
),
symmetric=relative_attention_bias_args.get("t5_bias_symmetric", False),
)
else:
raise NotImplementedError
self.encoder_embedding = MeanVarianceNormLayer(
self.encoder_embedding_config["input_size"]
)
def compute_lens_change(
self, feature_lens: int | torch.Tensor
) -> int | torch.Tensor:
"""feature_lens: int
return updated feature lens.
This used to return a different lambda function for each case that
computed the right thing. That does not work within Torchscript.
If you really need this to be faster, create nn.Module()-s for all
the cases and return one of them. Torchscript does support that.
"""
if self.input_layer == "nemo_conv":
# Handle the special causal case
subsampling_causal_cond = self.nemo_conv_settings.get(
"subsampling", "dw_striding"
) in [
"dw_striding",
"striding",
"striding_conv1d",
]
is_causal = self.nemo_conv_settings.get("is_causal", False)
if is_causal and subsampling_causal_cond:
lens_change = (
torch.ceil(feature_lens / self.time_reduction).long()
if isinstance(feature_lens, Tensor)
else math.ceil(feature_lens / self.time_reduction)
)
feature_lens_remainder = feature_lens % self.time_reduction
if isinstance(feature_lens, Tensor):
lens_change[feature_lens_remainder != 1] += 1
elif feature_lens_remainder != 1:
lens_change += 1
return lens_change
ceil_func = math.ceil if isinstance(feature_lens, int) else torch.ceil
return ceil_func(feature_lens / self.time_reduction)
@abc.abstractmethod
def forward(self) -> Any:
"""Abstract forward method implementation."""
def _chunk_size_selection(
self,
chunk_size: int | list[int] | None = None,
left_chunk: int | list[int] | None = None,
) -> tuple[int, int]:
"""If chunk size is a list, we will randomly select a chunk size."""
if chunk_size is None:
chunk_size = self.chunk_size
if left_chunk is None:
left_chunk = self.left_chunk
if isinstance(chunk_size, list):
# Variable chunk size during training
chunk_size_index = int(
torch.randint(low=0, high=len(chunk_size), size=(1,))
)
chunk_size_train_eff = chunk_size[chunk_size_index]
if not isinstance(left_chunk, list):
raise ValueError(
"Since chunk_size is a list, left_chunk must be a list"
)
if len(left_chunk) != len(chunk_size):
raise ValueError(
"The length of left_chunk must be the same as length of chunk_size."
)
left_chunk_train_eff = left_chunk[chunk_size_index]
else:
chunk_size_train_eff = chunk_size
left_chunk_train_eff = left_chunk
return chunk_size_train_eff, left_chunk_train_eff
def _get_embed_class(self, embed: nn.Module) -> nn.Module:
# pylint: disable=protected-access
is_embed_using_act_chkpt = isinstance(embed, CheckpointWrapper)
is_embed_fsdp_wrapped = isinstance(embed, FullyShardedDataParallel)
embed_class = embed
if is_embed_using_act_chkpt:
embed_class = embed._checkpoint_wrapped_module
if is_embed_fsdp_wrapped:
embed_class = embed.module
return embed_class
def _forward_embeddings_core(
self, input_tensor: torch.Tensor, masks: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
embed_class = self._get_embed_class(self.embed)
assert isinstance(embed_class, NemoConvSubsampling)
input_tensor, masks = self.embed(input_tensor, masks)
return input_tensor, masks
def _position_embedding(
self, input_tensor: torch.Tensor
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
pos_k = None
pos_v = None
if self.relative_attention_bias_layer is None:
input_tensor = self.pos_emb(
input_tensor
) # default to add abs sinusoid embedding
return pos_k, pos_v
def _streaming_mask(
self,
seq_len: int,
batch_size: int,
chunk_size: int | list[int],
left_chunk: int | list[int],
) -> torch.Tensor:
chunk_size_train_eff, left_chunk_train_eff = self._chunk_size_selection(
chunk_size, left_chunk
)
# Create mask matrix for streaming
# S stores start index. if chunksize is 18, s is [0,18,36,....]
chunk_start_idx = np.arange(0, seq_len, chunk_size_train_eff)
enc_streaming_mask = (
adaptive_enc_mask(
seq_len, chunk_start_idx, left_window=left_chunk_train_eff
)
.unsqueeze(0)
.expand([batch_size, -1, -1])
)
return enc_streaming_mask
def forward_embeddings(
self,
xs_pad: torch.Tensor,
masks: torch.Tensor,
chunk_size_nc: int | list[int] | None = None,
left_chunk_nc: int | list[int] | None = None,
) -> (
tuple[
torch.Tensor,
torch.Tensor | None,
torch.Tensor | None,
torch.Tensor,
torch.Tensor,
]
| tuple[
torch.Tensor,
torch.Tensor | None,
torch.Tensor | None,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]
):
"""Forwarding the inputs through the top embedding layers
Args:
xs_pad: torch.Tensor
input tensor
masks: torch.Tensor
input mask
chunk_size_nc: (optional, default is None) chunk size for
non-causal layers
left_chunk_nc: (optional, default is None) # of left chunks for
non-causal layers
"""
# pylint: disable=R0915
# get new lens.
seq_len = int(self.compute_lens_change(xs_pad.shape[1]))
if seq_len <= 0:
raise ValueError(
f"""The sequence length after time reduction is invalid:
{seq_len}. Your input feature is too short. Consider
filtering out the very short sentence from data
loader""",
)
batch_size = xs_pad.shape[0]
enc_streaming_mask = self._streaming_mask(
seq_len, batch_size, self.chunk_size, self.left_chunk
)
if xs_pad.is_cuda:
enc_streaming_mask = enc_streaming_mask.cuda()
xs_pad = xs_pad.cuda()
input_tensor = xs_pad
input_tensor, masks = self._forward_embeddings_core(input_tensor, masks)
streaming_mask = enc_streaming_mask
if streaming_mask is not None and masks is not None:
hs_mask = masks & streaming_mask
elif masks is not None:
hs_mask = masks
else:
hs_mask = streaming_mask
if chunk_size_nc is not None:
enc_streaming_mask_nc = self._streaming_mask(
seq_len, batch_size, chunk_size_nc, left_chunk_nc
)
if xs_pad.is_cuda:
enc_streaming_mask_nc = enc_streaming_mask_nc.cuda()
if masks is not None:
hs_mask_nc = masks & enc_streaming_mask_nc
else:
hs_mask_nc = enc_streaming_mask_nc
else:
hs_mask_nc = None
pos_k, pos_v = self._position_embedding(input_tensor)
if chunk_size_nc is None:
return input_tensor, pos_k, pos_v, hs_mask, masks
return input_tensor, pos_k, pos_v, hs_mask, masks, hs_mask_nc
def get_offset(self) -> int:
"""Returns offset used when retaining inputs for decoding.
This is essentially, how many additional frames have to be added to
the front-end CNN input to ensure it can produce a single output.
So if the "padding" parameter is 0, typically offset will be > 0.
"""
return get_offset(self.input_layer, self.time_reduction)
class ConformerEncoder(TransformerEncoderBase):
"""ConformerEncoder module.
see original paper for more details:
https://arxiv.org/abs/2005.08100
Please set causal = True in streaming model
Args:
input_size: int
input feature dimension.
chunk_size: int, list(int)
Number of frames for each chunk
This variable can take 2 forms:
int: Used for inference, or single chunk size training
list(int) : Used only for variable chunk size training
Some examples for the 2 cases:
chunk_size = 12
chunk_size = [6, 8, 12, 24]
left_chunk: int, list(int)
Number of chunks used for masking in streaming mode.
This variable can take 2 forms:
int: Used for inference, or single chunk size training
list(int) : Used only for variable chunk size training. When
chunk_size is a list, left_chunk must be a list with same length.
Some examples for the 2 cases:
left_chunk = 6
left_chunk = [12, 9, 6, 3]
num_lang: int
This parameter is used to store the number of languages in the
lang_dict, only used for multiseed/multilingual models.
default None.
attention_dim: int, optional
attention dimension. default 256.
attention_heads: int, optional
the number of heads. default 4
linear_units:
the number of units of position-wise feed forward.
default 2048
num_block:
number of Transformer layer. default 6
dropout_rate: float, optional
dropout rate. default 0.1
input_layer: str, optional
input layer type before Conformer,
one of ["linear", "conv2d", "custom", "vgg2l", "embed"],
default "conv2d"
causal: bool, optional
if set to True, convolution have no access
to future frames. default False.
batch_norm: bool, optional
if set to True, apply batchnorm before activation
in ConvModule layer of the conformer.
default False
cnn_out: int, optional
the number of CNN channels before Conformer.
default -1.
cnn_layer_norm: bool, optional
layer norm between Conformer and the first CNN.
default False.
ext_pw_out_channel: int, optional
the number of channel for CNN
before depthwise_seperable_CNN.
If 0 then use linear. default 0.
ext_pw_kernel_size: int, optional
kernel size of N before depthwise_seperable_CNN.
only work for ext_pw_out_channel > 0.
default 1
depthwise_seperable_out_channel: int, optional
the number of channel for
depthwise_seperable_CNN.
default 256.
depthwise_multiplier: int, optional
the number of multiplier for
depthwise_seperable_CNN.
default 1.
chunk_se: int, optional
0 for offline SE.
1 for streaming SE, where mean is computed
by accumulated history until current chunk_se.
2 for streaming SE, where mean is computed
by only the current chunk.
default 0.
kernel_size: int, optional
the number of kernels for depthwise_seperable_CNN.
default 3.
activation: str, optional
FeedForward block activation.
one of ["relu", "swish", "sigmoid"]
default "relu".
conv_activation: str, optional
activation function used in ConvModule part
of the conformer, default "relu".
conv_glu_type: str, optional
activation used use glu in depthwise_seperable_CNN,
default "sigmoid"
bias_in_glu: bool, optional
if set to True, use additive bias in the weight module
before GLU. default True
linear_glu_in_convm: bool, optional
if set to True, use GLULinear module,
otherwise, used GLUPointWiseConv module.
default to False.
attention_glu_type: str
only work for glu_in_attention !=0
default "swish".
export: bool, optional
if set to True, it removes the padding from convolutional layers
and allow the onnx conversion for inference.
default False.
activation_checkpointing: str, optional
a dictionarry of {"module","interval","offload"}, where
"module": str
accept ["transformer", "attention"] to select
which module should do activation checkpointing.
"interval": int, default 1,
interval of applying activation checkpointing,
interval = 1 means that we apply checkpointing
on every layer (if activation), otherwise,
we apply it every x interval.
"offload": bool, default False,
if set to True, we offload activation to cpu and
reload it during backward, otherwise,
we recalculate activation in backward.
default "".
extra_layer_output_idx: int
the layer index to be exposed.
relative_attention_bias_args: dict, optional
use more efficient scalar bias-based relative multihead attention
(Q*K^T + B) implemented in cmb.basics.embedding.
[T5/ALiBi]RelativeAttentionLogitBias
usage: relative_attention_bias_args={"type": t5/alibi}
additional method-specific arguments can be provided (see
transformer_base.py)
time_reduction: int optional
time reduction factor
default 4
use_pt_scaled_dot_product_attention: whether to use pytorch scaled
dot product attention in training.
Default: False
nemo_conv_settings: dict, optional
A dictionary of settings for NeMo Subsampling.
default: None
usage: nemo_conv_settings=
{
"subsampling":
dw_striding/striding/dw_striding_conv1d/striding_conv1d,
"conv_channels": int,
"subsampling_conv_chunking_factor": int,
"is_causal": True/False
}
conv2d_extra_padding: str, optional
Add extra padding in conv2d subsampling layers. Choices are
(feat, feat_time, none, True)
Default: none
replication_pad_for_subsample_embedding: For batched-streaming
decoding, use "replication" padding for the cache at start of
utterance.
Default: False
attention_group_size: int, optional
the number of groups to use for attention, default 1
(Multi-Head Attention),
1 = typical Multi-Head Attention,
1 < attention_group_size < attention_heads = Grouped-Query
Attention
attention_group_size = attention_heads = Multi-Query Attention
"""
extra_multi_layer_output_idxs: list[int]
def __init__( # pylint: disable-all
self,
input_size: int,
chunk_size: int | list[int],
left_chunk: int | list[int],
num_lang: int | None = None,
attention_dim: int = 256,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
input_layer: str = "nemo_conv",
causal: bool = True,
batch_norm: bool = False,
cnn_out: int = -1,
cnn_layer_norm: bool = False,
ext_pw_out_channel: int = 0,
ext_pw_kernel_size: int = 1,
depthwise_seperable_out_channel: int = 256,
depthwise_multiplier: int = 1,
chunk_se: int = 0,
kernel_size: int = 3,
activation: str = "relu",
conv_activation: str = "relu",
conv_glu_type: str = "sigmoid",
bias_in_glu: bool = True,
linear_glu_in_convm: bool = False,
attention_glu_type: str = "swish",
export: bool = False,
extra_layer_output_idx: int = -1,
extra_multi_layer_output_idxs: list[int] = [], # noqa
activation_checkpointing: str = "",
relative_attention_bias_args: dict[str, Any] | None = None,
time_reduction: int = 4,
use_pt_scaled_dot_product_attention: bool = False,
nemo_conv_settings: dict[str, Any] | None = None,
conv2d_extra_padding: Literal["feat", "feat_time", "none", True] = "none",
replication_pad_for_subsample_embedding: bool = False,
attention_group_size: int = 1,
encoder_embedding_config: dict[str, Any] | None = None,
) -> None:
super().__init__(
input_size,
chunk_size,
left_chunk,
attention_dim,
attention_heads,
input_layer,
cnn_out,
cnn_layer_norm,
time_reduction,
dropout_rate=dropout_rate,
relative_attention_bias_args=relative_attention_bias_args,
positional_dropout_rate=0.0,
nemo_conv_settings=nemo_conv_settings,
conv2d_extra_padding=conv2d_extra_padding,
attention_group_size=attention_group_size,
encoder_embedding_config=encoder_embedding_config,
)
self.num_blocks = num_blocks
self.num_lang = num_lang
self.kernel_size = kernel_size
self.replication_pad_for_subsample_embedding: bool = (
replication_pad_for_subsample_embedding
)
assert self.num_heads % attention_group_size == 0, (
"attention_group_size must divide n_head"
)
self.num_heads_k = self.num_heads // attention_group_size
self.encoders = MultiSequential(
*[
ConformerEncoderLayer(
d_model=attention_dim,
ext_pw_out_channel=ext_pw_out_channel,
depthwise_seperable_out_channel=depthwise_seperable_out_channel,
depthwise_multiplier=depthwise_multiplier,
n_head=attention_heads,
d_ffn=linear_units,
ext_pw_kernel_size=ext_pw_kernel_size,
kernel_size=kernel_size,
dropout_rate=dropout_rate,
causal=causal,
batch_norm=batch_norm,
activation=activation,
chunk_se=chunk_se,
chunk_size=chunk_size,
conv_activation=conv_activation,
conv_glu_type=conv_glu_type,
bias_in_glu=bias_in_glu,
linear_glu_in_convm=linear_glu_in_convm,
attention_glu_type=attention_glu_type,
activation_checkpointing=activation_checkpointing,
export=export,
use_pt_scaled_dot_product_attention=use_pt_scaled_dot_product_attention,
attn_group_sizes=attention_group_size,
)
for _ in range(num_blocks)
]
)
self.extra_layer_output_idx = extra_layer_output_idx
self.extra_multi_layer_output_idxs = extra_multi_layer_output_idxs
# Make a zeros scalar we can use in get_initial_state to determine
# the device and the needed dtype:
self.register_buffer("dev_type", torch.zeros(()), persistent=False)
def init_relative_attention_bias(
self, input_tensor: torch.Tensor
) -> torch.Tensor | None:
if self.relative_attention_bias_layer:
return self.relative_attention_bias_layer(input_tensor)
def calculate_hs_mask(
self, xs_pad: torch.Tensor, device: torch.device, mask: torch.Tensor | None
) -> torch.Tensor:
max_audio_length = xs_pad.shape[1]
batch_size = xs_pad.shape[0]
enc_streaming_mask = self._streaming_mask(
max_audio_length, batch_size, self.chunk_size, self.left_chunk
)
enc_streaming_mask = enc_streaming_mask.to(device)
if mask is None:
return enc_streaming_mask
feature_lens = mask.sum(1)
padding_length = feature_lens
pad_mask = torch.arange(0, max_audio_length, device=device).expand(
padding_length.size(0), -1
) < padding_length.unsqueeze(1)
pad_mask = pad_mask.unsqueeze(1)
pad_mask = pad_mask & enc_streaming_mask
return pad_mask
@torch.jit.ignore
def forward(
self, xs_pad: torch.Tensor, masks: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""Conformer Forward function
Args:
xs_pad: torch.Tensor
input tensor
masks: torch.Tensor
post-embedding input lengths
"""
xs_pad = self.encoder_embedding(xs_pad)
input_tensor, pos_k, pos_v, hs_mask, masks = self.forward_embeddings(
xs_pad, masks
)
unfolded = False
ori_bz, seq_len, D = input_tensor.shape
max_seq_len = 500 # maximum position for absolute positional encoding
if seq_len > max_seq_len:
# audio sequence is longer than max_seq_len, unfold it into chunks
# of max_seq_len
unfolded = True
# the unfold op will drop residual frames, pad it to the multiple
# of max_seq_len
if seq_len % max_seq_len > 0:
chunk_pad_size = max_seq_len - (seq_len % max_seq_len)
else:
chunk_pad_size = 0
if chunk_pad_size > 0:
input_tensor_pad = F.pad(
input_tensor, (0, 0, 0, chunk_pad_size), "constant", 0
)
input_tensor = input_tensor_pad.to(input_tensor.device)
input_tensor = unfold_tensor(input_tensor, max_seq_len)
if masks is not None:
# revise hs_mask here because the previous calculated hs_mask
# did not consider extra pad
subsampled_pad_mask = masks.squeeze(
1
) # [bz, subsampled_unmask_seq_len]
extra_padded_subsamlped_pad_mask = F.pad(
subsampled_pad_mask, (0, chunk_pad_size), "constant", False
) # extra padding to the pad mask
extra_padded_subsamlped_pad_mask = (
extra_padded_subsamlped_pad_mask.unsqueeze(-1).float()
)
masks_unfold = unfold_tensor(
extra_padded_subsamlped_pad_mask, max_seq_len
) # unfold the pad mask like we did to the input tensor
masks_unfold = masks_unfold.squeeze(
-1
).bool() # unfold op does not support bool tensor
else:
masks_unfold = None
hs_mask = self.calculate_hs_mask(
input_tensor, input_tensor.device, masks_unfold
) # calculate hs_mask based on the unfolded pad mask
# layer_emb = None
relative_attention_bias = self.init_relative_attention_bias(input_tensor)
_simplified_path = (
self.extra_layer_output_idx == -1 and relative_attention_bias is None
)
if _simplified_path:
input_tensor, *_ = self.encoders(input_tensor, pos_k, pos_v, hs_mask)
else:
for i, layer in enumerate(self.encoders):
input_tensor, _, _, _ = layer(
input_tensor,
pos_k,
pos_v,
hs_mask,
relative_attention_bias=relative_attention_bias,
)
# if i == self.extra_layer_output_idx:
# layer_emb = input_tensor
if unfolded:
embed_dim = input_tensor.shape[-1]
input_tensor = input_tensor.reshape(ori_bz, -1, embed_dim)
# if we ever padded before unfolding, we need to remove the padding
if chunk_pad_size > 0:
input_tensor = input_tensor[:, :-chunk_pad_size, :]
return input_tensor, masks # , layer_emb
class WindowQformer(nn.Module):
"""Window-level Qformer"""
def __init__(
self,
window_size: int = 8,
num_queries: int = 1,
num_blocks: int = 2,
attention_dim: int = 512,
attention_heads: int = 8,
linear_units: int = 2048,
dropout_rate: float = 0.0,
normalize_before: bool = True,
):
super().__init__()
self.decoders = nn.ModuleList(
[
nn.TransformerDecoderLayer(
d_model=attention_dim,
nhead=attention_heads,
dim_feedforward=linear_units,
dropout=dropout_rate,
activation="relu",
batch_first=True,
norm_first=normalize_before, # TODO need to verify
)
for _ in range(num_blocks)
]
)
self.queries = nn.Parameter(torch.zeros(1, num_queries, attention_dim))
self.after_norm = (
nn.LayerNorm(attention_dim, eps=1e-12) if normalize_before else None
)
self.window_size = window_size
def forward(
self,
audio_embed: torch.Tensor,
mask: torch.Tensor | None,
embed_len: int | None = None,
) -> tuple[torch.Tensor, int | None]:
"""forward decoder"""
# audio_embed: N x T x D => N x D x T
audio_embed = audio_embed.transpose(1, 2)
# audio_embed: N x D x 1 x T => N x DK x T'
padding = audio_embed.shape[-1] % self.window_size
if padding > 0:
audio_embed = F.pad(
audio_embed, (0, self.window_size - padding), "constant", 0
)
embed_chunk = F.unfold(
audio_embed[..., None, :],
kernel_size=(1, self.window_size),
stride=(1, self.window_size),
)
bsz, _, slen = embed_chunk.shape
# N x D x K x T'
embed_chunk = embed_chunk.view(bsz, -1, self.window_size, slen)
# N x T' x K x D
embed_chunk = embed_chunk.transpose(1, 3).contiguous()
# NT' x K x D
embed_chunk = embed_chunk.view(bsz * slen, self.window_size, -1)
# NT' x 1 x D
q = self.queries.expand(bsz * slen, -1, -1)
for layer in self.decoders:
q = layer(tgt=q, memory=embed_chunk, tgt_mask=None, memory_mask=mask)
if self.after_norm is not None:
q = self.after_norm(q)
if embed_len is not None:
embed_len = embed_len // self.window_size
# N x T' x D
out = q.view(bsz, slen, -1)
return out, embed_len
class AudioEmbedding(nn.Module):
"""Image embedding."""
def __init__(self, config: PretrainedConfig, **kwargs: Any) -> None:
super().__init__()
self.config = config
# n_embed or hidden_size for text LM
hidden_size = config.n_embd if hasattr(config, "n_embd") else config.hidden_size
# self.wte = nn.Embedding(config.vocab_size, hidden_size)
audio_dim_out = (
None # Set this variable according to the actual audio processor
)
self.layer_idx = -2
if (
isinstance(config.audio_processor, dict)
and config.audio_processor.get("name", None) == "cascades"
):
encoder_config = config.audio_processor.get("config", None)
assert encoder_config is not None
self.encoder = ConformerEncoder(**encoder_config)
audio_dim_out = encoder_config["attention_dim"]
n_mels = encoder_config["input_size"]
else:
raise NotImplementedError("")
assert audio_dim_out is not None, "Remember to set values for audio_dim_out"
self.audio_dim_out = audio_dim_out
self.audio_dim_in = n_mels
self.freeze_audio_processor = kwargs.get("freeze_audio_processor", False)
self.downsample_rate = kwargs.get("downsample_rate", 1)
if kwargs.get("use_qformer", False):
qformer_config = kwargs.get("qformer_config", {})
qformer_config["attention_dim"] = audio_dim_out
self.qformer = WindowQformer(**qformer_config)
else:
self.qformer = None
if kwargs.get("use_conv_downsample", False):
assert self.qformer is None, (
"don't support use qformer and conv downsample together"
)
nemo_conv_settings = kwargs.get("nemo_conv_settings", {})
default_nemo_conv_settings = {
"subsampling": "dw_striding",
"subsampling_factor": self.downsample_rate,
"feat_in": audio_dim_out,
"feat_out": audio_dim_out,
"conv_channels": 256,
"subsampling_conv_chunking_factor": 1,
"activation": nn.ReLU(),
"is_causal": False,
}
# Override any of the defaults with the incoming, user settings
if nemo_conv_settings:
default_nemo_conv_settings.update(nemo_conv_settings)
for i in ["subsampling_factor", "feat_in", "feat_out"]:
assert i not in nemo_conv_settings, (
"{i} should be specified outside of the NeMo dictionary"
)
self.conv_ds = NemoConvSubsampling(
**default_nemo_conv_settings,
)
else:
self.conv_ds = None
projection_cls = kwargs.get("projection_cls", "linear")
if projection_cls == "linear":
self.audio_projection = nn.Linear(audio_dim_out, hidden_size)
elif projection_cls == "mlp":
# follow llava-v1.5's implementation
# (do not use image_projection and image_proj_norm)
dim_projection = hidden_size
depth = 2
self.linear_downsample_rate = (
1 if (self.qformer or self.conv_ds) else self.downsample_rate
)
layers = [
nn.Linear(audio_dim_out * self.linear_downsample_rate, dim_projection)
]
for _ in range(1, depth):
layers.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection)])
self.audio_projection = nn.Sequential(*layers)
# NOTE vision-speech tasks use a separate projection layer
layers = [
nn.Linear(audio_dim_out * self.linear_downsample_rate, dim_projection)
]
for _ in range(1, depth):
layers.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection)])
self.audio_projection_for_vision = nn.Sequential(*layers)
else:
raise NotImplementedError(
f"projection_cls = {projection_cls}, not implemented"
)
# TODO: audio sequence compression - Qformer
self.vocab_size = config.vocab_size
self.input_embeds = None
self.audio_embed_sizes = None
def set_audio_embeds(self, input_embeds: torch.Tensor) -> None:
self.input_embeds = input_embeds
def set_audio_embed_sizes(self, audio_embed_sizes: torch.Tensor) -> None:
self.audio_embed_sizes = audio_embed_sizes
def get_audio_features(
self,
input_embeds: torch.Tensor,
audio_attention_mask: torch.Tensor | None = None,
audio_projection_mode: str = "speech",
) -> torch.Tensor:
"""
arguments:
input_embeds: audio features (B, T, D) B: num audios in a sequence
"""
if self.freeze_audio_processor:
with torch.no_grad():
audio_features, masks = self.encoder(input_embeds, audio_attention_mask)
else:
audio_features, masks = self.encoder(input_embeds, audio_attention_mask)
if self.qformer is not None:
audio_features, _ = self.qformer(audio_features, mask=None)
if self.conv_ds is not None:
if masks is not None:
masks = masks.squeeze(1)
audio_features, masks = self.conv_ds(audio_features, mask=masks)
if self.linear_downsample_rate != 1:
bs, seq_len, feat_dim = audio_features.size()
padding = seq_len % self.linear_downsample_rate
if padding > 0:
audio_features = F.pad(
audio_features,
(0, 0, 0, self.linear_downsample_rate - padding),
"constant",
0,
)
seq_len = audio_features.size(1)
audio_features = audio_features.view(
bs,
seq_len // self.linear_downsample_rate,
feat_dim * self.linear_downsample_rate,
)
if audio_projection_mode == "speech":
audio_set_tensor = self.audio_projection(audio_features)
elif audio_projection_mode == "vision":
audio_set_tensor = self.audio_projection_for_vision(audio_features)
else:
raise ValueError(
f"audio_projection_mode = {audio_projection_mode} not implemented"
)
return audio_set_tensor
def forward(
self,
audio_features: torch.Tensor,
audio_attention_mask: torch.Tensor | None = None,
audio_projection_mode: str = "speech",
) -> torch.Tensor:
"""
arguments:
audio_features: audio features (T, D)
returns:
audio_embeds: audio embeddings (num_audio_tokens, hidden_dim)
"""
audio_embeds = self.get_audio_features(
audio_features.unsqueeze(0),
audio_attention_mask=audio_attention_mask,
audio_projection_mode=audio_projection_mode,
)
return audio_embeds.squeeze(0)