# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. # Code copied from Microsoft/MoE by Jacob Platin (jacobplatin@microsoft.com) # but implemented by the Phi-Speech team #!/usr/bin/env python3 import abc import math from typing import Any, Literal import numpy as np import torch import torch.nn.functional as F from torch import Tensor, nn from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( CheckpointWrapper, ) from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel from transformers import PretrainedConfig from vllm.model_executor.models.phi4mm_utils import ( AbsolutePositionalEncoding, ConvModule, FeedForward, MeanVarianceNormLayer, MultiHeadedAttention, MultiSequential, NemoConvSubsampling, T5RelativeAttentionLogitBias, adaptive_enc_mask, get_offset, unfold_tensor, ) _AUDIO_PLACEHOLDER_TOKEN_ID = 200011 # <|endoftext11|> class ConformerEncoderLayer(nn.Module): """ConformerEncoder Layer module. for more details see conformer paper: https://arxiv.org/abs/2005.08100 This module implement the Conformer block layer. Args: d_model: int attention dim. ext_pw_out_channel: int if > 0, ext_pw_out_channel is a dim channel size for the last pointwise conv after swish activation. depthwise_seperable_out_channel: int if set different to 0, the number of depthwise_seperable_out_channel will be used as a channel_out of the second conv1d layer. otherwise, it equals to 0, the second conv1d layer is skipped. depthwise_multiplier: int number of input_dim channels duplication. this value will be used to compute the hidden channels of the Conv1D. n_head: int the number of heads for multihead attention module. d_ffn: int output size of the feed_forward blocks. ext_pw_kernel_size: int kernel size of the conv pointwise of the conformer. kernel_size: int kernel size. dropout_rate: float dropout rate. causal: bool, optional if set to True, convolution have no access to future frames. default False. batch_norm: bool, optional if set to True, apply batchnorm before activation in ConvModule layer of the conformer. default False activation: str, optional activation function name, one of ["relu", "swish", "sigmoid"], sigmoid activation is only used with "glu_in_fnn=True", default "relu". chunk_se: int, optional 0 for offline SE. 1 for streaming SE, where mean is computed by accumulated history until current chunk_se. 2 for streaming SE, where mean is computed by only the current chunk. default 0. chunk_size: int, optional chunk_size for cnn. default 18 conv_activation: str, optional activation function used in ConvModule part of the conformer, default "relu". conv_glu_type: str, optional activation function used for the glu inside the ConvModule part of the conformer. default: "sigmoid". bias_in_glu: bool, optional if set to True, use additive bias in the weight module before GLU. linear_glu_in_convm: bool, optional if set to True, use GLULinear module, otherwise, used GLUPointWiseConv module. default to False. attention_inner_dim: int, optional if equal to -1, attention dim for linears k/q/v is equal to d_model. otherwise attention_inner_dim is used. default -1. attention_glu_type: str, optional activation function for glu used in the multihead attention, default "swish". activation_checkpointing: str, optional a dictionary of {"module","interval","offload"}, where "module": str accept ["transformer", "attention"] to select which module should do activation checkpointing. "interval": int, default 1, interval of applying activation checkpointing, interval = 1 means that we apply checkpointing on every layer (if activation), otherwise, we apply it every x interval. "offload": bool, default False, if set to True, we offload activation to cpu and reload it during backward, otherwise, we recalculate activation in backward. default "". export: bool, optional if set to True, it removes the padding from convolutional layers and allow the onnx conversion for inference. default False. use_pt_scaled_dot_product_attention: bool, optional if set to True, use pytorch's scaled dot product attention implementation in training. attn_group_sizes: int, optional the number of groups to use for attention, default 1 (Multi-Head Attention), 1 = typical Multi-Head Attention, 1 < attn_group_sizes < attention_heads = Grouped-Query Attention attn_group_sizes = attention_heads = Multi-Query Attention """ def __init__( self, d_model: int = 512, ext_pw_out_channel: int = 0, depthwise_seperable_out_channel: int = 256, depthwise_multiplier: int = 1, n_head: int = 4, d_ffn: int = 2048, ext_pw_kernel_size: int = 1, kernel_size: int = 3, dropout_rate: float = 0.1, causal: bool = False, batch_norm: bool = False, activation: str = "relu", chunk_se: int = 0, chunk_size: int = 18, conv_activation: str = "relu", conv_glu_type: str = "sigmoid", bias_in_glu: bool = True, linear_glu_in_convm: bool = False, attention_inner_dim: int = -1, attention_glu_type: str = "swish", activation_checkpointing: str = "", export: bool = False, use_pt_scaled_dot_product_attention: bool = False, attn_group_sizes: int = 1, ) -> None: super().__init__() self.feed_forward_in = FeedForward( d_model=d_model, d_inner=d_ffn, dropout_rate=dropout_rate, activation=activation, bias_in_glu=bias_in_glu, ) self.self_attn = MultiHeadedAttention( n_head, d_model, dropout_rate, attention_inner_dim, attention_glu_type, bias_in_glu, use_pt_scaled_dot_product_attention=use_pt_scaled_dot_product_attention, group_size=attn_group_sizes, ) self.conv = ConvModule( d_model, ext_pw_out_channel, depthwise_seperable_out_channel, ext_pw_kernel_size, kernel_size, depthwise_multiplier, dropout_rate, causal, batch_norm, chunk_se, chunk_size, conv_activation, conv_glu_type, bias_in_glu, linear_glu_in_convm, export=export, ) self.feed_forward_out = FeedForward( d_model=d_model, d_inner=d_ffn, dropout_rate=dropout_rate, activation=activation, bias_in_glu=bias_in_glu, ) self.layer_norm_att = nn.LayerNorm(d_model) self.layer_norm = nn.LayerNorm(d_model) def forward( self, x: torch.Tensor, pos_k: torch.Tensor, pos_v: torch.Tensor, mask: torch.Tensor, relative_attention_bias: Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ConformerEncoder forward. Args: x: input feature of shape (batch, max_time_in, size) pos_k: positional key embedding. pos_v: positional value embedding. mask: mask for x (batch, max_time_in) relative_attention_bias: bias added to attention logits w.r.t. relative positions (1, n_head, time1, time2) """ x = x + 0.5 * self.feed_forward_in(x) norm_x = self.layer_norm_att(x) x = x + self.self_attn( norm_x, norm_x, norm_x, pos_k, pos_v, mask, relative_attention_bias=relative_attention_bias, ) x = x + self.conv(x) x = x + 0.5 * self.feed_forward_out(x) out = self.layer_norm(x) return out, pos_k, pos_v, mask class TransformerEncoderBase(abc.ABC, nn.Module): """The Base class for Transformer based encoders Please set causal = True in streaming model Args: input_size: int input feature dimension. chunk_size: int, list(int) Number of frames for each chunk This variable can take 2 forms: int: Used for inference, or single chunk size training list(int) : Used only for variable chunk size training Some examples for the 2 cases: chunk_size = 12 chunk_size = [6, 8, 12, 24] left_chunk: int, list(int) Number of chunks used for masking in streaming mode. This variable can take 2 forms: int: Used for inference, or single chunk size training list(int) : Used only for variable chunk size training. When chunk_size is a list, left_chunk must be a list with same length. Some examples for the 2 cases: left_chunk = 6 left_chunk = [12, 9, 6, 3] attention_dim: int, optional attention dimension. default 256. attention_heads: int, optional the number of heads. default 4 input_layer: str, optional input layer type before Conformer, one of ["linear", "conv2d", "custom", "vgg2l", "embed"], default "conv2d" cnn_out: int, optional the number of CNN channels before Conformer. default -1. cnn_layer_norm: bool, optional layer norm between Conformer and the first CNN. default False. time_reduction: int, optional time reduction factor default 4 dropout_rate: float, optional dropout rate. default 0.1 padding_idx: int, optional padding index for input_layer=embed default -1 relative_attention_bias_args: dict, optional use more efficient scalar bias-based relative multihead attention (Q*K^T + B) implemented in cmb.basics.embedding. [T5/ALiBi]RelativeAttentionLogitBias usage: relative_attention_bias_args={"type": t5/alibi} additional method-specific arguments can be provided (see transformer_base.py) positional_dropout_rate: float, optional dropout rate after positional encoding. default 0.0 nemo_conv_settings: dict, optional A dictionary of settings for NeMo Subsampling. default None conv2d_extra_padding: str, optional Add extra padding in conv2d subsampling layers. Choices are (feat, feat_time, none, True). if True or feat_time, the extra padding is added into non full supraframe utts in batch. Default: none attention_group_size: int, optional the number of groups to use for attention, default 1 (Multi-Head Attention), 1 = typical Multi-Head Attention, 1 < attention_group_size < attention_heads = Grouped-Query Attention attention_group_size = attention_heads = Multi-Query Attention """ def __init__( self, input_size: int, chunk_size: int | list[int], left_chunk: int | list[int], attention_dim: int = 256, attention_heads: int = 4, input_layer: str = "nemo_conv", cnn_out: int = -1, cnn_layer_norm: bool = False, time_reduction: int = 4, dropout_rate: float = 0.0, padding_idx: int = -1, relative_attention_bias_args: dict[str, Any] | None = None, positional_dropout_rate: float = 0.0, nemo_conv_settings: dict[str, Any] | None = None, conv2d_extra_padding: Literal["feat", "feat_time", "none", True] = "none", attention_group_size: int = 1, encoder_embedding_config: dict[str, Any] | None = None, ) -> None: super().__init__() self.input_size = input_size self.input_layer = input_layer self.chunk_size = chunk_size self.left_chunk = left_chunk self.attention_dim = attention_dim self.num_heads = attention_heads self.attention_group_size = attention_group_size self.time_reduction = time_reduction self.nemo_conv_settings = nemo_conv_settings self.encoder_embedding_config = encoder_embedding_config if self.input_layer == "nemo_conv": default_nemo_conv_settings = { "subsampling": "dw_striding", "subsampling_factor": self.time_reduction, "feat_in": input_size, "feat_out": attention_dim, "conv_channels": 256, "subsampling_conv_chunking_factor": 1, "activation": nn.ReLU(), "is_causal": False, } # Override any of the defaults with the incoming, user settings if nemo_conv_settings: default_nemo_conv_settings.update(nemo_conv_settings) for i in ["subsampling_factor", "feat_in", "feat_out"]: assert i not in nemo_conv_settings, ( "{i} should be specified outside of the NeMo dictionary" ) self.embed = NemoConvSubsampling( **default_nemo_conv_settings, ) else: raise ValueError("unknown input_layer: " + input_layer) self.pos_emb = AbsolutePositionalEncoding( attention_dim, positional_dropout_rate ) self.relative_attention_bias_type = ( relative_attention_bias_args.get("type") if relative_attention_bias_args else None ) if self.relative_attention_bias_type == "t5": assert self.num_heads % self.attention_group_size == 0, ( "attention_group_size must divide n_head" ) self.relative_attention_bias_layer = T5RelativeAttentionLogitBias( self.num_heads // self.attention_group_size, max_distance=relative_attention_bias_args.get( "t5_bias_max_distance", 1000 ), symmetric=relative_attention_bias_args.get("t5_bias_symmetric", False), ) else: raise NotImplementedError self.encoder_embedding = MeanVarianceNormLayer( self.encoder_embedding_config["input_size"] ) def compute_lens_change( self, feature_lens: int | torch.Tensor ) -> int | torch.Tensor: """feature_lens: int return updated feature lens. This used to return a different lambda function for each case that computed the right thing. That does not work within Torchscript. If you really need this to be faster, create nn.Module()-s for all the cases and return one of them. Torchscript does support that. """ if self.input_layer == "nemo_conv": # Handle the special causal case subsampling_causal_cond = self.nemo_conv_settings.get( "subsampling", "dw_striding" ) in [ "dw_striding", "striding", "striding_conv1d", ] is_causal = self.nemo_conv_settings.get("is_causal", False) if is_causal and subsampling_causal_cond: lens_change = ( torch.ceil(feature_lens / self.time_reduction).long() if isinstance(feature_lens, Tensor) else math.ceil(feature_lens / self.time_reduction) ) feature_lens_remainder = feature_lens % self.time_reduction if isinstance(feature_lens, Tensor): lens_change[feature_lens_remainder != 1] += 1 elif feature_lens_remainder != 1: lens_change += 1 return lens_change ceil_func = math.ceil if isinstance(feature_lens, int) else torch.ceil return ceil_func(feature_lens / self.time_reduction) @abc.abstractmethod def forward(self) -> Any: """Abstract forward method implementation.""" def _chunk_size_selection( self, chunk_size: int | list[int] | None = None, left_chunk: int | list[int] | None = None, ) -> tuple[int, int]: """If chunk size is a list, we will randomly select a chunk size.""" if chunk_size is None: chunk_size = self.chunk_size if left_chunk is None: left_chunk = self.left_chunk if isinstance(chunk_size, list): # Variable chunk size during training chunk_size_index = int( torch.randint(low=0, high=len(chunk_size), size=(1,)) ) chunk_size_train_eff = chunk_size[chunk_size_index] if not isinstance(left_chunk, list): raise ValueError( "Since chunk_size is a list, left_chunk must be a list" ) if len(left_chunk) != len(chunk_size): raise ValueError( "The length of left_chunk must be the same as length of chunk_size." ) left_chunk_train_eff = left_chunk[chunk_size_index] else: chunk_size_train_eff = chunk_size left_chunk_train_eff = left_chunk return chunk_size_train_eff, left_chunk_train_eff def _get_embed_class(self, embed: nn.Module) -> nn.Module: # pylint: disable=protected-access is_embed_using_act_chkpt = isinstance(embed, CheckpointWrapper) is_embed_fsdp_wrapped = isinstance(embed, FullyShardedDataParallel) embed_class = embed if is_embed_using_act_chkpt: embed_class = embed._checkpoint_wrapped_module if is_embed_fsdp_wrapped: embed_class = embed.module return embed_class def _forward_embeddings_core( self, input_tensor: torch.Tensor, masks: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: embed_class = self._get_embed_class(self.embed) assert isinstance(embed_class, NemoConvSubsampling) input_tensor, masks = self.embed(input_tensor, masks) return input_tensor, masks def _position_embedding( self, input_tensor: torch.Tensor ) -> tuple[torch.Tensor | None, torch.Tensor | None]: pos_k = None pos_v = None if self.relative_attention_bias_layer is None: input_tensor = self.pos_emb( input_tensor ) # default to add abs sinusoid embedding return pos_k, pos_v def _streaming_mask( self, seq_len: int, batch_size: int, chunk_size: int | list[int], left_chunk: int | list[int], ) -> torch.Tensor: chunk_size_train_eff, left_chunk_train_eff = self._chunk_size_selection( chunk_size, left_chunk ) # Create mask matrix for streaming # S stores start index. if chunksize is 18, s is [0,18,36,....] chunk_start_idx = np.arange(0, seq_len, chunk_size_train_eff) enc_streaming_mask = ( adaptive_enc_mask( seq_len, chunk_start_idx, left_window=left_chunk_train_eff ) .unsqueeze(0) .expand([batch_size, -1, -1]) ) return enc_streaming_mask def forward_embeddings( self, xs_pad: torch.Tensor, masks: torch.Tensor, chunk_size_nc: int | list[int] | None = None, left_chunk_nc: int | list[int] | None = None, ) -> ( tuple[ torch.Tensor, torch.Tensor | None, torch.Tensor | None, torch.Tensor, torch.Tensor, ] | tuple[ torch.Tensor, torch.Tensor | None, torch.Tensor | None, torch.Tensor, torch.Tensor, torch.Tensor, ] ): """Forwarding the inputs through the top embedding layers Args: xs_pad: torch.Tensor input tensor masks: torch.Tensor input mask chunk_size_nc: (optional, default is None) chunk size for non-causal layers left_chunk_nc: (optional, default is None) # of left chunks for non-causal layers """ # pylint: disable=R0915 # get new lens. seq_len = int(self.compute_lens_change(xs_pad.shape[1])) if seq_len <= 0: raise ValueError( f"""The sequence length after time reduction is invalid: {seq_len}. Your input feature is too short. Consider filtering out the very short sentence from data loader""", ) batch_size = xs_pad.shape[0] enc_streaming_mask = self._streaming_mask( seq_len, batch_size, self.chunk_size, self.left_chunk ) if xs_pad.is_cuda: enc_streaming_mask = enc_streaming_mask.cuda() xs_pad = xs_pad.cuda() input_tensor = xs_pad input_tensor, masks = self._forward_embeddings_core(input_tensor, masks) streaming_mask = enc_streaming_mask if streaming_mask is not None and masks is not None: hs_mask = masks & streaming_mask elif masks is not None: hs_mask = masks else: hs_mask = streaming_mask if chunk_size_nc is not None: enc_streaming_mask_nc = self._streaming_mask( seq_len, batch_size, chunk_size_nc, left_chunk_nc ) if xs_pad.is_cuda: enc_streaming_mask_nc = enc_streaming_mask_nc.cuda() if masks is not None: hs_mask_nc = masks & enc_streaming_mask_nc else: hs_mask_nc = enc_streaming_mask_nc else: hs_mask_nc = None pos_k, pos_v = self._position_embedding(input_tensor) if chunk_size_nc is None: return input_tensor, pos_k, pos_v, hs_mask, masks return input_tensor, pos_k, pos_v, hs_mask, masks, hs_mask_nc def get_offset(self) -> int: """Returns offset used when retaining inputs for decoding. This is essentially, how many additional frames have to be added to the front-end CNN input to ensure it can produce a single output. So if the "padding" parameter is 0, typically offset will be > 0. """ return get_offset(self.input_layer, self.time_reduction) class ConformerEncoder(TransformerEncoderBase): """ConformerEncoder module. see original paper for more details: https://arxiv.org/abs/2005.08100 Please set causal = True in streaming model Args: input_size: int input feature dimension. chunk_size: int, list(int) Number of frames for each chunk This variable can take 2 forms: int: Used for inference, or single chunk size training list(int) : Used only for variable chunk size training Some examples for the 2 cases: chunk_size = 12 chunk_size = [6, 8, 12, 24] left_chunk: int, list(int) Number of chunks used for masking in streaming mode. This variable can take 2 forms: int: Used for inference, or single chunk size training list(int) : Used only for variable chunk size training. When chunk_size is a list, left_chunk must be a list with same length. Some examples for the 2 cases: left_chunk = 6 left_chunk = [12, 9, 6, 3] num_lang: int This parameter is used to store the number of languages in the lang_dict, only used for multiseed/multilingual models. default None. attention_dim: int, optional attention dimension. default 256. attention_heads: int, optional the number of heads. default 4 linear_units: the number of units of position-wise feed forward. default 2048 num_block: number of Transformer layer. default 6 dropout_rate: float, optional dropout rate. default 0.1 input_layer: str, optional input layer type before Conformer, one of ["linear", "conv2d", "custom", "vgg2l", "embed"], default "conv2d" causal: bool, optional if set to True, convolution have no access to future frames. default False. batch_norm: bool, optional if set to True, apply batchnorm before activation in ConvModule layer of the conformer. default False cnn_out: int, optional the number of CNN channels before Conformer. default -1. cnn_layer_norm: bool, optional layer norm between Conformer and the first CNN. default False. ext_pw_out_channel: int, optional the number of channel for CNN before depthwise_seperable_CNN. If 0 then use linear. default 0. ext_pw_kernel_size: int, optional kernel size of N before depthwise_seperable_CNN. only work for ext_pw_out_channel > 0. default 1 depthwise_seperable_out_channel: int, optional the number of channel for depthwise_seperable_CNN. default 256. depthwise_multiplier: int, optional the number of multiplier for depthwise_seperable_CNN. default 1. chunk_se: int, optional 0 for offline SE. 1 for streaming SE, where mean is computed by accumulated history until current chunk_se. 2 for streaming SE, where mean is computed by only the current chunk. default 0. kernel_size: int, optional the number of kernels for depthwise_seperable_CNN. default 3. activation: str, optional FeedForward block activation. one of ["relu", "swish", "sigmoid"] default "relu". conv_activation: str, optional activation function used in ConvModule part of the conformer, default "relu". conv_glu_type: str, optional activation used use glu in depthwise_seperable_CNN, default "sigmoid" bias_in_glu: bool, optional if set to True, use additive bias in the weight module before GLU. default True linear_glu_in_convm: bool, optional if set to True, use GLULinear module, otherwise, used GLUPointWiseConv module. default to False. attention_glu_type: str only work for glu_in_attention !=0 default "swish". export: bool, optional if set to True, it removes the padding from convolutional layers and allow the onnx conversion for inference. default False. activation_checkpointing: str, optional a dictionarry of {"module","interval","offload"}, where "module": str accept ["transformer", "attention"] to select which module should do activation checkpointing. "interval": int, default 1, interval of applying activation checkpointing, interval = 1 means that we apply checkpointing on every layer (if activation), otherwise, we apply it every x interval. "offload": bool, default False, if set to True, we offload activation to cpu and reload it during backward, otherwise, we recalculate activation in backward. default "". extra_layer_output_idx: int the layer index to be exposed. relative_attention_bias_args: dict, optional use more efficient scalar bias-based relative multihead attention (Q*K^T + B) implemented in cmb.basics.embedding. [T5/ALiBi]RelativeAttentionLogitBias usage: relative_attention_bias_args={"type": t5/alibi} additional method-specific arguments can be provided (see transformer_base.py) time_reduction: int optional time reduction factor default 4 use_pt_scaled_dot_product_attention: whether to use pytorch scaled dot product attention in training. Default: False nemo_conv_settings: dict, optional A dictionary of settings for NeMo Subsampling. default: None usage: nemo_conv_settings= { "subsampling": dw_striding/striding/dw_striding_conv1d/striding_conv1d, "conv_channels": int, "subsampling_conv_chunking_factor": int, "is_causal": True/False } conv2d_extra_padding: str, optional Add extra padding in conv2d subsampling layers. Choices are (feat, feat_time, none, True) Default: none replication_pad_for_subsample_embedding: For batched-streaming decoding, use "replication" padding for the cache at start of utterance. Default: False attention_group_size: int, optional the number of groups to use for attention, default 1 (Multi-Head Attention), 1 = typical Multi-Head Attention, 1 < attention_group_size < attention_heads = Grouped-Query Attention attention_group_size = attention_heads = Multi-Query Attention """ extra_multi_layer_output_idxs: list[int] def __init__( # pylint: disable-all self, input_size: int, chunk_size: int | list[int], left_chunk: int | list[int], num_lang: int | None = None, attention_dim: int = 256, attention_heads: int = 4, linear_units: int = 2048, num_blocks: int = 6, dropout_rate: float = 0.1, input_layer: str = "nemo_conv", causal: bool = True, batch_norm: bool = False, cnn_out: int = -1, cnn_layer_norm: bool = False, ext_pw_out_channel: int = 0, ext_pw_kernel_size: int = 1, depthwise_seperable_out_channel: int = 256, depthwise_multiplier: int = 1, chunk_se: int = 0, kernel_size: int = 3, activation: str = "relu", conv_activation: str = "relu", conv_glu_type: str = "sigmoid", bias_in_glu: bool = True, linear_glu_in_convm: bool = False, attention_glu_type: str = "swish", export: bool = False, extra_layer_output_idx: int = -1, extra_multi_layer_output_idxs: list[int] = [], # noqa activation_checkpointing: str = "", relative_attention_bias_args: dict[str, Any] | None = None, time_reduction: int = 4, use_pt_scaled_dot_product_attention: bool = False, nemo_conv_settings: dict[str, Any] | None = None, conv2d_extra_padding: Literal["feat", "feat_time", "none", True] = "none", replication_pad_for_subsample_embedding: bool = False, attention_group_size: int = 1, encoder_embedding_config: dict[str, Any] | None = None, ) -> None: super().__init__( input_size, chunk_size, left_chunk, attention_dim, attention_heads, input_layer, cnn_out, cnn_layer_norm, time_reduction, dropout_rate=dropout_rate, relative_attention_bias_args=relative_attention_bias_args, positional_dropout_rate=0.0, nemo_conv_settings=nemo_conv_settings, conv2d_extra_padding=conv2d_extra_padding, attention_group_size=attention_group_size, encoder_embedding_config=encoder_embedding_config, ) self.num_blocks = num_blocks self.num_lang = num_lang self.kernel_size = kernel_size self.replication_pad_for_subsample_embedding: bool = ( replication_pad_for_subsample_embedding ) assert self.num_heads % attention_group_size == 0, ( "attention_group_size must divide n_head" ) self.num_heads_k = self.num_heads // attention_group_size self.encoders = MultiSequential( *[ ConformerEncoderLayer( d_model=attention_dim, ext_pw_out_channel=ext_pw_out_channel, depthwise_seperable_out_channel=depthwise_seperable_out_channel, depthwise_multiplier=depthwise_multiplier, n_head=attention_heads, d_ffn=linear_units, ext_pw_kernel_size=ext_pw_kernel_size, kernel_size=kernel_size, dropout_rate=dropout_rate, causal=causal, batch_norm=batch_norm, activation=activation, chunk_se=chunk_se, chunk_size=chunk_size, conv_activation=conv_activation, conv_glu_type=conv_glu_type, bias_in_glu=bias_in_glu, linear_glu_in_convm=linear_glu_in_convm, attention_glu_type=attention_glu_type, activation_checkpointing=activation_checkpointing, export=export, use_pt_scaled_dot_product_attention=use_pt_scaled_dot_product_attention, attn_group_sizes=attention_group_size, ) for _ in range(num_blocks) ] ) self.extra_layer_output_idx = extra_layer_output_idx self.extra_multi_layer_output_idxs = extra_multi_layer_output_idxs # Make a zeros scalar we can use in get_initial_state to determine # the device and the needed dtype: self.register_buffer("dev_type", torch.zeros(()), persistent=False) def init_relative_attention_bias( self, input_tensor: torch.Tensor ) -> torch.Tensor | None: if self.relative_attention_bias_layer: return self.relative_attention_bias_layer(input_tensor) def calculate_hs_mask( self, xs_pad: torch.Tensor, device: torch.device, mask: torch.Tensor | None ) -> torch.Tensor: max_audio_length = xs_pad.shape[1] batch_size = xs_pad.shape[0] enc_streaming_mask = self._streaming_mask( max_audio_length, batch_size, self.chunk_size, self.left_chunk ) enc_streaming_mask = enc_streaming_mask.to(device) if mask is None: return enc_streaming_mask feature_lens = mask.sum(1) padding_length = feature_lens pad_mask = torch.arange(0, max_audio_length, device=device).expand( padding_length.size(0), -1 ) < padding_length.unsqueeze(1) pad_mask = pad_mask.unsqueeze(1) pad_mask = pad_mask & enc_streaming_mask return pad_mask @torch.jit.ignore def forward( self, xs_pad: torch.Tensor, masks: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: """Conformer Forward function Args: xs_pad: torch.Tensor input tensor masks: torch.Tensor post-embedding input lengths """ xs_pad = self.encoder_embedding(xs_pad) input_tensor, pos_k, pos_v, hs_mask, masks = self.forward_embeddings( xs_pad, masks ) unfolded = False ori_bz, seq_len, D = input_tensor.shape max_seq_len = 500 # maximum position for absolute positional encoding if seq_len > max_seq_len: # audio sequence is longer than max_seq_len, unfold it into chunks # of max_seq_len unfolded = True # the unfold op will drop residual frames, pad it to the multiple # of max_seq_len if seq_len % max_seq_len > 0: chunk_pad_size = max_seq_len - (seq_len % max_seq_len) else: chunk_pad_size = 0 if chunk_pad_size > 0: input_tensor_pad = F.pad( input_tensor, (0, 0, 0, chunk_pad_size), "constant", 0 ) input_tensor = input_tensor_pad.to(input_tensor.device) input_tensor = unfold_tensor(input_tensor, max_seq_len) if masks is not None: # revise hs_mask here because the previous calculated hs_mask # did not consider extra pad subsampled_pad_mask = masks.squeeze( 1 ) # [bz, subsampled_unmask_seq_len] extra_padded_subsamlped_pad_mask = F.pad( subsampled_pad_mask, (0, chunk_pad_size), "constant", False ) # extra padding to the pad mask extra_padded_subsamlped_pad_mask = ( extra_padded_subsamlped_pad_mask.unsqueeze(-1).float() ) masks_unfold = unfold_tensor( extra_padded_subsamlped_pad_mask, max_seq_len ) # unfold the pad mask like we did to the input tensor masks_unfold = masks_unfold.squeeze( -1 ).bool() # unfold op does not support bool tensor else: masks_unfold = None hs_mask = self.calculate_hs_mask( input_tensor, input_tensor.device, masks_unfold ) # calculate hs_mask based on the unfolded pad mask # layer_emb = None relative_attention_bias = self.init_relative_attention_bias(input_tensor) _simplified_path = ( self.extra_layer_output_idx == -1 and relative_attention_bias is None ) if _simplified_path: input_tensor, *_ = self.encoders(input_tensor, pos_k, pos_v, hs_mask) else: for i, layer in enumerate(self.encoders): input_tensor, _, _, _ = layer( input_tensor, pos_k, pos_v, hs_mask, relative_attention_bias=relative_attention_bias, ) # if i == self.extra_layer_output_idx: # layer_emb = input_tensor if unfolded: embed_dim = input_tensor.shape[-1] input_tensor = input_tensor.reshape(ori_bz, -1, embed_dim) # if we ever padded before unfolding, we need to remove the padding if chunk_pad_size > 0: input_tensor = input_tensor[:, :-chunk_pad_size, :] return input_tensor, masks # , layer_emb class WindowQformer(nn.Module): """Window-level Qformer""" def __init__( self, window_size: int = 8, num_queries: int = 1, num_blocks: int = 2, attention_dim: int = 512, attention_heads: int = 8, linear_units: int = 2048, dropout_rate: float = 0.0, normalize_before: bool = True, ): super().__init__() self.decoders = nn.ModuleList( [ nn.TransformerDecoderLayer( d_model=attention_dim, nhead=attention_heads, dim_feedforward=linear_units, dropout=dropout_rate, activation="relu", batch_first=True, norm_first=normalize_before, # TODO need to verify ) for _ in range(num_blocks) ] ) self.queries = nn.Parameter(torch.zeros(1, num_queries, attention_dim)) self.after_norm = ( nn.LayerNorm(attention_dim, eps=1e-12) if normalize_before else None ) self.window_size = window_size def forward( self, audio_embed: torch.Tensor, mask: torch.Tensor | None, embed_len: int | None = None, ) -> tuple[torch.Tensor, int | None]: """forward decoder""" # audio_embed: N x T x D => N x D x T audio_embed = audio_embed.transpose(1, 2) # audio_embed: N x D x 1 x T => N x DK x T' padding = audio_embed.shape[-1] % self.window_size if padding > 0: audio_embed = F.pad( audio_embed, (0, self.window_size - padding), "constant", 0 ) embed_chunk = F.unfold( audio_embed[..., None, :], kernel_size=(1, self.window_size), stride=(1, self.window_size), ) bsz, _, slen = embed_chunk.shape # N x D x K x T' embed_chunk = embed_chunk.view(bsz, -1, self.window_size, slen) # N x T' x K x D embed_chunk = embed_chunk.transpose(1, 3).contiguous() # NT' x K x D embed_chunk = embed_chunk.view(bsz * slen, self.window_size, -1) # NT' x 1 x D q = self.queries.expand(bsz * slen, -1, -1) for layer in self.decoders: q = layer(tgt=q, memory=embed_chunk, tgt_mask=None, memory_mask=mask) if self.after_norm is not None: q = self.after_norm(q) if embed_len is not None: embed_len = embed_len // self.window_size # N x T' x D out = q.view(bsz, slen, -1) return out, embed_len class AudioEmbedding(nn.Module): """Image embedding.""" def __init__(self, config: PretrainedConfig, **kwargs: Any) -> None: super().__init__() self.config = config # n_embed or hidden_size for text LM hidden_size = config.n_embd if hasattr(config, "n_embd") else config.hidden_size # self.wte = nn.Embedding(config.vocab_size, hidden_size) audio_dim_out = ( None # Set this variable according to the actual audio processor ) self.layer_idx = -2 if ( isinstance(config.audio_processor, dict) and config.audio_processor.get("name", None) == "cascades" ): encoder_config = config.audio_processor.get("config", None) assert encoder_config is not None self.encoder = ConformerEncoder(**encoder_config) audio_dim_out = encoder_config["attention_dim"] n_mels = encoder_config["input_size"] else: raise NotImplementedError("") assert audio_dim_out is not None, "Remember to set values for audio_dim_out" self.audio_dim_out = audio_dim_out self.audio_dim_in = n_mels self.freeze_audio_processor = kwargs.get("freeze_audio_processor", False) self.downsample_rate = kwargs.get("downsample_rate", 1) if kwargs.get("use_qformer", False): qformer_config = kwargs.get("qformer_config", {}) qformer_config["attention_dim"] = audio_dim_out self.qformer = WindowQformer(**qformer_config) else: self.qformer = None if kwargs.get("use_conv_downsample", False): assert self.qformer is None, ( "don't support use qformer and conv downsample together" ) nemo_conv_settings = kwargs.get("nemo_conv_settings", {}) default_nemo_conv_settings = { "subsampling": "dw_striding", "subsampling_factor": self.downsample_rate, "feat_in": audio_dim_out, "feat_out": audio_dim_out, "conv_channels": 256, "subsampling_conv_chunking_factor": 1, "activation": nn.ReLU(), "is_causal": False, } # Override any of the defaults with the incoming, user settings if nemo_conv_settings: default_nemo_conv_settings.update(nemo_conv_settings) for i in ["subsampling_factor", "feat_in", "feat_out"]: assert i not in nemo_conv_settings, ( "{i} should be specified outside of the NeMo dictionary" ) self.conv_ds = NemoConvSubsampling( **default_nemo_conv_settings, ) else: self.conv_ds = None projection_cls = kwargs.get("projection_cls", "linear") if projection_cls == "linear": self.audio_projection = nn.Linear(audio_dim_out, hidden_size) elif projection_cls == "mlp": # follow llava-v1.5's implementation # (do not use image_projection and image_proj_norm) dim_projection = hidden_size depth = 2 self.linear_downsample_rate = ( 1 if (self.qformer or self.conv_ds) else self.downsample_rate ) layers = [ nn.Linear(audio_dim_out * self.linear_downsample_rate, dim_projection) ] for _ in range(1, depth): layers.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection)]) self.audio_projection = nn.Sequential(*layers) # NOTE vision-speech tasks use a separate projection layer layers = [ nn.Linear(audio_dim_out * self.linear_downsample_rate, dim_projection) ] for _ in range(1, depth): layers.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection)]) self.audio_projection_for_vision = nn.Sequential(*layers) else: raise NotImplementedError( f"projection_cls = {projection_cls}, not implemented" ) # TODO: audio sequence compression - Qformer self.vocab_size = config.vocab_size self.input_embeds = None self.audio_embed_sizes = None def set_audio_embeds(self, input_embeds: torch.Tensor) -> None: self.input_embeds = input_embeds def set_audio_embed_sizes(self, audio_embed_sizes: torch.Tensor) -> None: self.audio_embed_sizes = audio_embed_sizes def get_audio_features( self, input_embeds: torch.Tensor, audio_attention_mask: torch.Tensor | None = None, audio_projection_mode: str = "speech", ) -> torch.Tensor: """ arguments: input_embeds: audio features (B, T, D) B: num audios in a sequence """ if self.freeze_audio_processor: with torch.no_grad(): audio_features, masks = self.encoder(input_embeds, audio_attention_mask) else: audio_features, masks = self.encoder(input_embeds, audio_attention_mask) if self.qformer is not None: audio_features, _ = self.qformer(audio_features, mask=None) if self.conv_ds is not None: if masks is not None: masks = masks.squeeze(1) audio_features, masks = self.conv_ds(audio_features, mask=masks) if self.linear_downsample_rate != 1: bs, seq_len, feat_dim = audio_features.size() padding = seq_len % self.linear_downsample_rate if padding > 0: audio_features = F.pad( audio_features, (0, 0, 0, self.linear_downsample_rate - padding), "constant", 0, ) seq_len = audio_features.size(1) audio_features = audio_features.view( bs, seq_len // self.linear_downsample_rate, feat_dim * self.linear_downsample_rate, ) if audio_projection_mode == "speech": audio_set_tensor = self.audio_projection(audio_features) elif audio_projection_mode == "vision": audio_set_tensor = self.audio_projection_for_vision(audio_features) else: raise ValueError( f"audio_projection_mode = {audio_projection_mode} not implemented" ) return audio_set_tensor def forward( self, audio_features: torch.Tensor, audio_attention_mask: torch.Tensor | None = None, audio_projection_mode: str = "speech", ) -> torch.Tensor: """ arguments: audio_features: audio features (T, D) returns: audio_embeds: audio embeddings (num_audio_tokens, hidden_dim) """ audio_embeds = self.get_audio_features( audio_features.unsqueeze(0), audio_attention_mask=audio_attention_mask, audio_projection_mode=audio_projection_mode, ) return audio_embeds.squeeze(0)