From 3b1d440ede42855f031ba72af4817583e5dddba0 Mon Sep 17 00:00:00 2001 From: Xinyu Chen Date: Wed, 17 Dec 2025 17:43:00 +0800 Subject: [PATCH] CustomOp: grouped topk (#29575) Signed-off-by: Xinyu Chen --- tests/kernels/moe/test_grouped_topk.py | 10 ++-- .../layers/fused_moe/__init__.py | 4 +- .../layers/fused_moe/fused_moe.py | 52 +++++++++++++++++++ vllm/model_executor/layers/fused_moe/layer.py | 23 +++++--- 4 files changed, 75 insertions(+), 14 deletions(-) diff --git a/tests/kernels/moe/test_grouped_topk.py b/tests/kernels/moe/test_grouped_topk.py index 662e0723b7583..d26fe50b815b4 100644 --- a/tests/kernels/moe/test_grouped_topk.py +++ b/tests/kernels/moe/test_grouped_topk.py @@ -9,8 +9,8 @@ import pytest import torch from vllm.model_executor.layers.fused_moe.fused_moe import ( + GroupedTopk, fused_grouped_topk, - grouped_topk, ) from vllm.platforms import current_platform @@ -50,15 +50,17 @@ def test_grouped_topk( with monkeypatch.context() as m: m.setenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "0") - baseline_topk_weights, baseline_topk_ids = grouped_topk( - hidden_states=hidden_states, - gating_output=gating_output, + grouped_topk = GroupedTopk( topk=topk, renormalize=renormalize, num_expert_group=num_expert_group, topk_group=topk_group, scoring_func=scoring_func, routed_scaling_factor=routed_scaling_factor, + ) + baseline_topk_weights, baseline_topk_ids = grouped_topk( + hidden_states=hidden_states, + gating_output=gating_output, e_score_correction_bias=e_score_correction_bias, ) diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index d71cfc5ad8200..8fee4038b60b8 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -77,11 +77,11 @@ if HAS_TRITON: BatchedTritonExperts, ) from vllm.model_executor.layers.fused_moe.fused_moe import ( + GroupedTopk, TritonExperts, fused_experts, fused_topk, get_config_file_name, - grouped_topk, ) from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( TritonOrDeepGemmExperts, @@ -91,7 +91,7 @@ if HAS_TRITON: "fused_topk", "fused_experts", "get_config_file_name", - "grouped_topk", + "GroupedTopk", "cutlass_moe_fp8", "cutlass_moe_fp4", "cutlass_moe_w4a8_fp8", diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index b286c3bc6fc07..20782e2712f27 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -16,6 +16,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm._aiter_ops import rocm_aiter_ops from vllm.logger import init_logger +from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, ) @@ -1286,6 +1287,57 @@ def grouped_topk( return topk_weights.to(torch.float32), topk_ids.to(torch.int32) +@CustomOp.register("grouped_topk") +class GroupedTopk(CustomOp): + """GroupedTopk used by the Deepseek-V2 and Deepseek-V3 model.""" + + def __init__( + self, + topk: int, + renormalize: bool, + num_expert_group: int = 0, + topk_group: int = 0, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + ) -> None: + super().__init__() + self.native_impl = grouped_topk + self.topk = topk + self.renormalize = renormalize + self.num_expert_group = num_expert_group + self.topk_group = topk_group + self.scoring_func = scoring_func + self.routed_scaling_factor = routed_scaling_factor + + def forward_native( + self, + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + e_score_correction_bias: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + return self.native_impl( + hidden_states, + gating_output, + self.topk, + self.renormalize, + self.num_expert_group, + self.topk_group, + self.scoring_func, + self.routed_scaling_factor, + e_score_correction_bias, + ) + + def forward_cuda( + self, + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + e_score_correction_bias: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + return self.forward_native( + hidden_states, gating_output, e_score_correction_bias + ) + + @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) def eplb_map_to_physical_and_record( topk_ids: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index b39ce415a0f83..db97d6eb88ea5 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -67,7 +67,7 @@ else: return topk_ids eplb_map_to_physical_and_record = _eplb_map_to_physical_and_record -from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk +from vllm.model_executor.layers.fused_moe.fused_moe import GroupedTopk from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 rocm_aiter_grouped_topk, ) @@ -1594,19 +1594,26 @@ class FusedMoE(CustomOp): grouped_topk_impl = partial( rocm_aiter_grouped_topk, num_fused_shared_experts=self.num_fused_shared_experts, + topk=self.top_k, + renormalize=self.renormalize, + num_expert_group=self.num_expert_group, + topk_group=self.topk_group, + scoring_func=self.scoring_func, + routed_scaling_factor=self.routed_scaling_factor, ) else: - grouped_topk_impl = grouped_topk + grouped_topk_impl = GroupedTopk( + topk=self.top_k, + renormalize=self.renormalize, + num_expert_group=self.num_expert_group, + topk_group=self.topk_group, + scoring_func=self.scoring_func, + routed_scaling_factor=self.routed_scaling_factor, + ) topk_weights, topk_ids = grouped_topk_impl( hidden_states=hidden_states, gating_output=router_logits, - topk=self.top_k, - renormalize=self.renormalize, - num_expert_group=self.num_expert_group, - topk_group=self.topk_group, - scoring_func=self.scoring_func, - routed_scaling_factor=self.routed_scaling_factor, e_score_correction_bias=self.e_score_correction_bias, ) elif self.e_score_correction_bias is not None: