mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:06:06 +08:00
Fix MiniMax-M2 rmsnorm precision and remove useless code (#27627)
Signed-off-by: xuebi <xuebi@minimaxi.com> Co-authored-by: xuebi <xuebi@minimaxi.com>
This commit is contained in:
parent
ecca3fee76
commit
d6704dd099
@ -77,7 +77,7 @@ class MiniMaxText01RMSNormTP(CustomOp):
|
||||
if self.tp_world > 1:
|
||||
variance = tensor_model_parallel_all_reduce(variance) / self.tp_world
|
||||
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||
x = x.to(orig_dtype) * self.weight
|
||||
x = (x * self.weight).to(orig_dtype)
|
||||
return x
|
||||
|
||||
def forward(
|
||||
|
||||
@ -263,23 +263,6 @@ class MiniMaxM2DecoderLayer(nn.Module):
|
||||
# with the layer's index.
|
||||
layer_idx = int(prefix.split(sep=".")[-1])
|
||||
|
||||
# TODO: support MTP
|
||||
attn_window_size = getattr(config, "attn_window_size", None)
|
||||
if attn_window_size is not None:
|
||||
if isinstance(attn_window_size, list):
|
||||
attn_window_size = attn_window_size[layer_idx]
|
||||
elif isinstance(attn_window_size, int):
|
||||
attn_window_size = attn_window_size
|
||||
else:
|
||||
raise ValueError(f"Invalid attn_window_size: {attn_window_size}")
|
||||
attn_window_size = None if attn_window_size <= 0 else attn_window_size
|
||||
|
||||
# different rope theta for full layer and swa layer
|
||||
swa_rope_theta = getattr(config, "swa_rope_theta", -1)
|
||||
# default to full rope theta
|
||||
swa_rope_theta = rope_theta if swa_rope_theta <= 0 else swa_rope_theta
|
||||
rope_theta = swa_rope_theta if attn_window_size is not None else rope_theta
|
||||
|
||||
self.layer_idx = layer_idx
|
||||
self.self_attn = MiniMaxM2Attention(
|
||||
hidden_size=self.hidden_size,
|
||||
@ -288,7 +271,6 @@ class MiniMaxM2DecoderLayer(nn.Module):
|
||||
rotary_dim=config.rotary_dim,
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
attn_window_size=attn_window_size,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
rms_norm_eps=config.rms_norm_eps,
|
||||
qkv_bias=getattr(config, "attention_bias", False),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user