mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-18 07:25:01 +08:00
[Docs] Fix warnings in mkdocs build (continued) (#24092)
Signed-off-by: Zerohertz <ohg3417@gmail.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
parent
c0bd6a684a
commit
ccee371e86
@ -755,7 +755,7 @@ class FusedMoE(CustomOp):
|
|||||||
intermediate_size: Intermediate size of the experts
|
intermediate_size: Intermediate size of the experts
|
||||||
params_dtype: Data type for the parameters.
|
params_dtype: Data type for the parameters.
|
||||||
reduce_results: Whether to all all_reduce on the output of the layer
|
reduce_results: Whether to all all_reduce on the output of the layer
|
||||||
renomalize: Whether to renormalize the logits in the fused_moe kernel
|
renormalize: Whether to renormalize the logits in the fused_moe kernel
|
||||||
quant_config: Quantization configure.
|
quant_config: Quantization configure.
|
||||||
enable_eplb: Whether to enable expert parallelism load balancer.
|
enable_eplb: Whether to enable expert parallelism load balancer.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -420,9 +420,8 @@ def shuffle_weights(
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
*tensors: Variable number of torch.Tensor objects.
|
*tensors: Variable number of torch.Tensor objects.
|
||||||
layout: A pair of integers specifying the
|
layout: A pair of integers specifying the block sizes used to divide
|
||||||
block sizes used to divide the tensors during shuffling.
|
the tensors during shuffling. Default is (16, 16).
|
||||||
Default is (16, 16).
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A Tuple of shuffled tensors.
|
A Tuple of shuffled tensors.
|
||||||
|
|||||||
@ -10,7 +10,7 @@ like uniform random routing.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -50,7 +50,9 @@ class DistributionBasedRouting(RoutingStrategy):
|
|||||||
distributions for testing different routing patterns.
|
distributions for testing different routing patterns.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, distribution: str = "uniform", **distribution_params):
|
def __init__(self,
|
||||||
|
distribution: str = "uniform",
|
||||||
|
**distribution_params: Any):
|
||||||
"""
|
"""
|
||||||
Initialize distribution-based routing.
|
Initialize distribution-based routing.
|
||||||
|
|
||||||
@ -244,7 +246,7 @@ class RoutingSimulator:
|
|||||||
cls._routing_strategies[name] = strategy
|
cls._routing_strategies[name] = strategy
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_available_strategies(cls):
|
def get_available_strategies(cls) -> list[str]:
|
||||||
"""
|
"""
|
||||||
Get list of available routing strategy names.
|
Get list of available routing strategy names.
|
||||||
|
|
||||||
|
|||||||
@ -202,7 +202,7 @@ class BitBLASLinearMethod(LinearMethodBase):
|
|||||||
output_size: int,
|
output_size: int,
|
||||||
params_dtype: torch.dtype,
|
params_dtype: torch.dtype,
|
||||||
**extra_weight_attrs,
|
**extra_weight_attrs,
|
||||||
):
|
) -> None:
|
||||||
"""Creates quantized weights for use in linear operations.
|
"""Creates quantized weights for use in linear operations.
|
||||||
|
|
||||||
The function initializes and returns a dictionary containing quantized
|
The function initializes and returns a dictionary containing quantized
|
||||||
@ -211,7 +211,7 @@ class BitBLASLinearMethod(LinearMethodBase):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_size_per_partition: The size of the input partition.
|
input_size_per_partition: The size of the input partition.
|
||||||
output_size_per_partition: The size of the output partition.
|
output_partition_sizes: List of output partition sizes.
|
||||||
input_size: The total size of the input (unused).
|
input_size: The total size of the input (unused).
|
||||||
output_size: The total size of the output (unused).
|
output_size: The total size of the output (unused).
|
||||||
params_dtype:
|
params_dtype:
|
||||||
@ -222,9 +222,9 @@ class BitBLASLinearMethod(LinearMethodBase):
|
|||||||
scales ('scales'), and zeros ('zeros').
|
scales ('scales'), and zeros ('zeros').
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If `params_dtype` is not `torch.float16` or if the
|
ValueError: If `params_dtype` is not `torch.float16` or if the input
|
||||||
input size per partition is not divisible by the group size in
|
size per partition is not divisible by the group size
|
||||||
`quant_config`.
|
in `quant_config`.
|
||||||
"""
|
"""
|
||||||
del input_size, output_size # Unused arguments.
|
del input_size, output_size # Unused arguments.
|
||||||
weight_loader = extra_weight_attrs["weight_loader"]
|
weight_loader = extra_weight_attrs["weight_loader"]
|
||||||
|
|||||||
@ -265,9 +265,9 @@ class GPTQBitBLASLinearMethod(LinearMethodBase):
|
|||||||
scales ('scales'), and zeros ('zeros').
|
scales ('scales'), and zeros ('zeros').
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If `params_dtype` is not `torch.float16` or
|
ValueError: If `params_dtype` is not `torch.float16` or if the input
|
||||||
if the input size per partition is not divisible by the
|
size per partition is not divisible by the group size
|
||||||
group size in `quant_config`.
|
in `quant_config`.
|
||||||
"""
|
"""
|
||||||
if params_dtype != torch.float16:
|
if params_dtype != torch.float16:
|
||||||
raise ValueError("Parameter data type must be torch.float16, "
|
raise ValueError("Parameter data type must be torch.float16, "
|
||||||
|
|||||||
@ -47,10 +47,10 @@ def choose_mp_linear_kernel(
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
config (MPLinearLayerConfig): Description of the linear layer to be
|
config (MPLinearLayerConfig): Description of the linear layer to be
|
||||||
implemented.
|
implemented.
|
||||||
compute_capability (Optional[int], optional): The compute capability of
|
compute_capability (Optional[int], optional): The compute capability of
|
||||||
the target device, if None uses `current_platform` to get the compute
|
the target device, if None uses `current_platform` to get
|
||||||
capability. Defaults to None.
|
the compute capability. Defaults to None.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If no kernel can implement the given config.
|
ValueError: If no kernel can implement the given config.
|
||||||
|
|||||||
@ -7,7 +7,7 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
import abc
|
import abc
|
||||||
import math
|
import math
|
||||||
from typing import Literal, Optional
|
from typing import Any, Literal, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -131,31 +131,31 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
d_model=512,
|
d_model: int = 512,
|
||||||
ext_pw_out_channel=0,
|
ext_pw_out_channel: int = 0,
|
||||||
depthwise_seperable_out_channel=256,
|
depthwise_seperable_out_channel: int = 256,
|
||||||
depthwise_multiplier=1,
|
depthwise_multiplier: int = 1,
|
||||||
n_head=4,
|
n_head: int = 4,
|
||||||
d_ffn=2048,
|
d_ffn: int = 2048,
|
||||||
ext_pw_kernel_size=1,
|
ext_pw_kernel_size: int = 1,
|
||||||
kernel_size=3,
|
kernel_size: int = 3,
|
||||||
dropout_rate=0.1,
|
dropout_rate: float = 0.1,
|
||||||
causal=False,
|
causal: bool = False,
|
||||||
batch_norm=False,
|
batch_norm: bool = False,
|
||||||
activation="relu",
|
activation: str = "relu",
|
||||||
chunk_se=0,
|
chunk_se: int = 0,
|
||||||
chunk_size=18,
|
chunk_size: int = 18,
|
||||||
conv_activation="relu",
|
conv_activation: str = "relu",
|
||||||
conv_glu_type="sigmoid",
|
conv_glu_type: str = "sigmoid",
|
||||||
bias_in_glu=True,
|
bias_in_glu: bool = True,
|
||||||
linear_glu_in_convm=False,
|
linear_glu_in_convm: bool = False,
|
||||||
attention_inner_dim=-1,
|
attention_inner_dim: int = -1,
|
||||||
attention_glu_type="swish",
|
attention_glu_type: str = "swish",
|
||||||
activation_checkpointing="",
|
activation_checkpointing: str = "",
|
||||||
export=False,
|
export: bool = False,
|
||||||
use_pt_scaled_dot_product_attention=False,
|
use_pt_scaled_dot_product_attention: bool = False,
|
||||||
attn_group_sizes: int = 1,
|
attn_group_sizes: int = 1,
|
||||||
):
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.feed_forward_in = FeedForward(
|
self.feed_forward_in = FeedForward(
|
||||||
@ -209,24 +209,21 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x,
|
x: torch.Tensor,
|
||||||
pos_k,
|
pos_k: torch.Tensor,
|
||||||
pos_v,
|
pos_v: torch.Tensor,
|
||||||
mask,
|
mask: torch.Tensor,
|
||||||
relative_attention_bias: Optional[Tensor] = None,
|
relative_attention_bias: Optional[Tensor] = None,
|
||||||
):
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""ConformerEncoder forward.
|
"""ConformerEncoder forward.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x: torch.Tensor
|
x: input feature of shape (batch, max_time_in, size)
|
||||||
input feature of shape (batch, max_time_in, size)
|
pos_k: positional key embedding.
|
||||||
pos_k: torch.Tensor
|
pos_v: positional value embedding.
|
||||||
positional key embedding.
|
mask: mask for x (batch, max_time_in)
|
||||||
mask: torch.Tensor
|
relative_attention_bias: bias added to attention logits w.r.t.
|
||||||
mask for x (batch, max_time_in)
|
relative positions (1, n_head, time1, time2)
|
||||||
relative_attention_bias: Optional[torch.Tensor]
|
|
||||||
bias added to attention logits w.r.t. relative positions
|
|
||||||
(1, n_head, time1, time2)
|
|
||||||
"""
|
"""
|
||||||
x = x + 0.5 * self.feed_forward_in(x)
|
x = x + 0.5 * self.feed_forward_in(x)
|
||||||
norm_x = self.layer_norm_att(x)
|
norm_x = self.layer_norm_att(x)
|
||||||
@ -323,25 +320,25 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
input_size,
|
input_size: int,
|
||||||
chunk_size,
|
chunk_size: Union[int, list[int]],
|
||||||
left_chunk,
|
left_chunk: Union[int, list[int]],
|
||||||
attention_dim=256,
|
attention_dim: int = 256,
|
||||||
attention_heads=4,
|
attention_heads: int = 4,
|
||||||
input_layer="nemo_conv",
|
input_layer: str = "nemo_conv",
|
||||||
cnn_out=-1,
|
cnn_out: int = -1,
|
||||||
cnn_layer_norm=False,
|
cnn_layer_norm: bool = False,
|
||||||
time_reduction=4,
|
time_reduction: int = 4,
|
||||||
dropout_rate=0.0,
|
dropout_rate: float = 0.0,
|
||||||
padding_idx=-1,
|
padding_idx: int = -1,
|
||||||
relative_attention_bias_args=None,
|
relative_attention_bias_args: Optional[dict[str, Any]] = None,
|
||||||
positional_dropout_rate=0.0,
|
positional_dropout_rate: float = 0.0,
|
||||||
nemo_conv_settings=None,
|
nemo_conv_settings: Optional[dict[str, Any]] = None,
|
||||||
conv2d_extra_padding: Literal["feat", "feat_time", "none",
|
conv2d_extra_padding: Literal["feat", "feat_time", "none",
|
||||||
True] = "none",
|
True] = "none",
|
||||||
attention_group_size=1,
|
attention_group_size: int = 1,
|
||||||
encoder_embedding_config=None,
|
encoder_embedding_config: Optional[dict[str, Any]] = None,
|
||||||
):
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.input_size = input_size
|
self.input_size = input_size
|
||||||
self.input_layer = input_layer
|
self.input_layer = input_layer
|
||||||
@ -399,7 +396,10 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
|
|||||||
self.encoder_embedding = MeanVarianceNormLayer(
|
self.encoder_embedding = MeanVarianceNormLayer(
|
||||||
self.encoder_embedding_config["input_size"])
|
self.encoder_embedding_config["input_size"])
|
||||||
|
|
||||||
def compute_lens_change(self, feature_lens):
|
def compute_lens_change(
|
||||||
|
self,
|
||||||
|
feature_lens: Union[int,
|
||||||
|
torch.Tensor]) -> Union[int, torch.Tensor]:
|
||||||
"""feature_lens: int
|
"""feature_lens: int
|
||||||
return updated feature lens.
|
return updated feature lens.
|
||||||
|
|
||||||
@ -433,10 +433,14 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
|
|||||||
return ceil_func(feature_lens / self.time_reduction)
|
return ceil_func(feature_lens / self.time_reduction)
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def forward(self):
|
def forward(self) -> Any:
|
||||||
"""Abstract forward method implementation."""
|
"""Abstract forward method implementation."""
|
||||||
|
|
||||||
def _chunk_size_selection(self, chunk_size=None, left_chunk=None):
|
def _chunk_size_selection(
|
||||||
|
self,
|
||||||
|
chunk_size: Optional[Union[int, list[int]]] = None,
|
||||||
|
left_chunk: Optional[Union[int,
|
||||||
|
list[int]]] = None) -> tuple[int, int]:
|
||||||
"""If chunk size is a list, we will randomly select a chunk size."""
|
"""If chunk size is a list, we will randomly select a chunk size."""
|
||||||
|
|
||||||
if chunk_size is None:
|
if chunk_size is None:
|
||||||
@ -463,7 +467,7 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
|
|||||||
|
|
||||||
return chunk_size_train_eff, left_chunk_train_eff
|
return chunk_size_train_eff, left_chunk_train_eff
|
||||||
|
|
||||||
def _get_embed_class(self, embed):
|
def _get_embed_class(self, embed: nn.Module) -> nn.Module:
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
is_embed_using_act_chkpt = isinstance(embed, CheckpointWrapper)
|
is_embed_using_act_chkpt = isinstance(embed, CheckpointWrapper)
|
||||||
is_embed_fsdp_wrapped = isinstance(embed, FullyShardedDataParallel)
|
is_embed_fsdp_wrapped = isinstance(embed, FullyShardedDataParallel)
|
||||||
@ -474,13 +478,17 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
|
|||||||
embed_class = embed.module
|
embed_class = embed.module
|
||||||
return embed_class
|
return embed_class
|
||||||
|
|
||||||
def _forward_embeddings_core(self, input_tensor, masks):
|
def _forward_embeddings_core(
|
||||||
|
self, input_tensor: torch.Tensor,
|
||||||
|
masks: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
embed_class = self._get_embed_class(self.embed)
|
embed_class = self._get_embed_class(self.embed)
|
||||||
assert isinstance(embed_class, NemoConvSubsampling)
|
assert isinstance(embed_class, NemoConvSubsampling)
|
||||||
input_tensor, masks = self.embed(input_tensor, masks)
|
input_tensor, masks = self.embed(input_tensor, masks)
|
||||||
return input_tensor, masks
|
return input_tensor, masks
|
||||||
|
|
||||||
def _position_embedding(self, input_tensor):
|
def _position_embedding(
|
||||||
|
self, input_tensor: torch.Tensor
|
||||||
|
) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||||
pos_k = None
|
pos_k = None
|
||||||
pos_v = None
|
pos_v = None
|
||||||
if self.relative_attention_bias_layer is None:
|
if self.relative_attention_bias_layer is None:
|
||||||
@ -488,7 +496,9 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
|
|||||||
input_tensor) # default to add abs sinusoid embedding
|
input_tensor) # default to add abs sinusoid embedding
|
||||||
return pos_k, pos_v
|
return pos_k, pos_v
|
||||||
|
|
||||||
def _streaming_mask(self, seq_len, batch_size, chunk_size, left_chunk):
|
def _streaming_mask(self, seq_len: int, batch_size: int,
|
||||||
|
chunk_size: Union[int, list[int]],
|
||||||
|
left_chunk: Union[int, list[int]]) -> torch.Tensor:
|
||||||
chunk_size_train_eff, left_chunk_train_eff = \
|
chunk_size_train_eff, left_chunk_train_eff = \
|
||||||
self._chunk_size_selection(chunk_size, left_chunk)
|
self._chunk_size_selection(chunk_size, left_chunk)
|
||||||
|
|
||||||
@ -502,11 +512,17 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
|
|||||||
[batch_size, -1, -1]))
|
[batch_size, -1, -1]))
|
||||||
return enc_streaming_mask
|
return enc_streaming_mask
|
||||||
|
|
||||||
def forward_embeddings(self,
|
def forward_embeddings(
|
||||||
xs_pad,
|
self,
|
||||||
masks,
|
xs_pad: torch.Tensor,
|
||||||
chunk_size_nc=None,
|
masks: torch.Tensor,
|
||||||
left_chunk_nc=None):
|
chunk_size_nc: Optional[Union[int, list[int]]] = None,
|
||||||
|
left_chunk_nc: Optional[Union[int, list[int]]] = None
|
||||||
|
) -> Union[tuple[torch.Tensor, Optional[torch.Tensor],
|
||||||
|
Optional[torch.Tensor], torch.Tensor, torch.Tensor],
|
||||||
|
tuple[torch.Tensor, Optional[torch.Tensor],
|
||||||
|
Optional[torch.Tensor], torch.Tensor, torch.Tensor,
|
||||||
|
torch.Tensor]]:
|
||||||
"""Forwarding the inputs through the top embedding layers
|
"""Forwarding the inputs through the top embedding layers
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -569,7 +585,7 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
|
|||||||
return input_tensor, pos_k, pos_v, hs_mask, masks
|
return input_tensor, pos_k, pos_v, hs_mask, masks
|
||||||
return input_tensor, pos_k, pos_v, hs_mask, masks, hs_mask_nc
|
return input_tensor, pos_k, pos_v, hs_mask, masks, hs_mask_nc
|
||||||
|
|
||||||
def get_offset(self):
|
def get_offset(self) -> int:
|
||||||
"""Returns offset used when retaining inputs for decoding.
|
"""Returns offset used when retaining inputs for decoding.
|
||||||
|
|
||||||
This is essentially, how many additional frames have to be added to
|
This is essentially, how many additional frames have to be added to
|
||||||
@ -605,8 +621,6 @@ class ConformerEncoder(TransformerEncoderBase):
|
|||||||
Some examples for the 2 cases:
|
Some examples for the 2 cases:
|
||||||
left_chunk = 6
|
left_chunk = 6
|
||||||
left_chunk = [12, 9, 6, 3]
|
left_chunk = [12, 9, 6, 3]
|
||||||
left_chunk: int
|
|
||||||
number of chunks used for masking in streaming mode.
|
|
||||||
num_lang: int
|
num_lang: int
|
||||||
This parameter is used to store the number of languages in the
|
This parameter is used to store the number of languages in the
|
||||||
lang_dict, only used for multiseed/multilingual models.
|
lang_dict, only used for multiseed/multilingual models.
|
||||||
@ -751,46 +765,46 @@ class ConformerEncoder(TransformerEncoderBase):
|
|||||||
|
|
||||||
def __init__( # pylint: disable-all
|
def __init__( # pylint: disable-all
|
||||||
self,
|
self,
|
||||||
input_size,
|
input_size: int,
|
||||||
chunk_size,
|
chunk_size: Union[int, list[int]],
|
||||||
left_chunk,
|
left_chunk: Union[int, list[int]],
|
||||||
num_lang=None,
|
num_lang: Optional[int] = None,
|
||||||
attention_dim=256,
|
attention_dim: int = 256,
|
||||||
attention_heads=4,
|
attention_heads: int = 4,
|
||||||
linear_units=2048,
|
linear_units: int = 2048,
|
||||||
num_blocks=6,
|
num_blocks: int = 6,
|
||||||
dropout_rate=0.1,
|
dropout_rate: float = 0.1,
|
||||||
input_layer="nemo_conv",
|
input_layer: str = "nemo_conv",
|
||||||
causal=True,
|
causal: bool = True,
|
||||||
batch_norm=False,
|
batch_norm: bool = False,
|
||||||
cnn_out=-1,
|
cnn_out: int = -1,
|
||||||
cnn_layer_norm=False,
|
cnn_layer_norm: bool = False,
|
||||||
ext_pw_out_channel=0,
|
ext_pw_out_channel: int = 0,
|
||||||
ext_pw_kernel_size=1,
|
ext_pw_kernel_size: int = 1,
|
||||||
depthwise_seperable_out_channel=256,
|
depthwise_seperable_out_channel: int = 256,
|
||||||
depthwise_multiplier=1,
|
depthwise_multiplier: int = 1,
|
||||||
chunk_se=0,
|
chunk_se: int = 0,
|
||||||
kernel_size=3,
|
kernel_size: int = 3,
|
||||||
activation="relu",
|
activation: str = "relu",
|
||||||
conv_activation="relu",
|
conv_activation: str = "relu",
|
||||||
conv_glu_type="sigmoid",
|
conv_glu_type: str = "sigmoid",
|
||||||
bias_in_glu=True,
|
bias_in_glu: bool = True,
|
||||||
linear_glu_in_convm=False,
|
linear_glu_in_convm: bool = False,
|
||||||
attention_glu_type="swish",
|
attention_glu_type: str = "swish",
|
||||||
export=False,
|
export: bool = False,
|
||||||
extra_layer_output_idx=-1,
|
extra_layer_output_idx: int = -1,
|
||||||
extra_multi_layer_output_idxs=[], # noqa
|
extra_multi_layer_output_idxs: list[int] = [], # noqa
|
||||||
activation_checkpointing="",
|
activation_checkpointing: str = "",
|
||||||
relative_attention_bias_args=None,
|
relative_attention_bias_args: Optional[dict[str, Any]] = None,
|
||||||
time_reduction=4,
|
time_reduction: int = 4,
|
||||||
use_pt_scaled_dot_product_attention=False,
|
use_pt_scaled_dot_product_attention: bool = False,
|
||||||
nemo_conv_settings=None,
|
nemo_conv_settings: Optional[dict[str, Any]] = None,
|
||||||
conv2d_extra_padding: Literal["feat", "feat_time", "none",
|
conv2d_extra_padding: Literal["feat", "feat_time", "none",
|
||||||
True] = "none",
|
True] = "none",
|
||||||
replication_pad_for_subsample_embedding=False,
|
replication_pad_for_subsample_embedding: bool = False,
|
||||||
attention_group_size=1,
|
attention_group_size: int = 1,
|
||||||
encoder_embedding_config=None,
|
encoder_embedding_config: Optional[dict[str, Any]] = None,
|
||||||
):
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
input_size,
|
input_size,
|
||||||
chunk_size,
|
chunk_size,
|
||||||
@ -852,11 +866,13 @@ class ConformerEncoder(TransformerEncoderBase):
|
|||||||
# the device and the needed dtype:
|
# the device and the needed dtype:
|
||||||
self.register_buffer("dev_type", torch.zeros(()), persistent=False)
|
self.register_buffer("dev_type", torch.zeros(()), persistent=False)
|
||||||
|
|
||||||
def init_relative_attention_bias(self, input_tensor):
|
def init_relative_attention_bias(
|
||||||
|
self, input_tensor: torch.Tensor) -> Optional[torch.Tensor]:
|
||||||
if self.relative_attention_bias_layer:
|
if self.relative_attention_bias_layer:
|
||||||
return self.relative_attention_bias_layer(input_tensor)
|
return self.relative_attention_bias_layer(input_tensor)
|
||||||
|
|
||||||
def calculate_hs_mask(self, xs_pad, device, mask):
|
def calculate_hs_mask(self, xs_pad: torch.Tensor, device: torch.device,
|
||||||
|
mask: Optional[torch.Tensor]) -> torch.Tensor:
|
||||||
max_audio_length = xs_pad.shape[1]
|
max_audio_length = xs_pad.shape[1]
|
||||||
batch_size = xs_pad.shape[0]
|
batch_size = xs_pad.shape[0]
|
||||||
enc_streaming_mask = self._streaming_mask(max_audio_length, batch_size,
|
enc_streaming_mask = self._streaming_mask(max_audio_length, batch_size,
|
||||||
@ -877,7 +893,8 @@ class ConformerEncoder(TransformerEncoderBase):
|
|||||||
return pad_mask
|
return pad_mask
|
||||||
|
|
||||||
@torch.jit.ignore
|
@torch.jit.ignore
|
||||||
def forward(self, xs_pad, masks):
|
def forward(self, xs_pad: torch.Tensor,
|
||||||
|
masks: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Conformer Forward function
|
"""Conformer Forward function
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -997,7 +1014,12 @@ class WindowQformer(nn.Module):
|
|||||||
if normalize_before else None)
|
if normalize_before else None)
|
||||||
self.window_size = window_size
|
self.window_size = window_size
|
||||||
|
|
||||||
def forward(self, audio_embed, mask, embed_len=None):
|
def forward(
|
||||||
|
self,
|
||||||
|
audio_embed: torch.Tensor,
|
||||||
|
mask: Optional[torch.Tensor],
|
||||||
|
embed_len: Optional[int] = None
|
||||||
|
) -> tuple[torch.Tensor, Optional[int]]:
|
||||||
"""forward decoder"""
|
"""forward decoder"""
|
||||||
# audio_embed: N x T x D => N x D x T
|
# audio_embed: N x T x D => N x D x T
|
||||||
|
|
||||||
@ -1042,7 +1064,7 @@ class WindowQformer(nn.Module):
|
|||||||
class AudioEmbedding(nn.Module):
|
class AudioEmbedding(nn.Module):
|
||||||
"""Image embedding."""
|
"""Image embedding."""
|
||||||
|
|
||||||
def __init__(self, config: PretrainedConfig, **kwargs) -> None:
|
def __init__(self, config: PretrainedConfig, **kwargs: Any) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
# n_embed or hidden_size for text LM
|
# n_embed or hidden_size for text LM
|
||||||
@ -1148,19 +1170,18 @@ class AudioEmbedding(nn.Module):
|
|||||||
self.input_embeds = None
|
self.input_embeds = None
|
||||||
self.audio_embed_sizes = None
|
self.audio_embed_sizes = None
|
||||||
|
|
||||||
def set_audio_embeds(self, input_embeds: torch.FloatTensor) -> None:
|
def set_audio_embeds(self, input_embeds: torch.Tensor) -> None:
|
||||||
self.input_embeds = input_embeds
|
self.input_embeds = input_embeds
|
||||||
|
|
||||||
def set_audio_embed_sizes(self,
|
def set_audio_embed_sizes(self, audio_embed_sizes: torch.Tensor) -> None:
|
||||||
audio_embed_sizes: torch.LongTensor) -> None:
|
|
||||||
self.audio_embed_sizes = audio_embed_sizes
|
self.audio_embed_sizes = audio_embed_sizes
|
||||||
|
|
||||||
def get_audio_features(
|
def get_audio_features(
|
||||||
self,
|
self,
|
||||||
input_embeds: torch.FloatTensor,
|
input_embeds: torch.Tensor,
|
||||||
audio_attention_mask: torch.Tensor = None,
|
audio_attention_mask: Optional[torch.Tensor] = None,
|
||||||
audio_projection_mode: str = "speech",
|
audio_projection_mode: str = "speech",
|
||||||
) -> torch.FloatTensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
arguments:
|
arguments:
|
||||||
input_embeds: audio features (B, T, D) B: num audios in a sequence
|
input_embeds: audio features (B, T, D) B: num audios in a sequence
|
||||||
@ -1214,10 +1235,10 @@ class AudioEmbedding(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
audio_features: torch.FloatTensor,
|
audio_features: torch.Tensor,
|
||||||
audio_attention_mask: torch.Tensor = None,
|
audio_attention_mask: Optional[torch.Tensor] = None,
|
||||||
audio_projection_mode: str = "speech",
|
audio_projection_mode: str = "speech",
|
||||||
) -> torch.FloatTensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
arguments:
|
arguments:
|
||||||
audio_features: audio features (T, D)
|
audio_features: audio features (T, D)
|
||||||
|
|||||||
@ -16,13 +16,13 @@ from torch import Tensor, nn
|
|||||||
class BlockBase(nn.Module):
|
class BlockBase(nn.Module):
|
||||||
"""Block abstract module"""
|
"""Block abstract module"""
|
||||||
|
|
||||||
def __init__(self, input_size, output_size):
|
def __init__(self, input_size: int, output_size: int) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.input_size = input_size
|
self.input_size = input_size
|
||||||
self.output_size = output_size
|
self.output_size = output_size
|
||||||
|
|
||||||
|
|
||||||
def get_activation(name="relu"):
|
def get_activation(name: str = "relu") -> torch.nn.Module:
|
||||||
"""Select an activation function by name
|
"""Select an activation function by name
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -43,15 +43,18 @@ def get_activation(name="relu"):
|
|||||||
return nn.Identity()
|
return nn.Identity()
|
||||||
|
|
||||||
|
|
||||||
def adaptive_enc_mask(x_len, chunk_start_idx, left_window=0, right_window=0):
|
def adaptive_enc_mask(x_len: int,
|
||||||
|
chunk_start_idx: list[int],
|
||||||
|
left_window: int = 0,
|
||||||
|
right_window: int = 0) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
The function is very important for Transformer Transducer Streaming mode
|
The function is very important for Transformer Transducer Streaming mode
|
||||||
Args:
|
Args:
|
||||||
xs_len (int): sequence length
|
x_len: sequence length
|
||||||
chunk_start_idx (list): first idx of each chunk, such as [0,18,36,48].
|
chunk_start_idx: first idx of each chunk, such as [0,18,36,48].
|
||||||
It also supports adaptive chunk size [0,10,15,45]
|
It also supports adaptive chunk size [0,10,15,45]
|
||||||
left_window (int): how many left chunks can be seen
|
left_window: how many left chunks can be seen
|
||||||
right_window (int): how many right chunks can be seen. It is used for
|
right_window: how many right chunks can be seen. It is used for
|
||||||
chunk overlap model.
|
chunk overlap model.
|
||||||
Returns:
|
Returns:
|
||||||
mask (torch.Tensor): a mask tensor for streaming model
|
mask (torch.Tensor): a mask tensor for streaming model
|
||||||
@ -172,13 +175,13 @@ class GLUPointWiseConv(nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
input_dim,
|
input_dim: int,
|
||||||
output_dim,
|
output_dim: int,
|
||||||
kernel_size,
|
kernel_size: int,
|
||||||
glu_type="sigmoid",
|
glu_type: str = "sigmoid",
|
||||||
bias_in_glu=True,
|
bias_in_glu: bool = True,
|
||||||
causal=False,
|
causal: bool = False,
|
||||||
):
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.glu_type = glu_type
|
self.glu_type = glu_type
|
||||||
@ -216,11 +219,10 @@ class GLUPointWiseConv(nn.Module):
|
|||||||
self.b1 = nn.Parameter(torch.zeros(1, output_dim, 1))
|
self.b1 = nn.Parameter(torch.zeros(1, output_dim, 1))
|
||||||
self.b2 = nn.Parameter(torch.zeros(1, output_dim, 1))
|
self.b2 = nn.Parameter(torch.zeros(1, output_dim, 1))
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
x: torch.Tensor
|
x: input tensor
|
||||||
input tensor
|
|
||||||
"""
|
"""
|
||||||
# to be consistent with GLULinear, we assume the input always has the
|
# to be consistent with GLULinear, we assume the input always has the
|
||||||
# #channel (#dim) in the last dimension of the tensor, so need to
|
# #channel (#dim) in the last dimension of the tensor, so need to
|
||||||
@ -272,12 +274,12 @@ class DepthWiseSeperableConv1d(nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
input_dim,
|
input_dim: int,
|
||||||
depthwise_seperable_out_channel,
|
depthwise_seperable_out_channel: int,
|
||||||
kernel_size,
|
kernel_size: int,
|
||||||
depthwise_multiplier,
|
depthwise_multiplier: int,
|
||||||
padding=0,
|
padding: int = 0,
|
||||||
):
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.dw_conv = nn.Conv1d(
|
self.dw_conv = nn.Conv1d(
|
||||||
@ -301,12 +303,11 @@ class DepthWiseSeperableConv1d(nn.Module):
|
|||||||
self.pw_conv = nn.Identity()
|
self.pw_conv = nn.Identity()
|
||||||
self.depthwise_seperable_out_channel = depthwise_seperable_out_channel
|
self.depthwise_seperable_out_channel = depthwise_seperable_out_channel
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x: torch.Tensor
|
x: input tensor
|
||||||
input tensor
|
|
||||||
"""
|
"""
|
||||||
x = self.dw_conv(x)
|
x = self.dw_conv(x)
|
||||||
if self.depthwise_seperable_out_channel != 0:
|
if self.depthwise_seperable_out_channel != 0:
|
||||||
@ -375,23 +376,23 @@ class ConvModule(nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
input_dim,
|
input_dim: int,
|
||||||
ext_pw_out_channel,
|
ext_pw_out_channel: int,
|
||||||
depthwise_seperable_out_channel,
|
depthwise_seperable_out_channel: int,
|
||||||
ext_pw_kernel_size,
|
ext_pw_kernel_size: int,
|
||||||
kernel_size,
|
kernel_size: int,
|
||||||
depthwise_multiplier,
|
depthwise_multiplier: int,
|
||||||
dropout_rate,
|
dropout_rate: float,
|
||||||
causal=False,
|
causal: bool = False,
|
||||||
batch_norm=False,
|
batch_norm: bool = False,
|
||||||
chunk_se=0,
|
chunk_se: int = 0,
|
||||||
chunk_size=18,
|
chunk_size: int = 18,
|
||||||
activation="relu",
|
activation: str = "relu",
|
||||||
glu_type="sigmoid",
|
glu_type: str = "sigmoid",
|
||||||
bias_in_glu=True,
|
bias_in_glu: bool = True,
|
||||||
linear_glu_in_convm=False,
|
linear_glu_in_convm: bool = False,
|
||||||
export=False,
|
export: bool = False,
|
||||||
):
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layer_norm = nn.LayerNorm(input_dim)
|
self.layer_norm = nn.LayerNorm(input_dim)
|
||||||
self.input_dim = input_dim
|
self.input_dim = input_dim
|
||||||
@ -437,7 +438,7 @@ class ConvModule(nn.Module):
|
|||||||
self.ln2 = nn.Linear(input_dim * depthwise_multiplier,
|
self.ln2 = nn.Linear(input_dim * depthwise_multiplier,
|
||||||
input_dim)
|
input_dim)
|
||||||
|
|
||||||
def _add_ext_pw_layer(self):
|
def _add_ext_pw_layer(self) -> None:
|
||||||
"""
|
"""
|
||||||
This function is an extension of __init__ function
|
This function is an extension of __init__ function
|
||||||
and dedicated to the convolution module creation
|
and dedicated to the convolution module creation
|
||||||
@ -497,12 +498,11 @@ class ConvModule(nn.Module):
|
|||||||
self.pw_conv_simplify_w = torch.nn.Parameter(torch.ones(3))
|
self.pw_conv_simplify_w = torch.nn.Parameter(torch.ones(3))
|
||||||
self.pw_conv_simplify_b = torch.nn.Parameter(torch.zeros(3))
|
self.pw_conv_simplify_b = torch.nn.Parameter(torch.zeros(3))
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
"""ConvModule Forward.
|
"""ConvModule Forward.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x: torch.Tensor
|
x: input tensor.
|
||||||
input tensor.
|
|
||||||
"""
|
"""
|
||||||
x = self.layer_norm(x)
|
x = self.layer_norm(x)
|
||||||
|
|
||||||
@ -567,21 +567,20 @@ class GLULinear(nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
input_dim,
|
input_dim: int,
|
||||||
output_dim,
|
output_dim: int,
|
||||||
glu_type="sigmoid",
|
glu_type: str = "sigmoid",
|
||||||
bias_in_glu=True,
|
bias_in_glu: bool = True,
|
||||||
):
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.linear = nn.Linear(input_dim, output_dim * 2, bias_in_glu)
|
self.linear = nn.Linear(input_dim, output_dim * 2, bias_in_glu)
|
||||||
self.glu_act = GLU(-1, glu_type)
|
self.glu_act = GLU(-1, glu_type)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
"""GLULinear forward
|
"""GLULinear forward
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x: torch.Tensor
|
x: input tensor.
|
||||||
inpute tensor.
|
|
||||||
"""
|
"""
|
||||||
x = self.linear(x)
|
x = self.linear(x)
|
||||||
return self.glu_act(x)
|
return self.glu_act(x)
|
||||||
@ -609,12 +608,12 @@ class FeedForward(nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
d_model,
|
d_model: int,
|
||||||
d_inner,
|
d_inner: int,
|
||||||
dropout_rate,
|
dropout_rate: float,
|
||||||
activation="sigmoid",
|
activation: str = "sigmoid",
|
||||||
bias_in_glu=True,
|
bias_in_glu: bool = True,
|
||||||
):
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.d_model = d_model
|
self.d_model = d_model
|
||||||
self.d_inner = d_inner
|
self.d_inner = d_inner
|
||||||
@ -628,12 +627,11 @@ class FeedForward(nn.Module):
|
|||||||
nn.Dropout(dropout_rate),
|
nn.Dropout(dropout_rate),
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
"""FeedForward forward function.
|
"""FeedForward forward function.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x: torch.Tensor
|
x: input tensor.
|
||||||
input tensor.
|
|
||||||
"""
|
"""
|
||||||
out = self.net(self.layer_norm(x))
|
out = self.net(self.layer_norm(x))
|
||||||
|
|
||||||
@ -642,14 +640,14 @@ class FeedForward(nn.Module):
|
|||||||
|
|
||||||
#### positional encoding starts here
|
#### positional encoding starts here
|
||||||
def _pre_hook(
|
def _pre_hook(
|
||||||
state_dict,
|
state_dict: dict,
|
||||||
prefix,
|
prefix: str,
|
||||||
local_metadata,
|
local_metadata: dict,
|
||||||
strict,
|
strict: bool,
|
||||||
missing_keys,
|
missing_keys: list[str],
|
||||||
unexpected_keys,
|
unexpected_keys: list[str],
|
||||||
error_msgs,
|
error_msgs: list[str],
|
||||||
):
|
) -> None:
|
||||||
"""Perform pre-hook in load_state_dict for backward compatibility.
|
"""Perform pre-hook in load_state_dict for backward compatibility.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
@ -708,10 +706,10 @@ class T5RelativeAttentionLogitBias(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_heads,
|
num_heads: int,
|
||||||
num_buckets=-1,
|
num_buckets: int = -1,
|
||||||
max_distance=1000,
|
max_distance: int = 1000,
|
||||||
symmetric=False):
|
symmetric: bool = False) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.num_buckets = num_buckets
|
self.num_buckets = num_buckets
|
||||||
@ -727,7 +725,7 @@ class T5RelativeAttentionLogitBias(nn.Module):
|
|||||||
self.num_buckets *= 2
|
self.num_buckets *= 2
|
||||||
self.bias_values = nn.Embedding(self.num_buckets, self.num_heads)
|
self.bias_values = nn.Embedding(self.num_buckets, self.num_heads)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
# instantiate bias compatible with shape of x
|
# instantiate bias compatible with shape of x
|
||||||
maxpos = x.size(1)
|
maxpos = x.size(1)
|
||||||
context_position = torch.arange(maxpos,
|
context_position = torch.arange(maxpos,
|
||||||
@ -760,7 +758,7 @@ class T5RelativeAttentionLogitBias(nn.Module):
|
|||||||
|
|
||||||
return t5_rel_att_bias
|
return t5_rel_att_bias
|
||||||
|
|
||||||
def _bucket_relative_position(self, relative_position):
|
def _bucket_relative_position(self, relative_position: Tensor) -> Tensor:
|
||||||
# this is a placeholder (isn't tested, likely buggy) using HuggingFace
|
# this is a placeholder (isn't tested, likely buggy) using HuggingFace
|
||||||
# implem as a reference this also needs to be extended to support
|
# implem as a reference this also needs to be extended to support
|
||||||
# asymmetric +/- ve positions
|
# asymmetric +/- ve positions
|
||||||
@ -810,7 +808,10 @@ class AbsolutePositionalEncoding(nn.Module):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, d_model, dropout_rate, max_len=5000):
|
def __init__(self,
|
||||||
|
d_model: int,
|
||||||
|
dropout_rate: float,
|
||||||
|
max_len: int = 5000) -> None:
|
||||||
"""Construct an PositionalEncoding object."""
|
"""Construct an PositionalEncoding object."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.d_model = d_model
|
self.d_model = d_model
|
||||||
@ -820,11 +821,11 @@ class AbsolutePositionalEncoding(nn.Module):
|
|||||||
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
||||||
self._register_load_state_dict_pre_hook(_pre_hook)
|
self._register_load_state_dict_pre_hook(_pre_hook)
|
||||||
|
|
||||||
def extend_pe(self, x):
|
def extend_pe(self, x: torch.Tensor) -> None:
|
||||||
"""Reset the positional encodings.
|
"""Reset the positional encodings.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x: torch.Tensor
|
x: input tensor
|
||||||
"""
|
"""
|
||||||
if self.pe is not None and self.pe.size(1) >= x.size(1):
|
if self.pe is not None and self.pe.size(1) >= x.size(1):
|
||||||
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
||||||
@ -840,15 +841,14 @@ class AbsolutePositionalEncoding(nn.Module):
|
|||||||
pe = pe.unsqueeze(0)
|
pe = pe.unsqueeze(0)
|
||||||
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""Add positional encoding.
|
"""Add positional encoding.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x: torch.Tensor
|
x: Input tensor. shape is (batch, time, ...)
|
||||||
Input tensor. shape is (batch, time, ...)
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
|
Encoded tensor. Its shape is (batch, time, ...)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
self.extend_pe(x)
|
self.extend_pe(x)
|
||||||
@ -868,7 +868,7 @@ class MeanVarianceNormLayer(nn.Module):
|
|||||||
layer input size.
|
layer input size.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, input_size):
|
def __init__(self, input_size: int) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.input_size = input_size
|
self.input_size = input_size
|
||||||
self.global_mean = nn.Parameter(torch.zeros(input_size))
|
self.global_mean = nn.Parameter(torch.zeros(input_size))
|
||||||
@ -878,8 +878,7 @@ class MeanVarianceNormLayer(nn.Module):
|
|||||||
"""MeanVarianceNormLayer Forward
|
"""MeanVarianceNormLayer Forward
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_: torch.Tensor
|
input_: input tensor.
|
||||||
input tensor.
|
|
||||||
"""
|
"""
|
||||||
return (input_ - self.global_mean) * self.global_invstd
|
return (input_ - self.global_mean) * self.global_invstd
|
||||||
|
|
||||||
@ -949,7 +948,10 @@ class CausalConv1D(nn.Conv1d):
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_cache(self, x, cache=None):
|
def update_cache(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
cache: Optional[Tensor] = None) -> tuple[Tensor, Optional[Tensor]]:
|
||||||
if cache is None:
|
if cache is None:
|
||||||
new_x = F.pad(x, pad=(self._left_padding, self._right_padding))
|
new_x = F.pad(x, pad=(self._left_padding, self._right_padding))
|
||||||
next_cache = cache
|
next_cache = cache
|
||||||
@ -963,7 +965,11 @@ class CausalConv1D(nn.Conv1d):
|
|||||||
next_cache = next_cache[:, :, -cache.size(-1):]
|
next_cache = next_cache[:, :, -cache.size(-1):]
|
||||||
return new_x, next_cache
|
return new_x, next_cache
|
||||||
|
|
||||||
def forward(self, x, cache=None):
|
def forward(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
cache: Optional[Tensor] = None
|
||||||
|
) -> Union[Tensor, tuple[Tensor, Optional[Tensor]]]:
|
||||||
x, cache = self.update_cache(x, cache=cache)
|
x, cache = self.update_cache(x, cache=cache)
|
||||||
x = super().forward(x)
|
x = super().forward(x)
|
||||||
if cache is None:
|
if cache is None:
|
||||||
@ -1017,8 +1023,8 @@ class CausalConv2D(nn.Conv2d):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x,
|
x: Tensor,
|
||||||
):
|
) -> Tensor:
|
||||||
x = F.pad(
|
x = F.pad(
|
||||||
x,
|
x,
|
||||||
pad=(self._left_padding, self._right_padding, 0, 0),
|
pad=(self._left_padding, self._right_padding, 0, 0),
|
||||||
@ -1062,16 +1068,16 @@ class NemoConvSubsampling(torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
feat_in,
|
feat_in: int,
|
||||||
feat_out,
|
feat_out: int,
|
||||||
subsampling_factor=4,
|
subsampling_factor: int = 4,
|
||||||
subsampling="dw_striding",
|
subsampling: str = "dw_striding",
|
||||||
conv_channels=256,
|
conv_channels: int = 256,
|
||||||
subsampling_conv_chunking_factor=1,
|
subsampling_conv_chunking_factor: int = 1,
|
||||||
activation=nn.ReLU(), # noqa: B008
|
activation: torch.nn.Module = nn.ReLU(), # noqa: B008
|
||||||
is_causal=False,
|
is_causal: bool = False,
|
||||||
):
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._subsampling = subsampling
|
self._subsampling = subsampling
|
||||||
self._conv_channels = conv_channels
|
self._conv_channels = conv_channels
|
||||||
@ -1328,28 +1334,25 @@ class NemoConvSubsampling(torch.nn.Module):
|
|||||||
|
|
||||||
self.conv = torch.nn.Sequential(*layers)
|
self.conv = torch.nn.Sequential(*layers)
|
||||||
|
|
||||||
def get_sampling_frames(self):
|
def get_sampling_frames(self) -> list[int]:
|
||||||
return [1, self.subsampling_factor]
|
return [1, self.subsampling_factor]
|
||||||
|
|
||||||
def get_streaming_cache_size(self):
|
def get_streaming_cache_size(self) -> list[int]:
|
||||||
return [0, self.subsampling_factor + 1]
|
return [0, self.subsampling_factor + 1]
|
||||||
|
|
||||||
def forward(self, x, mask):
|
def forward(self, x: Tensor,
|
||||||
|
mask: Optional[Tensor]) -> tuple[Tensor, Optional[Tensor]]:
|
||||||
"""
|
"""
|
||||||
Forward method for NeMo subsampling.
|
Forward method for NeMo subsampling.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x[Batch, Time, Filters]: torch.Tensor
|
x: input tensor
|
||||||
input tensor
|
mask: input mask
|
||||||
x_mask: torch.Tensor
|
|
||||||
input mask
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
x: torch.Tensor
|
x: Resulting tensor from subsampling (B, T //
|
||||||
Resulting tensor from subsampling (B, T //
|
|
||||||
time_reduction_factor, feat_out)
|
time_reduction_factor, feat_out)
|
||||||
pad_mask: torch.Tensor
|
pad_mask: tensor of padded hidden state sequences (B, 1, T //
|
||||||
tensor of padded hidden state sequences (B, 1, T //
|
|
||||||
time_reduction_factor)
|
time_reduction_factor)
|
||||||
"""
|
"""
|
||||||
x = x.unsqueeze(1) if self.conv2d_subsampling else x.transpose(1, 2)
|
x = x.unsqueeze(1) if self.conv2d_subsampling else x.transpose(1, 2)
|
||||||
@ -1403,7 +1406,7 @@ class NemoConvSubsampling(torch.nn.Module):
|
|||||||
padding_length.size(0), -1) < padding_length.unsqueeze(1)
|
padding_length.size(0), -1) < padding_length.unsqueeze(1)
|
||||||
return x, pad_mask.unsqueeze(1)
|
return x, pad_mask.unsqueeze(1)
|
||||||
|
|
||||||
def reset_parameters(self):
|
def reset_parameters(self) -> None:
|
||||||
# initialize weights
|
# initialize weights
|
||||||
if self._subsampling == "dw_striding":
|
if self._subsampling == "dw_striding":
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@ -1433,7 +1436,7 @@ class NemoConvSubsampling(torch.nn.Module):
|
|||||||
torch.nn.init.uniform_(self.out.weight, -fc_scale, fc_scale)
|
torch.nn.init.uniform_(self.out.weight, -fc_scale, fc_scale)
|
||||||
torch.nn.init.uniform_(self.out.bias, -fc_scale, fc_scale)
|
torch.nn.init.uniform_(self.out.bias, -fc_scale, fc_scale)
|
||||||
|
|
||||||
def conv_split_by_batch(self, x):
|
def conv_split_by_batch(self, x: Tensor) -> tuple[Tensor, bool]:
|
||||||
"""Tries to split input by batch, run conv and concat results"""
|
"""Tries to split input by batch, run conv and concat results"""
|
||||||
b, _, _, _ = x.size()
|
b, _, _, _ = x.size()
|
||||||
if b == 1: # can't split if batch size is 1
|
if b == 1: # can't split if batch size is 1
|
||||||
@ -1460,7 +1463,7 @@ class NemoConvSubsampling(torch.nn.Module):
|
|||||||
True,
|
True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def conv_split_by_channel(self, x):
|
def conv_split_by_channel(self, x: Tensor) -> Tensor:
|
||||||
"""For dw convs, tries to split input by time, run conv and concat
|
"""For dw convs, tries to split input by time, run conv and concat
|
||||||
results"""
|
results"""
|
||||||
x = self.conv[0](x) # full conv2D
|
x = self.conv[0](x) # full conv2D
|
||||||
@ -1500,7 +1503,8 @@ class NemoConvSubsampling(torch.nn.Module):
|
|||||||
x = self.conv[i * 3 + 4](x) # activation
|
x = self.conv[i * 3 + 4](x) # activation
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def channel_chunked_conv(self, conv, chunk_size, x):
|
def channel_chunked_conv(self, conv: torch.nn.Module, chunk_size: int,
|
||||||
|
x: Tensor) -> Tensor:
|
||||||
"""Performs channel chunked convolution"""
|
"""Performs channel chunked convolution"""
|
||||||
|
|
||||||
ind = 0
|
ind = 0
|
||||||
@ -1541,7 +1545,7 @@ class NemoConvSubsampling(torch.nn.Module):
|
|||||||
return torch.cat(out_chunks, 1)
|
return torch.cat(out_chunks, 1)
|
||||||
|
|
||||||
def change_subsampling_conv_chunking_factor(
|
def change_subsampling_conv_chunking_factor(
|
||||||
self, subsampling_conv_chunking_factor: int):
|
self, subsampling_conv_chunking_factor: int) -> None:
|
||||||
if (subsampling_conv_chunking_factor != -1
|
if (subsampling_conv_chunking_factor != -1
|
||||||
and subsampling_conv_chunking_factor != 1
|
and subsampling_conv_chunking_factor != 1
|
||||||
and subsampling_conv_chunking_factor % 2 != 0):
|
and subsampling_conv_chunking_factor % 2 != 0):
|
||||||
@ -1552,12 +1556,12 @@ class NemoConvSubsampling(torch.nn.Module):
|
|||||||
self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor
|
self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor
|
||||||
|
|
||||||
|
|
||||||
def calc_length(lengths,
|
def calc_length(lengths: Tensor,
|
||||||
all_paddings,
|
all_paddings: int,
|
||||||
kernel_size,
|
kernel_size: int,
|
||||||
stride,
|
stride: int,
|
||||||
ceil_mode,
|
ceil_mode: bool,
|
||||||
repeat_num=1):
|
repeat_num: int = 1) -> Tensor:
|
||||||
"""Calculates the output length of a Tensor passed through a convolution or
|
"""Calculates the output length of a Tensor passed through a convolution or
|
||||||
max pooling layer"""
|
max pooling layer"""
|
||||||
add_pad: float = all_paddings - kernel_size
|
add_pad: float = all_paddings - kernel_size
|
||||||
@ -1573,11 +1577,11 @@ def calc_length(lengths,
|
|||||||
class AttModule(nn.Module):
|
class AttModule(nn.Module):
|
||||||
"""Attention abstraction module"""
|
"""Attention abstraction module"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.export_mode = False
|
self.export_mode = False
|
||||||
|
|
||||||
def set_export(self, mode=True):
|
def set_export(self, mode: bool = True) -> None:
|
||||||
"""set the export mode"""
|
"""set the export mode"""
|
||||||
self.export_mode = mode
|
self.export_mode = mode
|
||||||
|
|
||||||
@ -1591,14 +1595,10 @@ class AttModule(nn.Module):
|
|||||||
"""AttModule forward
|
"""AttModule forward
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x: torch.Tensor
|
x: input tensor.
|
||||||
input tensor.
|
memory: memory tensor.
|
||||||
memory: torch.Tensor, optional
|
pos_emb: positional encoder embedding.
|
||||||
memory tensor.
|
att_mask: attention mask tensor.
|
||||||
pos_emb: torch.Tensor, optional
|
|
||||||
positional encoder embedding.
|
|
||||||
att_mask: torch.Tensor, optional
|
|
||||||
attention mask tensor.
|
|
||||||
"""
|
"""
|
||||||
return x, memory, pos_emb, att_mask
|
return x, memory, pos_emb, att_mask
|
||||||
|
|
||||||
@ -1606,15 +1606,15 @@ class AttModule(nn.Module):
|
|||||||
class AttBlock(BlockBase, AttModule):
|
class AttBlock(BlockBase, AttModule):
|
||||||
"""Attention Block module to support both Attention and Block module."""
|
"""Attention Block module to support both Attention and Block module."""
|
||||||
|
|
||||||
def memory_dims(self, max_len=False):
|
def memory_dims(self, max_len: bool = False) -> tuple[int, int]:
|
||||||
"""memory dimensions"""
|
"""memory dimensions"""
|
||||||
return (1, self.input_size)
|
return (1, self.input_size)
|
||||||
|
|
||||||
|
|
||||||
def masked_softmax(
|
def masked_softmax(
|
||||||
scores,
|
scores: Tensor,
|
||||||
mask: Optional[Tensor],
|
mask: Optional[Tensor],
|
||||||
):
|
) -> Tensor:
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
mask = mask.unsqueeze(1).eq(0) # (batch, 1, time1, time2)
|
mask = mask.unsqueeze(1).eq(0) # (batch, 1, time1, time2)
|
||||||
scores = scores.masked_fill(mask, -torch.inf)
|
scores = scores.masked_fill(mask, -torch.inf)
|
||||||
@ -1636,10 +1636,6 @@ class MultiHeadedAttention(nn.Module):
|
|||||||
input size features.
|
input size features.
|
||||||
dropout_rate: float
|
dropout_rate: float
|
||||||
dropout rate.
|
dropout rate.
|
||||||
use_LN: bool
|
|
||||||
apply layer norm or not
|
|
||||||
dropout_at_output: bool
|
|
||||||
whether to apply dropout at output
|
|
||||||
attention_inner_dim: int, optional
|
attention_inner_dim: int, optional
|
||||||
the attention dimension used in the class,
|
the attention dimension used in the class,
|
||||||
it can be different from the input dimension n_feat.
|
it can be different from the input dimension n_feat.
|
||||||
@ -1666,16 +1662,16 @@ class MultiHeadedAttention(nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
n_head,
|
n_head: int,
|
||||||
n_feat,
|
n_feat: int,
|
||||||
dropout_rate,
|
dropout_rate: float,
|
||||||
attention_inner_dim=-1,
|
attention_inner_dim: int = -1,
|
||||||
glu_type="swish",
|
glu_type: str = "swish",
|
||||||
bias_in_glu=True,
|
bias_in_glu: bool = True,
|
||||||
use_pt_scaled_dot_product_attention=False,
|
use_pt_scaled_dot_product_attention: bool = False,
|
||||||
n_value=-1,
|
n_value: int = -1,
|
||||||
group_size: int = 1,
|
group_size: int = 1,
|
||||||
):
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if n_value == -1:
|
if n_value == -1:
|
||||||
n_value = n_feat
|
n_value = n_feat
|
||||||
@ -1718,28 +1714,22 @@ class MultiHeadedAttention(nn.Module):
|
|||||||
query: Tensor,
|
query: Tensor,
|
||||||
key: Tensor,
|
key: Tensor,
|
||||||
value: Tensor,
|
value: Tensor,
|
||||||
pos_k: Tensor,
|
pos_k: Optional[Tensor],
|
||||||
pos_v: Tensor,
|
pos_v: Optional[Tensor],
|
||||||
mask: Optional[Tensor],
|
mask: Optional[Tensor],
|
||||||
relative_attention_bias: Optional[Tensor] = None,
|
relative_attention_bias: Optional[Tensor] = None,
|
||||||
):
|
) -> Tensor:
|
||||||
"""Compute 'Scaled Dot Product Attention'.
|
"""Compute 'Scaled Dot Product Attention'.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: torch.Tensor
|
query: query tensor (batch, time1, size)
|
||||||
query tensor (batch, time1, size)
|
key: key tensor (batch, time2, size)
|
||||||
key: torch.Tensor
|
value: value tensor (batch, time1, size)
|
||||||
key tensor (batch, time2, size)
|
pos_k: key tensor used for relative positional embedding.
|
||||||
value: torch.Tensor
|
pos_v: value tensor used for relative positional embedding.
|
||||||
value tensor (batch, time1, size)
|
mask: mask tensor (batch, time1, time2)
|
||||||
pos_k: torch.Tensor
|
relative_attention_bias: bias added to attention logits w.r.t.
|
||||||
key tensor used for relative positional embedding.
|
relative positions
|
||||||
pos_v: torch.Tensor
|
|
||||||
value tensor used for relative positional embedding.
|
|
||||||
mask: torch.Tensor
|
|
||||||
mask tensor (batch, time1, time2)
|
|
||||||
relative_attention_bias: torch.Tensor
|
|
||||||
bias added to attention logits w.r.t. relative positions
|
|
||||||
(1, n_head, time1, time2)
|
(1, n_head, time1, time2)
|
||||||
"""
|
"""
|
||||||
n_batch = query.size(0)
|
n_batch = query.size(0)
|
||||||
@ -1832,20 +1822,20 @@ class MultiSequential(torch.nn.Sequential):
|
|||||||
"""Multi-input multi-output torch.nn.Sequential"""
|
"""Multi-input multi-output torch.nn.Sequential"""
|
||||||
|
|
||||||
@torch.jit.ignore
|
@torch.jit.ignore
|
||||||
def forward(self, *args):
|
def forward(self, *args) -> tuple:
|
||||||
"""Forward method implementation."""
|
"""Forward method implementation."""
|
||||||
for m in self:
|
for m in self:
|
||||||
args = m(*args)
|
args = m(*args)
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
def get_offset(input_layer: str, time_reduction: int):
|
def get_offset(input_layer: str, time_reduction: int) -> int:
|
||||||
"""Get an offset. We will use the offset for determining #frames of a
|
"""Get an offset. We will use the offset for determining #frames of a
|
||||||
subsampled feature.
|
subsampled feature.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_layer (str): Type of an input layer
|
input_layer: Type of an input layer
|
||||||
time_reduction (int): time reduction factor for downsampling a feature
|
time_reduction: time reduction factor for downsampling a feature
|
||||||
Returns:
|
Returns:
|
||||||
int: offset
|
int: offset
|
||||||
"""
|
"""
|
||||||
@ -1858,13 +1848,14 @@ def get_offset(input_layer: str, time_reduction: int):
|
|||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
def unfold_tensor(xs_pad, max_seq_len):
|
def unfold_tensor(xs_pad: Tensor, max_seq_len: int) -> Tensor:
|
||||||
"""
|
"""
|
||||||
For a given tensor with shape of (N, T, D), if sequence length T is
|
For a given tensor with shape of (N, T, D), if sequence length T is
|
||||||
longer than max_seq_len, this function unfold it to a
|
longer than max_seq_len, this function unfold it to a
|
||||||
(NT', max_seq_len, D) where T' is T // max_seq_len.
|
(NT', max_seq_len, D) where T' is T // max_seq_len.
|
||||||
Args:
|
Args:
|
||||||
xs_pad: N, T, D
|
xs_pad: input tensor with shape (N, T, D)
|
||||||
|
max_seq_len: maximum sequence length
|
||||||
"""
|
"""
|
||||||
_, _, D = xs_pad.shape
|
_, _, D = xs_pad.shape
|
||||||
xs_pad = xs_pad.transpose(-1, -2) # convert to N, D, T
|
xs_pad = xs_pad.transpose(-1, -2) # convert to N, D, T
|
||||||
|
|||||||
@ -1193,21 +1193,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
input_ids: Flattened (concatenated) input_ids corresponding to a
|
input_ids: Flattened (concatenated) input_ids corresponding to a
|
||||||
batch.
|
batch.
|
||||||
positions: Flattened (concatenated) position ids corresponding to a
|
positions: Flattened (concatenated) position ids corresponding to a
|
||||||
batch.
|
batch. **NOTE**: If mrope is enabled (default setting for
|
||||||
**NOTE**: If mrope is enabled (default setting for Qwen2.5-VL
|
Qwen2.5-VL opensource models), the shape will be `(3, seq_len)`,
|
||||||
opensource models), the shape will be `(3, seq_len)`,
|
|
||||||
otherwise it will be `(seq_len,).
|
otherwise it will be `(seq_len,).
|
||||||
pixel_values: Pixel values to be fed to a model.
|
|
||||||
`None` if no images are passed.
|
|
||||||
image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM.
|
|
||||||
`None` if no images are passed.
|
|
||||||
pixel_values_videos: Pixel values of videos to be fed to a model.
|
|
||||||
`None` if no videos are passed.
|
|
||||||
video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM.
|
|
||||||
`None` if no videos are passed.
|
|
||||||
second_per_grid_ts: Tensor `(num_videos)` of video time interval (
|
|
||||||
in seconds) for each grid along the temporal dimension in the
|
|
||||||
3D position IDs. `None` if no videos are passed.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if intermediate_tensors is not None:
|
if intermediate_tensors is not None:
|
||||||
|
|||||||
@ -9,7 +9,7 @@ model alternates between state space model layers and attention-based layers.
|
|||||||
"""
|
"""
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from itertools import cycle
|
from itertools import cycle
|
||||||
from typing import Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -528,8 +528,6 @@ class Zamba2MambaDecoderLayer(nn.Module):
|
|||||||
hidden_states: Input tensor [batch_size, seq_len, hidden_size]
|
hidden_states: Input tensor [batch_size, seq_len, hidden_size]
|
||||||
mamba_cache_params: Parameters for Mamba's state caches
|
mamba_cache_params: Parameters for Mamba's state caches
|
||||||
(one for conv, one for ssm)
|
(one for conv, one for ssm)
|
||||||
sequence_idx: Index tensor for identifying sequences in batch
|
|
||||||
Required for proper chunked processing in prefill
|
|
||||||
transformer_hidden_states: Optional output from transformer path
|
transformer_hidden_states: Optional output from transformer path
|
||||||
Added to input if provided (used in hybrid architecture)
|
Added to input if provided (used in hybrid architecture)
|
||||||
positions: Optional position IDs (unused in Mamba)
|
positions: Optional position IDs (unused in Mamba)
|
||||||
@ -591,8 +589,6 @@ class Zamba2HybridLayer(nn.Module):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
shared_transformer: Transformer decoder layer for attention pathway
|
shared_transformer: Transformer decoder layer for attention pathway
|
||||||
linear: Linear projection for transformer output before Mamba
|
|
||||||
mamba: Mamba decoder layer for state space pathway
|
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.block_idx = block_idx
|
self.block_idx = block_idx
|
||||||
@ -630,8 +626,6 @@ class Zamba2HybridLayer(nn.Module):
|
|||||||
positions: Position IDs for positional embeddings
|
positions: Position IDs for positional embeddings
|
||||||
mamba_cache_params: Parameters for Mamba's state caches
|
mamba_cache_params: Parameters for Mamba's state caches
|
||||||
(one for conv, one for ssm)
|
(one for conv, one for ssm)
|
||||||
sequence_idx: Indices for identifying sequences in batch,
|
|
||||||
required for proper chunked processing in prefill
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Output tensor combining transformer and Mamba representations
|
Output tensor combining transformer and Mamba representations
|
||||||
@ -915,8 +909,8 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
|||||||
prefix: Optional prefix for parameter names
|
prefix: Optional prefix for parameter names
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
AssertionError: If prefix caching is enabled (not supported by
|
AssertionError: If prefix caching is enabled
|
||||||
Mamba)
|
(not supported by Mamba)
|
||||||
"""
|
"""
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
cache_config = vllm_config.cache_config
|
cache_config = vllm_config.cache_config
|
||||||
@ -971,7 +965,7 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
|||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
**kwargs) -> torch.Tensor:
|
**kwargs: Any) -> torch.Tensor:
|
||||||
"""Forward pass through the model.
|
"""Forward pass through the model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1012,9 +1006,9 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
|||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def copy_inputs_before_cuda_graphs(self, input_buffers: dict[str,
|
def copy_inputs_before_cuda_graphs(
|
||||||
torch.Tensor],
|
self, input_buffers: dict[str, torch.Tensor],
|
||||||
**kwargs) -> dict[str, torch.Tensor]:
|
**kwargs: Any) -> dict[str, torch.Tensor]:
|
||||||
"""Copy inputs before CUDA graph capture.
|
"""Copy inputs before CUDA graph capture.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user