mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-27 07:15:15 +08:00
[MoE Refactor][2/N] Use Modular Kernels for Fp8 (#30825)
Signed-off-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Robert Shaw <robshaw@redhat.com>
This commit is contained in:
parent
4cf9429897
commit
95befecc18
@ -2,7 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
@ -51,7 +50,6 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
FlashinferMoeBackend,
|
||||
apply_flashinfer_per_tensor_scale_fp8,
|
||||
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
|
||||
flashinfer_cutlass_moe_fp8,
|
||||
get_flashinfer_moe_backend,
|
||||
register_moe_scaling_factors,
|
||||
rotate_flashinfer_fp8_moe_weights,
|
||||
@ -728,18 +726,28 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
|
||||
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
|
||||
self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
|
||||
if self.block_quant:
|
||||
assert self.weight_block_size == [128, 128], (
|
||||
f"Only support weight_block_size == [128, 128], "
|
||||
f"got {self.weight_block_size}"
|
||||
if self.block_quant and self.weight_block_size != [128, 128]:
|
||||
raise NotImplementedError(
|
||||
"FlashInfer CUTLASS FP8 MoE backend only supports block "
|
||||
"size [128, 128]."
|
||||
)
|
||||
if not self.block_quant:
|
||||
if layer.renormalize or layer.custom_routing_function is not None:
|
||||
raise NotImplementedError(
|
||||
"FlashInfer CUTLASS FP8 MoE backend does custom routing "
|
||||
f"function or renormalization, but got {layer.renormalize} and "
|
||||
f"{layer.custom_routing_function}."
|
||||
)
|
||||
if layer.scoring_func != "sigmoid":
|
||||
raise NotImplementedError(
|
||||
"FlashInfer CUTLASS FP8 MoE backend only supports "
|
||||
f"'sigmoid' scoring function, but got {layer.scoring_func}."
|
||||
)
|
||||
if layer.activation != "silu":
|
||||
raise NotImplementedError(
|
||||
"FlashInfer CUTLASS FP8 MoE backend only supports SiLU "
|
||||
"activation function, but got {layer.activation}."
|
||||
)
|
||||
self.flashinfer_moe_fn = partial(
|
||||
flashinfer_cutlass_moe_fp8,
|
||||
moe=self.moe,
|
||||
use_deepseek_fp8_block_scale=self.block_quant,
|
||||
)
|
||||
|
||||
self.allow_deep_gemm = self.fp8_backend == Fp8MoeBackend.DEEPGEMM
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@ -928,7 +936,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
# DeepGemm scales need to be transposed and aligned. We try to do
|
||||
# it ahead of time for performance reasons.
|
||||
if self.allow_deep_gemm:
|
||||
if self.fp8_backend == Fp8MoeBackend.DEEPGEMM:
|
||||
dg_w13_weight, dg_w13_weight_scale_inv = (
|
||||
deepgemm_post_process_fp8_weight_block(
|
||||
wq=layer.w13_weight.data,
|
||||
@ -1039,6 +1047,61 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
del layer.w13_input_scale
|
||||
del layer.w2_input_scale
|
||||
|
||||
# NOTE(rob): this is a WIP refactor. We are first migrating
|
||||
# all of the kernels in the TP case to use mk. Once this is
|
||||
# done, then we will initialzie the TP case and DP/EP case
|
||||
# via the same code path (i.e. via maybe_init_modular_kernel).
|
||||
# NOTE(rob): in progress migrating all into this format.
|
||||
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
FlashInferExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
|
||||
FlashInferAllGatherMoEPrepareAndFinalize,
|
||||
)
|
||||
|
||||
config = self.get_fused_moe_quant_config(layer)
|
||||
assert config is not None
|
||||
self.moe_quant_config = config
|
||||
|
||||
self.kernel = mk.FusedMoEModularKernel(
|
||||
FlashInferAllGatherMoEPrepareAndFinalize(
|
||||
use_dp=(self.moe.dp_size > 1),
|
||||
use_deepseek_fp8_block_scale=self.block_quant,
|
||||
),
|
||||
FlashInferExperts(
|
||||
out_dtype=torch.get_default_dtype(),
|
||||
quant_config=self.moe_quant_config,
|
||||
ep_rank=self.moe.ep_rank,
|
||||
ep_size=self.moe.ep_size,
|
||||
tp_rank=self.moe.tp_rank,
|
||||
tp_size=self.moe.tp_size,
|
||||
use_dp=(self.moe.dp_size > 1),
|
||||
use_deepseek_fp8_block_scale=self.block_quant,
|
||||
),
|
||||
)
|
||||
self.use_inplace = False
|
||||
|
||||
elif self.fp8_backend in [Fp8MoeBackend.DEEPGEMM, Fp8MoeBackend.TRITON]:
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
TritonOrDeepGemmExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
|
||||
config = self.get_fused_moe_quant_config(layer)
|
||||
assert config is not None
|
||||
self.moe_quant_config = config
|
||||
self.kernel = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
TritonOrDeepGemmExperts(
|
||||
quant_config=self.moe_quant_config,
|
||||
allow_deep_gemm=(self.fp8_backend == Fp8MoeBackend.DEEPGEMM),
|
||||
),
|
||||
)
|
||||
self.use_inplace = True
|
||||
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
@ -1091,7 +1154,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
assert max_num_tokens_per_rank is not None
|
||||
|
||||
experts_impl = (
|
||||
BatchedDeepGemmExperts if self.allow_deep_gemm else BatchedTritonExperts
|
||||
BatchedDeepGemmExperts
|
||||
if self.fp8_backend == Fp8MoeBackend.DEEPGEMM
|
||||
else BatchedTritonExperts
|
||||
)
|
||||
logger.debug(
|
||||
"%s(%s): max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
|
||||
@ -1126,7 +1191,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
return TritonOrDeepGemmExperts(
|
||||
quant_config=self.moe_quant_config,
|
||||
allow_deep_gemm=self.allow_deep_gemm,
|
||||
allow_deep_gemm=(self.fp8_backend == Fp8MoeBackend.DEEPGEMM),
|
||||
)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
@ -1164,6 +1229,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||
# TODO(rob): convert this to MK.
|
||||
if layer.enable_eplb:
|
||||
raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.")
|
||||
assert layer.activation == "silu", (
|
||||
@ -1228,6 +1294,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
rocm_aiter_fused_experts,
|
||||
)
|
||||
|
||||
# TODO(rob): convert this to MK.
|
||||
result = rocm_aiter_fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
@ -1240,6 +1307,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
elif self.use_marlin:
|
||||
# TODO(rob): convert this to MK.
|
||||
assert layer.activation == "silu", (
|
||||
f"{layer.activation} not supported for Marlin MoE."
|
||||
)
|
||||
@ -1261,47 +1329,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
input_dtype=self.marlin_input_dtype,
|
||||
workspace=layer.workspace,
|
||||
)
|
||||
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
|
||||
assert layer.activation == "silu", (
|
||||
f"Expected 'silu' activation but got {layer.activation}"
|
||||
)
|
||||
if not self.block_quant:
|
||||
assert (
|
||||
not layer.renormalize and layer.custom_routing_function is not None
|
||||
)
|
||||
assert layer.scoring_func == "sigmoid", (
|
||||
f"Expected 'sigmoid' scoring func but got {layer.scoring_func}"
|
||||
)
|
||||
# Delegate to CUTLASS FlashInfer path; function already bound with
|
||||
# use_deepseek_fp8_block_scale for block-quant when applicable
|
||||
result = self.flashinfer_moe_fn(
|
||||
else:
|
||||
result = self.kernel(
|
||||
x,
|
||||
layer,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=False,
|
||||
inplace=self.use_inplace,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
)
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
result = fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
expert_map=layer.expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
allow_deep_gemm=self.allow_deep_gemm,
|
||||
)
|
||||
|
||||
if layer.zero_expert_num != 0 and layer.zero_expert_type is not None:
|
||||
assert not isinstance(result, tuple), (
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user