[Model] Add has_weight to RMSNorm and re-enable weights loading tracker for Mamba (#10739)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py 2024-12-10 10:23:07 +08:00 committed by GitHub
parent 6d525288c1
commit d1f6d1c8af
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 32 additions and 10 deletions

View File

@ -20,6 +20,7 @@ class RMSNorm(CustomOp):
hidden_size: int,
eps: float = 1e-6,
var_hidden_size: Optional[int] = None,
has_weight: bool = True,
) -> None:
super().__init__()
@ -27,7 +28,11 @@ class RMSNorm(CustomOp):
self.variance_epsilon = eps
self.variance_size_override = (None if var_hidden_size == hidden_size
else var_hidden_size)
self.weight = nn.Parameter(torch.ones(hidden_size))
self.has_weight = has_weight
self.weight = torch.ones(hidden_size)
if self.has_weight:
self.weight = nn.Parameter(self.weight)
def forward_native(
self,
@ -59,7 +64,9 @@ class RMSNorm(CustomOp):
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
x = x.to(orig_dtype) * self.weight
x = x.to(orig_dtype)
if self.has_weight:
x = x * self.weight
if residual is None:
return x
else:

View File

@ -40,6 +40,7 @@ class MambaMixer(CustomOp):
use_conv_bias: bool,
use_bias: bool,
use_rms_norm: bool,
rms_norm_has_weight: bool = True,
rms_norm_eps: float = 1e-5,
activation="silu"):
super().__init__()
@ -105,14 +106,23 @@ class MambaMixer(CustomOp):
input_is_parallel=True,
)
self.dt_layernorm = RMSNorm(time_step_rank,
eps=rms_norm_eps) if use_rms_norm else None
self.dt_layernorm = RMSNorm(
time_step_rank,
eps=rms_norm_eps,
has_weight=rms_norm_has_weight,
) if use_rms_norm else None
self.b_layernorm = RMSNorm(ssm_state_size,
eps=rms_norm_eps) if use_rms_norm else None
self.b_layernorm = RMSNorm(
ssm_state_size,
eps=rms_norm_eps,
has_weight=rms_norm_has_weight,
) if use_rms_norm else None
self.c_layernorm = RMSNorm(ssm_state_size,
eps=rms_norm_eps) if use_rms_norm else None
self.c_layernorm = RMSNorm(
ssm_state_size,
eps=rms_norm_eps,
has_weight=rms_norm_has_weight,
) if use_rms_norm else None
def forward_native(self, hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,

View File

@ -1,5 +1,5 @@
"""PyTorch MAMBA model."""
from typing import Iterable, List, Optional, Tuple
from typing import Iterable, List, Optional, Set, Tuple
import torch
from torch import nn
@ -47,6 +47,7 @@ class MambaDecoderLayer(nn.Module):
use_conv_bias=config.use_conv_bias,
use_bias=config.use_bias,
use_rms_norm=self.is_falcon_mamba,
rms_norm_has_weight=not self.is_falcon_mamba,
rms_norm_eps=mixer_rms_eps,
activation=config.hidden_act)
@ -241,8 +242,10 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "A_log" in name:
name = name.replace("A_log", "A")
@ -254,3 +257,5 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params