CustomOp: grouped topk (#29575)

Signed-off-by: Xinyu Chen <xinyu1.chen@intel.com>
This commit is contained in:
Xinyu Chen 2025-12-17 17:43:00 +08:00 committed by GitHub
parent a9e15c21ef
commit 3b1d440ede
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 75 additions and 14 deletions

View File

@ -9,8 +9,8 @@ import pytest
import torch
from vllm.model_executor.layers.fused_moe.fused_moe import (
GroupedTopk,
fused_grouped_topk,
grouped_topk,
)
from vllm.platforms import current_platform
@ -50,15 +50,17 @@ def test_grouped_topk(
with monkeypatch.context() as m:
m.setenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "0")
baseline_topk_weights, baseline_topk_ids = grouped_topk(
hidden_states=hidden_states,
gating_output=gating_output,
grouped_topk = GroupedTopk(
topk=topk,
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
)
baseline_topk_weights, baseline_topk_ids = grouped_topk(
hidden_states=hidden_states,
gating_output=gating_output,
e_score_correction_bias=e_score_correction_bias,
)

View File

@ -77,11 +77,11 @@ if HAS_TRITON:
BatchedTritonExperts,
)
from vllm.model_executor.layers.fused_moe.fused_moe import (
GroupedTopk,
TritonExperts,
fused_experts,
fused_topk,
get_config_file_name,
grouped_topk,
)
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts,
@ -91,7 +91,7 @@ if HAS_TRITON:
"fused_topk",
"fused_experts",
"get_config_file_name",
"grouped_topk",
"GroupedTopk",
"cutlass_moe_fp8",
"cutlass_moe_fp4",
"cutlass_moe_w4a8_fp8",

View File

@ -16,6 +16,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
@ -1286,6 +1287,57 @@ def grouped_topk(
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
@CustomOp.register("grouped_topk")
class GroupedTopk(CustomOp):
"""GroupedTopk used by the Deepseek-V2 and Deepseek-V3 model."""
def __init__(
self,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
) -> None:
super().__init__()
self.native_impl = grouped_topk
self.topk = topk
self.renormalize = renormalize
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.scoring_func = scoring_func
self.routed_scaling_factor = routed_scaling_factor
def forward_native(
self,
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
e_score_correction_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
return self.native_impl(
hidden_states,
gating_output,
self.topk,
self.renormalize,
self.num_expert_group,
self.topk_group,
self.scoring_func,
self.routed_scaling_factor,
e_score_correction_bias,
)
def forward_cuda(
self,
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
e_score_correction_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
return self.forward_native(
hidden_states, gating_output, e_score_correction_bias
)
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def eplb_map_to_physical_and_record(
topk_ids: torch.Tensor,

View File

@ -67,7 +67,7 @@ else:
return topk_ids
eplb_map_to_physical_and_record = _eplb_map_to_physical_and_record
from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk
from vllm.model_executor.layers.fused_moe.fused_moe import GroupedTopk
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
rocm_aiter_grouped_topk,
)
@ -1594,19 +1594,26 @@ class FusedMoE(CustomOp):
grouped_topk_impl = partial(
rocm_aiter_grouped_topk,
num_fused_shared_experts=self.num_fused_shared_experts,
topk=self.top_k,
renormalize=self.renormalize,
num_expert_group=self.num_expert_group,
topk_group=self.topk_group,
scoring_func=self.scoring_func,
routed_scaling_factor=self.routed_scaling_factor,
)
else:
grouped_topk_impl = grouped_topk
grouped_topk_impl = GroupedTopk(
topk=self.top_k,
renormalize=self.renormalize,
num_expert_group=self.num_expert_group,
topk_group=self.topk_group,
scoring_func=self.scoring_func,
routed_scaling_factor=self.routed_scaling_factor,
)
topk_weights, topk_ids = grouped_topk_impl(
hidden_states=hidden_states,
gating_output=router_logits,
topk=self.top_k,
renormalize=self.renormalize,
num_expert_group=self.num_expert_group,
topk_group=self.topk_group,
scoring_func=self.scoring_func,
routed_scaling_factor=self.routed_scaling_factor,
e_score_correction_bias=self.e_score_correction_bias,
)
elif self.e_score_correction_bias is not None: