mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-29 15:07:28 +08:00
[Model] Add has_weight to RMSNorm and re-enable weights loading tracker for Mamba (#10739)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
6d525288c1
commit
d1f6d1c8af
@ -20,6 +20,7 @@ class RMSNorm(CustomOp):
|
||||
hidden_size: int,
|
||||
eps: float = 1e-6,
|
||||
var_hidden_size: Optional[int] = None,
|
||||
has_weight: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@ -27,7 +28,11 @@ class RMSNorm(CustomOp):
|
||||
self.variance_epsilon = eps
|
||||
self.variance_size_override = (None if var_hidden_size == hidden_size
|
||||
else var_hidden_size)
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.has_weight = has_weight
|
||||
|
||||
self.weight = torch.ones(hidden_size)
|
||||
if self.has_weight:
|
||||
self.weight = nn.Parameter(self.weight)
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
@ -59,7 +64,9 @@ class RMSNorm(CustomOp):
|
||||
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
|
||||
|
||||
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||
x = x.to(orig_dtype) * self.weight
|
||||
x = x.to(orig_dtype)
|
||||
if self.has_weight:
|
||||
x = x * self.weight
|
||||
if residual is None:
|
||||
return x
|
||||
else:
|
||||
|
||||
@ -40,6 +40,7 @@ class MambaMixer(CustomOp):
|
||||
use_conv_bias: bool,
|
||||
use_bias: bool,
|
||||
use_rms_norm: bool,
|
||||
rms_norm_has_weight: bool = True,
|
||||
rms_norm_eps: float = 1e-5,
|
||||
activation="silu"):
|
||||
super().__init__()
|
||||
@ -105,14 +106,23 @@ class MambaMixer(CustomOp):
|
||||
input_is_parallel=True,
|
||||
)
|
||||
|
||||
self.dt_layernorm = RMSNorm(time_step_rank,
|
||||
eps=rms_norm_eps) if use_rms_norm else None
|
||||
self.dt_layernorm = RMSNorm(
|
||||
time_step_rank,
|
||||
eps=rms_norm_eps,
|
||||
has_weight=rms_norm_has_weight,
|
||||
) if use_rms_norm else None
|
||||
|
||||
self.b_layernorm = RMSNorm(ssm_state_size,
|
||||
eps=rms_norm_eps) if use_rms_norm else None
|
||||
self.b_layernorm = RMSNorm(
|
||||
ssm_state_size,
|
||||
eps=rms_norm_eps,
|
||||
has_weight=rms_norm_has_weight,
|
||||
) if use_rms_norm else None
|
||||
|
||||
self.c_layernorm = RMSNorm(ssm_state_size,
|
||||
eps=rms_norm_eps) if use_rms_norm else None
|
||||
self.c_layernorm = RMSNorm(
|
||||
ssm_state_size,
|
||||
eps=rms_norm_eps,
|
||||
has_weight=rms_norm_has_weight,
|
||||
) if use_rms_norm else None
|
||||
|
||||
def forward_native(self, hidden_states: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
"""PyTorch MAMBA model."""
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
from typing import Iterable, List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -47,6 +47,7 @@ class MambaDecoderLayer(nn.Module):
|
||||
use_conv_bias=config.use_conv_bias,
|
||||
use_bias=config.use_bias,
|
||||
use_rms_norm=self.is_falcon_mamba,
|
||||
rms_norm_has_weight=not self.is_falcon_mamba,
|
||||
rms_norm_eps=mixer_rms_eps,
|
||||
activation=config.hidden_act)
|
||||
|
||||
@ -241,8 +242,10 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "A_log" in name:
|
||||
name = name.replace("A_log", "A")
|
||||
@ -254,3 +257,5 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user