[V1] [ROCm] Enable EP with AITER Fused MoE (#20270)

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
TJian 2025-07-01 09:48:30 -07:00 committed by GitHub
parent 3d19d47d91
commit 02cabff207
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 15 additions and 5 deletions

View File

@ -646,13 +646,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
indices_type=self.topk_indices_dtype)
if self.rocm_aiter_moe_enabled:
assert expert_map is None
return self.rocm_aiter_fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
expert_map=expert_map,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input)
else:

View File

@ -315,7 +315,8 @@ def rocm_aiter_fused_experts(
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None) -> torch.Tensor:
block_shape: Optional[list[int]] = None,
expert_map: Optional[torch.Tensor] = None) -> torch.Tensor:
activation_method = (ActivationMethod.SILU
if activation == "silu" else ActivationMethod.GELU)
@ -323,6 +324,11 @@ def rocm_aiter_fused_experts(
topk_weights = topk_weights.to(torch.float32)
topk_ids = topk_ids.to(torch.int32)
if expert_map is not None:
expert_mask = (expert_map > -1).to(torch.int32)
else:
expert_mask = None
# w8a8 per-channel quantization
if per_channel_quant and apply_router_weight_on_input and use_fp8_w8a8:
# AITER tkw1 kernel for FP8 models with `apply_router_weight_on_input`
@ -346,7 +352,7 @@ def rocm_aiter_fused_experts(
fc2_smooth_scale=None,
a16=False,
per_tensor_quant_scale=None,
expert_mask=None,
expert_mask=expert_mask,
activation_method=activation_method)
else:
@ -378,6 +384,7 @@ def rocm_aiter_fused_experts(
w2,
topk_weights,
topk_ids,
expert_mask=expert_mask,
quant_method=quant_method,
activation_method=activation_method,
w1_scale=w1_scale,

View File

@ -633,7 +633,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale)
a2_scale=layer.w2_input_scale,
expert_map=expert_map)
if self.use_marlin:
assert activation == "silu", (
f"{activation} not supported for Marlin MoE.")

View File

@ -442,6 +442,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
"""
def __init__(self, quant_config: Fp8Config):
from vllm.model_executor.layers.fused_moe import fused_experts
self.quant_config = quant_config
self.block_quant = self.quant_config.weight_block_size is not None
@ -879,7 +880,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if self.block_quant else layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size)
block_shape=self.quant_config.weight_block_size,
expert_map=expert_map)
elif self.use_marlin:
assert activation == "silu", (
f"{activation} not supported for Marlin MoE.")