[Model] Add has_weight to RMSNorm and re-enable weights loading tracker for Mamba (#10739)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py 2024-12-10 10:23:07 +08:00 committed by GitHub
parent 6d525288c1
commit d1f6d1c8af
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 32 additions and 10 deletions

View File

@ -20,6 +20,7 @@ class RMSNorm(CustomOp):
hidden_size: int, hidden_size: int,
eps: float = 1e-6, eps: float = 1e-6,
var_hidden_size: Optional[int] = None, var_hidden_size: Optional[int] = None,
has_weight: bool = True,
) -> None: ) -> None:
super().__init__() super().__init__()
@ -27,7 +28,11 @@ class RMSNorm(CustomOp):
self.variance_epsilon = eps self.variance_epsilon = eps
self.variance_size_override = (None if var_hidden_size == hidden_size self.variance_size_override = (None if var_hidden_size == hidden_size
else var_hidden_size) else var_hidden_size)
self.weight = nn.Parameter(torch.ones(hidden_size)) self.has_weight = has_weight
self.weight = torch.ones(hidden_size)
if self.has_weight:
self.weight = nn.Parameter(self.weight)
def forward_native( def forward_native(
self, self,
@ -59,7 +64,9 @@ class RMSNorm(CustomOp):
variance = x_var.pow(2).mean(dim=-1, keepdim=True) variance = x_var.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon) x = x * torch.rsqrt(variance + self.variance_epsilon)
x = x.to(orig_dtype) * self.weight x = x.to(orig_dtype)
if self.has_weight:
x = x * self.weight
if residual is None: if residual is None:
return x return x
else: else:

View File

@ -40,6 +40,7 @@ class MambaMixer(CustomOp):
use_conv_bias: bool, use_conv_bias: bool,
use_bias: bool, use_bias: bool,
use_rms_norm: bool, use_rms_norm: bool,
rms_norm_has_weight: bool = True,
rms_norm_eps: float = 1e-5, rms_norm_eps: float = 1e-5,
activation="silu"): activation="silu"):
super().__init__() super().__init__()
@ -105,14 +106,23 @@ class MambaMixer(CustomOp):
input_is_parallel=True, input_is_parallel=True,
) )
self.dt_layernorm = RMSNorm(time_step_rank, self.dt_layernorm = RMSNorm(
eps=rms_norm_eps) if use_rms_norm else None time_step_rank,
eps=rms_norm_eps,
has_weight=rms_norm_has_weight,
) if use_rms_norm else None
self.b_layernorm = RMSNorm(ssm_state_size, self.b_layernorm = RMSNorm(
eps=rms_norm_eps) if use_rms_norm else None ssm_state_size,
eps=rms_norm_eps,
has_weight=rms_norm_has_weight,
) if use_rms_norm else None
self.c_layernorm = RMSNorm(ssm_state_size, self.c_layernorm = RMSNorm(
eps=rms_norm_eps) if use_rms_norm else None ssm_state_size,
eps=rms_norm_eps,
has_weight=rms_norm_has_weight,
) if use_rms_norm else None
def forward_native(self, hidden_states: torch.Tensor, def forward_native(self, hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,

View File

@ -1,5 +1,5 @@
"""PyTorch MAMBA model.""" """PyTorch MAMBA model."""
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Set, Tuple
import torch import torch
from torch import nn from torch import nn
@ -47,6 +47,7 @@ class MambaDecoderLayer(nn.Module):
use_conv_bias=config.use_conv_bias, use_conv_bias=config.use_conv_bias,
use_bias=config.use_bias, use_bias=config.use_bias,
use_rms_norm=self.is_falcon_mamba, use_rms_norm=self.is_falcon_mamba,
rms_norm_has_weight=not self.is_falcon_mamba,
rms_norm_eps=mixer_rms_eps, rms_norm_eps=mixer_rms_eps,
activation=config.hidden_act) activation=config.hidden_act)
@ -241,8 +242,10 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "A_log" in name: if "A_log" in name:
name = name.replace("A_log", "A") name = name.replace("A_log", "A")
@ -254,3 +257,5 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params