[FEAT] [ROCm]: AITER Fused MOE V1 Support (#16752)

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
vllmellm 2025-04-25 11:06:50 +08:00 committed by GitHub
parent 0d6e187e88
commit eef364723c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 302 additions and 130 deletions

View File

@ -11,6 +11,8 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
dispatch_fused_experts_func, dispatch_topk_func, dispatch_fused_experts_func, dispatch_topk_func,
torch_vllm_inplace_fused_experts, torch_vllm_outplace_fused_experts, torch_vllm_inplace_fused_experts, torch_vllm_outplace_fused_experts,
vllm_topk_softmax) vllm_topk_softmax)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled)
from vllm.model_executor.layers.layernorm import ( from vllm.model_executor.layers.layernorm import (
RMSNorm, dispatch_cuda_rmsnorm_func, fused_add_rms_norm, rms_norm, RMSNorm, dispatch_cuda_rmsnorm_func, fused_add_rms_norm, rms_norm,
rocm_aiter_fused_add_rms_norm, rocm_aiter_rms_norm) rocm_aiter_fused_add_rms_norm, rocm_aiter_rms_norm)
@ -100,11 +102,10 @@ def test_enabled_ops_invalid(env: str):
def test_topk_dispatch(use_rocm_aiter: str, monkeypatch): def test_topk_dispatch(use_rocm_aiter: str, monkeypatch):
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
topk_func = dispatch_topk_func() topk_func = dispatch_topk_func()
is_rocm_aiter_moe_enabled.cache_clear()
if current_platform.is_rocm() and int(use_rocm_aiter): if current_platform.is_rocm() and int(use_rocm_aiter):
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_topk_softmax) rocm_aiter_topk_softmax)
assert topk_func == rocm_aiter_topk_softmax assert topk_func == rocm_aiter_topk_softmax
else: else:
assert topk_func == vllm_topk_softmax assert topk_func == vllm_topk_softmax
@ -116,11 +117,11 @@ def test_fused_experts_dispatch(use_rocm_aiter: str, inplace: bool,
monkeypatch): monkeypatch):
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
is_rocm_aiter_moe_enabled.cache_clear()
fused_experts_func = dispatch_fused_experts_func(inplace) fused_experts_func = dispatch_fused_experts_func(inplace)
if current_platform.is_rocm() and int(use_rocm_aiter): if current_platform.is_rocm() and int(use_rocm_aiter):
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_fused_experts) rocm_aiter_fused_experts)
assert fused_experts_func == rocm_aiter_fused_experts assert fused_experts_func == rocm_aiter_fused_experts
elif inplace: elif inplace:
assert fused_experts_func == torch_vllm_inplace_fused_experts assert fused_experts_func == torch_vllm_inplace_fused_experts

View File

