[MoE Refactor][2/N] Use Modular Kernels for Fp8 (#30825)

Signed-off-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
This commit is contained in:
Robert Shaw 2025-12-19 18:36:38 -05:00 committed by GitHub
parent 4cf9429897
commit 95befecc18
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import Enum
from functools import partial
from typing import TYPE_CHECKING, Any, Optional
import torch
@ -51,7 +50,6 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
FlashinferMoeBackend,
apply_flashinfer_per_tensor_scale_fp8,
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
flashinfer_cutlass_moe_fp8,
get_flashinfer_moe_backend,
register_moe_scaling_factors,
rotate_flashinfer_fp8_moe_weights,
@ -728,18 +726,28 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
if self.block_quant:
assert self.weight_block_size == [128, 128], (
f"Only support weight_block_size == [128, 128], "
f"got {self.weight_block_size}"
if self.block_quant and self.weight_block_size != [128, 128]:
raise NotImplementedError(
"FlashInfer CUTLASS FP8 MoE backend only supports block "
"size [128, 128]."
)
if not self.block_quant:
if layer.renormalize or layer.custom_routing_function is not None:
raise NotImplementedError(
"FlashInfer CUTLASS FP8 MoE backend does custom routing "
f"function or renormalization, but got {layer.renormalize} and "
f"{layer.custom_routing_function}."
)
if layer.scoring_func != "sigmoid":
raise NotImplementedError(
"FlashInfer CUTLASS FP8 MoE backend only supports "
f"'sigmoid' scoring function, but got {layer.scoring_func}."
)
if layer.activation != "silu":
raise NotImplementedError(
"FlashInfer CUTLASS FP8 MoE backend only supports SiLU "
"activation function, but got {layer.activation}."
)
self.flashinfer_moe_fn = partial(
flashinfer_cutlass_moe_fp8,
moe=self.moe,
use_deepseek_fp8_block_scale=self.block_quant,
)
self.allow_deep_gemm = self.fp8_backend == Fp8MoeBackend.DEEPGEMM
def create_weights(
self,
@ -928,7 +936,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# DeepGemm scales need to be transposed and aligned. We try to do
# it ahead of time for performance reasons.
if self.allow_deep_gemm:
if self.fp8_backend == Fp8MoeBackend.DEEPGEMM:
dg_w13_weight, dg_w13_weight_scale_inv = (
deepgemm_post_process_fp8_weight_block(
wq=layer.w13_weight.data,
@ -1039,6 +1047,61 @@ class Fp8MoEMethod(FusedMoEMethodBase):
del layer.w13_input_scale
del layer.w2_input_scale
# NOTE(rob): this is a WIP refactor. We are first migrating
# all of the kernels in the TP case to use mk. Once this is
# done, then we will initialzie the TP case and DP/EP case
# via the same code path (i.e. via maybe_init_modular_kernel).
# NOTE(rob): in progress migrating all into this format.
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
FlashInferAllGatherMoEPrepareAndFinalize,
)
config = self.get_fused_moe_quant_config(layer)
assert config is not None
self.moe_quant_config = config
self.kernel = mk.FusedMoEModularKernel(
FlashInferAllGatherMoEPrepareAndFinalize(
use_dp=(self.moe.dp_size > 1),
use_deepseek_fp8_block_scale=self.block_quant,
),
FlashInferExperts(
out_dtype=torch.get_default_dtype(),
quant_config=self.moe_quant_config,
ep_rank=self.moe.ep_rank,
ep_size=self.moe.ep_size,
tp_rank=self.moe.tp_rank,
tp_size=self.moe.tp_size,
use_dp=(self.moe.dp_size > 1),
use_deepseek_fp8_block_scale=self.block_quant,
),
)
self.use_inplace = False
elif self.fp8_backend in [Fp8MoeBackend.DEEPGEMM, Fp8MoeBackend.TRITON]:
from vllm.model_executor.layers.fused_moe import (
TritonOrDeepGemmExperts,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
config = self.get_fused_moe_quant_config(layer)
assert config is not None
self.moe_quant_config = config
self.kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
TritonOrDeepGemmExperts(
quant_config=self.moe_quant_config,
allow_deep_gemm=(self.fp8_backend == Fp8MoeBackend.DEEPGEMM),
),
)
self.use_inplace = True
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
@ -1091,7 +1154,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
assert max_num_tokens_per_rank is not None
experts_impl = (
BatchedDeepGemmExperts if self.allow_deep_gemm else BatchedTritonExperts
BatchedDeepGemmExperts
if self.fp8_backend == Fp8MoeBackend.DEEPGEMM
else BatchedTritonExperts
)
logger.debug(
"%s(%s): max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
@ -1126,7 +1191,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
)
return TritonOrDeepGemmExperts(
quant_config=self.moe_quant_config,
allow_deep_gemm=self.allow_deep_gemm,
allow_deep_gemm=(self.fp8_backend == Fp8MoeBackend.DEEPGEMM),
)
def get_fused_moe_quant_config(
@ -1164,6 +1229,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
# TODO(rob): convert this to MK.
if layer.enable_eplb:
raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.")
assert layer.activation == "silu", (
@ -1228,6 +1294,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
rocm_aiter_fused_experts,
)
# TODO(rob): convert this to MK.
result = rocm_aiter_fused_experts(
x,
layer.w13_weight,
@ -1240,6 +1307,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
quant_config=self.moe_quant_config,
)
elif self.use_marlin:
# TODO(rob): convert this to MK.
assert layer.activation == "silu", (
f"{layer.activation} not supported for Marlin MoE."
)
@ -1261,47 +1329,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
input_dtype=self.marlin_input_dtype,
workspace=layer.workspace,
)
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
assert layer.activation == "silu", (
f"Expected 'silu' activation but got {layer.activation}"
)
if not self.block_quant:
assert (
not layer.renormalize and layer.custom_routing_function is not None
)
assert layer.scoring_func == "sigmoid", (
f"Expected 'sigmoid' scoring func but got {layer.scoring_func}"
)
# Delegate to CUTLASS FlashInfer path; function already bound with
# use_deepseek_fp8_block_scale for block-quant when applicable
result = self.flashinfer_moe_fn(
else:
result = self.kernel(
x,
layer,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
inplace=False,
inplace=self.use_inplace,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
else:
from vllm.model_executor.layers.fused_moe import fused_experts
result = fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
expert_map=layer.expert_map,
quant_config=self.moe_quant_config,
allow_deep_gemm=self.allow_deep_gemm,
)
if layer.zero_expert_num != 0 and layer.zero_expert_type is not None:
assert not isinstance(result, tuple), (