From 612c2edb4f940b26550ca64576995957e4118a12 Mon Sep 17 00:00:00 2001 From: TJian Date: Wed, 14 May 2025 18:03:11 +0800 Subject: [PATCH] [FEAT] [ROCm]: Add AITER CK 2 Stages MoE support (#17110) Signed-off-by: tjtanaa Co-authored-by: Gregory Shtrasberg --- tests/kernels/moe/test_moe.py | 7 + .../model_executor/test_enabled_custom_ops.py | 24 +-- .../layers/fused_moe/fused_moe.py | 3 - vllm/model_executor/layers/fused_moe/layer.py | 32 +++- .../layers/fused_moe/rocm_aiter_fused_moe.py | 173 +++++++++++------- .../compressed_tensors_moe.py | 32 +++- .../model_executor/layers/quantization/fp8.py | 42 ++++- 7 files changed, 201 insertions(+), 112 deletions(-) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index c1d0940f26cb..96b090136e3c 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -224,9 +224,16 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, """Make sure our Mixtral MoE implementation agrees with the one from huggingface.""" + # clear the cache before every test + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + is_rocm_aiter_moe_enabled) + is_rocm_aiter_moe_enabled.cache_clear() if use_rocm_aiter: monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + if dtype == torch.float32: + pytest.skip("AITER ROCm test skip for float32") + # Instantiate our and huggingface's MoE blocks config = MixtralConfig() hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda") diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index 93453ddb657c..e957db5b3f16 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -8,10 +8,8 @@ from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.activation import (GeluAndMul, ReLUSquaredActivation, SiluAndMul) -from vllm.model_executor.layers.fused_moe.fused_moe import ( - dispatch_fused_experts_func, dispatch_topk_func, - torch_vllm_inplace_fused_experts, torch_vllm_outplace_fused_experts, - vllm_topk_softmax) +from vllm.model_executor.layers.fused_moe.fused_moe import (dispatch_topk_func, + vllm_topk_softmax) from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( is_rocm_aiter_moe_enabled) from vllm.model_executor.layers.layernorm import ( @@ -142,24 +140,6 @@ def test_topk_dispatch(use_rocm_aiter: str, monkeypatch): assert topk_func == vllm_topk_softmax -@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"]) -@pytest.mark.parametrize("inplace", [True, False]) -def test_fused_experts_dispatch(use_rocm_aiter: str, inplace: bool, - monkeypatch): - - monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) - is_rocm_aiter_moe_enabled.cache_clear() - fused_experts_func = dispatch_fused_experts_func(inplace) - if current_platform.is_rocm() and int(use_rocm_aiter): - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - rocm_aiter_fused_experts) - assert fused_experts_func == rocm_aiter_fused_experts - elif inplace: - assert fused_experts_func == torch_vllm_inplace_fused_experts - else: - assert fused_experts_func == torch_vllm_outplace_fused_experts - - @pytest.mark.parametrize("add_residual", [True, False]) @pytest.mark.parametrize("use_rocm_aiter", ["0", "1"]) @pytest.mark.parametrize("use_rocm_aiter_norm", ["0", "1"]) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 8c28cedbcd77..7bf4243305ac 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1100,9 +1100,6 @@ def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor: def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]: - if is_rocm_aiter_moe_enabled(): - from .rocm_aiter_fused_moe import rocm_aiter_fused_experts - return rocm_aiter_fused_experts if inplace: return torch_vllm_inplace_fused_experts return torch_vllm_outplace_fused_experts diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index f74e38bde6b8..14f360e3bbf3 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -84,6 +84,16 @@ class FusedMoEMethodBase(QuantizeMethodBase): class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): """MoE method without quantization.""" + def __init__(self): + super().__init__() + + self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() + if self.rocm_aiter_moe_enabled: + from .rocm_aiter_fused_moe import rocm_aiter_fused_experts + self.rocm_aiter_fused_experts = rocm_aiter_fused_experts + else: + self.rocm_aiter_fused_experts = None # type: ignore + def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): @@ -126,11 +136,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data) # Lazy import to avoid importing triton. from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - is_rocm_aiter_moe_enabled, shuffle_weights) - if is_rocm_aiter_moe_enabled(): - # reshaping weights is required for aiter moe kernel. - shuffled_w13, shuffled_w2 = shuffle_weights( - layer.w13_weight.data, layer.w2_weight.data) + shuffle_weights) + + if self.rocm_aiter_moe_enabled: + # use 2stage ck moe layout + shuffled_w13, shuffled_w2 = shuffle_weights(layer.w13_weight.data, + layer.w2_weight.data, + layout=(32, 32)) layer.w13_weight.data = shuffled_w13 layer.w2_weight.data = shuffled_w2 @@ -211,6 +223,16 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias) + if self.rocm_aiter_moe_enabled: + return self.rocm_aiter_fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input) + return fused_experts( hidden_states=x, w1=layer.w13_weight, diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index 04155ab69afd..a92081862bfa 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -20,7 +20,7 @@ def rocm_aiter_asm_moe_tkw1_impl( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, - topk_weight: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, fc1_scale: Optional[torch.Tensor] = None, fc2_scale: Optional[torch.Tensor] = None, @@ -40,7 +40,7 @@ def rocm_aiter_asm_moe_tkw1_impl( return asm_moe_tkw1(hidden_states, w1, w2, - topk_weight, + topk_weights, topk_ids, fc1_scale=fc1_scale, fc2_scale=fc2_scale, @@ -56,7 +56,7 @@ def rocm_aiter_asm_moe_tkw1_fake( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, - topk_weight: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, fc1_scale: Optional[torch.Tensor] = None, fc2_scale: Optional[torch.Tensor] = None, @@ -69,23 +69,6 @@ def rocm_aiter_asm_moe_tkw1_fake( return torch.empty_like(hidden_states) -def rocm_aiter_ck_moe_impl(hidden_states: torch.Tensor, w1: torch.Tensor, - w2: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.Tensor) -> torch.Tensor: - from aiter import ck_moe - return ck_moe(hidden_states=hidden_states, - w1=w1, - w2=w2, - topk_weights=topk_weights, - topk_ids=topk_ids) - - -def rocm_aiter_ck_moe_fake(hidden_states: torch.Tensor, w1: torch.Tensor, - w2: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.Tensor) -> torch.Tensor: - return torch.empty_like(hidden_states) - - def rocm_aiter_fmoe_fp8_blockscale_g1u1_impl( topk_ids: torch.Tensor, topk_weights: torch.Tensor, @@ -152,7 +135,7 @@ def rocm_aiter_fmoe_fp8_blockscale_g1u1_fake( def rocm_aiter_asm_moe_impl(hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, - topk_weight: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, fc1_scale: Optional[torch.Tensor] = None, fc2_scale: Optional[torch.Tensor] = None, @@ -175,7 +158,7 @@ def rocm_aiter_asm_moe_impl(hidden_states: torch.Tensor, return rocm_aiter_asm_fmoe.asm_moe(hidden_states=hidden_states, w1=w1, w2=w2, - topk_weight=topk_weight, + topk_weight=topk_weights, topk_ids=topk_ids, fc1_scale=fc1_scale, fc2_scale=fc2_scale, @@ -188,7 +171,7 @@ def rocm_aiter_asm_moe_impl(hidden_states: torch.Tensor, def rocm_aiter_asm_moe_fake(hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, - topk_weight: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, fc1_scale: Optional[torch.Tensor] = None, fc2_scale: Optional[torch.Tensor] = None, @@ -199,6 +182,49 @@ def rocm_aiter_asm_moe_fake(hidden_states: torch.Tensor, return torch.empty_like(hidden_states) +def rocm_aiter_ck_moe_2stages_impl( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + fc1_scale: Optional[torch.Tensor] = None, + fc2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_size: Optional[list[int]] = None, + expert_mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: + from aiter.fused_moe_bf16_asm import ck_moe_2stages + return ck_moe_2stages(a1=hidden_states, + w1=w1, + w2=w2, + topk_weight=topk_weights, + topk_ids=topk_ids, + fc1_scale=fc1_scale, + fc2_scale=fc2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + block_size=block_size, + expert_mask=expert_mask) + + +def rocm_aiter_ck_moe_2stages_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + fc1_scale: Optional[torch.Tensor] = None, + fc2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_size: Optional[list[int]] = None, + expert_mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: + return torch.empty_like(hidden_states) + + def rocm_aiter_topk_softmax_impl(topk_weights: torch.Tensor, topk_indices: torch.Tensor, token_expert_indices: torch.Tensor, @@ -258,14 +284,6 @@ if current_platform.is_rocm(): dispatch_key=current_platform.dispatch_key, ) - direct_register_custom_op( - op_name="rocm_aiter_ck_moe", - op_func=rocm_aiter_ck_moe_impl, - mutates_args=[], - fake_impl=rocm_aiter_ck_moe_fake, - dispatch_key=current_platform.dispatch_key, - ) - direct_register_custom_op( op_name="rocm_aiter_fmoe_fp8_blockscale_g1u1", op_func=rocm_aiter_fmoe_fp8_blockscale_g1u1_impl, @@ -282,6 +300,14 @@ if current_platform.is_rocm(): dispatch_key=current_platform.dispatch_key, ) + direct_register_custom_op( + op_name="rocm_aiter_ck_moe_2stages", + op_func=rocm_aiter_ck_moe_2stages_impl, + mutates_args=[], + fake_impl=rocm_aiter_ck_moe_2stages_fake, + dispatch_key=current_platform.dispatch_key, + ) + direct_register_custom_op( op_name="rocm_aiter_topk_softmax", op_func=rocm_aiter_topk_softmax_impl, @@ -331,29 +357,21 @@ def rocm_aiter_biased_group_topk( return topk_weights, topk_ids -def rocm_aiter_fused_experts(hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - inplace: bool = False, - activation: str = "silu", - apply_router_weight_on_input: bool = False, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - per_channel_quant: bool = False, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_zp: Optional[torch.Tensor] = None, - w2_zp: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[list[int]] = None, - allow_deep_gemm: bool = False) -> torch.Tensor: +def rocm_aiter_fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + per_channel_quant: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[list[int]] = None) -> torch.Tensor: from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8) @@ -376,8 +394,8 @@ def rocm_aiter_fused_experts(hidden_states: torch.Tensor, a1, a1_scale = per_token_group_quant_fp8(hidden_states, block_shape[1]) return torch.ops.vllm.rocm_aiter_fmoe_fp8_blockscale_g1u1( - topk_ids, topk_weights, hidden_states.dtype, expert_map, a1, w1, - w2, w1_scale, w2_scale, a1_scale, block_shape, None) + topk_ids, topk_weights, hidden_states.dtype, None, a1, w1, w2, + w1_scale, w2_scale, a1_scale, block_shape, None) # w8a8 per-channel quantization elif per_channel_quant and apply_router_weight_on_input and use_fp8_w8a8: @@ -402,17 +420,36 @@ def rocm_aiter_fused_experts(hidden_states: torch.Tensor, fc2_smooth_scale=None, a16=False, per_tensor_quant_scale=None, - expert_mask=expert_map, + expert_mask=None, activation_str=activation) # w8a8 per-tensor activation per-tensor weight elif use_fp8_w8a8: assert not apply_router_weight_on_input, ( "apply_router_weight_on_input is not supported for fp8_w8a8") + + # - faster static per-tensor-activation static per-tensor-weight + # fp8 quantization w8a8 + if a1_scale is not None and a2_scale is not None: + return torch.ops.vllm.rocm_aiter_ck_moe_2stages( + hidden_states=hidden_states, + w1=w1, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + fc1_scale=w1_scale, + fc2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale) + + # - fallback static per-tensor-activation static per-tensor-weight + # fp8 quantization w8a8 + # - dynamic per-tensor activation static per-tensor-weight + # fp8 quantization w8a8 return torch.ops.vllm.rocm_aiter_asm_moe(hidden_states=hidden_states, w1=w1, w2=w2, - topk_weight=topk_weights, + topk_weights=topk_weights, topk_ids=topk_ids, fc1_scale=w1_scale, fc2_scale=w2_scale, @@ -432,12 +469,12 @@ def rocm_aiter_fused_experts(hidden_states: torch.Tensor, topk_ids = topk_ids.to(torch.int32) topk_weights = torch.ones_like(topk_weights, dtype=torch.float32) - # w16a16 fallback to rocm_aiter_ck_moe w16a16 - return torch.ops.vllm.rocm_aiter_ck_moe(hidden_states=hidden_states, - w1=w1, - w2=w2, - topk_weights=topk_weights, - topk_ids=topk_ids) + return torch.ops.vllm.rocm_aiter_ck_moe_2stages( + hidden_states=hidden_states, + w1=w1, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids) def rocm_aiter_topk_softmax(topk_weights: torch.Tensor, @@ -451,7 +488,8 @@ def rocm_aiter_topk_softmax(topk_weights: torch.Tensor, return topk_weights, topk_indices -def shuffle_weights(*tensors: torch.Tensor) -> tuple[torch.Tensor, ...]: +def shuffle_weights(*tensors: torch.Tensor, + layout: tuple[int, int]) -> tuple[torch.Tensor, ...]: """ Applies shuffle_weight function from AITER to each input tensor and returns them. @@ -463,7 +501,8 @@ def shuffle_weights(*tensors: torch.Tensor) -> tuple[torch.Tensor, ...]: A Tuple of shuffled tensors. """ from aiter.ops.shuffle import shuffle_weight - return tuple(shuffle_weight(tensor) for tensor in tensors) + + return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors) def expand_weights(*tensors: torch.Tensor, @@ -485,4 +524,4 @@ def expand_weights(*tensors: torch.Tensor, return tuple( tensor.unsqueeze(-1).unsqueeze(-1).expand((-1, dim, -1)) - for tensor, dim in zip(tensors, expansion_dims)) + for tensor, dim in zip(tensors, expansion_dims)) \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index a74f1f7233af..fa0067c44802 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -125,6 +125,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): # Disable marlin for rocm if current_platform.is_rocm(): self.use_marlin = False + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + is_rocm_aiter_moe_enabled) + + self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, @@ -276,24 +280,22 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False) - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - is_rocm_aiter_moe_enabled) - # Property to determine if AITER is used - if is_rocm_aiter_moe_enabled(): + if self.rocm_aiter_moe_enabled: from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501 rocm_aiter_fused_experts, shuffle_weights) # reshaping weights is required for aiter moe kernel. - shuffled_w13, shuffled_w2 = shuffle_weights( - layer.w13_weight.data, layer.w2_weight.data) + shuffled_w13, shuffled_w2 = shuffle_weights(layer.w13_weight.data, + layer.w2_weight.data, + layout=(16, 16)) layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False) layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False) - self.fused_experts_func = rocm_aiter_fused_experts + self.rocm_aiter_fused_experts_func = rocm_aiter_fused_experts else: from vllm.model_executor.layers.fused_moe import fused_experts self.fused_experts_func = fused_experts @@ -335,6 +337,22 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias) + if self.rocm_aiter_moe_enabled: + return self.rocm_aiter_fused_experts_func( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + use_fp8_w8a8=True, + per_channel_quant=self.weight_quant.strategy == + QuantizationStrategy.CHANNEL, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale) if self.use_marlin: assert activation == "silu", ( f"{activation} not supported for Marlin MoE.") diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 589ca7bed329..cfd398c07fb9 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -591,6 +591,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( expand_weights, is_rocm_aiter_moe_enabled, shuffle_weights) + self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() + # TODO (rob): refactor block quant into separate class. if self.block_quant: assert self.quant_config.activation_scheme == "dynamic" @@ -616,10 +618,12 @@ class Fp8MoEMethod(FusedMoEMethodBase): layer.w2_weight = Parameter(w2_weight, requires_grad=False) layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv, requires_grad=False) - if is_rocm_aiter_moe_enabled(): + if self.rocm_aiter_moe_enabled: # reshaping weights is required for aiter moe kernel. shuffled_w13, shuffled_w2 = shuffle_weights( - layer.w13_weight.data, layer.w2_weight.data) + layer.w13_weight.data, + layer.w2_weight.data, + layout=(16, 16)) layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False) @@ -663,7 +667,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): requires_grad=False) layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) - if is_rocm_aiter_moe_enabled(): + if self.rocm_aiter_moe_enabled: # reshaping weights is required for aiter moe kernel. w13_scales, w2_scales = expand_weights( layer.w13_weight_scale.data, @@ -676,8 +680,9 @@ class Fp8MoEMethod(FusedMoEMethodBase): layer.w2_weight_scale = torch.nn.Parameter( w2_scales.contiguous(), requires_grad=False) - shuffled_w13, shuffled_w2 = shuffle_weights( - layer.w13_weight, layer.w2_weight) + shuffled_w13, shuffled_w2 = shuffle_weights(layer.w13_weight, + layer.w2_weight, + layout=(16, 16)) layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False) @@ -748,7 +753,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): dq_weight, max_w13_scales[expert_id]) start += shard_size - if is_rocm_aiter_moe_enabled(): + if self.rocm_aiter_moe_enabled: # reshaping weights is required for aiter moe kernel. expansion_dims = [ layer.w13_weight.shape[1], layer.w2_weight.shape[1] @@ -760,8 +765,9 @@ class Fp8MoEMethod(FusedMoEMethodBase): layer.w2_weight_scale = torch.nn.Parameter( w2_scales.contiguous(), requires_grad=False) - shuffled_w13, shuffled_w2 = shuffle_weights( - layer.w13_weight, layer.w2_weight) + shuffled_w13, shuffled_w2 = shuffle_weights(layer.w13_weight, + layer.w2_weight, + layout=(32, 32)) layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False) @@ -796,6 +802,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): activation: str = "silu", ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe import fused_experts + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + rocm_aiter_fused_experts) topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, @@ -810,6 +818,24 @@ class Fp8MoEMethod(FusedMoEMethodBase): e_score_correction_bias=e_score_correction_bias, ) + if self.rocm_aiter_moe_enabled: + return rocm_aiter_fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=activation, + use_fp8_w8a8=True, + apply_router_weight_on_input=apply_router_weight_on_input, + w1_scale=(layer.w13_weight_scale_inv + if self.block_quant else layer.w13_weight_scale), + w2_scale=(layer.w2_weight_scale_inv + if self.block_quant else layer.w2_weight_scale), + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + block_shape=self.quant_config.weight_block_size) + if self.use_marlin: assert activation == "silu", ( f"{activation} not supported for Marlin MoE.")