feat: BF16 FlashInfer Fused Cutlass MOE for Hopper and Blackwell Expert Parallel (#25503)

Signed-off-by: Duncan Moss <djm.moss@gmail.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Duncan Moss 2025-09-24 15:50:04 -07:00 committed by yewentao256
parent b4a80dad98
commit 461aa1463b
5 changed files with 121 additions and 6 deletions

View File

@ -144,6 +144,7 @@ if TYPE_CHECKING:
VLLM_USE_DEEP_GEMM_E8M0_HOPPER: bool = False
VLLM_SKIP_DEEP_GEMM_WARMUP: bool = False
VLLM_USE_FUSED_MOE_GROUPED_TOPK: bool = True
VLLM_USE_FLASHINFER_MOE_FP16: bool = False
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
VLLM_USE_FLASHINFER_MOE_FP4: bool = False
VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput",
@ -1145,6 +1146,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_FUSED_MOE_GROUPED_TOPK":
lambda: bool(int(os.getenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "1"))),
# Allow use of FlashInfer MoE kernels for fused moe ops.
"VLLM_USE_FLASHINFER_MOE_FP16":
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP16", "0"))),
# Allow use of FlashInfer MoE kernels for fused moe ops.
"VLLM_USE_FLASHINFER_MOE_FP8":
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP8", "0"))),
@ -1516,6 +1521,7 @@ def compute_hash() -> str:
"VLLM_USE_DEEP_GEMM_E8M0_HOPPER",
"VLLM_USE_TRTLLM_FP4_GEMM",
"VLLM_USE_FUSED_MOE_GROUPED_TOPK",
"VLLM_USE_FLASHINFER_MOE_FP16",
"VLLM_USE_FLASHINFER_MOE_FP8",
"VLLM_USE_FLASHINFER_MOE_FP4",
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8",

View File

