[FEAT] [ROCm]: Add AITER CK 2 Stages MoE support (#17110)

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Co-authored-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
TJian 2025-05-14 18:03:11 +08:00 committed by GitHub
parent 38fe728d60
commit 612c2edb4f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 201 additions and 112 deletions

View File

@ -224,9 +224,16 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
"""Make sure our Mixtral MoE implementation agrees with the one from """Make sure our Mixtral MoE implementation agrees with the one from
huggingface.""" huggingface."""
# clear the cache before every test
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled)
is_rocm_aiter_moe_enabled.cache_clear()
if use_rocm_aiter: if use_rocm_aiter:
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
if dtype == torch.float32:
pytest.skip("AITER ROCm test skip for float32")
# Instantiate our and huggingface's MoE blocks # Instantiate our and huggingface's MoE blocks
config = MixtralConfig() config = MixtralConfig()
hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda") hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda")

View File

@ -8,10 +8,8 @@ from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.activation import (GeluAndMul, from vllm.model_executor.layers.activation import (GeluAndMul,
ReLUSquaredActivation, ReLUSquaredActivation,
SiluAndMul) SiluAndMul)
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (dispatch_topk_func,
dispatch_fused_experts_func, dispatch_topk_func, vllm_topk_softmax)
torch_vllm_inplace_fused_experts, torch_vllm_outplace_fused_experts,
vllm_topk_softmax)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled) is_rocm_aiter_moe_enabled)
from vllm.model_executor.layers.layernorm import ( from vllm.model_executor.layers.layernorm import (
@ -142,24 +140,6 @@ def test_topk_dispatch(use_rocm_aiter: str, monkeypatch):
assert topk_func == vllm_topk_softmax assert topk_func == vllm_topk_softmax
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
@pytest.mark.parametrize("inplace", [True, False])
def test_fused_experts_dispatch(use_rocm_aiter: str, inplace: bool,
monkeypatch):
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
is_rocm_aiter_moe_enabled.cache_clear()
fused_experts_func = dispatch_fused_experts_func(inplace)
if current_platform.is_rocm() and int(use_rocm_aiter):
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_fused_experts)
assert fused_experts_func == rocm_aiter_fused_experts
elif inplace:
assert fused_experts_func == torch_vllm_inplace_fused_experts
else:
assert fused_experts_func == torch_vllm_outplace_fused_experts
@pytest.mark.parametrize("add_residual", [True, False]) @pytest.mark.parametrize("add_residual", [True, False])
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"]) @pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
@pytest.mark.parametrize("use_rocm_aiter_norm", ["0", "1"]) @pytest.mark.parametrize("use_rocm_aiter_norm", ["0", "1"])

View File

@ -1100,9 +1100,6 @@ def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor:
def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]: def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]:
if is_rocm_aiter_moe_enabled():
from .rocm_aiter_fused_moe import rocm_aiter_fused_experts
return rocm_aiter_fused_experts
if inplace: if inplace:
return torch_vllm_inplace_fused_experts return torch_vllm_inplace_fused_experts
return torch_vllm_outplace_fused_experts return torch_vllm_outplace_fused_experts

View File

