[MoE Refactor][9/N] Use modular kernel for unquantized Triton MoE (#31052)

Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
This commit is contained in:
Yongye Zhu 2025-12-22 09:34:19 -08:00 committed by GitHub
parent ab3a85fd68
commit 7b926e8901
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 22 additions and 7 deletions

View File

@ -60,6 +60,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import quantize_w
from vllm.model_executor.models.mixtral import MixtralMoE
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types
from vllm.v1.worker.workspace import init_workspace_manager
NUM_EXPERTS = [8, 64, 192]
EP_SIZE = [1, 4]
@ -487,6 +488,7 @@ def test_mixtral_moe(
monkeypatch.setenv("MASTER_ADDR", "localhost")
monkeypatch.setenv("MASTER_PORT", "12345")
init_distributed_environment()
init_workspace_manager(torch.cuda.current_device())
# Instantiate our and huggingface's MoE blocks
vllm_config.compilation_config.static_forward_context = dict()
@ -533,6 +535,11 @@ def test_mixtral_moe(
torch.cuda.synchronize()
torch.cuda.empty_cache()
# FIXME (zyongye) fix this after we move self.kernel
# assignment in FusedMoE.__init__
vllm_moe.experts.quant_method.process_weights_after_loading(vllm_moe.experts)
# Run forward passes for both MoE blocks
hf_states, _ = hf_moe.forward(hf_inputs)
vllm_states = vllm_moe.forward(vllm_inputs)

View File

@ -6,6 +6,7 @@ import torch
import torch.nn.functional as F
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
@ -23,6 +24,9 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
@ -30,9 +34,9 @@ from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
if current_platform.is_cuda_alike():
from .fused_batched_moe import BatchedTritonExperts
from .fused_moe import TritonExperts, fused_experts
from .fused_moe import TritonExperts
else:
fused_experts = None # type: ignore
TritonExperts = None # type: ignore
if current_platform.is_tpu():
from .moe_pallas import fused_moe as fused_moe_pallas
@ -265,6 +269,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
else:
layer.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
elif current_platform.is_cuda_alike():
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
self.kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
TritonExperts(self.moe_quant_config),
shared_experts=None,
)
def apply(
self,
@ -278,9 +289,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
router_logits=router_logits,
)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
if self.moe.has_bias:
return biased_moe_quant_config(
layer.w13_bias,
@ -322,7 +331,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
else:
result = fused_experts(
result = self.kernel(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
@ -330,7 +339,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_ids=topk_ids,
inplace=True,
activation=layer.activation,
quant_config=self.moe_quant_config,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,