@ -52,8 +52,10 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
tp_size: int = 1,
):
super().__init__(quant_config)
assert quant_config.quant_dtype in ("nvfp4", torch.float8_e4m3fn), (
"Only nvfp4,fp8 quantization are currently supported.")
assert quant_config.quant_dtype in (
"nvfp4", torch.float8_e4m3fn,
None), ("Only nvfp4, fp8, bfloat16 and"
" float16 quantization are currently supported.")
self.ep_rank = ep_rank
self.ep_size = ep_size
self.tp_rank = tp_rank
@ -109,8 +111,9 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
"""
aq_m, aq_n = aq.shape
workspace2 = (0, )
output_shape = (aq_m, aq_n * 2) if self.quant_dtype != \
torch.float8_e4m3fn else (aq_m, aq_n)
output_shape = (aq_m,
aq_n * 2) if self.quant_dtype == "nvfp4" else (aq_m,
aq_n)
workspace_dtype = a.dtype
workspace1 = output_shape
# The workspace is determined by `aq`, since it comes after any
@ -135,6 +138,10 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: Optional[bool],
):
assert activation == "silu", ("Only activation silu is supported in "
"FlashInferExperts")
if self.quant_dtype == torch.float8_e4m3fn:
quant_scales = [
self.g1_alphas, self.a2_gscale, self.g2_alphas, self.a1_gscale
@ -143,7 +150,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
a1q_scale = None # not passing input_sf in fp8
fc1_expert_weights = w1
fc2_expert_weights = w2
else:
elif self.quant_dtype == "nvfp4":
# Ensure w1_scale and w2_scale are not None before calling view
assert self.w1_scale is not None and self.w2_scale is not None, (
"w1_scale and w2_scale must not "
@ -161,6 +168,11 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
# FlashInfer API requires weight to be long for nvfp4
fc1_expert_weights = w1.view(torch.long)
fc2_expert_weights = w2.view(torch.long)
else:
quant_scales = None
a1q_scale = None
fc1_expert_weights = w1
fc2_expert_weights = w2
_ = flashinfer_cutlass_fused_moe(
input=hidden_states,
@ -211,3 +223,46 @@ def flashinfer_cutlass_moe_fp4(
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
def flashinfer_cutlass_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
quant_config: FusedMoEQuantConfig,
inplace: bool = False,
activation: str = "silu",
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
tp_rank: int = 0,
tp_size: int = 1,
ep_rank: int = 0,
ep_size: int = 1,
use_dp: bool = False,
) -> torch.Tensor:
fused_experts = mk.FusedMoEModularKernel(
create_flashinfer_prepare_finalize(use_dp=use_dp),
FlashInferExperts(
out_dtype=hidden_states.dtype,
quant_config=quant_config,
tp_rank=tp_rank,
tp_size=tp_size,
ep_rank=ep_rank,
ep_size=ep_size,
))
return fused_experts(
hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=inplace,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)

View File

@ -183,7 +183,8 @@ class FlashInferAllGatherMoEPrepareAndFinalize(
dim=0,
sizes=get_local_sizes(),
)
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
if quant_config.quant_dtype == "nvfp4":
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
return a1q, a1q_scale, None, topk_ids, topk_weights

View File

@ -39,6 +39,7 @@ from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
from vllm.utils import (cdiv, direct_register_custom_op, has_deep_ep, has_pplx,
round_up)
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
if current_platform.is_cuda_alike():
@ -296,6 +297,40 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
else:
self.rocm_aiter_fused_experts = None # type: ignore
# FlashInfer CUTLASS MoE is only supported on Hopper and later GPUS
self.flashinfer_cutlass_moe_enabled = (
has_flashinfer_cutlass_fused_moe()
and envs.VLLM_USE_FLASHINFER_MOE_FP16
and self.moe.moe_parallel_config.use_ep
and self.moe.moe_parallel_config.dp_size == 1
and current_platform.get_device_capability()[0] >= 9)
if self.flashinfer_cutlass_moe_enabled:
logger.info_once(
"Enabling FlashInfer CUTLASS MoE for UnquantizedFusedMoEMethod"
)
from functools import partial
from .flashinfer_cutlass_moe import flashinfer_cutlass_moe
self.flashinfer_cutlass_moe = partial(
flashinfer_cutlass_moe,
quant_config=FUSED_MOE_UNQUANTIZED_CONFIG,
tp_rank=self.moe.moe_parallel_config.tp_rank,
tp_size=self.moe.moe_parallel_config.tp_size,
ep_rank=self.moe.moe_parallel_config.ep_rank,
ep_size=self.moe.moe_parallel_config.ep_size)
else:
if (self.moe.moe_parallel_config.use_ep
and self.moe.moe_parallel_config.dp_size == 1):
logger.info_once(
"FlashInfer CUTLASS MoE is available for EP"
" but not enabled, consider setting"
" VLLM_USE_FLASHINFER_MOE_FP16=1 to enable it.")
elif self.moe.moe_parallel_config.dp_size > 1:
logger.info_once(
"FlashInfer CUTLASS MoE is currently not available for DP."
)
self.flashinfer_cutlass_moe = None # type: ignore
def maybe_make_prepare_finalize(
self) -> Optional[FusedMoEPrepareAndFinalize]:
if self.rocm_aiter_moe_enabled:
@ -367,6 +402,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_pad = 256 // weight.element_size()
weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
torch.cuda.empty_cache()
return weight
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
@ -386,6 +422,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer.w13_weight.data = shuffled_w13
layer.w2_weight.data = shuffled_w2
if self.flashinfer_cutlass_moe_enabled:
# Swap halves to arrange as [w3; w1] (kernel expectation)
w1_w, w3_w = torch.chunk(layer.w13_weight.data, 2, dim=1)
w13_weight_swapped = torch.cat([w3_w, w1_w], dim=1)
layer.w13_weight.data = w13_weight_swapped.contiguous()
if current_platform.is_xpu():
import intel_extension_for_pytorch as ipex
layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE(
@ -536,6 +578,15 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_map=expert_map,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input)
elif self.flashinfer_cutlass_moe_enabled:
return self.flashinfer_cutlass_moe(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input)
elif self.fused_experts is not None:
if self.moe.has_bias:
raise ValueError(

View File

@ -598,6 +598,8 @@ class SharedResizableBuffer:
def get(self, shape: tuple[int, ...], device: torch.device,
dtype: torch.dtype):
if shape == () or shape is None:
return None
shape_numel = prod(shape)
if (self.buffer is None or self.buffer.numel() < shape_numel
or self.buffer.device != device or self.buffer.dtype != dtype):