mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 10:26:07 +08:00
1297 lines
49 KiB
Python
1297 lines
49 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT license.
|
|
# Code copied from Microsoft/MoE by Jacob Platin (jacobplatin@microsoft.com)
|
|
# but implemented by the Phi-Speech team
|
|
#!/usr/bin/env python3
|
|
import abc
|
|
import math
|
|
from typing import Any, Literal
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch import Tensor, nn
|
|
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
|
CheckpointWrapper,
|
|
)
|
|
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
|
|
from transformers import PretrainedConfig
|
|
|
|
from vllm.model_executor.models.phi4mm_utils import (
|
|
AbsolutePositionalEncoding,
|
|
ConvModule,
|
|
FeedForward,
|
|
MeanVarianceNormLayer,
|
|
MultiHeadedAttention,
|
|
MultiSequential,
|
|
NemoConvSubsampling,
|
|
T5RelativeAttentionLogitBias,
|
|
adaptive_enc_mask,
|
|
get_offset,
|
|
unfold_tensor,
|
|
)
|
|
|
|
_AUDIO_PLACEHOLDER_TOKEN_ID = 200011 # <|endoftext11|>
|
|
|
|
|
|
class ConformerEncoderLayer(nn.Module):
|
|
"""ConformerEncoder Layer module.
|
|
for more details see conformer paper:
|
|
https://arxiv.org/abs/2005.08100
|
|
This module implement the Conformer block layer.
|
|
|
|
Args:
|
|
d_model: int
|
|
attention dim.
|
|
ext_pw_out_channel: int
|
|
if > 0, ext_pw_out_channel is a dim channel size
|
|
for the last pointwise conv after swish activation.
|
|
depthwise_seperable_out_channel: int
|
|
if set different to 0, the number of
|
|
depthwise_seperable_out_channel will be used as a
|
|
channel_out of the second conv1d layer.
|
|
otherwise, it equals to 0, the second conv1d layer is skipped.
|
|
depthwise_multiplier: int
|
|
number of input_dim channels duplication. this value
|
|
will be used to compute the hidden channels of the Conv1D.
|
|
n_head: int
|
|
the number of heads for multihead attention module.
|
|
d_ffn: int
|
|
output size of the feed_forward blocks.
|
|
ext_pw_kernel_size: int
|
|
kernel size of the conv pointwise of the conformer.
|
|
kernel_size: int
|
|
kernel size.
|
|
dropout_rate: float
|
|
dropout rate.
|
|
causal: bool, optional
|
|
if set to True, convolution have no access
|
|
to future frames. default False.
|
|
batch_norm: bool, optional
|
|
if set to True, apply batchnorm before activation
|
|
in ConvModule layer of the conformer.
|
|
default False
|
|
activation: str, optional
|
|
activation function name,
|
|
one of ["relu", "swish", "sigmoid"],
|
|
sigmoid activation is only used with "glu_in_fnn=True",
|
|
default "relu".
|
|
chunk_se: int, optional
|
|
0 for offline SE.
|
|
1 for streaming SE, where mean is computed
|
|
by accumulated history until current chunk_se.
|
|
2 for streaming SE, where mean is computed
|
|
by only the current chunk.
|
|
default 0.
|
|
chunk_size: int, optional
|
|
chunk_size for cnn. default 18
|
|
conv_activation: str, optional
|
|
activation function used in ConvModule part
|
|
of the conformer, default "relu".
|
|
conv_glu_type: str, optional
|
|
activation function used for the glu inside
|
|
the ConvModule part of the conformer.
|
|
default: "sigmoid".
|
|
bias_in_glu: bool, optional
|
|
if set to True, use additive bias in the weight module
|
|
before GLU.
|
|
linear_glu_in_convm: bool, optional
|
|
if set to True, use GLULinear module,
|
|
otherwise, used GLUPointWiseConv module.
|
|
default to False.
|
|
attention_inner_dim: int, optional
|
|
if equal to -1, attention dim for linears k/q/v is
|
|
equal to d_model. otherwise attention_inner_dim is used.
|
|
default -1.
|
|
attention_glu_type: str, optional
|
|
activation function for glu used in the multihead attention,
|
|
default "swish".
|
|
activation_checkpointing: str, optional
|
|
a dictionary of {"module","interval","offload"}, where
|
|
"module": str
|
|
accept ["transformer", "attention"] to select
|
|
which module should do activation checkpointing.
|
|
"interval": int, default 1,
|
|
interval of applying activation checkpointing,
|
|
interval = 1 means that we apply checkpointing
|
|
on every layer (if activation), otherwise,
|
|
we apply it every x interval.
|
|
"offload": bool, default False,
|
|
if set to True, we offload activation to cpu and
|
|
reload it during backward, otherwise,
|
|
we recalculate activation in backward.
|
|
default "".
|
|
export: bool, optional
|
|
if set to True, it removes the padding from convolutional layers
|
|
and allow the onnx conversion for inference.
|
|
default False.
|
|
use_pt_scaled_dot_product_attention: bool, optional
|
|
if set to True, use pytorch's scaled dot product attention
|
|
implementation in training.
|
|
attn_group_sizes: int, optional
|
|
the number of groups to use for attention, default 1
|
|
(Multi-Head Attention),
|
|
1 = typical Multi-Head Attention,
|
|
1 < attn_group_sizes < attention_heads = Grouped-Query Attention
|
|
attn_group_sizes = attention_heads = Multi-Query Attention
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
d_model: int = 512,
|
|
ext_pw_out_channel: int = 0,
|
|
depthwise_seperable_out_channel: int = 256,
|
|
depthwise_multiplier: int = 1,
|
|
n_head: int = 4,
|
|
d_ffn: int = 2048,
|
|
ext_pw_kernel_size: int = 1,
|
|
kernel_size: int = 3,
|
|
dropout_rate: float = 0.1,
|
|
causal: bool = False,
|
|
batch_norm: bool = False,
|
|
activation: str = "relu",
|
|
chunk_se: int = 0,
|
|
chunk_size: int = 18,
|
|
conv_activation: str = "relu",
|
|
conv_glu_type: str = "sigmoid",
|
|
bias_in_glu: bool = True,
|
|
linear_glu_in_convm: bool = False,
|
|
attention_inner_dim: int = -1,
|
|
attention_glu_type: str = "swish",
|
|
activation_checkpointing: str = "",
|
|
export: bool = False,
|
|
use_pt_scaled_dot_product_attention: bool = False,
|
|
attn_group_sizes: int = 1,
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.feed_forward_in = FeedForward(
|
|
d_model=d_model,
|
|
d_inner=d_ffn,
|
|
dropout_rate=dropout_rate,
|
|
activation=activation,
|
|
bias_in_glu=bias_in_glu,
|
|
)
|
|
|
|
self.self_attn = MultiHeadedAttention(
|
|
n_head,
|
|
d_model,
|
|
dropout_rate,
|
|
attention_inner_dim,
|
|
attention_glu_type,
|
|
bias_in_glu,
|
|
use_pt_scaled_dot_product_attention=use_pt_scaled_dot_product_attention,
|
|
group_size=attn_group_sizes,
|
|
)
|
|
self.conv = ConvModule(
|
|
d_model,
|
|
ext_pw_out_channel,
|
|
depthwise_seperable_out_channel,
|
|
ext_pw_kernel_size,
|
|
kernel_size,
|
|
depthwise_multiplier,
|
|
dropout_rate,
|
|
causal,
|
|
batch_norm,
|
|
chunk_se,
|
|
chunk_size,
|
|
conv_activation,
|
|
conv_glu_type,
|
|
bias_in_glu,
|
|
linear_glu_in_convm,
|
|
export=export,
|
|
)
|
|
|
|
self.feed_forward_out = FeedForward(
|
|
d_model=d_model,
|
|
d_inner=d_ffn,
|
|
dropout_rate=dropout_rate,
|
|
activation=activation,
|
|
bias_in_glu=bias_in_glu,
|
|
)
|
|
|
|
self.layer_norm_att = nn.LayerNorm(d_model)
|
|
self.layer_norm = nn.LayerNorm(d_model)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
pos_k: torch.Tensor,
|
|
pos_v: torch.Tensor,
|
|
mask: torch.Tensor,
|
|
relative_attention_bias: Tensor | None = None,
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
"""ConformerEncoder forward.
|
|
|
|
Args:
|
|
x: input feature of shape (batch, max_time_in, size)
|
|
pos_k: positional key embedding.
|
|
pos_v: positional value embedding.
|
|
mask: mask for x (batch, max_time_in)
|
|
relative_attention_bias: bias added to attention logits w.r.t.
|
|
relative positions (1, n_head, time1, time2)
|
|
"""
|
|
x = x + 0.5 * self.feed_forward_in(x)
|
|
norm_x = self.layer_norm_att(x)
|
|
|
|
x = x + self.self_attn(
|
|
norm_x,
|
|
norm_x,
|
|
norm_x,
|
|
pos_k,
|
|
pos_v,
|
|
mask,
|
|
relative_attention_bias=relative_attention_bias,
|
|
)
|
|
x = x + self.conv(x)
|
|
x = x + 0.5 * self.feed_forward_out(x)
|
|
|
|
out = self.layer_norm(x)
|
|
|
|
return out, pos_k, pos_v, mask
|
|
|
|
|
|
class TransformerEncoderBase(abc.ABC, nn.Module):
|
|
"""The Base class for Transformer based encoders
|
|
|
|
Please set causal = True in streaming model
|
|
Args:
|
|
input_size: int
|
|
input feature dimension.
|
|
chunk_size: int, list(int)
|
|
Number of frames for each chunk
|
|
This variable can take 2 forms:
|
|
int: Used for inference, or single chunk size training
|
|
list(int) : Used only for variable chunk size training
|
|
Some examples for the 2 cases:
|
|
chunk_size = 12
|
|
chunk_size = [6, 8, 12, 24]
|
|
left_chunk: int, list(int)
|
|
Number of chunks used for masking in streaming mode.
|
|
This variable can take 2 forms:
|
|
int: Used for inference, or single chunk size training
|
|
list(int) : Used only for variable chunk size training. When
|
|
chunk_size is a list, left_chunk must be a list with same length.
|
|
Some examples for the 2 cases:
|
|
left_chunk = 6
|
|
left_chunk = [12, 9, 6, 3]
|
|
attention_dim: int, optional
|
|
attention dimension. default 256.
|
|
attention_heads: int, optional
|
|
the number of heads. default 4
|
|
input_layer: str, optional
|
|
input layer type before Conformer,
|
|
one of ["linear", "conv2d", "custom", "vgg2l", "embed"],
|
|
default "conv2d"
|
|
cnn_out: int, optional
|
|
the number of CNN channels before Conformer.
|
|
default -1.
|
|
cnn_layer_norm: bool, optional
|
|
layer norm between Conformer and the first CNN.
|
|
default False.
|
|
time_reduction: int, optional
|
|
time reduction factor
|
|
default 4
|
|
dropout_rate: float, optional
|
|
dropout rate. default 0.1
|
|
padding_idx: int, optional
|
|
padding index for input_layer=embed
|
|
default -1
|
|
relative_attention_bias_args: dict, optional
|
|
use more efficient scalar bias-based relative multihead attention
|
|
(Q*K^T + B) implemented in cmb.basics.embedding.
|
|
[T5/ALiBi]RelativeAttentionLogitBias
|
|
usage: relative_attention_bias_args={"type": t5/alibi}
|
|
additional method-specific arguments can be provided (see
|
|
transformer_base.py)
|
|
positional_dropout_rate: float, optional
|
|
dropout rate after positional encoding. default 0.0
|
|
nemo_conv_settings: dict, optional
|
|
A dictionary of settings for NeMo Subsampling.
|
|
default None
|
|
conv2d_extra_padding: str, optional
|
|
Add extra padding in conv2d subsampling layers. Choices are
|
|
(feat, feat_time, none, True).
|
|
if True or feat_time, the extra padding is added into non full
|
|
supraframe utts in batch.
|
|
Default: none
|
|
attention_group_size: int, optional
|
|
the number of groups to use for attention, default 1
|
|
(Multi-Head Attention),
|
|
1 = typical Multi-Head Attention,
|
|
1 < attention_group_size < attention_heads = Grouped-Query
|
|
Attention
|
|
attention_group_size = attention_heads = Multi-Query Attention
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
input_size: int,
|
|
chunk_size: int | list[int],
|
|
left_chunk: int | list[int],
|
|
attention_dim: int = 256,
|
|
attention_heads: int = 4,
|
|
input_layer: str = "nemo_conv",
|
|
cnn_out: int = -1,
|
|
cnn_layer_norm: bool = False,
|
|
time_reduction: int = 4,
|
|
dropout_rate: float = 0.0,
|
|
padding_idx: int = -1,
|
|
relative_attention_bias_args: dict[str, Any] | None = None,
|
|
positional_dropout_rate: float = 0.0,
|
|
nemo_conv_settings: dict[str, Any] | None = None,
|
|
conv2d_extra_padding: Literal["feat", "feat_time", "none", True] = "none",
|
|
attention_group_size: int = 1,
|
|
encoder_embedding_config: dict[str, Any] | None = None,
|
|
) -> None:
|
|
super().__init__()
|
|
self.input_size = input_size
|
|
self.input_layer = input_layer
|
|
self.chunk_size = chunk_size
|
|
self.left_chunk = left_chunk
|
|
self.attention_dim = attention_dim
|
|
self.num_heads = attention_heads
|
|
self.attention_group_size = attention_group_size
|
|
self.time_reduction = time_reduction
|
|
self.nemo_conv_settings = nemo_conv_settings
|
|
self.encoder_embedding_config = encoder_embedding_config
|
|
|
|
if self.input_layer == "nemo_conv":
|
|
default_nemo_conv_settings = {
|
|
"subsampling": "dw_striding",
|
|
"subsampling_factor": self.time_reduction,
|
|
"feat_in": input_size,
|
|
"feat_out": attention_dim,
|
|
"conv_channels": 256,
|
|
"subsampling_conv_chunking_factor": 1,
|
|
"activation": nn.ReLU(),
|
|
"is_causal": False,
|
|
}
|
|
# Override any of the defaults with the incoming, user settings
|
|
if nemo_conv_settings:
|
|
default_nemo_conv_settings.update(nemo_conv_settings)
|
|
for i in ["subsampling_factor", "feat_in", "feat_out"]:
|
|
assert i not in nemo_conv_settings, (
|
|
"{i} should be specified outside of the NeMo dictionary"
|
|
)
|
|
|
|
self.embed = NemoConvSubsampling(
|
|
**default_nemo_conv_settings,
|
|
)
|
|
else:
|
|
raise ValueError("unknown input_layer: " + input_layer)
|
|
|
|
self.pos_emb = AbsolutePositionalEncoding(
|
|
attention_dim, positional_dropout_rate
|
|
)
|
|
|
|
self.relative_attention_bias_type = (
|
|
relative_attention_bias_args.get("type")
|
|
if relative_attention_bias_args
|
|
else None
|
|
)
|
|
if self.relative_attention_bias_type == "t5":
|
|
assert self.num_heads % self.attention_group_size == 0, (
|
|
"attention_group_size must divide n_head"
|
|
)
|
|
self.relative_attention_bias_layer = T5RelativeAttentionLogitBias(
|
|
self.num_heads // self.attention_group_size,
|
|
max_distance=relative_attention_bias_args.get(
|
|
"t5_bias_max_distance", 1000
|
|
),
|
|
symmetric=relative_attention_bias_args.get("t5_bias_symmetric", False),
|
|
)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
self.encoder_embedding = MeanVarianceNormLayer(
|
|
self.encoder_embedding_config["input_size"]
|
|
)
|
|
|
|
def compute_lens_change(
|
|
self, feature_lens: int | torch.Tensor
|
|
) -> int | torch.Tensor:
|
|
"""feature_lens: int
|
|
return updated feature lens.
|
|
|
|
This used to return a different lambda function for each case that
|
|
computed the right thing. That does not work within Torchscript.
|
|
If you really need this to be faster, create nn.Module()-s for all
|
|
the cases and return one of them. Torchscript does support that.
|
|
"""
|
|
if self.input_layer == "nemo_conv":
|
|
# Handle the special causal case
|
|
subsampling_causal_cond = self.nemo_conv_settings.get(
|
|
"subsampling", "dw_striding"
|
|
) in [
|
|
"dw_striding",
|
|
"striding",
|
|
"striding_conv1d",
|
|
]
|
|
is_causal = self.nemo_conv_settings.get("is_causal", False)
|
|
if is_causal and subsampling_causal_cond:
|
|
lens_change = (
|
|
torch.ceil(feature_lens / self.time_reduction).long()
|
|
if isinstance(feature_lens, Tensor)
|
|
else math.ceil(feature_lens / self.time_reduction)
|
|
)
|
|
feature_lens_remainder = feature_lens % self.time_reduction
|
|
if isinstance(feature_lens, Tensor):
|
|
lens_change[feature_lens_remainder != 1] += 1
|
|
elif feature_lens_remainder != 1:
|
|
lens_change += 1
|
|
return lens_change
|
|
ceil_func = math.ceil if isinstance(feature_lens, int) else torch.ceil
|
|
return ceil_func(feature_lens / self.time_reduction)
|
|
|
|
@abc.abstractmethod
|
|
def forward(self) -> Any:
|
|
"""Abstract forward method implementation."""
|
|
|
|
def _chunk_size_selection(
|
|
self,
|
|
chunk_size: int | list[int] | None = None,
|
|
left_chunk: int | list[int] | None = None,
|
|
) -> tuple[int, int]:
|
|
"""If chunk size is a list, we will randomly select a chunk size."""
|
|
|
|
if chunk_size is None:
|
|
chunk_size = self.chunk_size
|
|
if left_chunk is None:
|
|
left_chunk = self.left_chunk
|
|
if isinstance(chunk_size, list):
|
|
# Variable chunk size during training
|
|
chunk_size_index = int(
|
|
torch.randint(low=0, high=len(chunk_size), size=(1,))
|
|
)
|
|
chunk_size_train_eff = chunk_size[chunk_size_index]
|
|
if not isinstance(left_chunk, list):
|
|
raise ValueError(
|
|
"Since chunk_size is a list, left_chunk must be a list"
|
|
)
|
|
if len(left_chunk) != len(chunk_size):
|
|
raise ValueError(
|
|
"The length of left_chunk must be the same as length of chunk_size."
|
|
)
|
|
left_chunk_train_eff = left_chunk[chunk_size_index]
|
|
else:
|
|
chunk_size_train_eff = chunk_size
|
|
left_chunk_train_eff = left_chunk
|
|
|
|
return chunk_size_train_eff, left_chunk_train_eff
|
|
|
|
def _get_embed_class(self, embed: nn.Module) -> nn.Module:
|
|
# pylint: disable=protected-access
|
|
is_embed_using_act_chkpt = isinstance(embed, CheckpointWrapper)
|
|
is_embed_fsdp_wrapped = isinstance(embed, FullyShardedDataParallel)
|
|
embed_class = embed
|
|
if is_embed_using_act_chkpt:
|
|
embed_class = embed._checkpoint_wrapped_module
|
|
if is_embed_fsdp_wrapped:
|
|
embed_class = embed.module
|
|
return embed_class
|
|
|
|
def _forward_embeddings_core(
|
|
self, input_tensor: torch.Tensor, masks: torch.Tensor
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
embed_class = self._get_embed_class(self.embed)
|
|
assert isinstance(embed_class, NemoConvSubsampling)
|
|
input_tensor, masks = self.embed(input_tensor, masks)
|
|
return input_tensor, masks
|
|
|
|
def _position_embedding(
|
|
self, input_tensor: torch.Tensor
|
|
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
|
|
pos_k = None
|
|
pos_v = None
|
|
if self.relative_attention_bias_layer is None:
|
|
input_tensor = self.pos_emb(
|
|
input_tensor
|
|
) # default to add abs sinusoid embedding
|
|
return pos_k, pos_v
|
|
|
|
def _streaming_mask(
|
|
self,
|
|
seq_len: int,
|
|
batch_size: int,
|
|
chunk_size: int | list[int],
|
|
left_chunk: int | list[int],
|
|
) -> torch.Tensor:
|
|
chunk_size_train_eff, left_chunk_train_eff = self._chunk_size_selection(
|
|
chunk_size, left_chunk
|
|
)
|
|
|
|
# Create mask matrix for streaming
|
|
# S stores start index. if chunksize is 18, s is [0,18,36,....]
|
|
chunk_start_idx = np.arange(0, seq_len, chunk_size_train_eff)
|
|
|
|
enc_streaming_mask = (
|
|
adaptive_enc_mask(
|
|
seq_len, chunk_start_idx, left_window=left_chunk_train_eff
|
|
)
|
|
.unsqueeze(0)
|
|
.expand([batch_size, -1, -1])
|
|
)
|
|
return enc_streaming_mask
|
|
|
|
def forward_embeddings(
|
|
self,
|
|
xs_pad: torch.Tensor,
|
|
masks: torch.Tensor,
|
|
chunk_size_nc: int | list[int] | None = None,
|
|
left_chunk_nc: int | list[int] | None = None,
|
|
) -> (
|
|
tuple[
|
|
torch.Tensor,
|
|
torch.Tensor | None,
|
|
torch.Tensor | None,
|
|
torch.Tensor,
|
|
torch.Tensor,
|
|
]
|
|
| tuple[
|
|
torch.Tensor,
|
|
torch.Tensor | None,
|
|
torch.Tensor | None,
|
|
torch.Tensor,
|
|
torch.Tensor,
|
|
torch.Tensor,
|
|
]
|
|
):
|
|
"""Forwarding the inputs through the top embedding layers
|
|
|
|
Args:
|
|
xs_pad: torch.Tensor
|
|
input tensor
|
|
masks: torch.Tensor
|
|
input mask
|
|
chunk_size_nc: (optional, default is None) chunk size for
|
|
non-causal layers
|
|
left_chunk_nc: (optional, default is None) # of left chunks for
|
|
non-causal layers
|
|
"""
|
|
# pylint: disable=R0915
|
|
# get new lens.
|
|
seq_len = int(self.compute_lens_change(xs_pad.shape[1]))
|
|
if seq_len <= 0:
|
|
raise ValueError(
|
|
f"""The sequence length after time reduction is invalid:
|
|
{seq_len}. Your input feature is too short. Consider
|
|
filtering out the very short sentence from data
|
|
loader""",
|
|
)
|
|
|
|
batch_size = xs_pad.shape[0]
|
|
|
|
enc_streaming_mask = self._streaming_mask(
|
|
seq_len, batch_size, self.chunk_size, self.left_chunk
|
|
)
|
|
|
|
if xs_pad.is_cuda:
|
|
enc_streaming_mask = enc_streaming_mask.cuda()
|
|
xs_pad = xs_pad.cuda()
|
|
|
|
input_tensor = xs_pad
|
|
input_tensor, masks = self._forward_embeddings_core(input_tensor, masks)
|
|
|
|
streaming_mask = enc_streaming_mask
|
|
if streaming_mask is not None and masks is not None:
|
|
hs_mask = masks & streaming_mask
|
|
elif masks is not None:
|
|
hs_mask = masks
|
|
else:
|
|
hs_mask = streaming_mask
|
|
|
|
if chunk_size_nc is not None:
|
|
enc_streaming_mask_nc = self._streaming_mask(
|
|
seq_len, batch_size, chunk_size_nc, left_chunk_nc
|
|
)
|
|
if xs_pad.is_cuda:
|
|
enc_streaming_mask_nc = enc_streaming_mask_nc.cuda()
|
|
if masks is not None:
|
|
hs_mask_nc = masks & enc_streaming_mask_nc
|
|
else:
|
|
hs_mask_nc = enc_streaming_mask_nc
|
|
else:
|
|
hs_mask_nc = None
|
|
|
|
pos_k, pos_v = self._position_embedding(input_tensor)
|
|
|
|
if chunk_size_nc is None:
|
|
return input_tensor, pos_k, pos_v, hs_mask, masks
|
|
return input_tensor, pos_k, pos_v, hs_mask, masks, hs_mask_nc
|
|
|
|
def get_offset(self) -> int:
|
|
"""Returns offset used when retaining inputs for decoding.
|
|
|
|
This is essentially, how many additional frames have to be added to
|
|
the front-end CNN input to ensure it can produce a single output.
|
|
So if the "padding" parameter is 0, typically offset will be > 0.
|
|
"""
|
|
return get_offset(self.input_layer, self.time_reduction)
|
|
|
|
|
|
class ConformerEncoder(TransformerEncoderBase):
|
|
"""ConformerEncoder module.
|
|
see original paper for more details:
|
|
https://arxiv.org/abs/2005.08100
|
|
|
|
Please set causal = True in streaming model
|
|
Args:
|
|
input_size: int
|
|
input feature dimension.
|
|
chunk_size: int, list(int)
|
|
Number of frames for each chunk
|
|
This variable can take 2 forms:
|
|
int: Used for inference, or single chunk size training
|
|
list(int) : Used only for variable chunk size training
|
|
Some examples for the 2 cases:
|
|
chunk_size = 12
|
|
chunk_size = [6, 8, 12, 24]
|
|
left_chunk: int, list(int)
|
|
Number of chunks used for masking in streaming mode.
|
|
This variable can take 2 forms:
|
|
int: Used for inference, or single chunk size training
|
|
list(int) : Used only for variable chunk size training. When
|
|
chunk_size is a list, left_chunk must be a list with same length.
|
|
Some examples for the 2 cases:
|
|
left_chunk = 6
|
|
left_chunk = [12, 9, 6, 3]
|
|
num_lang: int
|
|
This parameter is used to store the number of languages in the
|
|
lang_dict, only used for multiseed/multilingual models.
|
|
default None.
|
|
attention_dim: int, optional
|
|
attention dimension. default 256.
|
|
attention_heads: int, optional
|
|
the number of heads. default 4
|
|
linear_units:
|
|
the number of units of position-wise feed forward.
|
|
default 2048
|
|
num_block:
|
|
number of Transformer layer. default 6
|
|
dropout_rate: float, optional
|
|
dropout rate. default 0.1
|
|
input_layer: str, optional
|
|
input layer type before Conformer,
|
|
one of ["linear", "conv2d", "custom", "vgg2l", "embed"],
|
|
default "conv2d"
|
|
causal: bool, optional
|
|
if set to True, convolution have no access
|
|
to future frames. default False.
|
|
batch_norm: bool, optional
|
|
if set to True, apply batchnorm before activation
|
|
in ConvModule layer of the conformer.
|
|
default False
|
|
cnn_out: int, optional
|
|
the number of CNN channels before Conformer.
|
|
default -1.
|
|
cnn_layer_norm: bool, optional
|
|
layer norm between Conformer and the first CNN.
|
|
default False.
|
|
ext_pw_out_channel: int, optional
|
|
the number of channel for CNN
|
|
before depthwise_seperable_CNN.
|
|
If 0 then use linear. default 0.
|
|
ext_pw_kernel_size: int, optional
|
|
kernel size of N before depthwise_seperable_CNN.
|
|
only work for ext_pw_out_channel > 0.
|
|
default 1
|
|
depthwise_seperable_out_channel: int, optional
|
|
the number of channel for
|
|
depthwise_seperable_CNN.
|
|
default 256.
|
|
depthwise_multiplier: int, optional
|
|
the number of multiplier for
|
|
depthwise_seperable_CNN.
|
|
default 1.
|
|
chunk_se: int, optional
|
|
0 for offline SE.
|
|
1 for streaming SE, where mean is computed
|
|
by accumulated history until current chunk_se.
|
|
2 for streaming SE, where mean is computed
|
|
by only the current chunk.
|
|
default 0.
|
|
kernel_size: int, optional
|
|
the number of kernels for depthwise_seperable_CNN.
|
|
default 3.
|
|
activation: str, optional
|
|
FeedForward block activation.
|
|
one of ["relu", "swish", "sigmoid"]
|
|
default "relu".
|
|
conv_activation: str, optional
|
|
activation function used in ConvModule part
|
|
of the conformer, default "relu".
|
|
conv_glu_type: str, optional
|
|
activation used use glu in depthwise_seperable_CNN,
|
|
default "sigmoid"
|
|
bias_in_glu: bool, optional
|
|
if set to True, use additive bias in the weight module
|
|
before GLU. default True
|
|
linear_glu_in_convm: bool, optional
|
|
if set to True, use GLULinear module,
|
|
otherwise, used GLUPointWiseConv module.
|
|
default to False.
|
|
attention_glu_type: str
|
|
only work for glu_in_attention !=0
|
|
default "swish".
|
|
export: bool, optional
|
|
if set to True, it removes the padding from convolutional layers
|
|
and allow the onnx conversion for inference.
|
|
default False.
|
|
activation_checkpointing: str, optional
|
|
a dictionarry of {"module","interval","offload"}, where
|
|
"module": str
|
|
accept ["transformer", "attention"] to select
|
|
which module should do activation checkpointing.
|
|
"interval": int, default 1,
|
|
interval of applying activation checkpointing,
|
|
interval = 1 means that we apply checkpointing
|
|
on every layer (if activation), otherwise,
|
|
we apply it every x interval.
|
|
"offload": bool, default False,
|
|
if set to True, we offload activation to cpu and
|
|
reload it during backward, otherwise,
|
|
we recalculate activation in backward.
|
|
default "".
|
|
extra_layer_output_idx: int
|
|
the layer index to be exposed.
|
|
relative_attention_bias_args: dict, optional
|
|
use more efficient scalar bias-based relative multihead attention
|
|
(Q*K^T + B) implemented in cmb.basics.embedding.
|
|
[T5/ALiBi]RelativeAttentionLogitBias
|
|
usage: relative_attention_bias_args={"type": t5/alibi}
|
|
additional method-specific arguments can be provided (see
|
|
transformer_base.py)
|
|
time_reduction: int optional
|
|
time reduction factor
|
|
default 4
|
|
use_pt_scaled_dot_product_attention: whether to use pytorch scaled
|
|
dot product attention in training.
|
|
Default: False
|
|
nemo_conv_settings: dict, optional
|
|
A dictionary of settings for NeMo Subsampling.
|
|
default: None
|
|
usage: nemo_conv_settings=
|
|
{
|
|
"subsampling":
|
|
dw_striding/striding/dw_striding_conv1d/striding_conv1d,
|
|
"conv_channels": int,
|
|
"subsampling_conv_chunking_factor": int,
|
|
"is_causal": True/False
|
|
}
|
|
conv2d_extra_padding: str, optional
|
|
Add extra padding in conv2d subsampling layers. Choices are
|
|
(feat, feat_time, none, True)
|
|
Default: none
|
|
replication_pad_for_subsample_embedding: For batched-streaming
|
|
decoding, use "replication" padding for the cache at start of
|
|
utterance.
|
|
Default: False
|
|
attention_group_size: int, optional
|
|
the number of groups to use for attention, default 1
|
|
(Multi-Head Attention),
|
|
1 = typical Multi-Head Attention,
|
|
1 < attention_group_size < attention_heads = Grouped-Query
|
|
Attention
|
|
attention_group_size = attention_heads = Multi-Query Attention
|
|
"""
|
|
|
|
extra_multi_layer_output_idxs: list[int]
|
|
|
|
def __init__( # pylint: disable-all
|
|
self,
|
|
input_size: int,
|
|
chunk_size: int | list[int],
|
|
left_chunk: int | list[int],
|
|
num_lang: int | None = None,
|
|
attention_dim: int = 256,
|
|
attention_heads: int = 4,
|
|
linear_units: int = 2048,
|
|
num_blocks: int = 6,
|
|
dropout_rate: float = 0.1,
|
|
input_layer: str = "nemo_conv",
|
|
causal: bool = True,
|
|
batch_norm: bool = False,
|
|
cnn_out: int = -1,
|
|
cnn_layer_norm: bool = False,
|
|
ext_pw_out_channel: int = 0,
|
|
ext_pw_kernel_size: int = 1,
|
|
depthwise_seperable_out_channel: int = 256,
|
|
depthwise_multiplier: int = 1,
|
|
chunk_se: int = 0,
|
|
kernel_size: int = 3,
|
|
activation: str = "relu",
|
|
conv_activation: str = "relu",
|
|
conv_glu_type: str = "sigmoid",
|
|
bias_in_glu: bool = True,
|
|
linear_glu_in_convm: bool = False,
|
|
attention_glu_type: str = "swish",
|
|
export: bool = False,
|
|
extra_layer_output_idx: int = -1,
|
|
extra_multi_layer_output_idxs: list[int] = [], # noqa
|
|
activation_checkpointing: str = "",
|
|
relative_attention_bias_args: dict[str, Any] | None = None,
|
|
time_reduction: int = 4,
|
|
use_pt_scaled_dot_product_attention: bool = False,
|
|
nemo_conv_settings: dict[str, Any] | None = None,
|
|
conv2d_extra_padding: Literal["feat", "feat_time", "none", True] = "none",
|
|
replication_pad_for_subsample_embedding: bool = False,
|
|
attention_group_size: int = 1,
|
|
encoder_embedding_config: dict[str, Any] | None = None,
|
|
) -> None:
|
|
super().__init__(
|
|
input_size,
|
|
chunk_size,
|
|
left_chunk,
|
|
attention_dim,
|
|
attention_heads,
|
|
input_layer,
|
|
cnn_out,
|
|
cnn_layer_norm,
|
|
time_reduction,
|
|
dropout_rate=dropout_rate,
|
|
relative_attention_bias_args=relative_attention_bias_args,
|
|
positional_dropout_rate=0.0,
|
|
nemo_conv_settings=nemo_conv_settings,
|
|
conv2d_extra_padding=conv2d_extra_padding,
|
|
attention_group_size=attention_group_size,
|
|
encoder_embedding_config=encoder_embedding_config,
|
|
)
|
|
self.num_blocks = num_blocks
|
|
self.num_lang = num_lang
|
|
self.kernel_size = kernel_size
|
|
self.replication_pad_for_subsample_embedding: bool = (
|
|
replication_pad_for_subsample_embedding
|
|
)
|
|
assert self.num_heads % attention_group_size == 0, (
|
|
"attention_group_size must divide n_head"
|
|
)
|
|
self.num_heads_k = self.num_heads // attention_group_size
|
|
|
|
self.encoders = MultiSequential(
|
|
*[
|
|
ConformerEncoderLayer(
|
|
d_model=attention_dim,
|
|
ext_pw_out_channel=ext_pw_out_channel,
|
|
depthwise_seperable_out_channel=depthwise_seperable_out_channel,
|
|
depthwise_multiplier=depthwise_multiplier,
|
|
n_head=attention_heads,
|
|
d_ffn=linear_units,
|
|
ext_pw_kernel_size=ext_pw_kernel_size,
|
|
kernel_size=kernel_size,
|
|
dropout_rate=dropout_rate,
|
|
causal=causal,
|
|
batch_norm=batch_norm,
|
|
activation=activation,
|
|
chunk_se=chunk_se,
|
|
chunk_size=chunk_size,
|
|
conv_activation=conv_activation,
|
|
conv_glu_type=conv_glu_type,
|
|
bias_in_glu=bias_in_glu,
|
|
linear_glu_in_convm=linear_glu_in_convm,
|
|
attention_glu_type=attention_glu_type,
|
|
activation_checkpointing=activation_checkpointing,
|
|
export=export,
|
|
use_pt_scaled_dot_product_attention=use_pt_scaled_dot_product_attention,
|
|
attn_group_sizes=attention_group_size,
|
|
)
|
|
for _ in range(num_blocks)
|
|
]
|
|
)
|
|
self.extra_layer_output_idx = extra_layer_output_idx
|
|
self.extra_multi_layer_output_idxs = extra_multi_layer_output_idxs
|
|
# Make a zeros scalar we can use in get_initial_state to determine
|
|
# the device and the needed dtype:
|
|
self.register_buffer("dev_type", torch.zeros(()), persistent=False)
|
|
|
|
def init_relative_attention_bias(
|
|
self, input_tensor: torch.Tensor
|
|
) -> torch.Tensor | None:
|
|
if self.relative_attention_bias_layer:
|
|
return self.relative_attention_bias_layer(input_tensor)
|
|
|
|
def calculate_hs_mask(
|
|
self, xs_pad: torch.Tensor, device: torch.device, mask: torch.Tensor | None
|
|
) -> torch.Tensor:
|
|
max_audio_length = xs_pad.shape[1]
|
|
batch_size = xs_pad.shape[0]
|
|
enc_streaming_mask = self._streaming_mask(
|
|
max_audio_length, batch_size, self.chunk_size, self.left_chunk
|
|
)
|
|
enc_streaming_mask = enc_streaming_mask.to(device)
|
|
if mask is None:
|
|
return enc_streaming_mask
|
|
|
|
feature_lens = mask.sum(1)
|
|
padding_length = feature_lens
|
|
pad_mask = torch.arange(0, max_audio_length, device=device).expand(
|
|
padding_length.size(0), -1
|
|
) < padding_length.unsqueeze(1)
|
|
pad_mask = pad_mask.unsqueeze(1)
|
|
pad_mask = pad_mask & enc_streaming_mask
|
|
return pad_mask
|
|
|
|
@torch.jit.ignore
|
|
def forward(
|
|
self, xs_pad: torch.Tensor, masks: torch.Tensor
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""Conformer Forward function
|
|
|
|
Args:
|
|
xs_pad: torch.Tensor
|
|
input tensor
|
|
masks: torch.Tensor
|
|
post-embedding input lengths
|
|
"""
|
|
xs_pad = self.encoder_embedding(xs_pad)
|
|
input_tensor, pos_k, pos_v, hs_mask, masks = self.forward_embeddings(
|
|
xs_pad, masks
|
|
)
|
|
|
|
unfolded = False
|
|
ori_bz, seq_len, D = input_tensor.shape
|
|
max_seq_len = 500 # maximum position for absolute positional encoding
|
|
if seq_len > max_seq_len:
|
|
# audio sequence is longer than max_seq_len, unfold it into chunks
|
|
# of max_seq_len
|
|
unfolded = True
|
|
# the unfold op will drop residual frames, pad it to the multiple
|
|
# of max_seq_len
|
|
if seq_len % max_seq_len > 0:
|
|
chunk_pad_size = max_seq_len - (seq_len % max_seq_len)
|
|
else:
|
|
chunk_pad_size = 0
|
|
if chunk_pad_size > 0:
|
|
input_tensor_pad = F.pad(
|
|
input_tensor, (0, 0, 0, chunk_pad_size), "constant", 0
|
|
)
|
|
input_tensor = input_tensor_pad.to(input_tensor.device)
|
|
input_tensor = unfold_tensor(input_tensor, max_seq_len)
|
|
if masks is not None:
|
|
# revise hs_mask here because the previous calculated hs_mask
|
|
# did not consider extra pad
|
|
subsampled_pad_mask = masks.squeeze(
|
|
1
|
|
) # [bz, subsampled_unmask_seq_len]
|
|
extra_padded_subsamlped_pad_mask = F.pad(
|
|
subsampled_pad_mask, (0, chunk_pad_size), "constant", False
|
|
) # extra padding to the pad mask
|
|
extra_padded_subsamlped_pad_mask = (
|
|
extra_padded_subsamlped_pad_mask.unsqueeze(-1).float()
|
|
)
|
|
masks_unfold = unfold_tensor(
|
|
extra_padded_subsamlped_pad_mask, max_seq_len
|
|
) # unfold the pad mask like we did to the input tensor
|
|
masks_unfold = masks_unfold.squeeze(
|
|
-1
|
|
).bool() # unfold op does not support bool tensor
|
|
else:
|
|
masks_unfold = None
|
|
hs_mask = self.calculate_hs_mask(
|
|
input_tensor, input_tensor.device, masks_unfold
|
|
) # calculate hs_mask based on the unfolded pad mask
|
|
|
|
# layer_emb = None
|
|
|
|
relative_attention_bias = self.init_relative_attention_bias(input_tensor)
|
|
|
|
_simplified_path = (
|
|
self.extra_layer_output_idx == -1 and relative_attention_bias is None
|
|
)
|
|
|
|
if _simplified_path:
|
|
input_tensor, *_ = self.encoders(input_tensor, pos_k, pos_v, hs_mask)
|
|
else:
|
|
for i, layer in enumerate(self.encoders):
|
|
input_tensor, _, _, _ = layer(
|
|
input_tensor,
|
|
pos_k,
|
|
pos_v,
|
|
hs_mask,
|
|
relative_attention_bias=relative_attention_bias,
|
|
)
|
|
|
|
# if i == self.extra_layer_output_idx:
|
|
# layer_emb = input_tensor
|
|
|
|
if unfolded:
|
|
embed_dim = input_tensor.shape[-1]
|
|
input_tensor = input_tensor.reshape(ori_bz, -1, embed_dim)
|
|
# if we ever padded before unfolding, we need to remove the padding
|
|
if chunk_pad_size > 0:
|
|
input_tensor = input_tensor[:, :-chunk_pad_size, :]
|
|
|
|
return input_tensor, masks # , layer_emb
|
|
|
|
|
|
class WindowQformer(nn.Module):
|
|
"""Window-level Qformer"""
|
|
|
|
def __init__(
|
|
self,
|
|
window_size: int = 8,
|
|
num_queries: int = 1,
|
|
num_blocks: int = 2,
|
|
attention_dim: int = 512,
|
|
attention_heads: int = 8,
|
|
linear_units: int = 2048,
|
|
dropout_rate: float = 0.0,
|
|
normalize_before: bool = True,
|
|
):
|
|
super().__init__()
|
|
|
|
self.decoders = nn.ModuleList(
|
|
[
|
|
nn.TransformerDecoderLayer(
|
|
d_model=attention_dim,
|
|
nhead=attention_heads,
|
|
dim_feedforward=linear_units,
|
|
dropout=dropout_rate,
|
|
activation="relu",
|
|
batch_first=True,
|
|
norm_first=normalize_before, # TODO need to verify
|
|
)
|
|
for _ in range(num_blocks)
|
|
]
|
|
)
|
|
|
|
self.queries = nn.Parameter(torch.zeros(1, num_queries, attention_dim))
|
|
self.after_norm = (
|
|
nn.LayerNorm(attention_dim, eps=1e-12) if normalize_before else None
|
|
)
|
|
self.window_size = window_size
|
|
|
|
def forward(
|
|
self,
|
|
audio_embed: torch.Tensor,
|
|
mask: torch.Tensor | None,
|
|
embed_len: int | None = None,
|
|
) -> tuple[torch.Tensor, int | None]:
|
|
"""forward decoder"""
|
|
# audio_embed: N x T x D => N x D x T
|
|
|
|
audio_embed = audio_embed.transpose(1, 2)
|
|
# audio_embed: N x D x 1 x T => N x DK x T'
|
|
padding = audio_embed.shape[-1] % self.window_size
|
|
if padding > 0:
|
|
audio_embed = F.pad(
|
|
audio_embed, (0, self.window_size - padding), "constant", 0
|
|
)
|
|
|
|
embed_chunk = F.unfold(
|
|
audio_embed[..., None, :],
|
|
kernel_size=(1, self.window_size),
|
|
stride=(1, self.window_size),
|
|
)
|
|
bsz, _, slen = embed_chunk.shape
|
|
# N x D x K x T'
|
|
embed_chunk = embed_chunk.view(bsz, -1, self.window_size, slen)
|
|
# N x T' x K x D
|
|
embed_chunk = embed_chunk.transpose(1, 3).contiguous()
|
|
# NT' x K x D
|
|
embed_chunk = embed_chunk.view(bsz * slen, self.window_size, -1)
|
|
# NT' x 1 x D
|
|
q = self.queries.expand(bsz * slen, -1, -1)
|
|
for layer in self.decoders:
|
|
q = layer(tgt=q, memory=embed_chunk, tgt_mask=None, memory_mask=mask)
|
|
|
|
if self.after_norm is not None:
|
|
q = self.after_norm(q)
|
|
|
|
if embed_len is not None:
|
|
embed_len = embed_len // self.window_size
|
|
# N x T' x D
|
|
out = q.view(bsz, slen, -1)
|
|
|
|
return out, embed_len
|
|
|
|
|
|
class AudioEmbedding(nn.Module):
|
|
"""Image embedding."""
|
|
|
|
def __init__(self, config: PretrainedConfig, **kwargs: Any) -> None:
|
|
super().__init__()
|
|
self.config = config
|
|
# n_embed or hidden_size for text LM
|
|
hidden_size = config.n_embd if hasattr(config, "n_embd") else config.hidden_size
|
|
|
|
# self.wte = nn.Embedding(config.vocab_size, hidden_size)
|
|
|
|
audio_dim_out = (
|
|
None # Set this variable according to the actual audio processor
|
|
)
|
|
self.layer_idx = -2
|
|
|
|
if (
|
|
isinstance(config.audio_processor, dict)
|
|
and config.audio_processor.get("name", None) == "cascades"
|
|
):
|
|
encoder_config = config.audio_processor.get("config", None)
|
|
assert encoder_config is not None
|
|
self.encoder = ConformerEncoder(**encoder_config)
|
|
|
|
audio_dim_out = encoder_config["attention_dim"]
|
|
n_mels = encoder_config["input_size"]
|
|
else:
|
|
raise NotImplementedError("")
|
|
|
|
assert audio_dim_out is not None, "Remember to set values for audio_dim_out"
|
|
self.audio_dim_out = audio_dim_out
|
|
self.audio_dim_in = n_mels
|
|
|
|
self.freeze_audio_processor = kwargs.get("freeze_audio_processor", False)
|
|
|
|
self.downsample_rate = kwargs.get("downsample_rate", 1)
|
|
|
|
if kwargs.get("use_qformer", False):
|
|
qformer_config = kwargs.get("qformer_config", {})
|
|
qformer_config["attention_dim"] = audio_dim_out
|
|
self.qformer = WindowQformer(**qformer_config)
|
|
else:
|
|
self.qformer = None
|
|
|
|
if kwargs.get("use_conv_downsample", False):
|
|
assert self.qformer is None, (
|
|
"don't support use qformer and conv downsample together"
|
|
)
|
|
nemo_conv_settings = kwargs.get("nemo_conv_settings", {})
|
|
default_nemo_conv_settings = {
|
|
"subsampling": "dw_striding",
|
|
"subsampling_factor": self.downsample_rate,
|
|
"feat_in": audio_dim_out,
|
|
"feat_out": audio_dim_out,
|
|
"conv_channels": 256,
|
|
"subsampling_conv_chunking_factor": 1,
|
|
"activation": nn.ReLU(),
|
|
"is_causal": False,
|
|
}
|
|
# Override any of the defaults with the incoming, user settings
|
|
if nemo_conv_settings:
|
|
default_nemo_conv_settings.update(nemo_conv_settings)
|
|
for i in ["subsampling_factor", "feat_in", "feat_out"]:
|
|
assert i not in nemo_conv_settings, (
|
|
"{i} should be specified outside of the NeMo dictionary"
|
|
)
|
|
|
|
self.conv_ds = NemoConvSubsampling(
|
|
**default_nemo_conv_settings,
|
|
)
|
|
else:
|
|
self.conv_ds = None
|
|
|
|
projection_cls = kwargs.get("projection_cls", "linear")
|
|
if projection_cls == "linear":
|
|
self.audio_projection = nn.Linear(audio_dim_out, hidden_size)
|
|
elif projection_cls == "mlp":
|
|
# follow llava-v1.5's implementation
|
|
# (do not use image_projection and image_proj_norm)
|
|
dim_projection = hidden_size
|
|
depth = 2
|
|
self.linear_downsample_rate = (
|
|
1 if (self.qformer or self.conv_ds) else self.downsample_rate
|
|
)
|
|
layers = [
|
|
nn.Linear(audio_dim_out * self.linear_downsample_rate, dim_projection)
|
|
]
|
|
for _ in range(1, depth):
|
|
layers.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection)])
|
|
self.audio_projection = nn.Sequential(*layers)
|
|
# NOTE vision-speech tasks use a separate projection layer
|
|
layers = [
|
|
nn.Linear(audio_dim_out * self.linear_downsample_rate, dim_projection)
|
|
]
|
|
for _ in range(1, depth):
|
|
layers.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection)])
|
|
self.audio_projection_for_vision = nn.Sequential(*layers)
|
|
else:
|
|
raise NotImplementedError(
|
|
f"projection_cls = {projection_cls}, not implemented"
|
|
)
|
|
|
|
# TODO: audio sequence compression - Qformer
|
|
self.vocab_size = config.vocab_size
|
|
self.input_embeds = None
|
|
self.audio_embed_sizes = None
|
|
|
|
def set_audio_embeds(self, input_embeds: torch.Tensor) -> None:
|
|
self.input_embeds = input_embeds
|
|
|
|
def set_audio_embed_sizes(self, audio_embed_sizes: torch.Tensor) -> None:
|
|
self.audio_embed_sizes = audio_embed_sizes
|
|
|
|
def get_audio_features(
|
|
self,
|
|
input_embeds: torch.Tensor,
|
|
audio_attention_mask: torch.Tensor | None = None,
|
|
audio_projection_mode: str = "speech",
|
|
) -> torch.Tensor:
|
|
"""
|
|
arguments:
|
|
input_embeds: audio features (B, T, D) B: num audios in a sequence
|
|
"""
|
|
if self.freeze_audio_processor:
|
|
with torch.no_grad():
|
|
audio_features, masks = self.encoder(input_embeds, audio_attention_mask)
|
|
else:
|
|
audio_features, masks = self.encoder(input_embeds, audio_attention_mask)
|
|
|
|
if self.qformer is not None:
|
|
audio_features, _ = self.qformer(audio_features, mask=None)
|
|
|
|
if self.conv_ds is not None:
|
|
if masks is not None:
|
|
masks = masks.squeeze(1)
|
|
|
|
audio_features, masks = self.conv_ds(audio_features, mask=masks)
|
|
|
|
if self.linear_downsample_rate != 1:
|
|
bs, seq_len, feat_dim = audio_features.size()
|
|
padding = seq_len % self.linear_downsample_rate
|
|
if padding > 0:
|
|
audio_features = F.pad(
|
|
audio_features,
|
|
(0, 0, 0, self.linear_downsample_rate - padding),
|
|
"constant",
|
|
0,
|
|
)
|
|
|
|
seq_len = audio_features.size(1)
|
|
audio_features = audio_features.view(
|
|
bs,
|
|
seq_len // self.linear_downsample_rate,
|
|
feat_dim * self.linear_downsample_rate,
|
|
)
|
|
|
|
if audio_projection_mode == "speech":
|
|
audio_set_tensor = self.audio_projection(audio_features)
|
|
elif audio_projection_mode == "vision":
|
|
audio_set_tensor = self.audio_projection_for_vision(audio_features)
|
|
else:
|
|
raise ValueError(
|
|
f"audio_projection_mode = {audio_projection_mode} not implemented"
|
|
)
|
|
|
|
return audio_set_tensor
|
|
|
|
def forward(
|
|
self,
|
|
audio_features: torch.Tensor,
|
|
audio_attention_mask: torch.Tensor | None = None,
|
|
audio_projection_mode: str = "speech",
|
|
) -> torch.Tensor:
|
|
"""
|
|
arguments:
|
|
audio_features: audio features (T, D)
|
|
|
|
returns:
|
|
audio_embeds: audio embeddings (num_audio_tokens, hidden_dim)
|
|
"""
|
|
audio_embeds = self.get_audio_features(
|
|
audio_features.unsqueeze(0),
|
|
audio_attention_mask=audio_attention_mask,
|
|
audio_projection_mode=audio_projection_mode,
|
|
)
|
|
return audio_embeds.squeeze(0)
|