diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index d0b6249e1c33..543c8ced165a 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -939,15 +939,17 @@ def grouped_topk(hidden_states: torch.Tensor, else: raise ValueError(f"Unsupported scoring function: {scoring_func}") + num_token = scores.shape[0] if e_score_correction_bias is not None: # Store original scores before applying correction bias. We use biased # scores for expert selection but original scores for routing weights original_scores = scores scores = scores + e_score_correction_bias.unsqueeze(0) - - num_token = scores.shape[0] - group_scores = scores.view(num_token, num_expert_group, - -1).max(dim=-1).values # [n, n_group] + group_scores = (scores.view(num_token, num_expert_group, + -1).topk(2, dim=-1)[0].sum(dim=-1)) + else: + group_scores = scores.view(num_token, num_expert_group, + -1).max(dim=-1).values # [n, n_group] group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[1] # [n, top_k_group] group_mask = torch.zeros_like(group_scores) # [n, n_group] @@ -955,7 +957,8 @@ def grouped_topk(hidden_states: torch.Tensor, score_mask = group_mask.unsqueeze(-1).expand( num_token, num_expert_group, scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e] - tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), + float("-inf")) # [n, e] if e_score_correction_bias is not None: topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]