From d1e1fb4363c61080b7cd20469d5a751e88a1cdb3 Mon Sep 17 00:00:00 2001 From: Divakar Verma <137818590+divakar-amd@users.noreply.github.com> Date: Wed, 10 Dec 2025 21:47:18 -0600 Subject: [PATCH] [Bugfix] Fix grouped_topk pytorch impl when num_experts can't be grouped properly (#29439) Signed-off-by: Divakar Verma Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Co-authored-by: TJian --- vllm/model_executor/layers/fused_moe/layer.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 61dd1892d67ea..7f803720d4770 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1556,6 +1556,14 @@ class FusedMoE(CustomOp): f"EPLB is not supported for {self.quant_method.method_name}." ) + def valid_grouping() -> bool: + # Check if num_experts is greater than num_expert_group + # and is divisible by num_expert_group + num_experts = router_logits.shape[-1] + if num_experts <= self.num_expert_group: + return False + return num_experts % self.num_expert_group == 0 + indices_type = self.quant_method.topk_indices_dtype # Check if we should use a routing simulation strategy @@ -1570,7 +1578,7 @@ class FusedMoE(CustomOp): ) # DeepSeekv2 uses grouped_top_k - elif self.use_grouped_topk: + elif self.use_grouped_topk and valid_grouping(): assert self.topk_group is not None assert self.num_expert_group is not None if rocm_aiter_ops.is_fused_moe_enabled():