From 7b926e890195f026b60f36506503a80afc583b33 Mon Sep 17 00:00:00 2001 From: Yongye Zhu Date: Mon, 22 Dec 2025 09:34:19 -0800 Subject: [PATCH] [MoE Refactor][9/N] Use modular kernel for unquantized Triton MoE (#31052) Signed-off-by: Yongye Zhu --- tests/kernels/moe/test_moe.py | 7 ++++++ .../fused_moe/unquantized_fused_moe_method.py | 22 +++++++++++++------ 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index ce99d9691fdc8..fd6ce6bfbd782 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -60,6 +60,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import quantize_w from vllm.model_executor.models.mixtral import MixtralMoE from vllm.platforms import current_platform from vllm.scalar_type import ScalarType, scalar_types +from vllm.v1.worker.workspace import init_workspace_manager NUM_EXPERTS = [8, 64, 192] EP_SIZE = [1, 4] @@ -487,6 +488,7 @@ def test_mixtral_moe( monkeypatch.setenv("MASTER_ADDR", "localhost") monkeypatch.setenv("MASTER_PORT", "12345") init_distributed_environment() + init_workspace_manager(torch.cuda.current_device()) # Instantiate our and huggingface's MoE blocks vllm_config.compilation_config.static_forward_context = dict() @@ -533,6 +535,11 @@ def test_mixtral_moe( torch.cuda.synchronize() torch.cuda.empty_cache() + # FIXME (zyongye) fix this after we move self.kernel + # assignment in FusedMoE.__init__ + + vllm_moe.experts.quant_method.process_weights_after_loading(vllm_moe.experts) + # Run forward passes for both MoE blocks hf_states, _ = hf_moe.forward(hf_inputs) vllm_states = vllm_moe.forward(vllm_inputs) diff --git a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py index 1ee7b65b22e3f..82dbccf3fa9da 100644 --- a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py +++ b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py @@ -6,6 +6,7 @@ import torch import torch.nn.functional as F import vllm.envs as envs +import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm._aiter_ops import rocm_aiter_ops from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp @@ -23,6 +24,9 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize, ) +from vllm.model_executor.layers.fused_moe.prepare_finalize import ( + MoEPrepareAndFinalizeNoEP, +) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum @@ -30,9 +34,9 @@ from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe if current_platform.is_cuda_alike(): from .fused_batched_moe import BatchedTritonExperts - from .fused_moe import TritonExperts, fused_experts + from .fused_moe import TritonExperts else: - fused_experts = None # type: ignore + TritonExperts = None # type: ignore if current_platform.is_tpu(): from .moe_pallas import fused_moe as fused_moe_pallas @@ -265,6 +269,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): layer.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer) else: layer.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer) + elif current_platform.is_cuda_alike(): + self.moe_quant_config = self.get_fused_moe_quant_config(layer) + self.kernel = mk.FusedMoEModularKernel( + MoEPrepareAndFinalizeNoEP(), + TritonExperts(self.moe_quant_config), + shared_experts=None, + ) def apply( self, @@ -278,9 +289,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): router_logits=router_logits, ) - def get_fused_moe_quant_config( - self, layer: torch.nn.Module - ) -> FusedMoEQuantConfig | None: + def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig: if self.moe.has_bias: return biased_moe_quant_config( layer.w13_bias, @@ -322,7 +331,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): apply_router_weight_on_input=layer.apply_router_weight_on_input, ) else: - result = fused_experts( + result = self.kernel( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, @@ -330,7 +339,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): topk_ids=topk_ids, inplace=True, activation=layer.activation, - quant_config=self.moe_quant_config, apply_router_weight_on_input=layer.apply_router_weight_on_input, global_num_experts=layer.global_num_experts, expert_map=layer.expert_map,