[Kernels] Enable FlashInfer FP8 Blockscale on SM90 (for TEP DSR1) (#27134)

Signed-off-by: Duncan Moss <djm.moss@gmail.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
Duncan Moss 2025-11-14 08:02:44 -08:00 committed by GitHub
parent 511a6b611d
commit 3f8a874065
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 179 additions and 68 deletions

View File

@ -57,6 +57,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
tp_rank: int = 0,
tp_size: int = 1,
use_dp: bool = False,
use_deepseek_fp8_block_scale: bool = False,
):
super().__init__(quant_config)
assert quant_config.quant_dtype in ("nvfp4", torch.float8_e4m3fn, None), (
@ -69,6 +70,10 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
self.tp_size = tp_size
self.out_dtype = out_dtype
self.use_dp = use_dp
# Enables DeepSeek-style FP8 block-scale path:
# - pass per-block weight scales to the kernel
# - skip input activation quantization (kernel applies scaling)
self.use_deepseek_fp8_block_scale = use_deepseek_fp8_block_scale
@property
def activation_formats(
@ -147,7 +152,12 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
"Only activation silu is supported in FlashInferExperts"
)
if self.quant_dtype == torch.float8_e4m3fn:
# Select quantization metadata based on FP8 format/path
if (
self.quant_dtype == torch.float8_e4m3fn
and not self.use_deepseek_fp8_block_scale
):
# FP8 per-tensor path: use global alphas/scales; do not pass input_sf
quant_scales = [
self.g1_alphas,
self.a2_gscale,
@ -176,6 +186,15 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
# FlashInfer API requires weight to be long for nvfp4
fc1_expert_weights = w1.view(torch.long)
fc2_expert_weights = w2.view(torch.long)
elif self.use_deepseek_fp8_block_scale:
# FP8 block-scale path: provide block-scale weights, omit a1q_scale
quant_scales = [
self.w1_scale,
self.w2_scale,
]
a1q_scale = None
fc1_expert_weights = w1
fc2_expert_weights = w2
else:
quant_scales = None
a1q_scale = None
@ -196,6 +215,8 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
ep_size=self.ep_size,
ep_rank=self.ep_rank,
output=output,
# Informs FlashInfer to use the block-scale decoding path when True
use_deepseek_fp8_block_scale=self.use_deepseek_fp8_block_scale,
)

View File

@ -28,11 +28,15 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
self,
use_dp: bool,
num_dispatchers: int = 1,
use_deepseek_fp8_block_scale: bool = False,
):
super().__init__()
self.num_dispatchers_ = num_dispatchers
self.use_dp = use_dp
self.local_tokens = None
# Toggle for DeepSeek-style FP8 block-scale path where activations are
# not quantized here and weight block scales are consumed by the kernel.
self.use_deepseek_fp8_block_scale = use_deepseek_fp8_block_scale
@property
def activation_format(self) -> mk.FusedMoEActivationFormat:
@ -73,8 +77,9 @@ class FlashInferAllToAllMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFina
self,
use_dp: bool,
num_dispatchers: int = 1,
use_deepseek_fp8_block_scale: bool = False,
):
super().__init__(use_dp, num_dispatchers)
super().__init__(use_dp, num_dispatchers, use_deepseek_fp8_block_scale)
self.alltoall_info = None
# Initialize all2all_manager only for DP case
@ -97,15 +102,19 @@ class FlashInferAllToAllMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFina
)
if not self.use_dp:
# Non-DP case: standard quantization
a1q, a1q_scale = moe_kernel_quantize_input(
a1,
quant_config.a1_gscale,
quant_config.quant_dtype,
quant_config.per_act_token_quant,
quant_config.block_shape,
is_fp4_scale_swizzled=not self.use_dp,
)
# Non-DP case: quantize activations unless using block-scale path
if not self.use_deepseek_fp8_block_scale:
a1q, a1q_scale = moe_kernel_quantize_input(
a1,
quant_config.a1_gscale,
quant_config.quant_dtype,
quant_config.per_act_token_quant,
quant_config.block_shape,
is_fp4_scale_swizzled=not self.use_dp,
)
else:
a1q = a1
a1q_scale = None
else:
# DP case: use FlashInfer AllToAll
global_num_tokens_cpu = get_local_sizes()
@ -122,6 +131,7 @@ class FlashInferAllToAllMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFina
top_k,
num_experts,
quant_config,
use_deepseek_fp8_block_scale=self.use_deepseek_fp8_block_scale,
)
)
@ -154,8 +164,9 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin
self,
use_dp: bool,
num_dispatchers: int = 1,
use_deepseek_fp8_block_scale: bool = False,
):
super().__init__(use_dp, num_dispatchers)
super().__init__(use_dp, num_dispatchers, use_deepseek_fp8_block_scale)
def prepare(
self,
@ -173,22 +184,42 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin
if not self.use_dp and quant_config.quant_dtype == "nvfp4":
return a1, None, None, topk_ids, topk_weights
a1q, a1q_scale = moe_kernel_quantize_input(
a1,
quant_config.a1_gscale,
quant_config.quant_dtype,
quant_config.per_act_token_quant,
quant_config.block_shape,
is_fp4_scale_swizzled=not self.use_dp,
)
if not self.use_deepseek_fp8_block_scale:
a1q, a1q_scale = moe_kernel_quantize_input(
a1,
quant_config.a1_gscale,
quant_config.quant_dtype,
quant_config.per_act_token_quant,
quant_config.block_shape,
is_fp4_scale_swizzled=not self.use_dp,
)
else:
# Block-scale path: pass activations through, omit per-token scales
a1q = a1
a1q_scale = None
if self.use_dp:
topk_weights, topk_ids, a1q, a1q_scale = get_dp_group().all_gatherv(
[topk_weights, topk_ids, a1q, a1q_scale],
dim=0,
sizes=get_local_sizes(),
)
if quant_config.quant_dtype == "nvfp4":
# Build gather list conditionally - omit a1q_scale if None
# (block-scale path)
gather_list = [topk_weights, topk_ids, a1q]
if a1q_scale is not None:
gather_list.append(a1q_scale)
gathered = get_dp_group().all_gatherv(
gather_list,
dim=0,
sizes=get_local_sizes(),
)
topk_weights, topk_ids, a1q, a1q_scale = gathered
else:
gathered = get_dp_group().all_gatherv(
gather_list,
dim=0,
sizes=get_local_sizes(),
)
topk_weights, topk_ids, a1q = gathered
a1q_scale = None
if quant_config.quant_dtype == "nvfp4" and a1q_scale is not None:
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
return a1q, a1q_scale, None, topk_ids, topk_weights
@ -221,6 +252,7 @@ def flashinfer_alltoall_dispatch(
top_k: int,
num_experts: int,
quant_config: FusedMoEQuantConfig,
use_deepseek_fp8_block_scale: bool = False,
):
from flashinfer.comm.trtllm_alltoall import MnnvlMoe
@ -250,30 +282,42 @@ def flashinfer_alltoall_dispatch(
)
topk_weights = topk_weights.view(dtype=orig_topk_weights_dtype)
x, x_sf = moe_kernel_quantize_input(
x,
gs,
quant_config.quant_dtype,
quant_config.per_act_token_quant,
quant_config.block_shape,
is_fp4_scale_swizzled=False, # delay swizzle to after comm
)
x = MnnvlMoe.mnnvl_moe_alltoallv(
x,
alltoall_info,
all2all_manager.workspace_tensor,
ep_rank,
ep_size,
)
if not use_deepseek_fp8_block_scale:
x, x_sf = moe_kernel_quantize_input(
x,
gs,
quant_config.quant_dtype,
quant_config.per_act_token_quant,
quant_config.block_shape,
is_fp4_scale_swizzled=False, # delay swizzle to after comm
)
x = MnnvlMoe.mnnvl_moe_alltoallv(
x,
alltoall_info,
all2all_manager.workspace_tensor,
ep_rank,
ep_size,
)
x_sf = MnnvlMoe.mnnvl_moe_alltoallv(
x_sf,
alltoall_info,
all2all_manager.workspace_tensor,
ep_rank,
ep_size,
)
x_sf = nvfp4_block_scale_interleave(x_sf)
x_sf = MnnvlMoe.mnnvl_moe_alltoallv(
x_sf,
alltoall_info,
all2all_manager.workspace_tensor,
ep_rank,
ep_size,
)
if quant_config.quant_dtype == "nvfp4":
x_sf = nvfp4_block_scale_interleave(x_sf)
else:
# Block-scale path: pass activations through without quantization
x_sf = None
x = MnnvlMoe.mnnvl_moe_alltoallv(
x,
alltoall_info,
all2all_manager.workspace_tensor,
ep_rank,
ep_size,
)
return alltoall_info, topk_ids, topk_weights, x, x_sf
@ -304,6 +348,7 @@ def create_flashinfer_prepare_finalize(
use_dp: bool,
use_nvfp4: bool = False,
enable_alltoallv: bool = False,
use_deepseek_fp8_block_scale: bool = False,
) -> FlashInferCutlassMoEPrepareAndFinalize:
"""Factory function to create the appropriate FlashInfer implementation."""
if use_nvfp4:
@ -311,5 +356,7 @@ def create_flashinfer_prepare_finalize(
return FlashInferAllToAllMoEPrepareAndFinalize(use_dp)
else:
return FlashInferAllGatherMoEPrepareAndFinalize(use_dp)
# Fp8 only supports AllGather
return FlashInferAllGatherMoEPrepareAndFinalize(use_dp)
# FP8 path currently supported via AllGather; optionally enable block-scale
return FlashInferAllGatherMoEPrepareAndFinalize(
use_dp=use_dp, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale
)

View File

@ -3,6 +3,7 @@
from collections.abc import Callable
from enum import Enum
from functools import partial
from typing import TYPE_CHECKING, Any, Optional
import torch
@ -122,10 +123,13 @@ def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend:
Select the primary FP8 MoE backend
Note: Shape-specific fallbacks may still occur at runtime.
"""
# prefer FlashInfer backends when available and enabled on supported GPUs
# Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100.
if (
current_platform.is_cuda()
and current_platform.is_device_capability(100)
and (
current_platform.is_device_capability(100)
or current_platform.is_device_capability(90)
)
and envs.VLLM_USE_FLASHINFER_MOE_FP8
and has_flashinfer_moe()
):
@ -134,14 +138,14 @@ def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend:
logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100")
return Fp8MoeBackend.FLASHINFER_TRTLLM
else:
if block_quant:
if block_quant and current_platform.is_device_capability(100):
raise ValueError(
"FlashInfer FP8 MoE throughput backend does not "
"support block quantization. Please use "
"VLLM_FLASHINFER_MOE_BACKEND=latency "
"instead."
)
logger.info_once("Using FlashInfer FP8 MoE CUTLASS backend for SM100")
logger.info_once("Using FlashInfer FP8 MoE CUTLASS backend for SM90/SM100")
return Fp8MoeBackend.FLASHINFER_CUTLASS
# weight-only path for older GPUs without native FP8
@ -641,6 +645,16 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
if self.block_quant:
assert self.weight_block_size == [128, 128], (
f"Only support weight_block_size == [128, 128], "
f"got {self.weight_block_size}"
)
self.flashinfer_moe_fn = partial(
flashinfer_cutlass_moe_fp8,
moe=self.moe,
use_deepseek_fp8_block_scale=self.block_quant,
)
self.allow_deep_gemm = self.fp8_backend == Fp8MoeBackend.DEEPGEMM
self.allow_cutlass_block_scaled_grouped_gemm = (
@ -1012,8 +1026,15 @@ class Fp8MoEMethod(FusedMoEMethodBase):
):
return None
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
if self.block_quant:
assert self.weight_block_size == [128, 128], (
f"Only support weight_block_size == [128, 128], "
f"got {self.weight_block_size}"
)
# Wire block-scale flag through prepare/finalize when using CUTLASS
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
self.moe
self.moe,
use_deepseek_fp8_block_scale=self.block_quant,
)
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
@ -1062,9 +1083,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
)
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
# Select GEMM experts with block-scale when weights are block-quantized
experts = select_cutlass_fp8_gemm_impl(
self.moe,
self.moe_quant_config,
use_deepseek_fp8_block_scale=self.block_quant,
)
logger.debug_once("Using %s", experts.__class__.__name__)
return experts
@ -1251,16 +1274,17 @@ class Fp8MoEMethod(FusedMoEMethodBase):
workspace=layer.workspace,
)
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
assert not self.block_quant
assert not renormalize and custom_routing_function is not None
assert activation == "silu", (
f"Expected 'silu' activation but got {activation}"
)
assert scoring_func == "sigmoid", (
f"Expected 'sigmoid' scoring func but got {scoring_func}"
)
result = flashinfer_cutlass_moe_fp8(
if not self.block_quant:
assert not renormalize and custom_routing_function is not None
assert scoring_func == "sigmoid", (
f"Expected 'sigmoid' scoring func but got {scoring_func}"
)
# Delegate to CUTLASS FlashInfer path; function already bound with
# use_deepseek_fp8_block_scale for block-quant when applicable
result = self.flashinfer_moe_fn(
x,
layer,
topk_weights,

View File

@ -17,6 +17,7 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
create_flashinfer_prepare_finalize,
)
from vllm.platforms import current_platform
logger = init_logger(__name__)
@ -190,17 +191,22 @@ def register_moe_scaling_factors(layer: torch.nn.Module) -> None:
def build_flashinfer_fp8_cutlass_moe_prepare_finalize(
moe: FusedMoEConfig | None,
moe: FusedMoEConfig | None, use_deepseek_fp8_block_scale: bool = False
) -> mk.FusedMoEPrepareAndFinalize:
"""Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel"""
use_dp = moe.moe_parallel_config.dp_size > 1 if moe is not None else False
return create_flashinfer_prepare_finalize(use_dp)
# Propagate block-scale flag so prepare/finalize can skip act quantization
# and inform the kernel to consume per-block weight scales.
return create_flashinfer_prepare_finalize(
use_dp, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale
)
def select_cutlass_fp8_gemm_impl(
moe: FusedMoEConfig | None,
quant_config: FusedMoEQuantConfig,
out_dtype: torch.dtype | None = None,
use_deepseek_fp8_block_scale: bool = False,
) -> mk.FusedMoEPermuteExpertsUnpermute:
"""Return a GEMM *experts* implementation for fused-MoE layers"""
@ -212,12 +218,14 @@ def select_cutlass_fp8_gemm_impl(
ep_size=moe.moe_parallel_config.ep_size,
tp_rank=moe.moe_parallel_config.tp_rank,
tp_size=moe.moe_parallel_config.tp_size,
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
)
assert out_dtype is not None, "If moe config is None, out_dtype must be passed"
return FlashInferExperts(
out_dtype=out_dtype,
quant_config=quant_config,
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
)
@ -231,14 +239,22 @@ def flashinfer_cutlass_moe_fp8(
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
use_deepseek_fp8_block_scale: bool = False,
moe: FusedMoEConfig | None = None,
) -> torch.Tensor:
quant_config = layer.quant_method.get_fused_moe_quant_config(layer)
assert quant_config is not None
# Construct modular kernel with block-scale support when requested.
fused_experts = mk.FusedMoEModularKernel(
build_flashinfer_fp8_cutlass_moe_prepare_finalize(moe=None),
build_flashinfer_fp8_cutlass_moe_prepare_finalize(
moe=moe, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale
),
select_cutlass_fp8_gemm_impl(
moe=None, quant_config=quant_config, out_dtype=hidden_states.dtype
moe=moe,
quant_config=quant_config,
out_dtype=hidden_states.dtype,
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
),
)
@ -258,7 +274,10 @@ def flashinfer_cutlass_moe_fp8(
def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
flashinfer_moe_backend = envs.VLLM_FLASHINFER_MOE_BACKEND
if flashinfer_moe_backend == "throughput":
# Prefer CUTLASS on SM90 to cover both SM90/SM100 generations
if flashinfer_moe_backend == "throughput" or current_platform.is_device_capability(
90
):
return FlashinferMoeBackend.CUTLASS
elif flashinfer_moe_backend == "latency":
return FlashinferMoeBackend.TENSORRT_LLM