From 461aa1463b8350a59415bc7eb32e6b601c1c695e Mon Sep 17 00:00:00 2001 From: Duncan Moss Date: Wed, 24 Sep 2025 15:50:04 -0700 Subject: [PATCH] feat: BF16 FlashInfer Fused Cutlass MOE for Hopper and Blackwell Expert Parallel (#25503) Signed-off-by: Duncan Moss Signed-off-by: yewentao256 --- vllm/envs.py | 6 ++ .../fused_moe/flashinfer_cutlass_moe.py | 65 +++++++++++++++++-- .../flashinfer_cutlass_prepare_finalize.py | 3 +- vllm/model_executor/layers/fused_moe/layer.py | 51 +++++++++++++++ .../layers/fused_moe/modular_kernel.py | 2 + 5 files changed, 121 insertions(+), 6 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index 4797d96bb899a..5d622c0675290 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -144,6 +144,7 @@ if TYPE_CHECKING: VLLM_USE_DEEP_GEMM_E8M0_HOPPER: bool = False VLLM_SKIP_DEEP_GEMM_WARMUP: bool = False VLLM_USE_FUSED_MOE_GROUPED_TOPK: bool = True + VLLM_USE_FLASHINFER_MOE_FP16: bool = False VLLM_USE_FLASHINFER_MOE_FP8: bool = False VLLM_USE_FLASHINFER_MOE_FP4: bool = False VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", @@ -1145,6 +1146,10 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_USE_FUSED_MOE_GROUPED_TOPK": lambda: bool(int(os.getenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "1"))), + # Allow use of FlashInfer MoE kernels for fused moe ops. + "VLLM_USE_FLASHINFER_MOE_FP16": + lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP16", "0"))), + # Allow use of FlashInfer MoE kernels for fused moe ops. "VLLM_USE_FLASHINFER_MOE_FP8": lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP8", "0"))), @@ -1516,6 +1521,7 @@ def compute_hash() -> str: "VLLM_USE_DEEP_GEMM_E8M0_HOPPER", "VLLM_USE_TRTLLM_FP4_GEMM", "VLLM_USE_FUSED_MOE_GROUPED_TOPK", + "VLLM_USE_FLASHINFER_MOE_FP16", "VLLM_USE_FLASHINFER_MOE_FP8", "VLLM_USE_FLASHINFER_MOE_FP4", "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index 8700181d18feb..3ea4ed39e9568 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -52,8 +52,10 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): tp_size: int = 1, ): super().__init__(quant_config) - assert quant_config.quant_dtype in ("nvfp4", torch.float8_e4m3fn), ( - "Only nvfp4,fp8 quantization are currently supported.") + assert quant_config.quant_dtype in ( + "nvfp4", torch.float8_e4m3fn, + None), ("Only nvfp4, fp8, bfloat16 and" + " float16 quantization are currently supported.") self.ep_rank = ep_rank self.ep_size = ep_size self.tp_rank = tp_rank @@ -109,8 +111,9 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): """ aq_m, aq_n = aq.shape workspace2 = (0, ) - output_shape = (aq_m, aq_n * 2) if self.quant_dtype != \ - torch.float8_e4m3fn else (aq_m, aq_n) + output_shape = (aq_m, + aq_n * 2) if self.quant_dtype == "nvfp4" else (aq_m, + aq_n) workspace_dtype = a.dtype workspace1 = output_shape # The workspace is determined by `aq`, since it comes after any @@ -135,6 +138,10 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): expert_tokens_meta: Optional[mk.ExpertTokensMetadata], apply_router_weight_on_input: Optional[bool], ): + + assert activation == "silu", ("Only activation silu is supported in " + "FlashInferExperts") + if self.quant_dtype == torch.float8_e4m3fn: quant_scales = [ self.g1_alphas, self.a2_gscale, self.g2_alphas, self.a1_gscale @@ -143,7 +150,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): a1q_scale = None # not passing input_sf in fp8 fc1_expert_weights = w1 fc2_expert_weights = w2 - else: + elif self.quant_dtype == "nvfp4": # Ensure w1_scale and w2_scale are not None before calling view assert self.w1_scale is not None and self.w2_scale is not None, ( "w1_scale and w2_scale must not " @@ -161,6 +168,11 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): # FlashInfer API requires weight to be long for nvfp4 fc1_expert_weights = w1.view(torch.long) fc2_expert_weights = w2.view(torch.long) + else: + quant_scales = None + a1q_scale = None + fc1_expert_weights = w1 + fc2_expert_weights = w2 _ = flashinfer_cutlass_fused_moe( input=hidden_states, @@ -211,3 +223,46 @@ def flashinfer_cutlass_moe_fp4( expert_map=expert_map, apply_router_weight_on_input=apply_router_weight_on_input, ) + + +def flashinfer_cutlass_moe( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + quant_config: FusedMoEQuantConfig, + inplace: bool = False, + activation: str = "silu", + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + tp_rank: int = 0, + tp_size: int = 1, + ep_rank: int = 0, + ep_size: int = 1, + use_dp: bool = False, +) -> torch.Tensor: + fused_experts = mk.FusedMoEModularKernel( + create_flashinfer_prepare_finalize(use_dp=use_dp), + FlashInferExperts( + out_dtype=hidden_states.dtype, + quant_config=quant_config, + tp_rank=tp_rank, + tp_size=tp_size, + ep_rank=ep_rank, + ep_size=ep_size, + )) + + return fused_experts( + hidden_states=hidden_states, + w1=w1, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=inplace, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + ) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py index 6e127064d32d6..ed364ac77b286 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py @@ -183,7 +183,8 @@ class FlashInferAllGatherMoEPrepareAndFinalize( dim=0, sizes=get_local_sizes(), ) - a1q_scale = nvfp4_block_scale_interleave(a1q_scale) + if quant_config.quant_dtype == "nvfp4": + a1q_scale = nvfp4_block_scale_interleave(a1q_scale) return a1q, a1q_scale, None, topk_ids, topk_weights diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index b68190e5d1c18..ea88539db27b5 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -39,6 +39,7 @@ from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum from vllm.utils import (cdiv, direct_register_custom_op, has_deep_ep, has_pplx, round_up) +from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe from vllm.v1.worker.ubatching import dbo_current_ubatch_id if current_platform.is_cuda_alike(): @@ -296,6 +297,40 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): else: self.rocm_aiter_fused_experts = None # type: ignore + # FlashInfer CUTLASS MoE is only supported on Hopper and later GPUS + self.flashinfer_cutlass_moe_enabled = ( + has_flashinfer_cutlass_fused_moe() + and envs.VLLM_USE_FLASHINFER_MOE_FP16 + and self.moe.moe_parallel_config.use_ep + and self.moe.moe_parallel_config.dp_size == 1 + and current_platform.get_device_capability()[0] >= 9) + if self.flashinfer_cutlass_moe_enabled: + logger.info_once( + "Enabling FlashInfer CUTLASS MoE for UnquantizedFusedMoEMethod" + ) + from functools import partial + + from .flashinfer_cutlass_moe import flashinfer_cutlass_moe + self.flashinfer_cutlass_moe = partial( + flashinfer_cutlass_moe, + quant_config=FUSED_MOE_UNQUANTIZED_CONFIG, + tp_rank=self.moe.moe_parallel_config.tp_rank, + tp_size=self.moe.moe_parallel_config.tp_size, + ep_rank=self.moe.moe_parallel_config.ep_rank, + ep_size=self.moe.moe_parallel_config.ep_size) + else: + if (self.moe.moe_parallel_config.use_ep + and self.moe.moe_parallel_config.dp_size == 1): + logger.info_once( + "FlashInfer CUTLASS MoE is available for EP" + " but not enabled, consider setting" + " VLLM_USE_FLASHINFER_MOE_FP16=1 to enable it.") + elif self.moe.moe_parallel_config.dp_size > 1: + logger.info_once( + "FlashInfer CUTLASS MoE is currently not available for DP." + ) + self.flashinfer_cutlass_moe = None # type: ignore + def maybe_make_prepare_finalize( self) -> Optional[FusedMoEPrepareAndFinalize]: if self.rocm_aiter_moe_enabled: @@ -367,6 +402,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): num_pad = 256 // weight.element_size() weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad] torch.cuda.empty_cache() + return weight def process_weights_after_loading(self, layer: torch.nn.Module) -> None: @@ -386,6 +422,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): layer.w13_weight.data = shuffled_w13 layer.w2_weight.data = shuffled_w2 + if self.flashinfer_cutlass_moe_enabled: + # Swap halves to arrange as [w3; w1] (kernel expectation) + w1_w, w3_w = torch.chunk(layer.w13_weight.data, 2, dim=1) + w13_weight_swapped = torch.cat([w3_w, w1_w], dim=1) + layer.w13_weight.data = w13_weight_swapped.contiguous() + if current_platform.is_xpu(): import intel_extension_for_pytorch as ipex layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE( @@ -536,6 +578,15 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): expert_map=expert_map, activation=activation, apply_router_weight_on_input=apply_router_weight_on_input) + elif self.flashinfer_cutlass_moe_enabled: + return self.flashinfer_cutlass_moe( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input) elif self.fused_experts is not None: if self.moe.has_bias: raise ValueError( diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 4ba14196682a5..b6afc8651e36d 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -598,6 +598,8 @@ class SharedResizableBuffer: def get(self, shape: tuple[int, ...], device: torch.device, dtype: torch.dtype): + if shape == () or shape is None: + return None shape_numel = prod(shape) if (self.buffer is None or self.buffer.numel() < shape_numel or self.buffer.device != device or self.buffer.dtype != dtype):