[NVIDIA] Add SM100 Flashinfer MoE blockscale fp8 backend for low latency (#20645)

Signed-off-by: kaixih <kaixih@nvidia.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Kaixi Hou 2025-07-19 02:33:01 -07:00 committed by GitHub
parent 7d94577138
commit 6d0734c562
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 187 additions and 31 deletions

View File

@ -119,7 +119,8 @@ if TYPE_CHECKING:
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
VLLM_TPU_MOST_MODEL_LEN: Optional[int] = None
VLLM_USE_DEEP_GEMM: bool = False
VLLM_USE_FLASHINFER_MOE: bool = False
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
VLLM_USE_FLASHINFER_MOE_FP4: bool = False
VLLM_XGRAMMAR_CACHE_MB: int = 0
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False
@ -854,9 +855,13 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_DEEP_GEMM":
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))),
# Allow use of FlashInfer MoE kernels for fused moe ops.
"VLLM_USE_FLASHINFER_MOE_FP8":
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP8", "0"))),
# Allow use of FlashInfer CUTLASS kernels for fused moe ops.
"VLLM_USE_FLASHINFER_MOE":
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE", "0"))),
"VLLM_USE_FLASHINFER_MOE_FP4":
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP4", "0"))),
# Control the cache sized used by the xgrammar compiler. The default
# of 512 MB should be enough for roughly 1000 JSON schemas.

View File

@ -191,7 +191,7 @@ class FusedMoEParallelConfig:
@property
def use_flashinfer_cutlass_kernels(self):
return (envs.VLLM_USE_FLASHINFER_MOE
return (envs.VLLM_USE_FLASHINFER_MOE_FP4
and has_flashinfer_cutlass_fused_moe())
@staticmethod

View File

@ -28,7 +28,7 @@ from vllm.model_executor.layers.fused_moe.prepare_finalize import (
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP)
from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache, moe_kernel_quantize_input)
_resize_cache, moe_kernel_quantize_input, per_token_group_quant_fp8)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
dequant_mxfp4)
from vllm.platforms import current_platform
@ -1061,6 +1061,104 @@ direct_register_custom_op(
)
def next_positive_power_of_2(x: int) -> int:
if x < 1:
return 1
return 1 << (x - 1).bit_length()
def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
# Guess tokens per expert assuming perfect expert distribution first.
num_tokens_per_expert = (num_tokens * top_k) // num_experts
# And pad the number to the next power of 2.
tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
return tile_tokens_dim
def flashinfer_fused_moe_blockscale_fp8(
routing_logits: torch.Tensor,
routing_bias: torch.Tensor,
x: torch.Tensor,
w13_weight: torch.Tensor,
w13_weight_scale_inv: torch.Tensor,
w2_weight: torch.Tensor,
w2_weight_scale_inv: torch.Tensor,
global_num_experts: int,
top_k: int,
num_expert_group: int,
topk_group: int,
intermediate_size: int,
expert_offset: int,
local_num_experts: int,
block_shape: list[int],
routed_scaling: float = 1.0) -> torch.Tensor:
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe
assert top_k <= global_num_experts
assert top_k <= 8
assert topk_group <= 4
assert global_num_experts > num_expert_group
assert global_num_experts % num_expert_group == 0
assert global_num_experts % 4 == 0
assert top_k < (topk_group * global_num_experts / num_expert_group)
assert block_shape == [128, 128]
a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1])
# NOTE: scales of hidden states have to be transposed!
a_sf_t = a_sf.t().contiguous()
return flashinfer_trtllm_fp8_block_scale_moe(
routing_logits=routing_logits,
routing_bias=routing_bias,
hidden_states=a_q,
hidden_states_scale=a_sf_t,
gemm1_weights=w13_weight,
gemm1_weights_scale=w13_weight_scale_inv,
gemm2_weights=w2_weight,
gemm2_weights_scale=w2_weight_scale_inv,
num_experts=global_num_experts,
top_k=top_k,
n_group=num_expert_group,
topk_group=topk_group,
intermediate_size=intermediate_size,
local_expert_offset=expert_offset,
local_num_experts=local_num_experts,
routed_scaling_factor=routed_scaling,
tile_tokens_dim=_get_tile_tokens_dim(x.shape[0], top_k,
global_num_experts),
routing_method_type=2, # DeepSeek-styled routing method
)
def flashinfer_fused_moe_blockscale_fp8_fake(
routing_logits: torch.Tensor,
routing_bias: torch.Tensor,
x: torch.Tensor,
w13_weight: torch.Tensor,
w13_weight_scale_inv: torch.Tensor,
w2_weight: torch.Tensor,
w2_weight_scale_inv: torch.Tensor,
global_num_experts: int,
top_k: int,
num_expert_group: int,
topk_group: int,
intermediate_size: int,
expert_offset: int,
local_num_experts: int,
block_shape: list[int],
routed_scaling: float = 1.0) -> torch.Tensor:
return torch.empty_like(x)
direct_register_custom_op(
op_name="flashinfer_fused_moe_blockscale_fp8",
op_func=flashinfer_fused_moe_blockscale_fp8,
mutates_args=[],
fake_impl=flashinfer_fused_moe_blockscale_fp8_fake,
tags=(torch.Tag.needs_fixed_stride_order, ),
)
def outplace_fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor,

View File

