mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 14:06:32 +08:00
CustomOp: grouped topk (#29575)
Signed-off-by: Xinyu Chen <xinyu1.chen@intel.com>
This commit is contained in:
parent
a9e15c21ef
commit
3b1d440ede
@ -9,8 +9,8 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
GroupedTopk,
|
||||
fused_grouped_topk,
|
||||
grouped_topk,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@ -50,15 +50,17 @@ def test_grouped_topk(
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "0")
|
||||
baseline_topk_weights, baseline_topk_ids = grouped_topk(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=gating_output,
|
||||
grouped_topk = GroupedTopk(
|
||||
topk=topk,
|
||||
renormalize=renormalize,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
baseline_topk_weights, baseline_topk_ids = grouped_topk(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=gating_output,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
|
||||
@ -77,11 +77,11 @@ if HAS_TRITON:
|
||||
BatchedTritonExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
GroupedTopk,
|
||||
TritonExperts,
|
||||
fused_experts,
|
||||
fused_topk,
|
||||
get_config_file_name,
|
||||
grouped_topk,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||
TritonOrDeepGemmExperts,
|
||||
@ -91,7 +91,7 @@ if HAS_TRITON:
|
||||
"fused_topk",
|
||||
"fused_experts",
|
||||
"get_config_file_name",
|
||||
"grouped_topk",
|
||||
"GroupedTopk",
|
||||
"cutlass_moe_fp8",
|
||||
"cutlass_moe_fp4",
|
||||
"cutlass_moe_w4a8_fp8",
|
||||
|
||||
@ -16,6 +16,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
@ -1286,6 +1287,57 @@ def grouped_topk(
|
||||
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
||||
|
||||
|
||||
@CustomOp.register("grouped_topk")
|
||||
class GroupedTopk(CustomOp):
|
||||
"""GroupedTopk used by the Deepseek-V2 and Deepseek-V3 model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
num_expert_group: int = 0,
|
||||
topk_group: int = 0,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.native_impl = grouped_topk
|
||||
self.topk = topk
|
||||
self.renormalize = renormalize
|
||||
self.num_expert_group = num_expert_group
|
||||
self.topk_group = topk_group
|
||||
self.scoring_func = scoring_func
|
||||
self.routed_scaling_factor = routed_scaling_factor
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return self.native_impl(
|
||||
hidden_states,
|
||||
gating_output,
|
||||
self.topk,
|
||||
self.renormalize,
|
||||
self.num_expert_group,
|
||||
self.topk_group,
|
||||
self.scoring_func,
|
||||
self.routed_scaling_factor,
|
||||
e_score_correction_bias,
|
||||
)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return self.forward_native(
|
||||
hidden_states, gating_output, e_score_correction_bias
|
||||
)
|
||||
|
||||
|
||||
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
|
||||
def eplb_map_to_physical_and_record(
|
||||
topk_ids: torch.Tensor,
|
||||
|
||||
@ -67,7 +67,7 @@ else:
|
||||
return topk_ids
|
||||
|
||||
eplb_map_to_physical_and_record = _eplb_map_to_physical_and_record
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import GroupedTopk
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
|
||||
rocm_aiter_grouped_topk,
|
||||
)
|
||||
@ -1594,19 +1594,26 @@ class FusedMoE(CustomOp):
|
||||
grouped_topk_impl = partial(
|
||||
rocm_aiter_grouped_topk,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
topk=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
num_expert_group=self.num_expert_group,
|
||||
topk_group=self.topk_group,
|
||||
scoring_func=self.scoring_func,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
)
|
||||
else:
|
||||
grouped_topk_impl = grouped_topk
|
||||
grouped_topk_impl = GroupedTopk(
|
||||
topk=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
num_expert_group=self.num_expert_group,
|
||||
topk_group=self.topk_group,
|
||||
scoring_func=self.scoring_func,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
)
|
||||
|
||||
topk_weights, topk_ids = grouped_topk_impl(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
topk=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
num_expert_group=self.num_expert_group,
|
||||
topk_group=self.topk_group,
|
||||
scoring_func=self.scoring_func,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
e_score_correction_bias=self.e_score_correction_bias,
|
||||
)
|
||||
elif self.e_score_correction_bias is not None:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user