From 02cabff207ca68094a73ba21296c82cdbcb1d1a5 Mon Sep 17 00:00:00 2001 From: TJian Date: Tue, 1 Jul 2025 09:48:30 -0700 Subject: [PATCH] [V1] [ROCm] Enable EP with AITER Fused MoE (#20270) Signed-off-by: tjtanaa --- vllm/model_executor/layers/fused_moe/layer.py | 2 +- .../layers/fused_moe/rocm_aiter_fused_moe.py | 11 +++++++++-- .../compressed_tensors/compressed_tensors_moe.py | 3 ++- vllm/model_executor/layers/quantization/fp8.py | 4 +++- 4 files changed, 15 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index d6ead084af99c..65a46ba5554bd 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -646,13 +646,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): indices_type=self.topk_indices_dtype) if self.rocm_aiter_moe_enabled: - assert expert_map is None return self.rocm_aiter_fused_experts( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, + expert_map=expert_map, activation=activation, apply_router_weight_on_input=apply_router_weight_on_input) else: diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index 00f1b1f6b911f..93e20c3477bbe 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -315,7 +315,8 @@ def rocm_aiter_fused_experts( w2_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[list[int]] = None) -> torch.Tensor: + block_shape: Optional[list[int]] = None, + expert_map: Optional[torch.Tensor] = None) -> torch.Tensor: activation_method = (ActivationMethod.SILU if activation == "silu" else ActivationMethod.GELU) @@ -323,6 +324,11 @@ def rocm_aiter_fused_experts( topk_weights = topk_weights.to(torch.float32) topk_ids = topk_ids.to(torch.int32) + if expert_map is not None: + expert_mask = (expert_map > -1).to(torch.int32) + else: + expert_mask = None + # w8a8 per-channel quantization if per_channel_quant and apply_router_weight_on_input and use_fp8_w8a8: # AITER tkw1 kernel for FP8 models with `apply_router_weight_on_input` @@ -346,7 +352,7 @@ def rocm_aiter_fused_experts( fc2_smooth_scale=None, a16=False, per_tensor_quant_scale=None, - expert_mask=None, + expert_mask=expert_mask, activation_method=activation_method) else: @@ -378,6 +384,7 @@ def rocm_aiter_fused_experts( w2, topk_weights, topk_ids, + expert_mask=expert_mask, quant_method=quant_method, activation_method=activation_method, w1_scale=w1_scale, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index fa4ce5668091b..92b82f5a02ffb 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -633,7 +633,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale) + a2_scale=layer.w2_input_scale, + expert_map=expert_map) if self.use_marlin: assert activation == "silu", ( f"{activation} not supported for Marlin MoE.") diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 60df679a74bda..ead345c794b8b 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -442,6 +442,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): """ def __init__(self, quant_config: Fp8Config): + from vllm.model_executor.layers.fused_moe import fused_experts self.quant_config = quant_config self.block_quant = self.quant_config.weight_block_size is not None @@ -879,7 +880,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): if self.block_quant else layer.w2_weight_scale), a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, - block_shape=self.quant_config.weight_block_size) + block_shape=self.quant_config.weight_block_size, + expert_map=expert_map) elif self.use_marlin: assert activation == "silu", ( f"{activation} not supported for Marlin MoE.")