@ -1,31 +1,35 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import List, Optional from functools import cache
from typing import List, Optional, Tuple
import torch import torch
import vllm.envs as envs from vllm import envs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
@cache
def is_rocm_aiter_moe_enabled() -> bool: def is_rocm_aiter_moe_enabled() -> bool:
return current_platform.is_rocm() \ return current_platform.is_rocm() \
and envs.VLLM_ROCM_USE_AITER_MOE \ and envs.VLLM_ROCM_USE_AITER_MOE \
and envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER
def rocm_aiter_asm_moe_tkw1(hidden_states, def rocm_aiter_asm_moe_tkw1_impl(
w1, hidden_states: torch.Tensor,
w2, w1: torch.Tensor,
topk_weight, w2: torch.Tensor,
topk_ids, topk_weight: torch.Tensor,
fc1_scale=None, topk_ids: torch.Tensor,
fc2_scale=None, fc1_scale: Optional[torch.Tensor] = None,
fc1_smooth_scale=None, fc2_scale: Optional[torch.Tensor] = None,
fc2_smooth_scale=None, fc1_smooth_scale: Optional[torch.Tensor] = None,
a16=False, fc2_smooth_scale: Optional[torch.Tensor] = None,
per_tensor_quant_scale=None, a16: bool = False,
expert_mask=None, per_tensor_quant_scale: Optional[torch.Tensor] = None,
activation_str: str = "silu") -> None: expert_mask: Optional[torch.Tensor] = None,
activation_str: str = "silu") -> torch.Tensor:
from aiter import ActivationType from aiter import ActivationType
from aiter.fused_moe_bf16_asm import asm_moe_tkw1 from aiter.fused_moe_bf16_asm import asm_moe_tkw1
@ -48,34 +52,236 @@ def rocm_aiter_asm_moe_tkw1(hidden_states,
activation=activation) activation=activation)
def rocm_aiter_fused_experts( def rocm_aiter_asm_moe_tkw1_fake(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_weight: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
inplace: bool = False, fc1_scale: Optional[torch.Tensor] = None,
activation: str = "silu", fc2_scale: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False, fc1_smooth_scale: Optional[torch.Tensor] = None,
use_fp8_w8a8: bool = False, fc2_smooth_scale: Optional[torch.Tensor] = None,
use_int8_w8a8: bool = False, a16: bool = False,
use_int8_w8a16: bool = False, per_tensor_quant_scale: Optional[torch.Tensor] = None,
use_int4_w4a16: bool = False, expert_mask: Optional[torch.Tensor] = None,
per_channel_quant: bool = False, activation_str: str = "silu") -> torch.Tensor:
global_num_experts: int = -1, return torch.empty_like(hidden_states)
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
allow_deep_gemm: bool = False,
) -> torch.Tensor:
import aiter as rocm_aiter
def rocm_aiter_ck_moe_impl(hidden_states: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor) -> torch.Tensor:
from aiter import ck_moe
return ck_moe(hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids)
def rocm_aiter_ck_moe_fake(hidden_states: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor) -> torch.Tensor:
return torch.empty_like(hidden_states)
def rocm_aiter_fmoe_fp8_blockscale_g1u1_impl(
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
hidden_states_dtype: torch.dtype,
expert_mask: torch.Tensor,
a1: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
a1_scale: torch.Tensor,
block_shape: List[int],
smooth_scale: Optional[torch.Tensor] = None) -> torch.Tensor:
from aiter import fmoe_fp8_blockscale_g1u1
from aiter.fused_moe_bf16_asm import moe_sorting_ck
topk = topk_ids.shape[1]
model_dim = w1.shape[-1]
local_E = E = w1.shape[0]
if expert_mask is not None:
E = expert_mask.numel()
(
sorted_token_ids,
sorted_weight_buf,
sorted_expert_ids,
num_valid_ids,
out_asm,
) = moe_sorting_ck(topk_ids,
topk_weights,
E,
model_dim,
hidden_states_dtype,
expert_mask=expert_mask)
fmoe_fp8_blockscale_g1u1(out_asm, a1, w1, w2, sorted_token_ids,
sorted_weight_buf, sorted_expert_ids,
num_valid_ids, topk, w1_scale.view(local_E, -1),
w2_scale.view(local_E, -1),
a1_scale.t().contiguous(), *block_shape,
smooth_scale)
return out_asm
def rocm_aiter_fmoe_fp8_blockscale_g1u1_fake(
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
hidden_states_dtype: torch.dtype,
expert_mask: torch.Tensor,
a1: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
a1_scale: torch.Tensor,
block_shape: List[int],
smooth_scale: Optional[torch.Tensor] = None) -> torch.Tensor:
return torch.empty_like(a1, dtype=torch.bf16)
def rocm_aiter_asm_moe_impl(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
fc1_scale: Optional[torch.Tensor] = None,
fc2_scale: Optional[torch.Tensor] = None,
fc1_smooth_scale: Optional[torch.Tensor] = None,
fc2_smooth_scale: Optional[torch.Tensor] = None,
a16: bool = False,
activation: str = "silu") -> torch.Tensor:
import aiter.fused_moe_bf16_asm as rocm_aiter_asm_fmoe import aiter.fused_moe_bf16_asm as rocm_aiter_asm_fmoe
from aiter import ActivationType
assert activation in ["silu", "gelu"], "The given activation:" \
f" {activation}" \
" is not supported in" \
" AITER."
if activation == "silu":
aiter_activation = ActivationType.Silu
else:
aiter_activation = ActivationType.Gelu
return rocm_aiter_asm_fmoe.asm_moe(hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weight=topk_weight,
topk_ids=topk_ids,
fc1_scale=fc1_scale,
fc2_scale=fc2_scale,
fc1_smooth_scale=fc1_smooth_scale,
fc2_smooth_scale=fc2_smooth_scale,
a16=a16,
activation=aiter_activation)
def rocm_aiter_asm_moe_fake(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
fc1_scale: Optional[torch.Tensor] = None,
fc2_scale: Optional[torch.Tensor] = None,
fc1_smooth_scale: Optional[torch.Tensor] = None,
fc2_smooth_scale: Optional[torch.Tensor] = None,
a16: bool = False,
activation: str = "silu") -> torch.Tensor:
return torch.empty_like(hidden_states)
def rocm_aiter_topk_softmax_impl(topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool) -> None:
from aiter import topk_softmax
topk_softmax(topk_weights, topk_indices, token_expert_indices,
gating_output, renormalize)
def rocm_aiter_topk_softmax_fake(topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool) -> None:
pass
if current_platform.is_rocm():
direct_register_custom_op(
op_name="rocm_aiter_asm_moe_tkw1",
op_func=rocm_aiter_asm_moe_tkw1_impl,
mutates_args=[],
fake_impl=rocm_aiter_asm_moe_tkw1_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_ck_moe",
op_func=rocm_aiter_ck_moe_impl,
mutates_args=[],
fake_impl=rocm_aiter_ck_moe_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_fmoe_fp8_blockscale_g1u1",
op_func=rocm_aiter_fmoe_fp8_blockscale_g1u1_impl,
mutates_args=[],
fake_impl=rocm_aiter_fmoe_fp8_blockscale_g1u1_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_asm_moe",
op_func=rocm_aiter_asm_moe_impl,
mutates_args=[],
fake_impl=rocm_aiter_asm_moe_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_topk_softmax",
op_func=rocm_aiter_topk_softmax_impl,
mutates_args=["topk_weights", "topk_indices", "token_expert_indices"],
fake_impl=rocm_aiter_topk_softmax_fake,
dispatch_key=current_platform.dispatch_key,
)
def rocm_aiter_fused_experts(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
allow_deep_gemm: bool = False) -> torch.Tensor:
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8) per_token_group_quant_fp8)
@ -84,60 +290,24 @@ def rocm_aiter_fused_experts(
topk_weights = topk_weights.to(torch.float32) topk_weights = topk_weights.to(torch.float32)
topk_ids = topk_ids.to(torch.int32) topk_ids = topk_ids.to(torch.int32)
if (block_shape is not None) and use_fp8_w8a8: # w8a8 block-scaled
if block_shape is not None and use_fp8_w8a8:
assert not apply_router_weight_on_input, ( assert not apply_router_weight_on_input, (
"apply_router_weight_on_input is not supported for block scaled moe" "apply_router_weight_on_input is not supported for block scaled moe"
) )
assert w1_scale is not None assert w1_scale is not None
assert w2_scale is not None assert w2_scale is not None
local_E = E = w1.shape[0]
if expert_map is not None:
E = expert_map.numel()
topk = topk_ids.shape[1]
model_dim = w1.shape[-1]
dtype = hidden_states.dtype
# The default block sizes are 128 in AITER. # The default block sizes are 128 in AITER.
if block_shape is None: block_shape = [128, 128] if block_shape is None else block_shape
block_shape = [128, 128]
scale_blk_k = block_shape[1] a1, a1_scale = per_token_group_quant_fp8(hidden_states, block_shape[1])
( return torch.ops.vllm.rocm_aiter_fmoe_fp8_blockscale_g1u1(
sorted_token_ids, topk_ids, topk_weights, hidden_states.dtype, expert_map, a1, w1,
sorted_weight_buf, w2, w1_scale, w2_scale, a1_scale, block_shape, None)
sorted_expert_ids,
num_valid_ids,
out_asm,
) = rocm_aiter_asm_fmoe.moe_sorting_ck(topk_ids,
topk_weights,
E,
model_dim,
dtype,
expert_mask=expert_map)
a1, a1_scale = per_token_group_quant_fp8(hidden_states, scale_blk_k)
rocm_aiter.fmoe_fp8_blockscale_g1u1(
out_asm,
a1,
w1,
w2,
sorted_token_ids,
sorted_weight_buf,
sorted_expert_ids,
num_valid_ids,
topk,
w1_scale.view(local_E, -1),
w2_scale.view(local_E, -1),
a1_scale.t().contiguous(),
block_shape[0],
block_shape[1],
None,
)
return out_asm
# w8a8 per-channel quantization
elif per_channel_quant and apply_router_weight_on_input and use_fp8_w8a8: elif per_channel_quant and apply_router_weight_on_input and use_fp8_w8a8:
# AITER tkw1 kernel for FP8 models with `apply_router_weight_on_input` # AITER tkw1 kernel for FP8 models with `apply_router_weight_on_input`
# This applies topk_weights on the GEMM output of the first FC layer # This applies topk_weights on the GEMM output of the first FC layer
@ -148,34 +318,36 @@ def rocm_aiter_fused_experts(
"Only support topk=1 when" "Only support topk=1 when"
" `apply_router_weight_on_input` is True") " `apply_router_weight_on_input` is True")
return rocm_aiter_asm_moe_tkw1(hidden_states, return torch.ops.vllm.rocm_aiter_asm_moe_tkw1(
w1, hidden_states,
w2, w1,
topk_weights, w2,
topk_ids, topk_weights,
fc1_scale=w1_scale, topk_ids,
fc2_scale=w2_scale, fc1_scale=w1_scale,
fc1_smooth_scale=None, fc2_scale=w2_scale,
fc2_smooth_scale=None, fc1_smooth_scale=None,
a16=False, fc2_smooth_scale=None,
per_tensor_quant_scale=None, a16=False,
expert_mask=expert_map, per_tensor_quant_scale=None,
activation_str=activation) expert_mask=expert_map,
activation_str=activation)
# w8a8 per-tensor activation per-tensor weight
elif use_fp8_w8a8: elif use_fp8_w8a8:
assert not apply_router_weight_on_input, ( assert not apply_router_weight_on_input, (
"apply_router_weight_on_input is not supported for fp8_w8a8") "apply_router_weight_on_input is not supported for fp8_w8a8")
return rocm_aiter_asm_fmoe.asm_moe(hidden_states=hidden_states, return torch.ops.vllm.rocm_aiter_asm_moe(hidden_states=hidden_states,
w1=w1, w1=w1,
w2=w2, w2=w2,
topk_weight=topk_weights, topk_weight=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
fc1_scale=w1_scale, fc1_scale=w1_scale,
fc2_scale=w2_scale, fc2_scale=w2_scale,
fc1_smooth_scale=None, fc1_smooth_scale=None,
fc2_smooth_scale=None, fc2_smooth_scale=None,
a16=False) a16=False,
activation=activation)
if apply_router_weight_on_input: if apply_router_weight_on_input:
assert (topk_weights.dim() == 2 assert (topk_weights.dim() == 2
), "`topk_weights` should be in shape (num_tokens, topk)" ), "`topk_weights` should be in shape (num_tokens, topk)"
@ -188,26 +360,26 @@ def rocm_aiter_fused_experts(
topk_ids = topk_ids.to(torch.int32) topk_ids = topk_ids.to(torch.int32)
topk_weights = torch.ones_like(topk_weights, dtype=torch.float32) topk_weights = torch.ones_like(topk_weights, dtype=torch.float32)
return rocm_aiter.ck_moe(hidden_states=hidden_states, # w16a16 fallback to rocm_aiter_ck_moe w16a16
w1=w1, return torch.ops.vllm.rocm_aiter_ck_moe(hidden_states=hidden_states,
w2=w2, w1=w1,
topk_weights=topk_weights, w2=w2,
topk_ids=topk_ids) topk_weights=topk_weights,
topk_ids=topk_ids)
def rocm_aiter_topk_softmax(topk_weights: torch.Tensor, def rocm_aiter_topk_softmax(topk_weights: torch.Tensor,
topk_indices: torch.Tensor, topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor, token_expert_indices: torch.Tensor,
gating_output: torch.Tensor, gating_output: torch.Tensor,
renormalize: bool) -> tuple[torch.Tensor, ...]: renormalize: bool) -> Tuple[torch.Tensor, ...]:
import aiter as rocm_aiter torch.ops.vllm.rocm_aiter_topk_softmax(topk_weights, topk_indices,
rocm_aiter.topk_softmax(topk_weights, topk_indices, token_expert_indices, token_expert_indices, gating_output,
gating_output, renormalize) renormalize)
return topk_weights, topk_indices return topk_weights, topk_indices
def shuffle_weights(*tensors: torch.Tensor) -> tuple[torch.Tensor, ...]: def shuffle_weights(*tensors: torch.Tensor) -> Tuple[torch.Tensor, ...]:
""" """
Applies shuffle_weight function from AITER to each Applies shuffle_weight function from AITER to each
input tensor and returns them. input tensor and returns them.
@ -216,15 +388,14 @@ def shuffle_weights(*tensors: torch.Tensor) -> tuple[torch.Tensor, ...]:
*tensors: Variable number of torch.Tensor objects. *tensors: Variable number of torch.Tensor objects.
Returns: Returns:
A tuple of shuffled tensors. A Tuple of shuffled tensors.
""" """
from aiter.ops.shuffle import shuffle_weight from aiter.ops.shuffle import shuffle_weight
return tuple(shuffle_weight(tensor) for tensor in tensors) return tuple(shuffle_weight(tensor) for tensor in tensors)
def expand_weights(*tensors: torch.Tensor, def expand_weights(*tensors: torch.Tensor,
expansion_dims: list[int]) -> tuple[torch.Tensor, ...]: expansion_dims: list[int]) -> Tuple[torch.Tensor, ...]:
""" """
Expands the dimensions of input tensors. Expands the dimensions of input tensors.
@ -234,7 +405,7 @@ def expand_weights(*tensors: torch.Tensor,
corresponding to each tensor. corresponding to each tensor.
Returns: Returns:
A tuple of tensors with expanded dimensions. A Tuple of tensors with expanded dimensions.
""" """
assert len(tensors) == len(expansion_dims), \ assert len(tensors) == len(expansion_dims), \

View File

@ -304,9 +304,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
e_score_correction_bias=e_score_correction_bias) e_score_correction_bias=e_score_correction_bias)
return self.fused_experts_func( return self.fused_experts_func(
x, hidden_states=x,
layer.w13_weight, w1=layer.w13_weight,
layer.w2_weight, w2=layer.w2_weight,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,