@ -43,6 +43,7 @@ from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.utils import has_deep_gemm
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
from vllm.utils.flashinfer import has_flashinfer_moe
if TYPE_CHECKING:
from vllm.model_executor.models.utils import WeightsMapper
@ -52,6 +53,11 @@ ACTIVATION_SCHEMES = ["static", "dynamic"]
logger = init_logger(__name__)
def _swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor:
return x.reshape(-1, 2, x.shape[-2] // 2,
x.shape[-1]).flip(dims=[1]).reshape(x.shape)
def _is_col_major(x: torch.Tensor) -> bool:
assert x.dim() == 3
b, m, n = x.shape
@ -473,6 +479,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self.quant_config = quant_config
self.block_quant = self.quant_config.weight_block_size is not None
self.flashinfer_moe_enabled = False
if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
logger.info_once(
"Using FlashInfer MoE FP8 kernels for Fp8MoEMethod.")
self.flashinfer_moe_enabled = True
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
self.use_marlin = (not current_platform.has_device_capability(89)
@ -674,6 +685,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
normalize_e4m3fn_to_e4m3fnuz(
layer.w2_weight, layer.w2_weight_scale_inv,
layer.w2_input_scale)
elif self.flashinfer_moe_enabled:
# NOTE: weights have to be swapped since the activation is
# applied on different half for flashinfer vs vllm
w13_weight = _swap_w13_to_w31(layer.w13_weight.data)
w13_weight_scale_inv = _swap_w13_to_w31(
layer.w13_weight_scale_inv.data)
w2_weight = layer.w2_weight.data
w2_weight_scale_inv = layer.w2_weight_scale_inv.data
else:
w13_weight = layer.w13_weight.data
w13_weight_scale_inv = layer.w13_weight_scale_inv.data
@ -915,25 +934,25 @@ class Fp8MoEMethod(FusedMoEMethodBase):
assert logical_to_physical_map is not None
assert logical_replica_count is not None
assert isinstance(layer, FusedMoE)
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
enable_eplb=enable_eplb,
expert_map=expert_map,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
if not self.flashinfer_moe_enabled:
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
enable_eplb=enable_eplb,
expert_map=expert_map,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
if self.rocm_aiter_moe_enabled:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
@ -971,6 +990,31 @@ class Fp8MoEMethod(FusedMoEMethodBase):
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map)
elif self.flashinfer_moe_enabled:
# Currently only work with DS models
assert self.block_quant
assert (renormalize and use_grouped_topk
and scoring_func == 'sigmoid'
and custom_routing_function is None)
assert activation == "silu"
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
routing_logits=router_logits.to(torch.float32),
routing_bias=e_score_correction_bias,
x=x,
w13_weight=layer.w13_weight,
w13_weight_scale_inv=layer.w13_weight_scale_inv,
w2_weight=layer.w2_weight,
w2_weight_scale_inv=layer.w2_weight_scale_inv,
global_num_experts=global_num_experts,
top_k=top_k,
num_expert_group=num_expert_group,
topk_group=topk_group,
intermediate_size=layer.intermediate_size_per_partition,
expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
block_shape=self.quant_config.weight_block_size,
routed_scaling=1.0,
)
else:
return self.fused_experts(
hidden_states=x,

View File

@ -721,7 +721,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
self.use_marlin = False
self.allow_flashinfer_cutlass = False
if envs.VLLM_USE_FLASHINFER_MOE:
if envs.VLLM_USE_FLASHINFER_MOE_FP4:
if self.cutlass_nvfp4_supported and current_platform.is_cuda() \
and current_platform.is_device_capability(100):
logger.info_once(
@ -800,10 +800,9 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
assert moe.dp_size > 1
logger.debug_once("Using CutlassExpertsFp4")
# Currently CutlassExpertsFp4 doesn't support DP
raise ValueError(
"CutlassExpertsFp4 doesn't support DP. "
"Use flashinfer CUTLASS FusedMoE(VLLM_USE_FLASHINFER_MOE)"
" backend instead.")
raise ValueError("CutlassExpertsFp4 doesn't support DP. "
"Use flashinfer CUTLASS FusedMoE backend instead "
"(set VLLM_USE_FLASHINFER_MOE_FP4=1)")
return experts

View File

@ -64,6 +64,8 @@ def _lazy_import_wrapper(module_name: str,
# Create lazy wrappers for each function
flashinfer_trtllm_fp8_block_scale_moe = _lazy_import_wrapper(
"flashinfer.fused_moe", "trtllm_fp8_block_scale_moe")
flashinfer_cutlass_fused_moe = _lazy_import_wrapper("flashinfer.fused_moe",
"cutlass_fused_moe")
fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize")
@ -77,10 +79,16 @@ autotune = _lazy_import_wrapper(
fallback_fn=lambda *args, **kwargs: contextlib.nullcontext())
@functools.cache
def has_flashinfer_moe() -> bool:
"""Return ``True`` if FlashInfer MoE module is available."""
return importlib.util.find_spec("flashinfer.fused_moe") is not None
@functools.cache
def has_flashinfer_cutlass_fused_moe() -> bool:
"""Return ``True`` if FlashInfer CUTLASS fused MoE is available."""
if not has_flashinfer():
if not has_flashinfer_moe():
return False
# Check if all required functions are available
@ -99,9 +107,11 @@ def has_flashinfer_cutlass_fused_moe() -> bool:
__all__ = [
"has_flashinfer",
"has_flashinfer_cutlass_fused_moe",
"flashinfer_trtllm_fp8_block_scale_moe",
"flashinfer_cutlass_fused_moe",
"fp4_quantize",
"fp4_swizzle_blockscale",
"autotune",
"has_flashinfer_moe",
"has_flashinfer_cutlass_fused_moe",
]