mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 16:05:42 +08:00
[FEAT][ROCm] Add AITER grouped topk for DeepSeekV2 (#18825)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
parent
c55d804672
commit
0f5e0d567e
@ -35,6 +35,15 @@ def test_rocm_aiter_biased_grouped_topk_custom_op_registration():
|
|||||||
assert callable(torch.ops.vllm.rocm_aiter_biased_grouped_topk)
|
assert callable(torch.ops.vllm.rocm_aiter_biased_grouped_topk)
|
||||||
|
|
||||||
|
|
||||||
|
def test_rocm_aiter_grouped_topk_custom_op_registration():
|
||||||
|
"""Test that the custom op is correctly registered."""
|
||||||
|
# Check if the op exists in torch.ops.vllm
|
||||||
|
assert hasattr(torch.ops.vllm, 'rocm_aiter_grouped_topk')
|
||||||
|
|
||||||
|
# Check if the op is callable
|
||||||
|
assert callable(torch.ops.vllm.rocm_aiter_grouped_topk)
|
||||||
|
|
||||||
|
|
||||||
def test_rocm_aiter_biased_grouped_topk_torch_compile_compatibility():
|
def test_rocm_aiter_biased_grouped_topk_torch_compile_compatibility():
|
||||||
"""Test that the op can be used with torch.compile."""
|
"""Test that the op can be used with torch.compile."""
|
||||||
# Create test tensors
|
# Create test tensors
|
||||||
@ -120,3 +129,87 @@ def test_rocm_aiter_biased_grouped_topk_torch_compile_compatibility():
|
|||||||
rtol=1e-2,
|
rtol=1e-2,
|
||||||
atol=1e-2)
|
atol=1e-2)
|
||||||
assert torch.allclose(topk_ids_original, topk_ids_compiled)
|
assert torch.allclose(topk_ids_original, topk_ids_compiled)
|
||||||
|
|
||||||
|
|
||||||
|
def test_rocm_aiter_grouped_topk_torch_compile_compatibility():
|
||||||
|
"""Test that the op can be used with torch.compile."""
|
||||||
|
# Create test tensors
|
||||||
|
token = 64
|
||||||
|
expert = 256
|
||||||
|
num_expert_group = 8
|
||||||
|
topk = 8
|
||||||
|
topk_group = 4
|
||||||
|
renormalize = True
|
||||||
|
scoring_func = "softmax"
|
||||||
|
scale_factor = 1.0
|
||||||
|
|
||||||
|
gating_output = torch.randn((token, expert),
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
device="cuda")
|
||||||
|
|
||||||
|
device = gating_output.device
|
||||||
|
topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
|
||||||
|
topk_weights = torch.empty((token, topk),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=device)
|
||||||
|
|
||||||
|
# Define a function that uses the op
|
||||||
|
def grouped_topk_fn(gating_output, topk_weights, topk_ids, scoring_func):
|
||||||
|
return torch.ops.vllm.rocm_aiter_grouped_topk(
|
||||||
|
gating_output, topk_weights, topk_ids, num_expert_group,
|
||||||
|
topk_group, renormalize, scoring_func, scale_factor)
|
||||||
|
|
||||||
|
# Verify the op's fake implementation
|
||||||
|
torch.library.opcheck(torch.ops.vllm.rocm_aiter_grouped_topk,
|
||||||
|
(gating_output, topk_weights, topk_ids),
|
||||||
|
kwargs={
|
||||||
|
"num_expert_group": num_expert_group,
|
||||||
|
"topk_group": topk_group,
|
||||||
|
"need_renorm": renormalize,
|
||||||
|
"scoring_func": scoring_func,
|
||||||
|
"routed_scaling_factor": scale_factor
|
||||||
|
},
|
||||||
|
test_utils=("test_faketensor"))
|
||||||
|
|
||||||
|
# Compile the function with appropriate settings
|
||||||
|
compiled_fn = torch.compile(grouped_topk_fn,
|
||||||
|
fullgraph=True,
|
||||||
|
backend="inductor",
|
||||||
|
mode="reduce-overhead",
|
||||||
|
dynamic=False)
|
||||||
|
|
||||||
|
topk_weights_original = torch.empty((token, topk),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=device)
|
||||||
|
topk_ids_original = torch.empty((token, topk),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device)
|
||||||
|
|
||||||
|
topk_weights_compiled = torch.empty((token, topk),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=device)
|
||||||
|
topk_ids_compiled = torch.empty((token, topk),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device)
|
||||||
|
|
||||||
|
# Run both compiled (V1 graph mode) and uncompiled versions (V1 eager mode)
|
||||||
|
grouped_topk_fn(gating_output, topk_weights_original, topk_ids_original,
|
||||||
|
scoring_func)
|
||||||
|
compiled_fn(gating_output, topk_weights_compiled, topk_ids_compiled,
|
||||||
|
scoring_func)
|
||||||
|
|
||||||
|
# Sort the results for comparison since the order might not be deterministic
|
||||||
|
topk_ids_original, indices_original = torch.sort(topk_ids_original)
|
||||||
|
topk_weights_original = torch.gather(topk_weights_original, 1,
|
||||||
|
indices_original)
|
||||||
|
|
||||||
|
topk_ids_compiled, indices_compiled = torch.sort(topk_ids_compiled)
|
||||||
|
topk_weights_compiled = torch.gather(topk_weights_compiled, 1,
|
||||||
|
indices_compiled)
|
||||||
|
|
||||||
|
# Verify results match
|
||||||
|
assert torch.allclose(topk_weights_original,
|
||||||
|
topk_weights_compiled,
|
||||||
|
rtol=1e-2,
|
||||||
|
atol=1e-2)
|
||||||
|
assert torch.allclose(topk_ids_original, topk_ids_compiled)
|
||||||
|
|||||||
@ -45,7 +45,7 @@ else:
|
|||||||
FusedMoEPrepareAndFinalize = None # type: ignore
|
FusedMoEPrepareAndFinalize = None # type: ignore
|
||||||
if is_rocm_aiter_moe_enabled():
|
if is_rocm_aiter_moe_enabled():
|
||||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
|
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
|
||||||
rocm_aiter_biased_group_topk as grouped_topk)
|
rocm_aiter_grouped_topk as grouped_topk)
|
||||||
else:
|
else:
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk
|
from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk
|
||||||
if current_platform.is_tpu():
|
if current_platform.is_tpu():
|
||||||
|
|||||||
@ -140,6 +140,36 @@ def rocm_aiter_biased_grouped_topk_fake(
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def rocm_aiter_grouped_topk_impl(
|
||||||
|
gating_output: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
num_expert_group: int,
|
||||||
|
topk_group: int,
|
||||||
|
need_renorm: bool,
|
||||||
|
scoring_func: str = "softmax",
|
||||||
|
routed_scaling_factor: float = 1.0 # mul to topk_weights
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
from aiter import grouped_topk
|
||||||
|
|
||||||
|
grouped_topk(gating_output, topk_weights, topk_ids, num_expert_group,
|
||||||
|
topk_group, need_renorm, scoring_func, routed_scaling_factor)
|
||||||
|
|
||||||
|
|
||||||
|
def rocm_aiter_grouped_topk_fake(
|
||||||
|
gating_output: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
num_expert_group: int,
|
||||||
|
topk_group: int,
|
||||||
|
need_renorm: bool,
|
||||||
|
scoring_func: str = "softmax",
|
||||||
|
routed_scaling_factor: float = 1.0 # mul to topk_weights
|
||||||
|
) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def rocm_aiter_fused_moe_impl(
|
def rocm_aiter_fused_moe_impl(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
w1: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
@ -218,27 +248,33 @@ if current_platform.is_rocm():
|
|||||||
dispatch_key=current_platform.dispatch_key,
|
dispatch_key=current_platform.dispatch_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="rocm_aiter_grouped_topk",
|
||||||
|
op_func=rocm_aiter_grouped_topk_impl,
|
||||||
|
mutates_args=["topk_weights", "topk_ids"],
|
||||||
|
fake_impl=rocm_aiter_grouped_topk_fake,
|
||||||
|
dispatch_key=current_platform.dispatch_key,
|
||||||
|
)
|
||||||
|
|
||||||
def rocm_aiter_biased_group_topk(
|
|
||||||
|
def rocm_aiter_grouped_topk(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
gating_output: torch.Tensor,
|
gating_output: torch.Tensor,
|
||||||
topk: int,
|
topk: int,
|
||||||
renormalize: bool,
|
renormalize: bool,
|
||||||
num_expert_group: int = 0,
|
num_expert_group: int = 0,
|
||||||
topk_group: int = 0,
|
topk_group: int = 0,
|
||||||
scoring_func: str = "sigmoid",
|
scoring_func: str = "softmax",
|
||||||
e_score_correction_bias: Optional[torch.Tensor] = None
|
e_score_correction_bias: Optional[torch.Tensor] = None
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
assert scoring_func == "sigmoid", (
|
|
||||||
"rocm_aiter_biased_group_topk only supports 'sigmoid' scoring_func.")
|
|
||||||
assert e_score_correction_bias is not None, (
|
|
||||||
"'e_score_correction_bias' must not be None.")
|
|
||||||
token = hidden_states.shape[0]
|
token = hidden_states.shape[0]
|
||||||
device = hidden_states.device
|
device = hidden_states.device
|
||||||
topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
|
topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
|
||||||
topk_weights = torch.empty((token, topk),
|
topk_weights = torch.empty((token, topk),
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=device)
|
device=device)
|
||||||
|
|
||||||
|
if e_score_correction_bias is not None:
|
||||||
torch.ops.vllm.rocm_aiter_biased_grouped_topk(
|
torch.ops.vllm.rocm_aiter_biased_grouped_topk(
|
||||||
gating_output,
|
gating_output,
|
||||||
e_score_correction_bias,
|
e_score_correction_bias,
|
||||||
@ -248,6 +284,18 @@ def rocm_aiter_biased_group_topk(
|
|||||||
topk_group,
|
topk_group,
|
||||||
renormalize,
|
renormalize,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
assert (scoring_func == "softmax" or scoring_func == "sigmoid")
|
||||||
|
torch.ops.vllm.rocm_aiter_grouped_topk(
|
||||||
|
gating_output,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
num_expert_group,
|
||||||
|
topk_group,
|
||||||
|
renormalize,
|
||||||
|
scoring_func,
|
||||||
|
)
|
||||||
|
|
||||||
return topk_weights, topk_ids
|
return topk_weights, topk_ids
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user