mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 15:24:28 +08:00
[FEAT] [ROCm] [V1]: Add AITER biased group topk for DeepSeekV3 (#17955)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
parent
12e6c0b41c
commit
2d912fb66f
122
tests/kernels/moe/test_rocm_aiter_topk.py
Normal file
122
tests/kernels/moe/test_rocm_aiter_topk.py
Normal file
@ -0,0 +1,122 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# This is a test for the AITER ops.
|
||||
# It tests if the AITER ops are
|
||||
# 1. correctly registered as custom ops
|
||||
# 2. correctly defined the relationship between
|
||||
# implementation and fake function
|
||||
# 3. can be used with torch.compile
|
||||
# This file will be skipped if AITER is not installed
|
||||
# and the platform is not ROCm.
|
||||
|
||||
import importlib.util
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# this import statement is needed to ensure the ops are registered
|
||||
import vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe # noqa: F401
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
# need to import once to ensure the ops are registered
|
||||
# Check if aiter package is installed
|
||||
aiter_available = importlib.util.find_spec("aiter") is not None
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not (current_platform.is_rocm() and aiter_available),
|
||||
reason="AITER ops are only available on ROCm with aiter package installed")
|
||||
|
||||
|
||||
def test_rocm_aiter_biased_grouped_topk_custom_op_registration():
|
||||
"""Test that the custom op is correctly registered."""
|
||||
# Check if the op exists in torch.ops.vllm
|
||||
assert hasattr(torch.ops.vllm, 'rocm_aiter_biased_grouped_topk')
|
||||
|
||||
# Check if the op is callable
|
||||
assert callable(torch.ops.vllm.rocm_aiter_biased_grouped_topk)
|
||||
|
||||
|
||||
def test_rocm_aiter_biased_grouped_topk_torch_compile_compatibility():
|
||||
"""Test that the op can be used with torch.compile."""
|
||||
# Create test tensors
|
||||
token = 64
|
||||
expert = 256
|
||||
num_expert_group = 8
|
||||
topk = 8
|
||||
topk_group = 4
|
||||
renormalize = True
|
||||
scale_factor = 1.0
|
||||
|
||||
gating_output = torch.randn((token, expert),
|
||||
dtype=torch.bfloat16,
|
||||
device="cuda")
|
||||
e_score_correction_bias = torch.randn((expert, ),
|
||||
dtype=torch.bfloat16,
|
||||
device="cuda")
|
||||
|
||||
device = gating_output.device
|
||||
topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
|
||||
topk_weights = torch.empty((token, topk),
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
|
||||
# Define a function that uses the op
|
||||
def biased_grouped_topk_fn(gating_output, e_score_correction_bias,
|
||||
topk_weights, topk_ids):
|
||||
return torch.ops.vllm.rocm_aiter_biased_grouped_topk(
|
||||
gating_output, e_score_correction_bias, topk_weights, topk_ids,
|
||||
num_expert_group, topk_group, renormalize, scale_factor)
|
||||
|
||||
# Verify the op's fake implementation
|
||||
torch.library.opcheck(
|
||||
torch.ops.vllm.rocm_aiter_biased_grouped_topk,
|
||||
(gating_output, e_score_correction_bias, topk_weights, topk_ids),
|
||||
kwargs={
|
||||
"num_expert_group": num_expert_group,
|
||||
"topk_group": topk_group,
|
||||
"need_renorm": renormalize,
|
||||
"routed_scaling_factor": scale_factor
|
||||
},
|
||||
test_utils=("test_faketensor"))
|
||||
|
||||
# Compile the function with appropriate settings
|
||||
compiled_fn = torch.compile(biased_grouped_topk_fn,
|
||||
fullgraph=True,
|
||||
backend="inductor",
|
||||
mode="reduce-overhead",
|
||||
dynamic=False)
|
||||
|
||||
topk_weights_original = torch.empty((token, topk),
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
topk_ids_original = torch.empty((token, topk),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
topk_weights_compiled = torch.empty((token, topk),
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
topk_ids_compiled = torch.empty((token, topk),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
# Run both compiled (V1 graph mode) and uncompiled versions (V1 eager mode)
|
||||
biased_grouped_topk_fn(gating_output, e_score_correction_bias,
|
||||
topk_weights_original, topk_ids_original)
|
||||
compiled_fn(gating_output, e_score_correction_bias, topk_weights_compiled,
|
||||
topk_ids_compiled)
|
||||
|
||||
# Sort the results for comparison since the order might not be deterministic
|
||||
topk_ids_original, indices_original = torch.sort(topk_ids_original)
|
||||
topk_weights_original = torch.gather(topk_weights_original, 1,
|
||||
indices_original)
|
||||
|
||||
topk_ids_compiled, indices_compiled = torch.sort(topk_ids_compiled)
|
||||
topk_weights_compiled = torch.gather(topk_weights_compiled, 1,
|
||||
indices_compiled)
|
||||
|
||||
# Verify results match
|
||||
assert torch.allclose(topk_weights_original,
|
||||
topk_weights_compiled,
|
||||
rtol=1e-2,
|
||||
atol=1e-2)
|
||||
assert torch.allclose(topk_ids_original, topk_ids_compiled)
|
||||
@ -17,6 +17,8 @@ from vllm.distributed import (get_dp_group, get_ep_group,
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
is_rocm_aiter_moe_enabled)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
@ -28,6 +30,11 @@ if current_platform.is_cuda_alike():
|
||||
from .fused_moe import fused_experts
|
||||
else:
|
||||
fused_experts = None # type: ignore
|
||||
if is_rocm_aiter_moe_enabled():
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
|
||||
rocm_aiter_biased_group_topk as grouped_topk)
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk
|
||||
if current_platform.is_tpu():
|
||||
# the iterative moe implementation is used until the moe_pallas is fixed
|
||||
from .moe_torch_iterative import fused_moe as fused_moe_pallas
|
||||
@ -802,8 +809,7 @@ class FusedMoE(torch.nn.Module):
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None):
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_topk, grouped_topk)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
|
||||
# DeekSeekv2 uses grouped_top_k
|
||||
if use_grouped_topk:
|
||||
|
||||
@ -216,6 +216,37 @@ def rocm_aiter_topk_softmax_fake(topk_weights: torch.Tensor,
|
||||
pass
|
||||
|
||||
|
||||
def rocm_aiter_biased_grouped_topk_impl(
|
||||
gating_output: torch.Tensor,
|
||||
correction_bias: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_expert_group: int,
|
||||
topk_group: int,
|
||||
need_renorm: bool,
|
||||
routed_scaling_factor: float = 1.0 # mul to topk_weights
|
||||
) -> None:
|
||||
|
||||
from aiter import biased_grouped_topk
|
||||
|
||||
biased_grouped_topk(gating_output, correction_bias, topk_weights, topk_ids,
|
||||
num_expert_group, topk_group, need_renorm,
|
||||
routed_scaling_factor)
|
||||
|
||||
|
||||
def rocm_aiter_biased_grouped_topk_fake(
|
||||
gating_output: torch.Tensor,
|
||||
correction_bias: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_expert_group: int,
|
||||
topk_group: int,
|
||||
need_renorm: bool,
|
||||
routed_scaling_factor: float = 1.0 # mul to topk_weights
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
if current_platform.is_rocm():
|
||||
|
||||
direct_register_custom_op(
|
||||
@ -258,6 +289,46 @@ if current_platform.is_rocm():
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_biased_grouped_topk",
|
||||
op_func=rocm_aiter_biased_grouped_topk_impl,
|
||||
mutates_args=["topk_weights", "topk_ids"],
|
||||
fake_impl=rocm_aiter_biased_grouped_topk_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
|
||||
def rocm_aiter_biased_group_topk(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
num_expert_group: int = 0,
|
||||
topk_group: int = 0,
|
||||
scoring_func: str = "sigmoid",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert scoring_func == "sigmoid", (
|
||||
"rocm_aiter_biased_group_topk only supports 'sigmoid' scoring_func.")
|
||||
assert e_score_correction_bias is not None, (
|
||||
"'e_score_correction_bias' must not be None.")
|
||||
token = hidden_states.shape[0]
|
||||
device = hidden_states.device
|
||||
topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
|
||||
topk_weights = torch.empty((token, topk),
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
torch.ops.vllm.rocm_aiter_biased_grouped_topk(
|
||||
gating_output,
|
||||
e_score_correction_bias,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
renormalize,
|
||||
)
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
def rocm_aiter_fused_experts(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user