@ -84,6 +84,16 @@ class FusedMoEMethodBase(QuantizeMethodBase):
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
"""MoE method without quantization.""" """MoE method without quantization."""
def __init__(self):
super().__init__()
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
if self.rocm_aiter_moe_enabled:
from .rocm_aiter_fused_moe import rocm_aiter_fused_experts
self.rocm_aiter_fused_experts = rocm_aiter_fused_experts
else:
self.rocm_aiter_fused_experts = None # type: ignore
def create_weights(self, layer: torch.nn.Module, num_experts: int, def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int, hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs): params_dtype: torch.dtype, **extra_weight_attrs):
@ -126,11 +136,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data) layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data)
# Lazy import to avoid importing triton. # Lazy import to avoid importing triton.
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled, shuffle_weights) shuffle_weights)
if is_rocm_aiter_moe_enabled():
# reshaping weights is required for aiter moe kernel. if self.rocm_aiter_moe_enabled:
shuffled_w13, shuffled_w2 = shuffle_weights( # use 2stage ck moe layout
layer.w13_weight.data, layer.w2_weight.data) shuffled_w13, shuffled_w2 = shuffle_weights(layer.w13_weight.data,
layer.w2_weight.data,
layout=(32, 32))
layer.w13_weight.data = shuffled_w13 layer.w13_weight.data = shuffled_w13
layer.w2_weight.data = shuffled_w2 layer.w2_weight.data = shuffled_w2
@ -211,6 +223,16 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias) e_score_correction_bias=e_score_correction_bias)
if self.rocm_aiter_moe_enabled:
return self.rocm_aiter_fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input)
return fused_experts( return fused_experts(
hidden_states=x, hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,

View File

@ -20,7 +20,7 @@ def rocm_aiter_asm_moe_tkw1_impl(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
topk_weight: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
fc1_scale: Optional[torch.Tensor] = None, fc1_scale: Optional[torch.Tensor] = None,
fc2_scale: Optional[torch.Tensor] = None, fc2_scale: Optional[torch.Tensor] = None,
@ -40,7 +40,7 @@ def rocm_aiter_asm_moe_tkw1_impl(
return asm_moe_tkw1(hidden_states, return asm_moe_tkw1(hidden_states,
w1, w1,
w2, w2,
topk_weight, topk_weights,
topk_ids, topk_ids,
fc1_scale=fc1_scale, fc1_scale=fc1_scale,
fc2_scale=fc2_scale, fc2_scale=fc2_scale,
@ -56,7 +56,7 @@ def rocm_aiter_asm_moe_tkw1_fake(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
topk_weight: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
fc1_scale: Optional[torch.Tensor] = None, fc1_scale: Optional[torch.Tensor] = None,
fc2_scale: Optional[torch.Tensor] = None, fc2_scale: Optional[torch.Tensor] = None,
@ -69,23 +69,6 @@ def rocm_aiter_asm_moe_tkw1_fake(
return torch.empty_like(hidden_states) return torch.empty_like(hidden_states)
def rocm_aiter_ck_moe_impl(hidden_states: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor) -> torch.Tensor:
from aiter import ck_moe
return ck_moe(hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids)
def rocm_aiter_ck_moe_fake(hidden_states: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor) -> torch.Tensor:
return torch.empty_like(hidden_states)
def rocm_aiter_fmoe_fp8_blockscale_g1u1_impl( def rocm_aiter_fmoe_fp8_blockscale_g1u1_impl(
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
@ -152,7 +135,7 @@ def rocm_aiter_fmoe_fp8_blockscale_g1u1_fake(
def rocm_aiter_asm_moe_impl(hidden_states: torch.Tensor, def rocm_aiter_asm_moe_impl(hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
topk_weight: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
fc1_scale: Optional[torch.Tensor] = None, fc1_scale: Optional[torch.Tensor] = None,
fc2_scale: Optional[torch.Tensor] = None, fc2_scale: Optional[torch.Tensor] = None,
@ -175,7 +158,7 @@ def rocm_aiter_asm_moe_impl(hidden_states: torch.Tensor,
return rocm_aiter_asm_fmoe.asm_moe(hidden_states=hidden_states, return rocm_aiter_asm_fmoe.asm_moe(hidden_states=hidden_states,
w1=w1, w1=w1,
w2=w2, w2=w2,
topk_weight=topk_weight, topk_weight=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
fc1_scale=fc1_scale, fc1_scale=fc1_scale,
fc2_scale=fc2_scale, fc2_scale=fc2_scale,
@ -188,7 +171,7 @@ def rocm_aiter_asm_moe_impl(hidden_states: torch.Tensor,
def rocm_aiter_asm_moe_fake(hidden_states: torch.Tensor, def rocm_aiter_asm_moe_fake(hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
topk_weight: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
fc1_scale: Optional[torch.Tensor] = None, fc1_scale: Optional[torch.Tensor] = None,
fc2_scale: Optional[torch.Tensor] = None, fc2_scale: Optional[torch.Tensor] = None,
@ -199,6 +182,49 @@ def rocm_aiter_asm_moe_fake(hidden_states: torch.Tensor,
return torch.empty_like(hidden_states) return torch.empty_like(hidden_states)
def rocm_aiter_ck_moe_2stages_impl(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
fc1_scale: Optional[torch.Tensor] = None,
fc2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_size: Optional[list[int]] = None,
expert_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
from aiter.fused_moe_bf16_asm import ck_moe_2stages
return ck_moe_2stages(a1=hidden_states,
w1=w1,
w2=w2,
topk_weight=topk_weights,
topk_ids=topk_ids,
fc1_scale=fc1_scale,
fc2_scale=fc2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_size=block_size,
expert_mask=expert_mask)
def rocm_aiter_ck_moe_2stages_fake(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
fc1_scale: Optional[torch.Tensor] = None,
fc2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_size: Optional[list[int]] = None,
expert_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
def rocm_aiter_topk_softmax_impl(topk_weights: torch.Tensor, def rocm_aiter_topk_softmax_impl(topk_weights: torch.Tensor,
topk_indices: torch.Tensor, topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor, token_expert_indices: torch.Tensor,
@ -258,14 +284,6 @@ if current_platform.is_rocm():
dispatch_key=current_platform.dispatch_key, dispatch_key=current_platform.dispatch_key,
) )
direct_register_custom_op(
op_name="rocm_aiter_ck_moe",
op_func=rocm_aiter_ck_moe_impl,
mutates_args=[],
fake_impl=rocm_aiter_ck_moe_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op( direct_register_custom_op(
op_name="rocm_aiter_fmoe_fp8_blockscale_g1u1", op_name="rocm_aiter_fmoe_fp8_blockscale_g1u1",
op_func=rocm_aiter_fmoe_fp8_blockscale_g1u1_impl, op_func=rocm_aiter_fmoe_fp8_blockscale_g1u1_impl,
@ -282,6 +300,14 @@ if current_platform.is_rocm():
dispatch_key=current_platform.dispatch_key, dispatch_key=current_platform.dispatch_key,
) )
direct_register_custom_op(
op_name="rocm_aiter_ck_moe_2stages",
op_func=rocm_aiter_ck_moe_2stages_impl,
mutates_args=[],
fake_impl=rocm_aiter_ck_moe_2stages_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op( direct_register_custom_op(
op_name="rocm_aiter_topk_softmax", op_name="rocm_aiter_topk_softmax",
op_func=rocm_aiter_topk_softmax_impl, op_func=rocm_aiter_topk_softmax_impl,
@ -331,29 +357,21 @@ def rocm_aiter_biased_group_topk(
return topk_weights, topk_ids return topk_weights, topk_ids
def rocm_aiter_fused_experts(hidden_states: torch.Tensor, def rocm_aiter_fused_experts(
w1: torch.Tensor, hidden_states: torch.Tensor,
w2: torch.Tensor, w1: torch.Tensor,
topk_weights: torch.Tensor, w2: torch.Tensor,
topk_ids: torch.Tensor, topk_weights: torch.Tensor,
inplace: bool = False, topk_ids: torch.Tensor,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False, per_channel_quant: bool = False,
use_int8_w8a16: bool = False, w1_scale: Optional[torch.Tensor] = None,
use_int4_w4a16: bool = False, w2_scale: Optional[torch.Tensor] = None,
per_channel_quant: bool = False, a1_scale: Optional[torch.Tensor] = None,
global_num_experts: int = -1, a2_scale: Optional[torch.Tensor] = None,
expert_map: Optional[torch.Tensor] = None, block_shape: Optional[list[int]] = None) -> torch.Tensor:
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
allow_deep_gemm: bool = False) -> torch.Tensor:
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8) per_token_group_quant_fp8)
@ -376,8 +394,8 @@ def rocm_aiter_fused_experts(hidden_states: torch.Tensor,
a1, a1_scale = per_token_group_quant_fp8(hidden_states, block_shape[1]) a1, a1_scale = per_token_group_quant_fp8(hidden_states, block_shape[1])
return torch.ops.vllm.rocm_aiter_fmoe_fp8_blockscale_g1u1( return torch.ops.vllm.rocm_aiter_fmoe_fp8_blockscale_g1u1(
topk_ids, topk_weights, hidden_states.dtype, expert_map, a1, w1, topk_ids, topk_weights, hidden_states.dtype, None, a1, w1, w2,
w2, w1_scale, w2_scale, a1_scale, block_shape, None) w1_scale, w2_scale, a1_scale, block_shape, None)
# w8a8 per-channel quantization # w8a8 per-channel quantization
elif per_channel_quant and apply_router_weight_on_input and use_fp8_w8a8: elif per_channel_quant and apply_router_weight_on_input and use_fp8_w8a8:
@ -402,17 +420,36 @@ def rocm_aiter_fused_experts(hidden_states: torch.Tensor,
fc2_smooth_scale=None, fc2_smooth_scale=None,
a16=False, a16=False,
per_tensor_quant_scale=None, per_tensor_quant_scale=None,
expert_mask=expert_map, expert_mask=None,
activation_str=activation) activation_str=activation)
# w8a8 per-tensor activation per-tensor weight # w8a8 per-tensor activation per-tensor weight
elif use_fp8_w8a8: elif use_fp8_w8a8:
assert not apply_router_weight_on_input, ( assert not apply_router_weight_on_input, (
"apply_router_weight_on_input is not supported for fp8_w8a8") "apply_router_weight_on_input is not supported for fp8_w8a8")
# - faster static per-tensor-activation static per-tensor-weight
# fp8 quantization w8a8
if a1_scale is not None and a2_scale is not None:
return torch.ops.vllm.rocm_aiter_ck_moe_2stages(
hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
fc1_scale=w1_scale,
fc2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale)
# - fallback static per-tensor-activation static per-tensor-weight
# fp8 quantization w8a8
# - dynamic per-tensor activation static per-tensor-weight
# fp8 quantization w8a8
return torch.ops.vllm.rocm_aiter_asm_moe(hidden_states=hidden_states, return torch.ops.vllm.rocm_aiter_asm_moe(hidden_states=hidden_states,
w1=w1, w1=w1,
w2=w2, w2=w2,
topk_weight=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
fc1_scale=w1_scale, fc1_scale=w1_scale,
fc2_scale=w2_scale, fc2_scale=w2_scale,
@ -432,12 +469,12 @@ def rocm_aiter_fused_experts(hidden_states: torch.Tensor,
topk_ids = topk_ids.to(torch.int32) topk_ids = topk_ids.to(torch.int32)
topk_weights = torch.ones_like(topk_weights, dtype=torch.float32) topk_weights = torch.ones_like(topk_weights, dtype=torch.float32)
# w16a16 fallback to rocm_aiter_ck_moe w16a16 return torch.ops.vllm.rocm_aiter_ck_moe_2stages(
return torch.ops.vllm.rocm_aiter_ck_moe(hidden_states=hidden_states, hidden_states=hidden_states,
w1=w1, w1=w1,
w2=w2, w2=w2,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids) topk_ids=topk_ids)
def rocm_aiter_topk_softmax(topk_weights: torch.Tensor, def rocm_aiter_topk_softmax(topk_weights: torch.Tensor,
@ -451,7 +488,8 @@ def rocm_aiter_topk_softmax(topk_weights: torch.Tensor,
return topk_weights, topk_indices return topk_weights, topk_indices
def shuffle_weights(*tensors: torch.Tensor) -> tuple[torch.Tensor, ...]: def shuffle_weights(*tensors: torch.Tensor,
layout: tuple[int, int]) -> tuple[torch.Tensor, ...]:
""" """
Applies shuffle_weight function from AITER to each Applies shuffle_weight function from AITER to each
input tensor and returns them. input tensor and returns them.
@ -463,7 +501,8 @@ def shuffle_weights(*tensors: torch.Tensor) -> tuple[torch.Tensor, ...]:
A Tuple of shuffled tensors. A Tuple of shuffled tensors.
""" """
from aiter.ops.shuffle import shuffle_weight from aiter.ops.shuffle import shuffle_weight
return tuple(shuffle_weight(tensor) for tensor in tensors)
return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors)
def expand_weights(*tensors: torch.Tensor, def expand_weights(*tensors: torch.Tensor,
@ -485,4 +524,4 @@ def expand_weights(*tensors: torch.Tensor,
return tuple( return tuple(
tensor.unsqueeze(-1).unsqueeze(-1).expand((-1, dim, -1)) tensor.unsqueeze(-1).unsqueeze(-1).expand((-1, dim, -1))
for tensor, dim in zip(tensors, expansion_dims)) for tensor, dim in zip(tensors, expansion_dims))

View File

@ -125,6 +125,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
# Disable marlin for rocm # Disable marlin for rocm
if current_platform.is_rocm(): if current_platform.is_rocm():
self.use_marlin = False self.use_marlin = False
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled)
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
def create_weights(self, layer: torch.nn.Module, num_experts: int, def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int, hidden_size: int, intermediate_size_per_partition: int,
@ -276,24 +280,22 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
requires_grad=False) requires_grad=False)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled)
# Property to determine if AITER is used # Property to determine if AITER is used
if is_rocm_aiter_moe_enabled(): if self.rocm_aiter_moe_enabled:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501 from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501
rocm_aiter_fused_experts, shuffle_weights) rocm_aiter_fused_experts, shuffle_weights)
# reshaping weights is required for aiter moe kernel. # reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = shuffle_weights( shuffled_w13, shuffled_w2 = shuffle_weights(layer.w13_weight.data,
layer.w13_weight.data, layer.w2_weight.data) layer.w2_weight.data,
layout=(16, 16))
layer.w13_weight = torch.nn.Parameter(shuffled_w13, layer.w13_weight = torch.nn.Parameter(shuffled_w13,
requires_grad=False) requires_grad=False)
layer.w2_weight = torch.nn.Parameter(shuffled_w2, layer.w2_weight = torch.nn.Parameter(shuffled_w2,
requires_grad=False) requires_grad=False)
self.fused_experts_func = rocm_aiter_fused_experts self.rocm_aiter_fused_experts_func = rocm_aiter_fused_experts
else: else:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
self.fused_experts_func = fused_experts self.fused_experts_func = fused_experts
@ -335,6 +337,22 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias) e_score_correction_bias=e_score_correction_bias)
if self.rocm_aiter_moe_enabled:
return self.rocm_aiter_fused_experts_func(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=True,
per_channel_quant=self.weight_quant.strategy ==
QuantizationStrategy.CHANNEL,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale)
if self.use_marlin: if self.use_marlin:
assert activation == "silu", ( assert activation == "silu", (
f"{activation} not supported for Marlin MoE.") f"{activation} not supported for Marlin MoE.")

View File

@ -591,6 +591,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
expand_weights, is_rocm_aiter_moe_enabled, shuffle_weights) expand_weights, is_rocm_aiter_moe_enabled, shuffle_weights)
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
# TODO (rob): refactor block quant into separate class. # TODO (rob): refactor block quant into separate class.
if self.block_quant: if self.block_quant:
assert self.quant_config.activation_scheme == "dynamic" assert self.quant_config.activation_scheme == "dynamic"
@ -616,10 +618,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w2_weight = Parameter(w2_weight, requires_grad=False) layer.w2_weight = Parameter(w2_weight, requires_grad=False)
layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv, layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv,
requires_grad=False) requires_grad=False)
if is_rocm_aiter_moe_enabled(): if self.rocm_aiter_moe_enabled:
# reshaping weights is required for aiter moe kernel. # reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = shuffle_weights( shuffled_w13, shuffled_w2 = shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data) layer.w13_weight.data,
layer.w2_weight.data,
layout=(16, 16))
layer.w13_weight = torch.nn.Parameter(shuffled_w13, layer.w13_weight = torch.nn.Parameter(shuffled_w13,
requires_grad=False) requires_grad=False)
@ -663,7 +667,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
requires_grad=False) requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight, layer.w2_weight = torch.nn.Parameter(w2_weight,
requires_grad=False) requires_grad=False)
if is_rocm_aiter_moe_enabled(): if self.rocm_aiter_moe_enabled:
# reshaping weights is required for aiter moe kernel. # reshaping weights is required for aiter moe kernel.
w13_scales, w2_scales = expand_weights( w13_scales, w2_scales = expand_weights(
layer.w13_weight_scale.data, layer.w13_weight_scale.data,
@ -676,8 +680,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w2_weight_scale = torch.nn.Parameter( layer.w2_weight_scale = torch.nn.Parameter(
w2_scales.contiguous(), requires_grad=False) w2_scales.contiguous(), requires_grad=False)
shuffled_w13, shuffled_w2 = shuffle_weights( shuffled_w13, shuffled_w2 = shuffle_weights(layer.w13_weight,
layer.w13_weight, layer.w2_weight) layer.w2_weight,
layout=(16, 16))
layer.w13_weight = torch.nn.Parameter(shuffled_w13, layer.w13_weight = torch.nn.Parameter(shuffled_w13,
requires_grad=False) requires_grad=False)
@ -748,7 +753,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
dq_weight, max_w13_scales[expert_id]) dq_weight, max_w13_scales[expert_id])
start += shard_size start += shard_size
if is_rocm_aiter_moe_enabled(): if self.rocm_aiter_moe_enabled:
# reshaping weights is required for aiter moe kernel. # reshaping weights is required for aiter moe kernel.
expansion_dims = [ expansion_dims = [
layer.w13_weight.shape[1], layer.w2_weight.shape[1] layer.w13_weight.shape[1], layer.w2_weight.shape[1]
@ -760,8 +765,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w2_weight_scale = torch.nn.Parameter( layer.w2_weight_scale = torch.nn.Parameter(
w2_scales.contiguous(), requires_grad=False) w2_scales.contiguous(), requires_grad=False)
shuffled_w13, shuffled_w2 = shuffle_weights( shuffled_w13, shuffled_w2 = shuffle_weights(layer.w13_weight,
layer.w13_weight, layer.w2_weight) layer.w2_weight,
layout=(32, 32))
layer.w13_weight = torch.nn.Parameter(shuffled_w13, layer.w13_weight = torch.nn.Parameter(shuffled_w13,
requires_grad=False) requires_grad=False)
@ -796,6 +802,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
activation: str = "silu", activation: str = "silu",
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_fused_experts)
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
@ -810,6 +818,24 @@ class Fp8MoEMethod(FusedMoEMethodBase):
e_score_correction_bias=e_score_correction_bias, e_score_correction_bias=e_score_correction_bias,
) )
if self.rocm_aiter_moe_enabled:
return rocm_aiter_fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
use_fp8_w8a8=True,
apply_router_weight_on_input=apply_router_weight_on_input,
w1_scale=(layer.w13_weight_scale_inv
if self.block_quant else layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale_inv
if self.block_quant else layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size)
if self.use_marlin: if self.use_marlin:
assert activation == "silu", ( assert activation == "silu", (
f"{activation} not supported for Marlin MoE.") f"{activation} not supported for Marlin MoE.")