mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-29 04:37:03 +08:00
[MoE Refactor][9/N] Use modular kernel for unquantized Triton MoE (#31052)
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
This commit is contained in:
parent
ab3a85fd68
commit
7b926e8901
@ -60,6 +60,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import quantize_w
|
|||||||
from vllm.model_executor.models.mixtral import MixtralMoE
|
from vllm.model_executor.models.mixtral import MixtralMoE
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.scalar_type import ScalarType, scalar_types
|
from vllm.scalar_type import ScalarType, scalar_types
|
||||||
|
from vllm.v1.worker.workspace import init_workspace_manager
|
||||||
|
|
||||||
NUM_EXPERTS = [8, 64, 192]
|
NUM_EXPERTS = [8, 64, 192]
|
||||||
EP_SIZE = [1, 4]
|
EP_SIZE = [1, 4]
|
||||||
@ -487,6 +488,7 @@ def test_mixtral_moe(
|
|||||||
monkeypatch.setenv("MASTER_ADDR", "localhost")
|
monkeypatch.setenv("MASTER_ADDR", "localhost")
|
||||||
monkeypatch.setenv("MASTER_PORT", "12345")
|
monkeypatch.setenv("MASTER_PORT", "12345")
|
||||||
init_distributed_environment()
|
init_distributed_environment()
|
||||||
|
init_workspace_manager(torch.cuda.current_device())
|
||||||
|
|
||||||
# Instantiate our and huggingface's MoE blocks
|
# Instantiate our and huggingface's MoE blocks
|
||||||
vllm_config.compilation_config.static_forward_context = dict()
|
vllm_config.compilation_config.static_forward_context = dict()
|
||||||
@ -533,6 +535,11 @@ def test_mixtral_moe(
|
|||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# FIXME (zyongye) fix this after we move self.kernel
|
||||||
|
# assignment in FusedMoE.__init__
|
||||||
|
|
||||||
|
vllm_moe.experts.quant_method.process_weights_after_loading(vllm_moe.experts)
|
||||||
|
|
||||||
# Run forward passes for both MoE blocks
|
# Run forward passes for both MoE blocks
|
||||||
hf_states, _ = hf_moe.forward(hf_inputs)
|
hf_states, _ = hf_moe.forward(hf_inputs)
|
||||||
vllm_states = vllm_moe.forward(vllm_inputs)
|
vllm_states = vllm_moe.forward(vllm_inputs)
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
from vllm._aiter_ops import rocm_aiter_ops
|
from vllm._aiter_ops import rocm_aiter_ops
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
@ -23,6 +24,9 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
|||||||
FusedMoEPermuteExpertsUnpermute,
|
FusedMoEPermuteExpertsUnpermute,
|
||||||
FusedMoEPrepareAndFinalize,
|
FusedMoEPrepareAndFinalize,
|
||||||
)
|
)
|
||||||
|
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||||
|
MoEPrepareAndFinalizeNoEP,
|
||||||
|
)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.platforms.interface import CpuArchEnum
|
from vllm.platforms.interface import CpuArchEnum
|
||||||
@ -30,9 +34,9 @@ from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
|||||||
|
|
||||||
if current_platform.is_cuda_alike():
|
if current_platform.is_cuda_alike():
|
||||||
from .fused_batched_moe import BatchedTritonExperts
|
from .fused_batched_moe import BatchedTritonExperts
|
||||||
from .fused_moe import TritonExperts, fused_experts
|
from .fused_moe import TritonExperts
|
||||||
else:
|
else:
|
||||||
fused_experts = None # type: ignore
|
TritonExperts = None # type: ignore
|
||||||
|
|
||||||
if current_platform.is_tpu():
|
if current_platform.is_tpu():
|
||||||
from .moe_pallas import fused_moe as fused_moe_pallas
|
from .moe_pallas import fused_moe as fused_moe_pallas
|
||||||
@ -265,6 +269,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
layer.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
|
layer.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
|
||||||
else:
|
else:
|
||||||
layer.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
|
layer.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
|
||||||
|
elif current_platform.is_cuda_alike():
|
||||||
|
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||||
|
self.kernel = mk.FusedMoEModularKernel(
|
||||||
|
MoEPrepareAndFinalizeNoEP(),
|
||||||
|
TritonExperts(self.moe_quant_config),
|
||||||
|
shared_experts=None,
|
||||||
|
)
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
@ -278,9 +289,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
router_logits=router_logits,
|
router_logits=router_logits,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_fused_moe_quant_config(
|
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
|
||||||
self, layer: torch.nn.Module
|
|
||||||
) -> FusedMoEQuantConfig | None:
|
|
||||||
if self.moe.has_bias:
|
if self.moe.has_bias:
|
||||||
return biased_moe_quant_config(
|
return biased_moe_quant_config(
|
||||||
layer.w13_bias,
|
layer.w13_bias,
|
||||||
@ -322,7 +331,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
result = fused_experts(
|
result = self.kernel(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
w1=layer.w13_weight,
|
w1=layer.w13_weight,
|
||||||
w2=layer.w2_weight,
|
w2=layer.w2_weight,
|
||||||
@ -330,7 +339,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
inplace=True,
|
inplace=True,
|
||||||
activation=layer.activation,
|
activation=layer.activation,
|
||||||
quant_config=self.moe_quant_config,
|
|
||||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||||
global_num_experts=layer.global_num_experts,
|
global_num_experts=layer.global_num_experts,
|
||||||
expert_map=layer.expert_map,
|
expert_map=layer.expert_map,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user