[Docs] Fix warnings in mkdocs build (continued) (#24092)

Signed-off-by: Zerohertz <ohg3417@gmail.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
Hyogeun Oh (오효근) 2025-09-10 22:23:28 +09:00 committed by GitHub
parent c0bd6a684a
commit ccee371e86
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 337 additions and 342 deletions

View File

@ -755,7 +755,7 @@ class FusedMoE(CustomOp):
intermediate_size: Intermediate size of the experts intermediate_size: Intermediate size of the experts
params_dtype: Data type for the parameters. params_dtype: Data type for the parameters.
reduce_results: Whether to all all_reduce on the output of the layer reduce_results: Whether to all all_reduce on the output of the layer
renomalize: Whether to renormalize the logits in the fused_moe kernel renormalize: Whether to renormalize the logits in the fused_moe kernel
quant_config: Quantization configure. quant_config: Quantization configure.
enable_eplb: Whether to enable expert parallelism load balancer. enable_eplb: Whether to enable expert parallelism load balancer.
""" """

View File

@ -420,9 +420,8 @@ def shuffle_weights(
Args: Args:
*tensors: Variable number of torch.Tensor objects. *tensors: Variable number of torch.Tensor objects.
layout: A pair of integers specifying the layout: A pair of integers specifying the block sizes used to divide
block sizes used to divide the tensors during shuffling. the tensors during shuffling. Default is (16, 16).
Default is (16, 16).
Returns: Returns:
A Tuple of shuffled tensors. A Tuple of shuffled tensors.

View File

@ -10,7 +10,7 @@ like uniform random routing.
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional from typing import Any, Optional
import torch import torch
@ -50,7 +50,9 @@ class DistributionBasedRouting(RoutingStrategy):
distributions for testing different routing patterns. distributions for testing different routing patterns.
""" """
def __init__(self, distribution: str = "uniform", **distribution_params): def __init__(self,
distribution: str = "uniform",
**distribution_params: Any):
""" """
Initialize distribution-based routing. Initialize distribution-based routing.
@ -244,7 +246,7 @@ class RoutingSimulator:
cls._routing_strategies[name] = strategy cls._routing_strategies[name] = strategy
@classmethod @classmethod
def get_available_strategies(cls): def get_available_strategies(cls) -> list[str]:
""" """
Get list of available routing strategy names. Get list of available routing strategy names.

View File

@ -202,7 +202,7 @@ class BitBLASLinearMethod(LinearMethodBase):
output_size: int, output_size: int,
params_dtype: torch.dtype, params_dtype: torch.dtype,
**extra_weight_attrs, **extra_weight_attrs,
): ) -> None:
"""Creates quantized weights for use in linear operations. """Creates quantized weights for use in linear operations.
The function initializes and returns a dictionary containing quantized The function initializes and returns a dictionary containing quantized
@ -211,7 +211,7 @@ class BitBLASLinearMethod(LinearMethodBase):
Args: Args:
input_size_per_partition: The size of the input partition. input_size_per_partition: The size of the input partition.
output_size_per_partition: The size of the output partition. output_partition_sizes: List of output partition sizes.
input_size: The total size of the input (unused). input_size: The total size of the input (unused).
output_size: The total size of the output (unused). output_size: The total size of the output (unused).
params_dtype: params_dtype:
@ -222,9 +222,9 @@ class BitBLASLinearMethod(LinearMethodBase):
scales ('scales'), and zeros ('zeros'). scales ('scales'), and zeros ('zeros').
Raises: Raises:
ValueError: If `params_dtype` is not `torch.float16` or if the ValueError: If `params_dtype` is not `torch.float16` or if the input
input size per partition is not divisible by the group size in size per partition is not divisible by the group size
`quant_config`. in `quant_config`.
""" """
del input_size, output_size # Unused arguments. del input_size, output_size # Unused arguments.
weight_loader = extra_weight_attrs["weight_loader"] weight_loader = extra_weight_attrs["weight_loader"]

View File

