diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index 85ce77fb1f7f7..943695f921ad3 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -57,6 +57,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): tp_rank: int = 0, tp_size: int = 1, use_dp: bool = False, + use_deepseek_fp8_block_scale: bool = False, ): super().__init__(quant_config) assert quant_config.quant_dtype in ("nvfp4", torch.float8_e4m3fn, None), ( @@ -69,6 +70,10 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): self.tp_size = tp_size self.out_dtype = out_dtype self.use_dp = use_dp + # Enables DeepSeek-style FP8 block-scale path: + # - pass per-block weight scales to the kernel + # - skip input activation quantization (kernel applies scaling) + self.use_deepseek_fp8_block_scale = use_deepseek_fp8_block_scale @property def activation_formats( @@ -147,7 +152,12 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): "Only activation silu is supported in FlashInferExperts" ) - if self.quant_dtype == torch.float8_e4m3fn: + # Select quantization metadata based on FP8 format/path + if ( + self.quant_dtype == torch.float8_e4m3fn + and not self.use_deepseek_fp8_block_scale + ): + # FP8 per-tensor path: use global alphas/scales; do not pass input_sf quant_scales = [ self.g1_alphas, self.a2_gscale, @@ -176,6 +186,15 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): # FlashInfer API requires weight to be long for nvfp4 fc1_expert_weights = w1.view(torch.long) fc2_expert_weights = w2.view(torch.long) + elif self.use_deepseek_fp8_block_scale: + # FP8 block-scale path: provide block-scale weights, omit a1q_scale + quant_scales = [ + self.w1_scale, + self.w2_scale, + ] + a1q_scale = None + fc1_expert_weights = w1 + fc2_expert_weights = w2 else: quant_scales = None a1q_scale = None @@ -196,6 +215,8 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): ep_size=self.ep_size, ep_rank=self.ep_rank, output=output, + # Informs FlashInfer to use the block-scale decoding path when True + use_deepseek_fp8_block_scale=self.use_deepseek_fp8_block_scale, ) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py index bc9aab5208d9a..762890867e605 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py @@ -28,11 +28,15 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): self, use_dp: bool, num_dispatchers: int = 1, + use_deepseek_fp8_block_scale: bool = False, ): super().__init__() self.num_dispatchers_ = num_dispatchers self.use_dp = use_dp self.local_tokens = None + # Toggle for DeepSeek-style FP8 block-scale path where activations are + # not quantized here and weight block scales are consumed by the kernel. + self.use_deepseek_fp8_block_scale = use_deepseek_fp8_block_scale @property def activation_format(self) -> mk.FusedMoEActivationFormat: @@ -73,8 +77,9 @@ class FlashInferAllToAllMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFina self, use_dp: bool, num_dispatchers: int = 1, + use_deepseek_fp8_block_scale: bool = False, ): - super().__init__(use_dp, num_dispatchers) + super().__init__(use_dp, num_dispatchers, use_deepseek_fp8_block_scale) self.alltoall_info = None # Initialize all2all_manager only for DP case @@ -97,15 +102,19 @@ class FlashInferAllToAllMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFina ) if not self.use_dp: - # Non-DP case: standard quantization - a1q, a1q_scale = moe_kernel_quantize_input( - a1, - quant_config.a1_gscale, - quant_config.quant_dtype, - quant_config.per_act_token_quant, - quant_config.block_shape, - is_fp4_scale_swizzled=not self.use_dp, - ) + # Non-DP case: quantize activations unless using block-scale path + if not self.use_deepseek_fp8_block_scale: + a1q, a1q_scale = moe_kernel_quantize_input( + a1, + quant_config.a1_gscale, + quant_config.quant_dtype, + quant_config.per_act_token_quant, + quant_config.block_shape, + is_fp4_scale_swizzled=not self.use_dp, + ) + else: + a1q = a1 + a1q_scale = None else: # DP case: use FlashInfer AllToAll global_num_tokens_cpu = get_local_sizes() @@ -122,6 +131,7 @@ class FlashInferAllToAllMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFina top_k, num_experts, quant_config, + use_deepseek_fp8_block_scale=self.use_deepseek_fp8_block_scale, ) ) @@ -154,8 +164,9 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin self, use_dp: bool, num_dispatchers: int = 1, + use_deepseek_fp8_block_scale: bool = False, ): - super().__init__(use_dp, num_dispatchers) + super().__init__(use_dp, num_dispatchers, use_deepseek_fp8_block_scale) def prepare( self, @@ -173,22 +184,42 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin if not self.use_dp and quant_config.quant_dtype == "nvfp4": return a1, None, None, topk_ids, topk_weights - a1q, a1q_scale = moe_kernel_quantize_input( - a1, - quant_config.a1_gscale, - quant_config.quant_dtype, - quant_config.per_act_token_quant, - quant_config.block_shape, - is_fp4_scale_swizzled=not self.use_dp, - ) + if not self.use_deepseek_fp8_block_scale: + a1q, a1q_scale = moe_kernel_quantize_input( + a1, + quant_config.a1_gscale, + quant_config.quant_dtype, + quant_config.per_act_token_quant, + quant_config.block_shape, + is_fp4_scale_swizzled=not self.use_dp, + ) + else: + # Block-scale path: pass activations through, omit per-token scales + a1q = a1 + a1q_scale = None if self.use_dp: - topk_weights, topk_ids, a1q, a1q_scale = get_dp_group().all_gatherv( - [topk_weights, topk_ids, a1q, a1q_scale], - dim=0, - sizes=get_local_sizes(), - ) - if quant_config.quant_dtype == "nvfp4": + # Build gather list conditionally - omit a1q_scale if None + # (block-scale path) + gather_list = [topk_weights, topk_ids, a1q] + if a1q_scale is not None: + gather_list.append(a1q_scale) + gathered = get_dp_group().all_gatherv( + gather_list, + dim=0, + sizes=get_local_sizes(), + ) + topk_weights, topk_ids, a1q, a1q_scale = gathered + else: + gathered = get_dp_group().all_gatherv( + gather_list, + dim=0, + sizes=get_local_sizes(), + ) + topk_weights, topk_ids, a1q = gathered + a1q_scale = None + + if quant_config.quant_dtype == "nvfp4" and a1q_scale is not None: a1q_scale = nvfp4_block_scale_interleave(a1q_scale) return a1q, a1q_scale, None, topk_ids, topk_weights @@ -221,6 +252,7 @@ def flashinfer_alltoall_dispatch( top_k: int, num_experts: int, quant_config: FusedMoEQuantConfig, + use_deepseek_fp8_block_scale: bool = False, ): from flashinfer.comm.trtllm_alltoall import MnnvlMoe @@ -250,30 +282,42 @@ def flashinfer_alltoall_dispatch( ) topk_weights = topk_weights.view(dtype=orig_topk_weights_dtype) - x, x_sf = moe_kernel_quantize_input( - x, - gs, - quant_config.quant_dtype, - quant_config.per_act_token_quant, - quant_config.block_shape, - is_fp4_scale_swizzled=False, # delay swizzle to after comm - ) - x = MnnvlMoe.mnnvl_moe_alltoallv( - x, - alltoall_info, - all2all_manager.workspace_tensor, - ep_rank, - ep_size, - ) + if not use_deepseek_fp8_block_scale: + x, x_sf = moe_kernel_quantize_input( + x, + gs, + quant_config.quant_dtype, + quant_config.per_act_token_quant, + quant_config.block_shape, + is_fp4_scale_swizzled=False, # delay swizzle to after comm + ) + x = MnnvlMoe.mnnvl_moe_alltoallv( + x, + alltoall_info, + all2all_manager.workspace_tensor, + ep_rank, + ep_size, + ) - x_sf = MnnvlMoe.mnnvl_moe_alltoallv( - x_sf, - alltoall_info, - all2all_manager.workspace_tensor, - ep_rank, - ep_size, - ) - x_sf = nvfp4_block_scale_interleave(x_sf) + x_sf = MnnvlMoe.mnnvl_moe_alltoallv( + x_sf, + alltoall_info, + all2all_manager.workspace_tensor, + ep_rank, + ep_size, + ) + if quant_config.quant_dtype == "nvfp4": + x_sf = nvfp4_block_scale_interleave(x_sf) + else: + # Block-scale path: pass activations through without quantization + x_sf = None + x = MnnvlMoe.mnnvl_moe_alltoallv( + x, + alltoall_info, + all2all_manager.workspace_tensor, + ep_rank, + ep_size, + ) return alltoall_info, topk_ids, topk_weights, x, x_sf @@ -304,6 +348,7 @@ def create_flashinfer_prepare_finalize( use_dp: bool, use_nvfp4: bool = False, enable_alltoallv: bool = False, + use_deepseek_fp8_block_scale: bool = False, ) -> FlashInferCutlassMoEPrepareAndFinalize: """Factory function to create the appropriate FlashInfer implementation.""" if use_nvfp4: @@ -311,5 +356,7 @@ def create_flashinfer_prepare_finalize( return FlashInferAllToAllMoEPrepareAndFinalize(use_dp) else: return FlashInferAllGatherMoEPrepareAndFinalize(use_dp) - # Fp8 only supports AllGather - return FlashInferAllGatherMoEPrepareAndFinalize(use_dp) + # FP8 path currently supported via AllGather; optionally enable block-scale + return FlashInferAllGatherMoEPrepareAndFinalize( + use_dp=use_dp, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale + ) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index bbd0a4df1048b..0479bec338408 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -3,6 +3,7 @@ from collections.abc import Callable from enum import Enum +from functools import partial from typing import TYPE_CHECKING, Any, Optional import torch @@ -122,10 +123,13 @@ def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend: Select the primary FP8 MoE backend Note: Shape-specific fallbacks may still occur at runtime. """ - # prefer FlashInfer backends when available and enabled on supported GPUs + # Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100. if ( current_platform.is_cuda() - and current_platform.is_device_capability(100) + and ( + current_platform.is_device_capability(100) + or current_platform.is_device_capability(90) + ) and envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe() ): @@ -134,14 +138,14 @@ def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend: logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100") return Fp8MoeBackend.FLASHINFER_TRTLLM else: - if block_quant: + if block_quant and current_platform.is_device_capability(100): raise ValueError( "FlashInfer FP8 MoE throughput backend does not " "support block quantization. Please use " "VLLM_FLASHINFER_MOE_BACKEND=latency " "instead." ) - logger.info_once("Using FlashInfer FP8 MoE CUTLASS backend for SM100") + logger.info_once("Using FlashInfer FP8 MoE CUTLASS backend for SM90/SM100") return Fp8MoeBackend.FLASHINFER_CUTLASS # weight-only path for older GPUs without native FP8 @@ -641,6 +645,16 @@ class Fp8MoEMethod(FusedMoEMethodBase): self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS: self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS + if self.block_quant: + assert self.weight_block_size == [128, 128], ( + f"Only support weight_block_size == [128, 128], " + f"got {self.weight_block_size}" + ) + self.flashinfer_moe_fn = partial( + flashinfer_cutlass_moe_fp8, + moe=self.moe, + use_deepseek_fp8_block_scale=self.block_quant, + ) self.allow_deep_gemm = self.fp8_backend == Fp8MoeBackend.DEEPGEMM self.allow_cutlass_block_scaled_grouped_gemm = ( @@ -1012,8 +1026,15 @@ class Fp8MoEMethod(FusedMoEMethodBase): ): return None elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: + if self.block_quant: + assert self.weight_block_size == [128, 128], ( + f"Only support weight_block_size == [128, 128], " + f"got {self.weight_block_size}" + ) + # Wire block-scale flag through prepare/finalize when using CUTLASS prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize( - self.moe + self.moe, + use_deepseek_fp8_block_scale=self.block_quant, ) logger.debug_once("%s", prepare_finalize.__class__.__name__) return prepare_finalize @@ -1062,9 +1083,11 @@ class Fp8MoEMethod(FusedMoEMethodBase): ) elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: + # Select GEMM experts with block-scale when weights are block-quantized experts = select_cutlass_fp8_gemm_impl( self.moe, self.moe_quant_config, + use_deepseek_fp8_block_scale=self.block_quant, ) logger.debug_once("Using %s", experts.__class__.__name__) return experts @@ -1251,16 +1274,17 @@ class Fp8MoEMethod(FusedMoEMethodBase): workspace=layer.workspace, ) elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: - assert not self.block_quant - assert not renormalize and custom_routing_function is not None assert activation == "silu", ( f"Expected 'silu' activation but got {activation}" ) - assert scoring_func == "sigmoid", ( - f"Expected 'sigmoid' scoring func but got {scoring_func}" - ) - - result = flashinfer_cutlass_moe_fp8( + if not self.block_quant: + assert not renormalize and custom_routing_function is not None + assert scoring_func == "sigmoid", ( + f"Expected 'sigmoid' scoring func but got {scoring_func}" + ) + # Delegate to CUTLASS FlashInfer path; function already bound with + # use_deepseek_fp8_block_scale for block-quant when applicable + result = self.flashinfer_moe_fn( x, layer, topk_weights, diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index e49d374f154d8..d9e9b42402712 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -17,6 +17,7 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 create_flashinfer_prepare_finalize, ) +from vllm.platforms import current_platform logger = init_logger(__name__) @@ -190,17 +191,22 @@ def register_moe_scaling_factors(layer: torch.nn.Module) -> None: def build_flashinfer_fp8_cutlass_moe_prepare_finalize( - moe: FusedMoEConfig | None, + moe: FusedMoEConfig | None, use_deepseek_fp8_block_scale: bool = False ) -> mk.FusedMoEPrepareAndFinalize: """Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel""" use_dp = moe.moe_parallel_config.dp_size > 1 if moe is not None else False - return create_flashinfer_prepare_finalize(use_dp) + # Propagate block-scale flag so prepare/finalize can skip act quantization + # and inform the kernel to consume per-block weight scales. + return create_flashinfer_prepare_finalize( + use_dp, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale + ) def select_cutlass_fp8_gemm_impl( moe: FusedMoEConfig | None, quant_config: FusedMoEQuantConfig, out_dtype: torch.dtype | None = None, + use_deepseek_fp8_block_scale: bool = False, ) -> mk.FusedMoEPermuteExpertsUnpermute: """Return a GEMM *experts* implementation for fused-MoE layers""" @@ -212,12 +218,14 @@ def select_cutlass_fp8_gemm_impl( ep_size=moe.moe_parallel_config.ep_size, tp_rank=moe.moe_parallel_config.tp_rank, tp_size=moe.moe_parallel_config.tp_size, + use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale, ) assert out_dtype is not None, "If moe config is None, out_dtype must be passed" return FlashInferExperts( out_dtype=out_dtype, quant_config=quant_config, + use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale, ) @@ -231,14 +239,22 @@ def flashinfer_cutlass_moe_fp8( global_num_experts: int = -1, expert_map: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, + use_deepseek_fp8_block_scale: bool = False, + moe: FusedMoEConfig | None = None, ) -> torch.Tensor: quant_config = layer.quant_method.get_fused_moe_quant_config(layer) assert quant_config is not None + # Construct modular kernel with block-scale support when requested. fused_experts = mk.FusedMoEModularKernel( - build_flashinfer_fp8_cutlass_moe_prepare_finalize(moe=None), + build_flashinfer_fp8_cutlass_moe_prepare_finalize( + moe=moe, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale + ), select_cutlass_fp8_gemm_impl( - moe=None, quant_config=quant_config, out_dtype=hidden_states.dtype + moe=moe, + quant_config=quant_config, + out_dtype=hidden_states.dtype, + use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale, ), ) @@ -258,7 +274,10 @@ def flashinfer_cutlass_moe_fp8( def get_flashinfer_moe_backend() -> FlashinferMoeBackend: flashinfer_moe_backend = envs.VLLM_FLASHINFER_MOE_BACKEND - if flashinfer_moe_backend == "throughput": + # Prefer CUTLASS on SM90 to cover both SM90/SM100 generations + if flashinfer_moe_backend == "throughput" or current_platform.is_device_capability( + 90 + ): return FlashinferMoeBackend.CUTLASS elif flashinfer_moe_backend == "latency": return FlashinferMoeBackend.TENSORRT_LLM