From b57b967386e1962955c93b7cb39828448b68789b Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Date: Mon, 22 Dec 2025 18:42:58 -0500 Subject: [PATCH] [MoE Refactor][7/N] AITER MK (#31102) Signed-off-by: Robert Shaw Co-authored-by: Robert Shaw --- .../layers/fused_moe/fused_moe.py | 6 +- .../layers/fused_moe/prepare_finalize.py | 9 ++ .../layers/fused_moe/rocm_aiter_fused_moe.py | 79 ++++++++++++ .../model_executor/layers/quantization/fp8.py | 116 ++++++++---------- 4 files changed, 144 insertions(+), 66 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index c8d80ae023d43..bf51554341607 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -2132,6 +2132,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): torch.float16, torch.bfloat16, torch.float8_e4m3fn, + torch.float8_e4m3fnuz, ] E, num_tokens, N, K, top_k_num = self.moe_problem_size( @@ -2156,7 +2157,10 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): compute_type = tl.float16 elif hidden_states.dtype == torch.float32: compute_type = tl.float32 - elif hidden_states.dtype == torch.float8_e4m3fn: + elif ( + hidden_states.dtype == torch.float8_e4m3fn + or hidden_states.dtype == torch.float8_e4m3fnuz + ): compute_type = tl.bfloat16 else: raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}") diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize.py b/vllm/model_executor/layers/fused_moe/prepare_finalize.py index e27e2eb32da0f..5d806fa843a3c 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize.py @@ -13,6 +13,10 @@ from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): + def __init__(self, defer_input_quant: bool = False) -> None: + super().__init__() + self.defer_input_quant = defer_input_quant + @property def activation_format(self) -> mk.FusedMoEActivationFormat: return mk.FusedMoEActivationFormat.Standard @@ -48,6 +52,11 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): # Note: do not use inplace for shared experts overlap a1 = a1 * topk_weights.to(a1.dtype) + # Defer input quant to moe kernel for backends (e.g. AITER, FI) + # which use a single kernel call for quant + experts. + if self.defer_input_quant: + return a1, None, None, None, None + a1q, a1q_scale = moe_kernel_quantize_input( a1, quant_config.a1_scale, diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index 882ad0a537cd5..ebd9e3a4a8f2a 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -5,11 +5,15 @@ from functools import lru_cache import torch +import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm._aiter_ops import rocm_aiter_ops from vllm.model_executor.layers.fused_moe.config import ( FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEQuantConfig, ) +from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceNoOP, +) class QuantMethod(IntEnum): @@ -263,3 +267,78 @@ def rocm_aiter_fused_experts( a2_scale=quant_config.a2_scale, doweight_stage1=apply_router_weight_on_input, ) + + +class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute): + def __init__(self, quant_config): + super().__init__(quant_config) + + @property + def activation_formats( + self, + ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + return ( + mk.FusedMoEActivationFormat.Standard, + mk.FusedMoEActivationFormat.Standard, + ) + + def supports_expert_map(self): + return True + + def supports_chunking(self): + return False + + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + return TopKWeightAndReduceNoOP() + + def workspace_shapes( + self, + M: int, + N: int, + K: int, + topk: int, + global_num_experts: int, + local_num_experts: int, + expert_tokens_meta: mk.ExpertTokensMetadata | None, + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: + # Workspaces are managed internally by AITER. + workspace1 = (0,) + workspace2 = (0,) + output = (M, K) + return (workspace1, workspace2, output) + + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: torch.Tensor | None, + a1q_scale: torch.Tensor | None, + a2_scale: torch.Tensor | None, + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: mk.ExpertTokensMetadata | None, + apply_router_weight_on_input: bool, + ): + assert a1q_scale is None + assert a2_scale is None + assert expert_tokens_meta is None + + result = rocm_aiter_fused_experts( + hidden_states=hidden_states, + w1=w1, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + expert_map=expert_map, + quant_config=self.quant_config, + ) + assert result.shape == output.shape + output.copy_(result) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index d19b20798ed06..9da19c082dc27 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -117,6 +117,7 @@ class Fp8MoeBackend(Enum): DEEPGEMM = 3 MARLIN = 4 TRITON = 5 + AITER = 6 def get_fp8_moe_backend( @@ -189,6 +190,10 @@ def get_fp8_moe_backend( logger.info_once("Using DeepGEMM backend for FP8 MoE", scope="local") return Fp8MoeBackend.DEEPGEMM + if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MOE: + logger.info_once("Using ROCm AITER backend for FP8 MoE", scope="local") + return Fp8MoeBackend.AITER + # default to Triton logger.info_once("Using Triton backend for FP8 MoE") return Fp8MoeBackend.TRITON @@ -888,16 +893,10 @@ class Fp8MoEMethod(FusedMoEMethodBase): layer.w13_input_scale = None layer.w2_input_scale = None - self.rocm_aiter_moe_enabled = False - def process_weights_after_loading(self, layer: Module) -> None: if getattr(layer, "_already_called_process_weights_after_loading", False): return - # Lazy import to avoid importing triton too early. - - self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled() - # TODO (rob): refactor block quant into separate class. if self.block_quant: assert self.quant_config.activation_scheme == "dynamic" @@ -932,7 +931,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): replace_parameter(layer, "w13_weight_scale_inv", w13_weight_scale_inv) replace_parameter(layer, "w2_weight", w2_weight) replace_parameter(layer, "w2_weight_scale_inv", w2_weight_scale_inv) - if self.rocm_aiter_moe_enabled: + if self.fp8_backend == Fp8MoeBackend.AITER: # reshaping weights is required for aiter moe kernel. shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( layer.w13_weight.data, layer.w2_weight.data @@ -1026,7 +1025,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ) start += shard_size - if self.rocm_aiter_moe_enabled: + if self.fp8_backend == Fp8MoeBackend.AITER: shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( layer.w13_weight, layer.w2_weight ) @@ -1072,6 +1071,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): self.moe_quant_config = config self.kernel = mk.FusedMoEModularKernel( + # TODO(rob): we can use the generic MoEPrepareAndFinalizeNoEP + # with the changes to defer input quantization FlashInferAllGatherMoEPrepareAndFinalize( use_dp=(self.moe.dp_size > 1), use_deepseek_fp8_block_scale=self.block_quant, @@ -1093,6 +1094,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): Fp8MoeBackend.DEEPGEMM, Fp8MoeBackend.TRITON, Fp8MoeBackend.MARLIN, + Fp8MoeBackend.AITER, ]: from vllm.model_executor.layers.fused_moe import ( TritonOrDeepGemmExperts, @@ -1103,24 +1105,33 @@ class Fp8MoEMethod(FusedMoEMethodBase): from vllm.model_executor.layers.fused_moe.prepare_finalize import ( MoEPrepareAndFinalizeNoEP, ) + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + AiterExperts, + ) config = self.get_fused_moe_quant_config(layer) assert config is not None self.moe_quant_config = config - use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN - allow_deep_gemm = self.fp8_backend == Fp8MoeBackend.DEEPGEMM - moe_kernel = ( - MarlinExperts(quant_config=self.moe_quant_config) - if use_marlin - else TritonOrDeepGemmExperts( - quant_config=self.moe_quant_config, - allow_deep_gemm=allow_deep_gemm, - ) - ) - self.kernel = mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(), moe_kernel - ) + if self.fp8_backend == Fp8MoeBackend.AITER: + self.kernel = mk.FusedMoEModularKernel( + # TODO: make defer_input_quant an attr of the AiterExperts + MoEPrepareAndFinalizeNoEP(defer_input_quant=True), + AiterExperts(quant_config=self.moe_quant_config), + ) + elif self.fp8_backend == Fp8MoeBackend.MARLIN: + self.kernel = mk.FusedMoEModularKernel( + MoEPrepareAndFinalizeNoEP(), + MarlinExperts(quant_config=self.moe_quant_config), + ) + else: + self.kernel = mk.FusedMoEModularKernel( + MoEPrepareAndFinalizeNoEP(), + TritonOrDeepGemmExperts( + quant_config=self.moe_quant_config, + allow_deep_gemm=(self.fp8_backend == Fp8MoeBackend.DEEPGEMM), + ), + ) self.use_inplace = True def maybe_make_prepare_finalize( @@ -1128,7 +1139,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, ) -> mk.FusedMoEPrepareAndFinalize | None: if ( - self.rocm_aiter_moe_enabled + self.fp8_backend == Fp8MoeBackend.AITER or self.fp8_backend == Fp8MoeBackend.MARLIN or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM ): @@ -1161,11 +1172,10 @@ class Fp8MoEMethod(FusedMoEMethodBase): TritonOrDeepGemmExperts, ) - assert ( - self.fp8_backend != Fp8MoeBackend.MARLIN - ) and not self.rocm_aiter_moe_enabled, ( - "Marlin and ROCm AITER are not supported with all2all yet." - ) + if self.fp8_backend in [Fp8MoeBackend.MARLIN, Fp8MoeBackend.AITER]: + raise NotImplementedError( + "Marlin and ROCm AITER are not supported with all2all yet." + ) assert self.moe_quant_config is not None @@ -1313,37 +1323,18 @@ class Fp8MoEMethod(FusedMoEMethodBase): hidden_states=x, router_logits=router_logits, ) - - if self.rocm_aiter_moe_enabled: - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 - rocm_aiter_fused_experts, - ) - - # TODO(rob): convert this to MK. - result = rocm_aiter_fused_experts( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - activation=layer.activation, - apply_router_weight_on_input=layer.apply_router_weight_on_input, - expert_map=layer.expert_map, - quant_config=self.moe_quant_config, - ) - else: - result = self.kernel( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights, - topk_ids, - inplace=self.use_inplace, - activation=layer.activation, - global_num_experts=layer.global_num_experts, - expert_map=layer.expert_map, - apply_router_weight_on_input=layer.apply_router_weight_on_input, - ) + result = self.kernel( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + inplace=self.use_inplace, + activation=layer.activation, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + ) return result @@ -1456,15 +1447,10 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod): layer.w13_input_scale = None layer.w2_input_scale = None - self.rocm_aiter_moe_enabled = False - def process_weights_after_loading(self, layer: Module) -> None: if getattr(layer, "_already_called_process_weights_after_loading", False): return - # Lazy import to avoid importing triton too early. - self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled() - # If checkpoint is fp16, quantize in place. fp8_dtype = current_platform.fp8_dtype() w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype) @@ -1481,7 +1467,7 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod): replace_parameter(layer, "w2_weight", w2_weight) # Reshuffle weights for AITER if needed. - if self.rocm_aiter_moe_enabled: + if self.fp8_backend == Fp8MoeBackend.AITER: shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( layer.w13_weight, layer.w2_weight ) @@ -1489,7 +1475,7 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod): replace_parameter(layer, "w2_weight", shuffled_w2) # Rushuffle weights for MARLIN if needed. - if self.fp8_backend == Fp8MoeBackend.MARLIN: + elif self.fp8_backend == Fp8MoeBackend.MARLIN: prepare_moe_fp8_layer_for_marlin( layer, False, input_dtype=self.marlin_input_dtype )