diff --git a/vllm/model_executor/layers/mamba/linear_attn.py b/vllm/model_executor/layers/mamba/linear_attn.py index d93cef1a27ad..5fe37a6289e0 100644 --- a/vllm/model_executor/layers/mamba/linear_attn.py +++ b/vllm/model_executor/layers/mamba/linear_attn.py @@ -83,17 +83,7 @@ class MiniMaxText01RMSNormTP(CustomOp): variance = tensor_model_parallel_all_reduce( variance) / self.tp_world x = x * torch.rsqrt(variance + self.variance_epsilon) - - weight = self.weight - if x.size(-1) != self.weight.size(0): - if self.weight.size(0) < x.size(-1): - repeat_count = (x.size(-1) + self.weight.size(0)) // x.size(-1) - full_weight = self.weight.repeat(repeat_count) - weight = full_weight[:x.size(-1)] - else: - weight = self.weight[:x.size(-1)] - - x = x.to(orig_dtype) * weight + x = x.to(orig_dtype) * self.weight return x def forward(