@ -265,9 +265,9 @@ class GPTQBitBLASLinearMethod(LinearMethodBase):
scales ('scales'), and zeros ('zeros'). scales ('scales'), and zeros ('zeros').
Raises: Raises:
ValueError: If `params_dtype` is not `torch.float16` or ValueError: If `params_dtype` is not `torch.float16` or if the input
if the input size per partition is not divisible by the size per partition is not divisible by the group size
group size in `quant_config`. in `quant_config`.
""" """
if params_dtype != torch.float16: if params_dtype != torch.float16:
raise ValueError("Parameter data type must be torch.float16, " raise ValueError("Parameter data type must be torch.float16, "

View File

@ -49,8 +49,8 @@ def choose_mp_linear_kernel(
config (MPLinearLayerConfig): Description of the linear layer to be config (MPLinearLayerConfig): Description of the linear layer to be
implemented. implemented.
compute_capability (Optional[int], optional): The compute capability of compute_capability (Optional[int], optional): The compute capability of
the target device, if None uses `current_platform` to get the compute the target device, if None uses `current_platform` to get
capability. Defaults to None. the compute capability. Defaults to None.
Raises: Raises:
ValueError: If no kernel can implement the given config. ValueError: If no kernel can implement the given config.

View File

@ -7,7 +7,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import abc import abc
import math import math
from typing import Literal, Optional from typing import Any, Literal, Optional, Union
import numpy as np import numpy as np
import torch import torch
@ -131,31 +131,31 @@ class ConformerEncoderLayer(nn.Module):
def __init__( def __init__(
self, self,
d_model=512, d_model: int = 512,
ext_pw_out_channel=0, ext_pw_out_channel: int = 0,
depthwise_seperable_out_channel=256, depthwise_seperable_out_channel: int = 256,
depthwise_multiplier=1, depthwise_multiplier: int = 1,
n_head=4, n_head: int = 4,
d_ffn=2048, d_ffn: int = 2048,
ext_pw_kernel_size=1, ext_pw_kernel_size: int = 1,
kernel_size=3, kernel_size: int = 3,
dropout_rate=0.1, dropout_rate: float = 0.1,
causal=False, causal: bool = False,
batch_norm=False, batch_norm: bool = False,
activation="relu", activation: str = "relu",
chunk_se=0, chunk_se: int = 0,
chunk_size=18, chunk_size: int = 18,
conv_activation="relu", conv_activation: str = "relu",
conv_glu_type="sigmoid", conv_glu_type: str = "sigmoid",
bias_in_glu=True, bias_in_glu: bool = True,
linear_glu_in_convm=False, linear_glu_in_convm: bool = False,
attention_inner_dim=-1, attention_inner_dim: int = -1,
attention_glu_type="swish", attention_glu_type: str = "swish",
activation_checkpointing="", activation_checkpointing: str = "",
export=False, export: bool = False,
use_pt_scaled_dot_product_attention=False, use_pt_scaled_dot_product_attention: bool = False,
attn_group_sizes: int = 1, attn_group_sizes: int = 1,
): ) -> None:
super().__init__() super().__init__()
self.feed_forward_in = FeedForward( self.feed_forward_in = FeedForward(
@ -209,24 +209,21 @@ class ConformerEncoderLayer(nn.Module):
def forward( def forward(
self, self,
x, x: torch.Tensor,
pos_k, pos_k: torch.Tensor,
pos_v, pos_v: torch.Tensor,
mask, mask: torch.Tensor,
relative_attention_bias: Optional[Tensor] = None, relative_attention_bias: Optional[Tensor] = None,
): ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""ConformerEncoder forward. """ConformerEncoder forward.
Args: Args:
x: torch.Tensor x: input feature of shape (batch, max_time_in, size)
input feature of shape (batch, max_time_in, size) pos_k: positional key embedding.
pos_k: torch.Tensor pos_v: positional value embedding.
positional key embedding. mask: mask for x (batch, max_time_in)
mask: torch.Tensor relative_attention_bias: bias added to attention logits w.r.t.
mask for x (batch, max_time_in) relative positions (1, n_head, time1, time2)
relative_attention_bias: Optional[torch.Tensor]
bias added to attention logits w.r.t. relative positions
(1, n_head, time1, time2)
""" """
x = x + 0.5 * self.feed_forward_in(x) x = x + 0.5 * self.feed_forward_in(x)
norm_x = self.layer_norm_att(x) norm_x = self.layer_norm_att(x)
@ -323,25 +320,25 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
def __init__( def __init__(
self, self,
input_size, input_size: int,
chunk_size, chunk_size: Union[int, list[int]],
left_chunk, left_chunk: Union[int, list[int]],
attention_dim=256, attention_dim: int = 256,
attention_heads=4, attention_heads: int = 4,
input_layer="nemo_conv", input_layer: str = "nemo_conv",
cnn_out=-1, cnn_out: int = -1,
cnn_layer_norm=False, cnn_layer_norm: bool = False,
time_reduction=4, time_reduction: int = 4,
dropout_rate=0.0, dropout_rate: float = 0.0,
padding_idx=-1, padding_idx: int = -1,
relative_attention_bias_args=None, relative_attention_bias_args: Optional[dict[str, Any]] = None,
positional_dropout_rate=0.0, positional_dropout_rate: float = 0.0,
nemo_conv_settings=None, nemo_conv_settings: Optional[dict[str, Any]] = None,
conv2d_extra_padding: Literal["feat", "feat_time", "none", conv2d_extra_padding: Literal["feat", "feat_time", "none",
True] = "none", True] = "none",
attention_group_size=1, attention_group_size: int = 1,
encoder_embedding_config=None, encoder_embedding_config: Optional[dict[str, Any]] = None,
): ) -> None:
super().__init__() super().__init__()
self.input_size = input_size self.input_size = input_size
self.input_layer = input_layer self.input_layer = input_layer
@ -399,7 +396,10 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
self.encoder_embedding = MeanVarianceNormLayer( self.encoder_embedding = MeanVarianceNormLayer(
self.encoder_embedding_config["input_size"]) self.encoder_embedding_config["input_size"])
def compute_lens_change(self, feature_lens): def compute_lens_change(
self,
feature_lens: Union[int,
torch.Tensor]) -> Union[int, torch.Tensor]:
"""feature_lens: int """feature_lens: int
return updated feature lens. return updated feature lens.
@ -433,10 +433,14 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
return ceil_func(feature_lens / self.time_reduction) return ceil_func(feature_lens / self.time_reduction)
@abc.abstractmethod @abc.abstractmethod
def forward(self): def forward(self) -> Any:
"""Abstract forward method implementation.""" """Abstract forward method implementation."""
def _chunk_size_selection(self, chunk_size=None, left_chunk=None): def _chunk_size_selection(
self,
chunk_size: Optional[Union[int, list[int]]] = None,
left_chunk: Optional[Union[int,
list[int]]] = None) -> tuple[int, int]:
"""If chunk size is a list, we will randomly select a chunk size.""" """If chunk size is a list, we will randomly select a chunk size."""
if chunk_size is None: if chunk_size is None:
@ -463,7 +467,7 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
return chunk_size_train_eff, left_chunk_train_eff return chunk_size_train_eff, left_chunk_train_eff
def _get_embed_class(self, embed): def _get_embed_class(self, embed: nn.Module) -> nn.Module:
# pylint: disable=protected-access # pylint: disable=protected-access
is_embed_using_act_chkpt = isinstance(embed, CheckpointWrapper) is_embed_using_act_chkpt = isinstance(embed, CheckpointWrapper)
is_embed_fsdp_wrapped = isinstance(embed, FullyShardedDataParallel) is_embed_fsdp_wrapped = isinstance(embed, FullyShardedDataParallel)
@ -474,13 +478,17 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
embed_class = embed.module embed_class = embed.module
return embed_class return embed_class
def _forward_embeddings_core(self, input_tensor, masks): def _forward_embeddings_core(
self, input_tensor: torch.Tensor,
masks: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
embed_class = self._get_embed_class(self.embed) embed_class = self._get_embed_class(self.embed)
assert isinstance(embed_class, NemoConvSubsampling) assert isinstance(embed_class, NemoConvSubsampling)
input_tensor, masks = self.embed(input_tensor, masks) input_tensor, masks = self.embed(input_tensor, masks)
return input_tensor, masks return input_tensor, masks
def _position_embedding(self, input_tensor): def _position_embedding(
self, input_tensor: torch.Tensor
) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
pos_k = None pos_k = None
pos_v = None pos_v = None
if self.relative_attention_bias_layer is None: if self.relative_attention_bias_layer is None:
@ -488,7 +496,9 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
input_tensor) # default to add abs sinusoid embedding input_tensor) # default to add abs sinusoid embedding
return pos_k, pos_v return pos_k, pos_v
def _streaming_mask(self, seq_len, batch_size, chunk_size, left_chunk): def _streaming_mask(self, seq_len: int, batch_size: int,
chunk_size: Union[int, list[int]],
left_chunk: Union[int, list[int]]) -> torch.Tensor:
chunk_size_train_eff, left_chunk_train_eff = \ chunk_size_train_eff, left_chunk_train_eff = \
self._chunk_size_selection(chunk_size, left_chunk) self._chunk_size_selection(chunk_size, left_chunk)
@ -502,11 +512,17 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
[batch_size, -1, -1])) [batch_size, -1, -1]))
return enc_streaming_mask return enc_streaming_mask
def forward_embeddings(self, def forward_embeddings(
xs_pad, self,
masks, xs_pad: torch.Tensor,
chunk_size_nc=None, masks: torch.Tensor,
left_chunk_nc=None): chunk_size_nc: Optional[Union[int, list[int]]] = None,
left_chunk_nc: Optional[Union[int, list[int]]] = None
) -> Union[tuple[torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor], torch.Tensor, torch.Tensor],
tuple[torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor], torch.Tensor, torch.Tensor,
torch.Tensor]]:
"""Forwarding the inputs through the top embedding layers """Forwarding the inputs through the top embedding layers
Args: Args:
@ -569,7 +585,7 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
return input_tensor, pos_k, pos_v, hs_mask, masks return input_tensor, pos_k, pos_v, hs_mask, masks
return input_tensor, pos_k, pos_v, hs_mask, masks, hs_mask_nc return input_tensor, pos_k, pos_v, hs_mask, masks, hs_mask_nc
def get_offset(self): def get_offset(self) -> int:
"""Returns offset used when retaining inputs for decoding. """Returns offset used when retaining inputs for decoding.
This is essentially, how many additional frames have to be added to This is essentially, how many additional frames have to be added to
@ -605,8 +621,6 @@ class ConformerEncoder(TransformerEncoderBase):
Some examples for the 2 cases: Some examples for the 2 cases:
left_chunk = 6 left_chunk = 6
left_chunk = [12, 9, 6, 3] left_chunk = [12, 9, 6, 3]
left_chunk: int
number of chunks used for masking in streaming mode.
num_lang: int num_lang: int
This parameter is used to store the number of languages in the This parameter is used to store the number of languages in the
lang_dict, only used for multiseed/multilingual models. lang_dict, only used for multiseed/multilingual models.
@ -751,46 +765,46 @@ class ConformerEncoder(TransformerEncoderBase):
def __init__( # pylint: disable-all def __init__( # pylint: disable-all
self, self,
input_size, input_size: int,
chunk_size, chunk_size: Union[int, list[int]],
left_chunk, left_chunk: Union[int, list[int]],
num_lang=None, num_lang: Optional[int] = None,
attention_dim=256, attention_dim: int = 256,
attention_heads=4, attention_heads: int = 4,
linear_units=2048, linear_units: int = 2048,
num_blocks=6, num_blocks: int = 6,
dropout_rate=0.1, dropout_rate: float = 0.1,
input_layer="nemo_conv", input_layer: str = "nemo_conv",
causal=True, causal: bool = True,
batch_norm=False, batch_norm: bool = False,
cnn_out=-1, cnn_out: int = -1,
cnn_layer_norm=False, cnn_layer_norm: bool = False,
ext_pw_out_channel=0, ext_pw_out_channel: int = 0,
ext_pw_kernel_size=1, ext_pw_kernel_size: int = 1,
depthwise_seperable_out_channel=256, depthwise_seperable_out_channel: int = 256,
depthwise_multiplier=1, depthwise_multiplier: int = 1,
chunk_se=0, chunk_se: int = 0,
kernel_size=3, kernel_size: int = 3,
activation="relu", activation: str = "relu",
conv_activation="relu", conv_activation: str = "relu",
conv_glu_type="sigmoid", conv_glu_type: str = "sigmoid",
bias_in_glu=True, bias_in_glu: bool = True,
linear_glu_in_convm=False, linear_glu_in_convm: bool = False,
attention_glu_type="swish", attention_glu_type: str = "swish",
export=False, export: bool = False,
extra_layer_output_idx=-1, extra_layer_output_idx: int = -1,
extra_multi_layer_output_idxs=[], # noqa extra_multi_layer_output_idxs: list[int] = [], # noqa
activation_checkpointing="", activation_checkpointing: str = "",
relative_attention_bias_args=None, relative_attention_bias_args: Optional[dict[str, Any]] = None,
time_reduction=4, time_reduction: int = 4,
use_pt_scaled_dot_product_attention=False, use_pt_scaled_dot_product_attention: bool = False,
nemo_conv_settings=None, nemo_conv_settings: Optional[dict[str, Any]] = None,
conv2d_extra_padding: Literal["feat", "feat_time", "none", conv2d_extra_padding: Literal["feat", "feat_time", "none",
True] = "none", True] = "none",
replication_pad_for_subsample_embedding=False, replication_pad_for_subsample_embedding: bool = False,
attention_group_size=1, attention_group_size: int = 1,
encoder_embedding_config=None, encoder_embedding_config: Optional[dict[str, Any]] = None,
): ) -> None:
super().__init__( super().__init__(
input_size, input_size,
chunk_size, chunk_size,
@ -852,11 +866,13 @@ class ConformerEncoder(TransformerEncoderBase):
# the device and the needed dtype: # the device and the needed dtype:
self.register_buffer("dev_type", torch.zeros(()), persistent=False) self.register_buffer("dev_type", torch.zeros(()), persistent=False)
def init_relative_attention_bias(self, input_tensor): def init_relative_attention_bias(
self, input_tensor: torch.Tensor) -> Optional[torch.Tensor]:
if self.relative_attention_bias_layer: if self.relative_attention_bias_layer:
return self.relative_attention_bias_layer(input_tensor) return self.relative_attention_bias_layer(input_tensor)
def calculate_hs_mask(self, xs_pad, device, mask): def calculate_hs_mask(self, xs_pad: torch.Tensor, device: torch.device,
mask: Optional[torch.Tensor]) -> torch.Tensor:
max_audio_length = xs_pad.shape[1] max_audio_length = xs_pad.shape[1]
batch_size = xs_pad.shape[0] batch_size = xs_pad.shape[0]
enc_streaming_mask = self._streaming_mask(max_audio_length, batch_size, enc_streaming_mask = self._streaming_mask(max_audio_length, batch_size,
@ -877,7 +893,8 @@ class ConformerEncoder(TransformerEncoderBase):
return pad_mask return pad_mask
@torch.jit.ignore @torch.jit.ignore
def forward(self, xs_pad, masks): def forward(self, xs_pad: torch.Tensor,
masks: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Conformer Forward function """Conformer Forward function
Args: Args:
@ -997,7 +1014,12 @@ class WindowQformer(nn.Module):
if normalize_before else None) if normalize_before else None)
self.window_size = window_size self.window_size = window_size
def forward(self, audio_embed, mask, embed_len=None): def forward(
self,
audio_embed: torch.Tensor,
mask: Optional[torch.Tensor],
embed_len: Optional[int] = None
) -> tuple[torch.Tensor, Optional[int]]:
"""forward decoder""" """forward decoder"""
# audio_embed: N x T x D => N x D x T # audio_embed: N x T x D => N x D x T
@ -1042,7 +1064,7 @@ class WindowQformer(nn.Module):
class AudioEmbedding(nn.Module): class AudioEmbedding(nn.Module):
"""Image embedding.""" """Image embedding."""
def __init__(self, config: PretrainedConfig, **kwargs) -> None: def __init__(self, config: PretrainedConfig, **kwargs: Any) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
# n_embed or hidden_size for text LM # n_embed or hidden_size for text LM
@ -1148,19 +1170,18 @@ class AudioEmbedding(nn.Module):
self.input_embeds = None self.input_embeds = None
self.audio_embed_sizes = None self.audio_embed_sizes = None
def set_audio_embeds(self, input_embeds: torch.FloatTensor) -> None: def set_audio_embeds(self, input_embeds: torch.Tensor) -> None:
self.input_embeds = input_embeds self.input_embeds = input_embeds
def set_audio_embed_sizes(self, def set_audio_embed_sizes(self, audio_embed_sizes: torch.Tensor) -> None:
audio_embed_sizes: torch.LongTensor) -> None:
self.audio_embed_sizes = audio_embed_sizes self.audio_embed_sizes = audio_embed_sizes
def get_audio_features( def get_audio_features(
self, self,
input_embeds: torch.FloatTensor, input_embeds: torch.Tensor,
audio_attention_mask: torch.Tensor = None, audio_attention_mask: Optional[torch.Tensor] = None,
audio_projection_mode: str = "speech", audio_projection_mode: str = "speech",
) -> torch.FloatTensor: ) -> torch.Tensor:
""" """
arguments: arguments:
input_embeds: audio features (B, T, D) B: num audios in a sequence input_embeds: audio features (B, T, D) B: num audios in a sequence
@ -1214,10 +1235,10 @@ class AudioEmbedding(nn.Module):
def forward( def forward(
self, self,
audio_features: torch.FloatTensor, audio_features: torch.Tensor,
audio_attention_mask: torch.Tensor = None, audio_attention_mask: Optional[torch.Tensor] = None,
audio_projection_mode: str = "speech", audio_projection_mode: str = "speech",
) -> torch.FloatTensor: ) -> torch.Tensor:
""" """
arguments: arguments:
audio_features: audio features (T, D) audio_features: audio features (T, D)

View File

@ -16,13 +16,13 @@ from torch import Tensor, nn
class BlockBase(nn.Module): class BlockBase(nn.Module):
"""Block abstract module""" """Block abstract module"""
def __init__(self, input_size, output_size): def __init__(self, input_size: int, output_size: int) -> None:
super().__init__() super().__init__()
self.input_size = input_size self.input_size = input_size
self.output_size = output_size self.output_size = output_size
def get_activation(name="relu"): def get_activation(name: str = "relu") -> torch.nn.Module:
"""Select an activation function by name """Select an activation function by name
Args: Args:
@ -43,15 +43,18 @@ def get_activation(name="relu"):
return nn.Identity() return nn.Identity()
def adaptive_enc_mask(x_len, chunk_start_idx, left_window=0, right_window=0): def adaptive_enc_mask(x_len: int,
chunk_start_idx: list[int],
left_window: int = 0,
right_window: int = 0) -> torch.Tensor:
""" """
The function is very important for Transformer Transducer Streaming mode The function is very important for Transformer Transducer Streaming mode
Args: Args:
xs_len (int): sequence length x_len: sequence length
chunk_start_idx (list): first idx of each chunk, such as [0,18,36,48]. chunk_start_idx: first idx of each chunk, such as [0,18,36,48].
It also supports adaptive chunk size [0,10,15,45] It also supports adaptive chunk size [0,10,15,45]
left_window (int): how many left chunks can be seen left_window: how many left chunks can be seen
right_window (int): how many right chunks can be seen. It is used for right_window: how many right chunks can be seen. It is used for
chunk overlap model. chunk overlap model.
Returns: Returns:
mask (torch.Tensor): a mask tensor for streaming model mask (torch.Tensor): a mask tensor for streaming model
@ -172,13 +175,13 @@ class GLUPointWiseConv(nn.Module):
def __init__( def __init__(
self, self,
input_dim, input_dim: int,
output_dim, output_dim: int,
kernel_size, kernel_size: int,
glu_type="sigmoid", glu_type: str = "sigmoid",
bias_in_glu=True, bias_in_glu: bool = True,
causal=False, causal: bool = False,
): ) -> None:
super().__init__() super().__init__()
self.glu_type = glu_type self.glu_type = glu_type
@ -216,11 +219,10 @@ class GLUPointWiseConv(nn.Module):
self.b1 = nn.Parameter(torch.zeros(1, output_dim, 1)) self.b1 = nn.Parameter(torch.zeros(1, output_dim, 1))
self.b2 = nn.Parameter(torch.zeros(1, output_dim, 1)) self.b2 = nn.Parameter(torch.zeros(1, output_dim, 1))
def forward(self, x): def forward(self, x: Tensor) -> Tensor:
""" """
Args: Args:
x: torch.Tensor x: input tensor
input tensor
""" """
# to be consistent with GLULinear, we assume the input always has the # to be consistent with GLULinear, we assume the input always has the
# #channel (#dim) in the last dimension of the tensor, so need to # #channel (#dim) in the last dimension of the tensor, so need to
@ -272,12 +274,12 @@ class DepthWiseSeperableConv1d(nn.Module):
def __init__( def __init__(
self, self,
input_dim, input_dim: int,
depthwise_seperable_out_channel, depthwise_seperable_out_channel: int,
kernel_size, kernel_size: int,
depthwise_multiplier, depthwise_multiplier: int,
padding=0, padding: int = 0,
): ) -> None:
super().__init__() super().__init__()
self.dw_conv = nn.Conv1d( self.dw_conv = nn.Conv1d(
@ -301,12 +303,11 @@ class DepthWiseSeperableConv1d(nn.Module):
self.pw_conv = nn.Identity() self.pw_conv = nn.Identity()
self.depthwise_seperable_out_channel = depthwise_seperable_out_channel self.depthwise_seperable_out_channel = depthwise_seperable_out_channel
def forward(self, x): def forward(self, x: Tensor) -> Tensor:
""" """
Args: Args:
x: torch.Tensor x: input tensor
input tensor
""" """
x = self.dw_conv(x) x = self.dw_conv(x)
if self.depthwise_seperable_out_channel != 0: if self.depthwise_seperable_out_channel != 0:
@ -375,23 +376,23 @@ class ConvModule(nn.Module):
def __init__( def __init__(
self, self,
input_dim, input_dim: int,
ext_pw_out_channel, ext_pw_out_channel: int,
depthwise_seperable_out_channel, depthwise_seperable_out_channel: int,
ext_pw_kernel_size, ext_pw_kernel_size: int,
kernel_size, kernel_size: int,
depthwise_multiplier, depthwise_multiplier: int,
dropout_rate, dropout_rate: float,
causal=False, causal: bool = False,
batch_norm=False, batch_norm: bool = False,
chunk_se=0, chunk_se: int = 0,
chunk_size=18, chunk_size: int = 18,
activation="relu", activation: str = "relu",
glu_type="sigmoid", glu_type: str = "sigmoid",
bias_in_glu=True, bias_in_glu: bool = True,
linear_glu_in_convm=False, linear_glu_in_convm: bool = False,
export=False, export: bool = False,
): ) -> None:
super().__init__() super().__init__()
self.layer_norm = nn.LayerNorm(input_dim) self.layer_norm = nn.LayerNorm(input_dim)
self.input_dim = input_dim self.input_dim = input_dim
@ -437,7 +438,7 @@ class ConvModule(nn.Module):
self.ln2 = nn.Linear(input_dim * depthwise_multiplier, self.ln2 = nn.Linear(input_dim * depthwise_multiplier,
input_dim) input_dim)
def _add_ext_pw_layer(self): def _add_ext_pw_layer(self) -> None:
""" """
This function is an extension of __init__ function This function is an extension of __init__ function
and dedicated to the convolution module creation and dedicated to the convolution module creation
@ -497,12 +498,11 @@ class ConvModule(nn.Module):
self.pw_conv_simplify_w = torch.nn.Parameter(torch.ones(3)) self.pw_conv_simplify_w = torch.nn.Parameter(torch.ones(3))
self.pw_conv_simplify_b = torch.nn.Parameter(torch.zeros(3)) self.pw_conv_simplify_b = torch.nn.Parameter(torch.zeros(3))
def forward(self, x): def forward(self, x: Tensor) -> Tensor:
"""ConvModule Forward. """ConvModule Forward.
Args: Args:
x: torch.Tensor x: input tensor.
input tensor.
""" """
x = self.layer_norm(x) x = self.layer_norm(x)
@ -567,21 +567,20 @@ class GLULinear(nn.Module):
def __init__( def __init__(
self, self,
input_dim, input_dim: int,
output_dim, output_dim: int,
glu_type="sigmoid", glu_type: str = "sigmoid",
bias_in_glu=True, bias_in_glu: bool = True,
): ) -> None:
super().__init__() super().__init__()
self.linear = nn.Linear(input_dim, output_dim * 2, bias_in_glu) self.linear = nn.Linear(input_dim, output_dim * 2, bias_in_glu)
self.glu_act = GLU(-1, glu_type) self.glu_act = GLU(-1, glu_type)
def forward(self, x): def forward(self, x: Tensor) -> Tensor:
"""GLULinear forward """GLULinear forward
Args: Args:
x: torch.Tensor x: input tensor.
inpute tensor.
""" """
x = self.linear(x) x = self.linear(x)
return self.glu_act(x) return self.glu_act(x)
@ -609,12 +608,12 @@ class FeedForward(nn.Module):
def __init__( def __init__(
self, self,
d_model, d_model: int,
d_inner, d_inner: int,
dropout_rate, dropout_rate: float,
activation="sigmoid", activation: str = "sigmoid",
bias_in_glu=True, bias_in_glu: bool = True,
): ) -> None:
super().__init__() super().__init__()
self.d_model = d_model self.d_model = d_model
self.d_inner = d_inner self.d_inner = d_inner
@ -628,12 +627,11 @@ class FeedForward(nn.Module):
nn.Dropout(dropout_rate), nn.Dropout(dropout_rate),
) )
def forward(self, x): def forward(self, x: Tensor) -> Tensor:
"""FeedForward forward function. """FeedForward forward function.
Args: Args:
x: torch.Tensor x: input tensor.
input tensor.
""" """
out = self.net(self.layer_norm(x)) out = self.net(self.layer_norm(x))
@ -642,14 +640,14 @@ class FeedForward(nn.Module):
#### positional encoding starts here #### positional encoding starts here
def _pre_hook( def _pre_hook(
state_dict, state_dict: dict,
prefix, prefix: str,
local_metadata, local_metadata: dict,
strict, strict: bool,
missing_keys, missing_keys: list[str],
unexpected_keys, unexpected_keys: list[str],
error_msgs, error_msgs: list[str],
): ) -> None:
"""Perform pre-hook in load_state_dict for backward compatibility. """Perform pre-hook in load_state_dict for backward compatibility.
Note: Note:
@ -708,10 +706,10 @@ class T5RelativeAttentionLogitBias(nn.Module):
""" """
def __init__(self, def __init__(self,
num_heads, num_heads: int,
num_buckets=-1, num_buckets: int = -1,
max_distance=1000, max_distance: int = 1000,
symmetric=False): symmetric: bool = False) -> None:
super().__init__() super().__init__()
self.num_heads = num_heads self.num_heads = num_heads
self.num_buckets = num_buckets self.num_buckets = num_buckets
@ -727,7 +725,7 @@ class T5RelativeAttentionLogitBias(nn.Module):
self.num_buckets *= 2 self.num_buckets *= 2
self.bias_values = nn.Embedding(self.num_buckets, self.num_heads) self.bias_values = nn.Embedding(self.num_buckets, self.num_heads)
def forward(self, x): def forward(self, x: Tensor) -> Tensor:
# instantiate bias compatible with shape of x # instantiate bias compatible with shape of x
maxpos = x.size(1) maxpos = x.size(1)
context_position = torch.arange(maxpos, context_position = torch.arange(maxpos,
@ -760,7 +758,7 @@ class T5RelativeAttentionLogitBias(nn.Module):
return t5_rel_att_bias return t5_rel_att_bias
def _bucket_relative_position(self, relative_position): def _bucket_relative_position(self, relative_position: Tensor) -> Tensor:
# this is a placeholder (isn't tested, likely buggy) using HuggingFace # this is a placeholder (isn't tested, likely buggy) using HuggingFace
# implem as a reference this also needs to be extended to support # implem as a reference this also needs to be extended to support
# asymmetric +/- ve positions # asymmetric +/- ve positions
@ -810,7 +808,10 @@ class AbsolutePositionalEncoding(nn.Module):
""" """
def __init__(self, d_model, dropout_rate, max_len=5000): def __init__(self,
d_model: int,
dropout_rate: float,
max_len: int = 5000) -> None:
"""Construct an PositionalEncoding object.""" """Construct an PositionalEncoding object."""
super().__init__() super().__init__()
self.d_model = d_model self.d_model = d_model
@ -820,11 +821,11 @@ class AbsolutePositionalEncoding(nn.Module):
self.extend_pe(torch.tensor(0.0).expand(1, max_len)) self.extend_pe(torch.tensor(0.0).expand(1, max_len))
self._register_load_state_dict_pre_hook(_pre_hook) self._register_load_state_dict_pre_hook(_pre_hook)
def extend_pe(self, x): def extend_pe(self, x: torch.Tensor) -> None:
"""Reset the positional encodings. """Reset the positional encodings.
Args: Args:
x: torch.Tensor x: input tensor
""" """
if self.pe is not None and self.pe.size(1) >= x.size(1): if self.pe is not None and self.pe.size(1) >= x.size(1):
if self.pe.dtype != x.dtype or self.pe.device != x.device: if self.pe.dtype != x.dtype or self.pe.device != x.device:
@ -840,15 +841,14 @@ class AbsolutePositionalEncoding(nn.Module):
pe = pe.unsqueeze(0) pe = pe.unsqueeze(0)
self.pe = pe.to(device=x.device, dtype=x.dtype) self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Add positional encoding. """Add positional encoding.
Args: Args:
x: torch.Tensor x: Input tensor. shape is (batch, time, ...)
Input tensor. shape is (batch, time, ...)
Returns: Returns:
torch.Tensor: Encoded tensor. Its shape is (batch, time, ...) Encoded tensor. Its shape is (batch, time, ...)
""" """
self.extend_pe(x) self.extend_pe(x)
@ -868,7 +868,7 @@ class MeanVarianceNormLayer(nn.Module):
layer input size. layer input size.
""" """
def __init__(self, input_size): def __init__(self, input_size: int) -> None:
super().__init__() super().__init__()
self.input_size = input_size self.input_size = input_size
self.global_mean = nn.Parameter(torch.zeros(input_size)) self.global_mean = nn.Parameter(torch.zeros(input_size))
@ -878,8 +878,7 @@ class MeanVarianceNormLayer(nn.Module):
"""MeanVarianceNormLayer Forward """MeanVarianceNormLayer Forward
Args: Args:
input_: torch.Tensor input_: input tensor.
input tensor.
""" """
return (input_ - self.global_mean) * self.global_invstd return (input_ - self.global_mean) * self.global_invstd
@ -949,7 +948,10 @@ class CausalConv1D(nn.Conv1d):
dtype=dtype, dtype=dtype,
) )
def update_cache(self, x, cache=None): def update_cache(
self,
x: Tensor,
cache: Optional[Tensor] = None) -> tuple[Tensor, Optional[Tensor]]:
if cache is None: if cache is None:
new_x = F.pad(x, pad=(self._left_padding, self._right_padding)) new_x = F.pad(x, pad=(self._left_padding, self._right_padding))
next_cache = cache next_cache = cache
@ -963,7 +965,11 @@ class CausalConv1D(nn.Conv1d):
next_cache = next_cache[:, :, -cache.size(-1):] next_cache = next_cache[:, :, -cache.size(-1):]
return new_x, next_cache return new_x, next_cache
def forward(self, x, cache=None): def forward(
self,
x: Tensor,
cache: Optional[Tensor] = None
) -> Union[Tensor, tuple[Tensor, Optional[Tensor]]]:
x, cache = self.update_cache(x, cache=cache) x, cache = self.update_cache(x, cache=cache)
x = super().forward(x) x = super().forward(x)
if cache is None: if cache is None:
@ -1017,8 +1023,8 @@ class CausalConv2D(nn.Conv2d):
def forward( def forward(
self, self,
x, x: Tensor,
): ) -> Tensor:
x = F.pad( x = F.pad(
x, x,
pad=(self._left_padding, self._right_padding, 0, 0), pad=(self._left_padding, self._right_padding, 0, 0),
@ -1063,15 +1069,15 @@ class NemoConvSubsampling(torch.nn.Module):
def __init__( def __init__(
self, self,
feat_in, feat_in: int,
feat_out, feat_out: int,
subsampling_factor=4, subsampling_factor: int = 4,
subsampling="dw_striding", subsampling: str = "dw_striding",
conv_channels=256, conv_channels: int = 256,
subsampling_conv_chunking_factor=1, subsampling_conv_chunking_factor: int = 1,
activation=nn.ReLU(), # noqa: B008 activation: torch.nn.Module = nn.ReLU(), # noqa: B008
is_causal=False, is_causal: bool = False,
): ) -> None:
super().__init__() super().__init__()
self._subsampling = subsampling self._subsampling = subsampling
self._conv_channels = conv_channels self._conv_channels = conv_channels
@ -1328,28 +1334,25 @@ class NemoConvSubsampling(torch.nn.Module):
self.conv = torch.nn.Sequential(*layers) self.conv = torch.nn.Sequential(*layers)
def get_sampling_frames(self): def get_sampling_frames(self) -> list[int]:
return [1, self.subsampling_factor] return [1, self.subsampling_factor]
def get_streaming_cache_size(self): def get_streaming_cache_size(self) -> list[int]:
return [0, self.subsampling_factor + 1] return [0, self.subsampling_factor + 1]
def forward(self, x, mask): def forward(self, x: Tensor,
mask: Optional[Tensor]) -> tuple[Tensor, Optional[Tensor]]:
""" """
Forward method for NeMo subsampling. Forward method for NeMo subsampling.
Args: Args:
x[Batch, Time, Filters]: torch.Tensor x: input tensor
input tensor mask: input mask
x_mask: torch.Tensor
input mask
Returns: Returns:
x: torch.Tensor x: Resulting tensor from subsampling (B, T //
Resulting tensor from subsampling (B, T //
time_reduction_factor, feat_out) time_reduction_factor, feat_out)
pad_mask: torch.Tensor pad_mask: tensor of padded hidden state sequences (B, 1, T //
tensor of padded hidden state sequences (B, 1, T //
time_reduction_factor) time_reduction_factor)
""" """
x = x.unsqueeze(1) if self.conv2d_subsampling else x.transpose(1, 2) x = x.unsqueeze(1) if self.conv2d_subsampling else x.transpose(1, 2)
@ -1403,7 +1406,7 @@ class NemoConvSubsampling(torch.nn.Module):
padding_length.size(0), -1) < padding_length.unsqueeze(1) padding_length.size(0), -1) < padding_length.unsqueeze(1)
return x, pad_mask.unsqueeze(1) return x, pad_mask.unsqueeze(1)
def reset_parameters(self): def reset_parameters(self) -> None:
# initialize weights # initialize weights
if self._subsampling == "dw_striding": if self._subsampling == "dw_striding":
with torch.no_grad(): with torch.no_grad():
@ -1433,7 +1436,7 @@ class NemoConvSubsampling(torch.nn.Module):
torch.nn.init.uniform_(self.out.weight, -fc_scale, fc_scale) torch.nn.init.uniform_(self.out.weight, -fc_scale, fc_scale)
torch.nn.init.uniform_(self.out.bias, -fc_scale, fc_scale) torch.nn.init.uniform_(self.out.bias, -fc_scale, fc_scale)
def conv_split_by_batch(self, x): def conv_split_by_batch(self, x: Tensor) -> tuple[Tensor, bool]:
"""Tries to split input by batch, run conv and concat results""" """Tries to split input by batch, run conv and concat results"""
b, _, _, _ = x.size() b, _, _, _ = x.size()
if b == 1: # can't split if batch size is 1 if b == 1: # can't split if batch size is 1
@ -1460,7 +1463,7 @@ class NemoConvSubsampling(torch.nn.Module):
True, True,
) )
def conv_split_by_channel(self, x): def conv_split_by_channel(self, x: Tensor) -> Tensor:
"""For dw convs, tries to split input by time, run conv and concat """For dw convs, tries to split input by time, run conv and concat
results""" results"""
x = self.conv[0](x) # full conv2D x = self.conv[0](x) # full conv2D
@ -1500,7 +1503,8 @@ class NemoConvSubsampling(torch.nn.Module):
x = self.conv[i * 3 + 4](x) # activation x = self.conv[i * 3 + 4](x) # activation
return x return x
def channel_chunked_conv(self, conv, chunk_size, x): def channel_chunked_conv(self, conv: torch.nn.Module, chunk_size: int,
x: Tensor) -> Tensor:
"""Performs channel chunked convolution""" """Performs channel chunked convolution"""
ind = 0 ind = 0
@ -1541,7 +1545,7 @@ class NemoConvSubsampling(torch.nn.Module):
return torch.cat(out_chunks, 1) return torch.cat(out_chunks, 1)
def change_subsampling_conv_chunking_factor( def change_subsampling_conv_chunking_factor(
self, subsampling_conv_chunking_factor: int): self, subsampling_conv_chunking_factor: int) -> None:
if (subsampling_conv_chunking_factor != -1 if (subsampling_conv_chunking_factor != -1
and subsampling_conv_chunking_factor != 1 and subsampling_conv_chunking_factor != 1
and subsampling_conv_chunking_factor % 2 != 0): and subsampling_conv_chunking_factor % 2 != 0):
@ -1552,12 +1556,12 @@ class NemoConvSubsampling(torch.nn.Module):
self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor
def calc_length(lengths, def calc_length(lengths: Tensor,
all_paddings, all_paddings: int,
kernel_size, kernel_size: int,
stride, stride: int,
ceil_mode, ceil_mode: bool,
repeat_num=1): repeat_num: int = 1) -> Tensor:
"""Calculates the output length of a Tensor passed through a convolution or """Calculates the output length of a Tensor passed through a convolution or
max pooling layer""" max pooling layer"""
add_pad: float = all_paddings - kernel_size add_pad: float = all_paddings - kernel_size
@ -1573,11 +1577,11 @@ def calc_length(lengths,
class AttModule(nn.Module): class AttModule(nn.Module):
"""Attention abstraction module""" """Attention abstraction module"""
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self.export_mode = False self.export_mode = False
def set_export(self, mode=True): def set_export(self, mode: bool = True) -> None:
"""set the export mode""" """set the export mode"""
self.export_mode = mode self.export_mode = mode
@ -1591,14 +1595,10 @@ class AttModule(nn.Module):
"""AttModule forward """AttModule forward
Args: Args:
x: torch.Tensor x: input tensor.
input tensor. memory: memory tensor.
memory: torch.Tensor, optional pos_emb: positional encoder embedding.
memory tensor. att_mask: attention mask tensor.
pos_emb: torch.Tensor, optional
positional encoder embedding.
att_mask: torch.Tensor, optional
attention mask tensor.
""" """
return x, memory, pos_emb, att_mask return x, memory, pos_emb, att_mask
@ -1606,15 +1606,15 @@ class AttModule(nn.Module):
class AttBlock(BlockBase, AttModule): class AttBlock(BlockBase, AttModule):
"""Attention Block module to support both Attention and Block module.""" """Attention Block module to support both Attention and Block module."""
def memory_dims(self, max_len=False): def memory_dims(self, max_len: bool = False) -> tuple[int, int]:
"""memory dimensions""" """memory dimensions"""
return (1, self.input_size) return (1, self.input_size)
def masked_softmax( def masked_softmax(
scores, scores: Tensor,
mask: Optional[Tensor], mask: Optional[Tensor],
): ) -> Tensor:
if mask is not None: if mask is not None:
mask = mask.unsqueeze(1).eq(0) # (batch, 1, time1, time2) mask = mask.unsqueeze(1).eq(0) # (batch, 1, time1, time2)
scores = scores.masked_fill(mask, -torch.inf) scores = scores.masked_fill(mask, -torch.inf)
@ -1636,10 +1636,6 @@ class MultiHeadedAttention(nn.Module):
input size features. input size features.
dropout_rate: float dropout_rate: float
dropout rate. dropout rate.
use_LN: bool
apply layer norm or not
dropout_at_output: bool
whether to apply dropout at output
attention_inner_dim: int, optional attention_inner_dim: int, optional
the attention dimension used in the class, the attention dimension used in the class,
it can be different from the input dimension n_feat. it can be different from the input dimension n_feat.
@ -1666,16 +1662,16 @@ class MultiHeadedAttention(nn.Module):
def __init__( def __init__(
self, self,
n_head, n_head: int,
n_feat, n_feat: int,
dropout_rate, dropout_rate: float,
attention_inner_dim=-1, attention_inner_dim: int = -1,
glu_type="swish", glu_type: str = "swish",
bias_in_glu=True, bias_in_glu: bool = True,
use_pt_scaled_dot_product_attention=False, use_pt_scaled_dot_product_attention: bool = False,
n_value=-1, n_value: int = -1,
group_size: int = 1, group_size: int = 1,
): ) -> None:
super().__init__() super().__init__()
if n_value == -1: if n_value == -1:
n_value = n_feat n_value = n_feat
@ -1718,28 +1714,22 @@ class MultiHeadedAttention(nn.Module):
query: Tensor, query: Tensor,
key: Tensor, key: Tensor,
value: Tensor, value: Tensor,
pos_k: Tensor, pos_k: Optional[Tensor],
pos_v: Tensor, pos_v: Optional[Tensor],
mask: Optional[Tensor], mask: Optional[Tensor],
relative_attention_bias: Optional[Tensor] = None, relative_attention_bias: Optional[Tensor] = None,
): ) -> Tensor:
"""Compute 'Scaled Dot Product Attention'. """Compute 'Scaled Dot Product Attention'.
Args: Args:
query: torch.Tensor query: query tensor (batch, time1, size)
query tensor (batch, time1, size) key: key tensor (batch, time2, size)
key: torch.Tensor value: value tensor (batch, time1, size)
key tensor (batch, time2, size) pos_k: key tensor used for relative positional embedding.
value: torch.Tensor pos_v: value tensor used for relative positional embedding.
value tensor (batch, time1, size) mask: mask tensor (batch, time1, time2)
pos_k: torch.Tensor relative_attention_bias: bias added to attention logits w.r.t.
key tensor used for relative positional embedding. relative positions
pos_v: torch.Tensor
value tensor used for relative positional embedding.
mask: torch.Tensor
mask tensor (batch, time1, time2)
relative_attention_bias: torch.Tensor
bias added to attention logits w.r.t. relative positions
(1, n_head, time1, time2) (1, n_head, time1, time2)
""" """
n_batch = query.size(0) n_batch = query.size(0)
@ -1832,20 +1822,20 @@ class MultiSequential(torch.nn.Sequential):
"""Multi-input multi-output torch.nn.Sequential""" """Multi-input multi-output torch.nn.Sequential"""
@torch.jit.ignore @torch.jit.ignore
def forward(self, *args): def forward(self, *args) -> tuple:
"""Forward method implementation.""" """Forward method implementation."""
for m in self: for m in self:
args = m(*args) args = m(*args)
return args return args
def get_offset(input_layer: str, time_reduction: int): def get_offset(input_layer: str, time_reduction: int) -> int:
"""Get an offset. We will use the offset for determining #frames of a """Get an offset. We will use the offset for determining #frames of a
subsampled feature. subsampled feature.
Args: Args:
input_layer (str): Type of an input layer input_layer: Type of an input layer
time_reduction (int): time reduction factor for downsampling a feature time_reduction: time reduction factor for downsampling a feature
Returns: Returns:
int: offset int: offset
""" """
@ -1858,13 +1848,14 @@ def get_offset(input_layer: str, time_reduction: int):
return 0 return 0
def unfold_tensor(xs_pad, max_seq_len): def unfold_tensor(xs_pad: Tensor, max_seq_len: int) -> Tensor:
""" """
For a given tensor with shape of (N, T, D), if sequence length T is For a given tensor with shape of (N, T, D), if sequence length T is
longer than max_seq_len, this function unfold it to a longer than max_seq_len, this function unfold it to a
(NT', max_seq_len, D) where T' is T // max_seq_len. (NT', max_seq_len, D) where T' is T // max_seq_len.
Args: Args:
xs_pad: N, T, D xs_pad: input tensor with shape (N, T, D)
max_seq_len: maximum sequence length
""" """
_, _, D = xs_pad.shape _, _, D = xs_pad.shape
xs_pad = xs_pad.transpose(-1, -2) # convert to N, D, T xs_pad = xs_pad.transpose(-1, -2) # convert to N, D, T

View File

@ -1193,21 +1193,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
input_ids: Flattened (concatenated) input_ids corresponding to a input_ids: Flattened (concatenated) input_ids corresponding to a
batch. batch.
positions: Flattened (concatenated) position ids corresponding to a positions: Flattened (concatenated) position ids corresponding to a
batch. batch. **NOTE**: If mrope is enabled (default setting for
**NOTE**: If mrope is enabled (default setting for Qwen2.5-VL Qwen2.5-VL opensource models), the shape will be `(3, seq_len)`,
opensource models), the shape will be `(3, seq_len)`,
otherwise it will be `(seq_len,). otherwise it will be `(seq_len,).
pixel_values: Pixel values to be fed to a model.
`None` if no images are passed.
image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM.
`None` if no images are passed.
pixel_values_videos: Pixel values of videos to be fed to a model.
`None` if no videos are passed.
video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM.
`None` if no videos are passed.
second_per_grid_ts: Tensor `(num_videos)` of video time interval (
in seconds) for each grid along the temporal dimension in the
3D position IDs. `None` if no videos are passed.
""" """
if intermediate_tensors is not None: if intermediate_tensors is not None:

View File

@ -9,7 +9,7 @@ model alternates between state space model layers and attention-based layers.
""" """
from collections.abc import Iterable from collections.abc import Iterable
from itertools import cycle from itertools import cycle
from typing import Optional, Union from typing import Any, Optional, Union
import torch import torch
from torch import nn from torch import nn
@ -528,8 +528,6 @@ class Zamba2MambaDecoderLayer(nn.Module):
hidden_states: Input tensor [batch_size, seq_len, hidden_size] hidden_states: Input tensor [batch_size, seq_len, hidden_size]
mamba_cache_params: Parameters for Mamba's state caches mamba_cache_params: Parameters for Mamba's state caches
(one for conv, one for ssm) (one for conv, one for ssm)
sequence_idx: Index tensor for identifying sequences in batch
Required for proper chunked processing in prefill
transformer_hidden_states: Optional output from transformer path transformer_hidden_states: Optional output from transformer path
Added to input if provided (used in hybrid architecture) Added to input if provided (used in hybrid architecture)
positions: Optional position IDs (unused in Mamba) positions: Optional position IDs (unused in Mamba)
@ -591,8 +589,6 @@ class Zamba2HybridLayer(nn.Module):
Args: Args:
shared_transformer: Transformer decoder layer for attention pathway shared_transformer: Transformer decoder layer for attention pathway
linear: Linear projection for transformer output before Mamba
mamba: Mamba decoder layer for state space pathway
""" """
super().__init__() super().__init__()
self.block_idx = block_idx self.block_idx = block_idx
@ -630,8 +626,6 @@ class Zamba2HybridLayer(nn.Module):
positions: Position IDs for positional embeddings positions: Position IDs for positional embeddings
mamba_cache_params: Parameters for Mamba's state caches mamba_cache_params: Parameters for Mamba's state caches
(one for conv, one for ssm) (one for conv, one for ssm)
sequence_idx: Indices for identifying sequences in batch,
required for proper chunked processing in prefill
Returns: Returns:
Output tensor combining transformer and Mamba representations Output tensor combining transformer and Mamba representations
@ -915,8 +909,8 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
prefix: Optional prefix for parameter names prefix: Optional prefix for parameter names
Raises: Raises:
AssertionError: If prefix caching is enabled (not supported by AssertionError: If prefix caching is enabled
Mamba) (not supported by Mamba)
""" """
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
@ -971,7 +965,7 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs) -> torch.Tensor: **kwargs: Any) -> torch.Tensor:
"""Forward pass through the model. """Forward pass through the model.
Args: Args:
@ -1012,9 +1006,9 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
return hidden_states return hidden_states
def copy_inputs_before_cuda_graphs(self, input_buffers: dict[str, def copy_inputs_before_cuda_graphs(
torch.Tensor], self, input_buffers: dict[str, torch.Tensor],
**kwargs) -> dict[str, torch.Tensor]: **kwargs: Any) -> dict[str, torch.Tensor]:
"""Copy inputs before CUDA graph capture. """Copy inputs before CUDA graph capture.
Args: Args: