From d6704dd099b71a9881613dee74b440d78594fa8f Mon Sep 17 00:00:00 2001 From: Roger Young <42564206+rogeryoungh@users.noreply.github.com> Date: Wed, 29 Oct 2025 21:01:05 +0800 Subject: [PATCH] Fix MiniMax-M2 rmsnorm precision and remove useless code (#27627) Signed-off-by: xuebi Co-authored-by: xuebi --- .../model_executor/layers/mamba/linear_attn.py | 2 +- vllm/model_executor/models/minimax_m2.py | 18 ------------------ 2 files changed, 1 insertion(+), 19 deletions(-) diff --git a/vllm/model_executor/layers/mamba/linear_attn.py b/vllm/model_executor/layers/mamba/linear_attn.py index fd4567ee47018..0a2742ff49a44 100644 --- a/vllm/model_executor/layers/mamba/linear_attn.py +++ b/vllm/model_executor/layers/mamba/linear_attn.py @@ -77,7 +77,7 @@ class MiniMaxText01RMSNormTP(CustomOp): if self.tp_world > 1: variance = tensor_model_parallel_all_reduce(variance) / self.tp_world x = x * torch.rsqrt(variance + self.variance_epsilon) - x = x.to(orig_dtype) * self.weight + x = (x * self.weight).to(orig_dtype) return x def forward( diff --git a/vllm/model_executor/models/minimax_m2.py b/vllm/model_executor/models/minimax_m2.py index dadb8a19c004e..21ed428a05d0f 100644 --- a/vllm/model_executor/models/minimax_m2.py +++ b/vllm/model_executor/models/minimax_m2.py @@ -263,23 +263,6 @@ class MiniMaxM2DecoderLayer(nn.Module): # with the layer's index. layer_idx = int(prefix.split(sep=".")[-1]) - # TODO: support MTP - attn_window_size = getattr(config, "attn_window_size", None) - if attn_window_size is not None: - if isinstance(attn_window_size, list): - attn_window_size = attn_window_size[layer_idx] - elif isinstance(attn_window_size, int): - attn_window_size = attn_window_size - else: - raise ValueError(f"Invalid attn_window_size: {attn_window_size}") - attn_window_size = None if attn_window_size <= 0 else attn_window_size - - # different rope theta for full layer and swa layer - swa_rope_theta = getattr(config, "swa_rope_theta", -1) - # default to full rope theta - swa_rope_theta = rope_theta if swa_rope_theta <= 0 else swa_rope_theta - rope_theta = swa_rope_theta if attn_window_size is not None else rope_theta - self.layer_idx = layer_idx self.self_attn = MiniMaxM2Attention( hidden_size=self.hidden_size, @@ -288,7 +271,6 @@ class MiniMaxM2DecoderLayer(nn.Module): rotary_dim=config.rotary_dim, rope_theta=rope_theta, rope_scaling=rope_scaling, - attn_window_size=attn_window_size, max_position_embeddings=max_position_embeddings, rms_norm_eps=config.rms_norm_eps, qkv_bias=getattr(config, "attention_bias", False),