mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-29 16:00:54 +08:00
[MoE Refactor][9/N] Use modular kernel for unquantized Triton MoE (#31052)
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
This commit is contained in:
parent
ab3a85fd68
commit
7b926e8901
@ -60,6 +60,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import quantize_w
|
||||
from vllm.model_executor.models.mixtral import MixtralMoE
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
from vllm.v1.worker.workspace import init_workspace_manager
|
||||
|
||||
NUM_EXPERTS = [8, 64, 192]
|
||||
EP_SIZE = [1, 4]
|
||||
@ -487,6 +488,7 @@ def test_mixtral_moe(
|
||||
monkeypatch.setenv("MASTER_ADDR", "localhost")
|
||||
monkeypatch.setenv("MASTER_PORT", "12345")
|
||||
init_distributed_environment()
|
||||
init_workspace_manager(torch.cuda.current_device())
|
||||
|
||||
# Instantiate our and huggingface's MoE blocks
|
||||
vllm_config.compilation_config.static_forward_context = dict()
|
||||
@ -533,6 +535,11 @@ def test_mixtral_moe(
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# FIXME (zyongye) fix this after we move self.kernel
|
||||
# assignment in FusedMoE.__init__
|
||||
|
||||
vllm_moe.experts.quant_method.process_weights_after_loading(vllm_moe.experts)
|
||||
|
||||
# Run forward passes for both MoE blocks
|
||||
hf_states, _ = hf_moe.forward(hf_inputs)
|
||||
vllm_states = vllm_moe.forward(vllm_inputs)
|
||||
|
||||
@ -6,6 +6,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
@ -23,6 +24,9 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEPermuteExpertsUnpermute,
|
||||
FusedMoEPrepareAndFinalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import CpuArchEnum
|
||||
@ -30,9 +34,9 @@ from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
from .fused_batched_moe import BatchedTritonExperts
|
||||
from .fused_moe import TritonExperts, fused_experts
|
||||
from .fused_moe import TritonExperts
|
||||
else:
|
||||
fused_experts = None # type: ignore
|
||||
TritonExperts = None # type: ignore
|
||||
|
||||
if current_platform.is_tpu():
|
||||
from .moe_pallas import fused_moe as fused_moe_pallas
|
||||
@ -265,6 +269,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
layer.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
|
||||
else:
|
||||
layer.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
|
||||
elif current_platform.is_cuda_alike():
|
||||
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||
self.kernel = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
TritonExperts(self.moe_quant_config),
|
||||
shared_experts=None,
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
@ -278,9 +289,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
|
||||
if self.moe.has_bias:
|
||||
return biased_moe_quant_config(
|
||||
layer.w13_bias,
|
||||
@ -322,7 +331,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
)
|
||||
else:
|
||||
result = fused_experts(
|
||||
result = self.kernel(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
@ -330,7 +339,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=layer.activation,
|
||||
quant_config=self.moe_quant_config,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user