Support Tensorrt-LLM MoE fp4 for low-latency (#21331)

Signed-off-by: Shu Wang <shuw@nvidia.com>
Signed-off-by: Po-Han Huang <pohanh@nvidia.com>
Signed-off-by: Shu Wang. <shuw@nvidia.com>
Signed-off-by: XIn Li <xinli@nvidia.com>
Co-authored-by: XIn Li <xinli@nvidia.com>
This commit is contained in:
Shu Wang 2025-08-07 21:18:22 -05:00 committed by GitHub
parent d57dc2364e
commit a3b9c17b56
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 288 additions and 43 deletions

View File

@ -129,6 +129,7 @@ if TYPE_CHECKING:
VLLM_SKIP_DEEP_GEMM_WARMUP: bool = False
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
VLLM_USE_FLASHINFER_MOE_FP4: bool = False
VLLM_FLASHINFER_MOE_BACKEND: str = "throughput"
VLLM_XGRAMMAR_CACHE_MB: int = 0
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False
@ -982,6 +983,20 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ALL2ALL_BACKEND":
lambda: os.getenv("VLLM_ALL2ALL_BACKEND", "naive"),
# Flashinfer MoE backend for vLLM's fused Mixture-of-Experts support. Both
# require compute capability 10.0 or above.
# Available options:
# - "throughput": [default]
# Uses CUTLASS kernels optimized for high-throughput batch inference.
# - "latency":
# Uses TensorRT-LLM kernels optimized for low-latency inference.
# To set this backend, define the environment variable:
# export VLLM_FLASHINFER_MOE_BACKEND=latency.
# If not set, defaults to "throughput".
"VLLM_FLASHINFER_MOE_BACKEND": lambda: os.getenv(
"VLLM_FLASHINFER_MOE_BACKEND", "throughput"
),
# Control the maximum number of tokens per expert supported by the
# NVFP4 MoE CUTLASS Kernel. This value is used to create a buffer for
# the blockscale tensor of activations NVFP4 Quantization.

View File

@ -192,7 +192,8 @@ class FusedMoEParallelConfig:
@property
def use_flashinfer_cutlass_kernels(self):
return (envs.VLLM_USE_FLASHINFER_MOE_FP4
and has_flashinfer_cutlass_fused_moe())
and has_flashinfer_cutlass_fused_moe()
and envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput")
@staticmethod
def make(tp_size_: int, dp_size_: int,

View File

@ -105,7 +105,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
detect_nvfp4_moe_support)
_nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
self.allow_flashinfer_cutlass = _nvfp4.allow_flashinfer_cutlass
self.allow_flashinfer = _nvfp4.allow_flashinfer
self.use_marlin = _nvfp4.use_marlin
self.group_size = 16
self.fused_experts = None # type: ignore[assignment]
@ -212,7 +212,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
requires_grad=False)
# reorder GEMM1 weights and block scales for FlashInfer CUTLASS kernel.
if self.allow_flashinfer_cutlass:
if self.allow_flashinfer:
w, s = reorder_w1w3_to_w3w1(layer.w13_weight.data,
layer.w13_weight_scale.data,
dim=-2)
@ -266,7 +266,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
(layer.w2_input_global_scale), requires_grad=False)
def maybe_swap_experts_impl(self, moe_parallel_config):
if not self.allow_flashinfer_cutlass:
if not self.allow_flashinfer:
return
self.fused_experts = build_flashinfer_fp4_cutlass_moe_kernel(
moe_parallel_config)
@ -277,8 +277,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( # noqa: E501
select_nvfp4_gemm_impl)
return select_nvfp4_gemm_impl(self.allow_flashinfer_cutlass, moe,
logger)
return select_nvfp4_gemm_impl(self.allow_flashinfer, moe, logger)
def apply(
self,

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import Enum
from typing import Any, Callable, Optional, Union
import torch
@ -36,6 +37,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
from vllm.model_executor.parameter import (ModelWeightParameter,
PerTensorScaleParameter)
from vllm.scalar_type import scalar_types
from vllm.utils import next_power_of_2
from vllm.utils.flashinfer import has_flashinfer_moe
logger = init_logger(__name__)
@ -44,6 +46,11 @@ QUANT_ALGOS = ["FP8", "NVFP4"]
KV_CACHE_QUANT_ALGOS = ["FP8"]
class FlashinferMoeBackend(Enum):
TENSORRT_LLM = "TensorRT-LLM"
CUTLASS = "CUTLASS"
class ModelOptFp8Config(QuantizationConfig):
"""Config class for ModelOpt FP8."""
@ -185,7 +192,7 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
Args: quant_config: The ModelOpt quantization config.
"""
def __init__(self, quant_config: ModelOptFp8Config):
def __init__(self, quant_config: ModelOptFp8Config) -> None:
self.quant_config = quant_config
self.fp8_linear = Fp8LinearOp(
act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR)
@ -265,7 +272,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
quant_config: The ModelOpt quantization config.
"""
def __init__(self, quant_config: ModelOptFp8Config):
def __init__(self, quant_config: ModelOptFp8Config) -> None:
self.quant_config = quant_config
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_fp8_supported)
@ -670,7 +677,8 @@ class ModelOptNvFp4Config(QuantizationConfig):
return cls(is_checkpoint_nvfp4_serialized, kv_cache_quant_algo,
exclude_modules, group_size)
def is_layer_excluded(self, prefix: str, exclude_modules: list):
def is_layer_excluded(self, prefix: str,
exclude_modules: list[str]) -> bool:
import regex as re
for pattern in exclude_modules:
regex_str = pattern.replace('.', r'\.').replace('*', r'.*')
@ -714,7 +722,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
Args: quant_config: The ModelOpt quantization config.
"""
def __init__(self, quant_config: ModelOptNvFp4Config):
def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
self.quant_config = quant_config
self.cutlass_nvfp4_supported = cutlass_fp4_supported()
self.use_marlin = False
@ -859,6 +867,16 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
return out.view(*output_shape)
def _get_tile_tokens_dim(num_tokens: int, top_k: int, num_experts: int) -> int:
# Guess tokens per expert assuming perfect expert distribution first.
num_tokens_per_expert = (num_tokens * top_k) // num_experts
# And pad the number to the next power of 2.
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
return tile_tokens_dim
class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
"""
MoE Method for FP4 Quantization.
@ -866,22 +884,40 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
quant_config: NVFP4 Quant Config
"""
def __init__(self, quant_config: ModelOptNvFp4Config):
def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
self.quant_config = quant_config
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501
detect_nvfp4_moe_support)
_nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
self.allow_flashinfer_cutlass = _nvfp4.allow_flashinfer_cutlass
self.allow_flashinfer = _nvfp4.allow_flashinfer
self.use_marlin = _nvfp4.use_marlin
self.flashinfer_moe_backend = None
self.fused_experts = None # type: ignore
if self.allow_flashinfer:
flashinfer_moe_backend = envs.VLLM_FLASHINFER_MOE_BACKEND
if flashinfer_moe_backend == "throughput":
self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
logger.info_once("Using FlashInfer CUTLASS kernels for "
"ModelOptNvFp4FusedMoE.")
elif flashinfer_moe_backend == "latency":
self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
logger.info_once("Using FlashInfer TensorRT-LLM kernels for "
"ModelOptNvFp4FusedMoE.")
else:
allowed_backends = ["throughput", "latency"]
raise ValueError(
f"Unknown flashinfer moe backend: {flashinfer_moe_backend}"
f" expected one of {allowed_backends}")
self.fused_experts: Optional[
mk.FusedMoEModularKernel] = None # type: ignore[assignment]
def maybe_swap_experts_impl(
self,
moe_parallel_config: FusedMoEParallelConfig,
):
if not self.allow_flashinfer_cutlass:
if not self.allow_flashinfer:
return
self.fused_experts = build_flashinfer_fp4_cutlass_moe_kernel(
moe_parallel_config)
@ -897,8 +933,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( # noqa: E501
select_nvfp4_gemm_impl)
return select_nvfp4_gemm_impl(self.allow_flashinfer_cutlass, moe,
logger)
return select_nvfp4_gemm_impl(self.allow_flashinfer, moe, logger)
def uses_weight_scale_2_pattern(self) -> bool:
"""
@ -996,14 +1031,101 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
weight_loader=weight_loader)
layer.register_parameter("w2_input_scale", w2_input_scale)
def prepare_static_weight_layouts_for_trtllm_moe(
self,
gemm1_weights: torch.Tensor,
gemm2_weights: torch.Tensor,
gemm1_scales_linear_fp4_bytes: torch.Tensor,
gemm2_scales_linear_fp4_bytes: torch.Tensor,
hidden_size: int,
intermediate_size: int,
num_experts: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Prepare quantized weights for kernel (done offline with weights)."""
from flashinfer import (reorder_rows_for_gated_act_gemm,
shuffle_matrix_a, shuffle_matrix_sf_a)
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
# Convert quantized weights to proper formats
gemm1_weights_fp4 = gemm1_weights.view(torch.float8_e4m3fn).reshape(
num_experts, 2 * intermediate_size, hidden_size // 2) # packed fp4
gemm1_scales_linear_fp4 = gemm1_scales_linear_fp4_bytes.view(
torch.float8_e4m3fn).reshape(num_experts, 2 * intermediate_size,
hidden_size //
16) # fp8 scaling factors
gemm2_weights_fp4 = gemm2_weights.view(torch.float8_e4m3fn).reshape(
num_experts, hidden_size, intermediate_size // 2) # packed fp4
gemm2_scales_linear_fp4 = gemm2_scales_linear_fp4_bytes.view(
torch.float8_e4m3fn).reshape(num_experts, hidden_size,
intermediate_size //
16) # fp8 scaling factors
# Reorder rows of W1 and scales for fused gated activation
gemm1_weights_fp4_interleaved = []
gemm1_scales_fp4_interleaved = []
for i in range(num_experts):
gemm1_weights_fp4_interleaved.append(
reorder_rows_for_gated_act_gemm(gemm1_weights_fp4[i].clone()))
gemm1_scales_fp4_interleaved.append(
reorder_rows_for_gated_act_gemm(
gemm1_scales_linear_fp4[i].clone()))
# Stack weights and scales for all experts
gemm1_weights_fp4_interleaved = torch.stack(
gemm1_weights_fp4_interleaved).reshape(num_experts,
2 * intermediate_size,
hidden_size // 2)
gemm1_scales_fp4_interleaved = torch.stack(
gemm1_scales_fp4_interleaved).reshape(num_experts,
2 * intermediate_size,
hidden_size // 16)
# Shuffle weights and scaling factors for transposed mma output
gemm1_weights_fp4_shuffled = []
gemm1_scales_fp4_shuffled = []
gemm2_weights_fp4_shuffled = []
gemm2_scales_fp4_shuffled = []
for i in range(num_experts):
gemm1_weights_fp4_shuffled.append(
shuffle_matrix_a(
gemm1_weights_fp4_interleaved[i].view(torch.uint8),
epilogue_tile_m))
gemm1_scales_fp4_shuffled.append(
shuffle_matrix_sf_a(
gemm1_scales_fp4_interleaved[i].view(torch.uint8),
epilogue_tile_m))
gemm2_weights_fp4_shuffled.append(
shuffle_matrix_a(gemm2_weights_fp4[i].view(torch.uint8),
epilogue_tile_m))
gemm2_scales_fp4_shuffled.append(
shuffle_matrix_sf_a(
gemm2_scales_linear_fp4[i].view(torch.uint8),
epilogue_tile_m))
# Stack weights for all experts
gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled)
gemm1_scales_fp4_shuffled = (
torch.stack(gemm1_scales_fp4_shuffled).view(
torch.float8_e4m3fn).reshape(num_experts,
2 * intermediate_size,
hidden_size // 16))
gemm2_weights_fp4_shuffled = torch.stack(gemm2_weights_fp4_shuffled)
gemm2_scales_fp4_shuffled = (
torch.stack(gemm2_scales_fp4_shuffled).view(
torch.float8_e4m3fn).reshape(num_experts, hidden_size,
intermediate_size // 16))
return (gemm1_weights_fp4_shuffled, gemm1_scales_fp4_shuffled,
gemm2_weights_fp4_shuffled, gemm2_scales_fp4_shuffled)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# GEMM 1
# The FlashInfer Cutlass fused MoE kernel expects the combined weights
# to be ordered as [w3, w1], unlike the standard [w1, w3] layout.
# GEMM 1 processing
gemm1_weight = layer.w13_weight.data
gemm1_weight_scale = layer.w13_weight_scale.data
if self.allow_flashinfer_cutlass:
if self.allow_flashinfer:
gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1(
gemm1_weight, gemm1_weight_scale, dim=-2)
@ -1011,6 +1133,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
layer.w13_weight_scale = Parameter(gemm1_weight_scale,
requires_grad=False)
# Common processing for w13_weight_scale_2
if not torch.allclose(layer.w13_weight_scale_2[:, 0],
layer.w13_weight_scale_2[:, 1]):
logger.warning_once(
@ -1021,26 +1144,18 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2,
requires_grad=False)
# Common processing for input scales and alphas
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(
torch.float32)
layer.g1_alphas = Parameter(
(w13_input_scale * w13_weight_scale_2).to(torch.float32),
requires_grad=False)
assert (layer.w13_weight_scale.shape[2] % 16 == 0), (
"Expected weight_scale.dim(1) to be divisible by 16")
assert (layer.w13_weight_scale.dtype == torch.float8_e4m3fn), (
"Weight Blockscale must be represented as FP8-E4M3")
w13_blockscale_swizzled = swizzle_blockscale(layer.w13_weight_scale)
layer.w13_blockscale_swizzled = Parameter(w13_blockscale_swizzled,
requires_grad=False)
# This is for quantization, so we need to invert it.
layer.w13_input_scale_quant = Parameter(
(1 / w13_input_scale).to(torch.float32), requires_grad=False)
# GEMM 2
# GEMM 2 processing
layer.g2_alphas = Parameter(
(layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
requires_grad=False)
@ -1049,15 +1164,63 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
layer.w2_input_scale_quant = Parameter(
(1 / layer.w2_input_scale).to(torch.float32), requires_grad=False)
assert (layer.w2_weight_scale.shape[2] % 16 == 0), (
"Expected weight_scale.dim(1) to be divisible by 16")
assert (layer.w2_weight_scale.dtype == torch.float8_e4m3fn), (
"Weight Blockscale must be represented as FP8-E4M3")
w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale)
# TensorRT-LLM specific processing
if self.allow_flashinfer and \
self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
# Prepare static weights for TRT-LLM kernel
(gemm1_weights_fp4_shuffled, gemm1_scales_fp4_shuffled,
gemm2_weights_fp4_shuffled, gemm2_scales_fp4_shuffled
) = self.prepare_static_weight_layouts_for_trtllm_moe(
layer.w13_weight,
layer.w2_weight,
layer.w13_weight_scale,
layer.w2_weight_scale,
layer.w2_weight.size(-2), # hidden_size
layer.w13_weight.size(-2) // 2, # intermediate_size
layer.w13_weight.size(0), # num_experts
)
layer.w2_blockscale_swizzled = Parameter(w2_blockscale_swizzled,
requires_grad=False)
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
layer.gemm1_weights_fp4_shuffled = Parameter(
gemm1_weights_fp4_shuffled, requires_grad=False)
layer.gemm2_weights_fp4_shuffled = Parameter(
gemm2_weights_fp4_shuffled, requires_grad=False)
layer.gemm1_scales_fp4_shuffled = Parameter(
gemm1_scales_fp4_shuffled, requires_grad=False)
layer.gemm2_scales_fp4_shuffled = Parameter(
gemm2_scales_fp4_shuffled, requires_grad=False)
# Additional parameter needed for TRT-LLM
layer.g1_scale_c = Parameter(
(layer.w2_input_scale_quant * layer.g1_alphas).to(
torch.float32),
requires_grad=False,
)
# Clean up weights that won't be used by TRT-LLM
del layer.w2_weight
del layer.w2_weight_scale
del layer.w13_weight
del layer.w13_weight_scale
else:
# Non-TRT-LLM processing (Cutlass or non-flashinfer)
assert (layer.w13_weight_scale.shape[2] % 16 == 0), (
"Expected weight_scale.dim(1) to be divisible by 16")
assert (layer.w13_weight_scale.dtype == torch.float8_e4m3fn), (
"Weight Blockscale must be represented as FP8-E4M3")
w13_blockscale_swizzled = swizzle_blockscale(
layer.w13_weight_scale)
layer.w13_blockscale_swizzled = Parameter(w13_blockscale_swizzled,
requires_grad=False)
assert (layer.w2_weight_scale.shape[2] % 16 == 0), (
"Expected weight_scale.dim(1) to be divisible by 16")
assert (layer.w2_weight_scale.dtype == torch.float8_e4m3fn), (
"Weight Blockscale must be represented as FP8-E4M3")
w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale)
layer.w2_blockscale_swizzled = Parameter(w2_blockscale_swizzled,
requires_grad=False)
layer.w2_weight = Parameter(layer.w2_weight.data,
requires_grad=False)
if self.use_marlin:
prepare_moe_fp4_layer_for_marlin(layer)
@ -1095,6 +1258,60 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
"EPLB not supported for `ModelOptNvFp4FusedMoE` yet.")
assert activation == "silu", "Only SiLU activation is supported."
if self.allow_flashinfer and \
self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
import flashinfer
from vllm.model_executor.models.llama4 import Llama4MoE
a1_gscale = layer.w13_input_scale_quant
(hidden_states_fp4,
hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
x,
a1_gscale,
is_sf_swizzled_layout=False,
)
use_llama4_routing = \
custom_routing_function is Llama4MoE.custom_routing_function
routing_method_type = flashinfer.RoutingMethodType.DeepSeekV3
if use_llama4_routing:
routing_method_type = flashinfer.RoutingMethodType.Llama4
out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
routing_logits=router_logits
if use_llama4_routing else router_logits.to(torch.float32),
routing_bias=e_score_correction_bias,
hidden_states=hidden_states_fp4,
hidden_states_scale=hidden_states_scale_linear_fp4.view(
torch.float8_e4m3fn).flatten(),
gemm1_weights=layer.gemm1_weights_fp4_shuffled.data,
gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view(
torch.float8_e4m3fn),
gemm1_bias=None,
gemm1_alpha=None,
gemm1_beta=None,
gemm1_clamp_limit=None,
gemm2_weights=layer.gemm2_weights_fp4_shuffled.data,
gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view(
torch.float8_e4m3fn),
gemm2_bias=None,
output1_scale_scalar=layer.g1_scale_c.data,
output1_scale_gate_scalar=layer.g1_alphas.data,
output2_scale_scalar=layer.g2_alphas.data,
num_experts=global_num_experts,
top_k=top_k,
n_group=num_expert_group,
topk_group=topk_group,
intermediate_size=layer.intermediate_size_per_partition,
local_expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
routed_scaling_factor=None,
tile_tokens_dim=_get_tile_tokens_dim(x.shape[0], top_k,
layer.local_num_experts),
routing_method_type=routing_method_type,
do_finalize=True,
)[0]
return out
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
@ -1149,6 +1366,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input)
else:
assert self.allow_flashinfer and \
self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
out = flashinfer_fp4_cutlass_moe_forward(
self.fused_experts,
layer,
@ -1160,4 +1379,5 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
return out

View File

@ -126,7 +126,7 @@ def flashinfer_fp4_cutlass_moe_forward(
def select_nvfp4_gemm_impl(
allow_flashinfer_cutlass: bool,
allow_flashinfer: bool,
moe, # FusedMoEConfig
logger):
"""Return a GEMM *experts* implementation for NV-FP4 fused-MoE layers"""
@ -137,8 +137,14 @@ def select_nvfp4_gemm_impl(
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
if allow_flashinfer_cutlass:
logger.debug_once("Using FlashInferExperts")
if allow_flashinfer:
flashinfer_backend = envs.VLLM_FLASHINFER_MOE_BACKEND
if flashinfer_backend != "throughput":
raise ValueError(
f"Only throughput backend is supported for FlashInferExperts, "
f"but got {flashinfer_backend}.")
logger.debug_once(
"Initializing FlashInferExperts with throughput backend.")
return FlashInferExperts(
use_nvfp4_w4a4=True,
use_dp=moe.moe_parallel_config.dp_size > 1,

View File

@ -21,7 +21,7 @@ class NvFp4Support:
"""Result container for NV-FP4 capability probing."""
cutlass_supported: bool
allow_flashinfer_cutlass: bool
allow_flashinfer: bool
use_marlin: bool
@ -54,6 +54,6 @@ def detect_nvfp4_moe_support(class_name: str = "") -> NvFp4Support:
return NvFp4Support(
cutlass_supported=cutlass_supported,
allow_flashinfer_cutlass=allow_flashinfer,
allow_flashinfer=allow_flashinfer,
use_marlin=use_marlin,
)

View File

@ -86,6 +86,8 @@ flashinfer_cutlass_fused_moe = _lazy_import_wrapper("flashinfer.fused_moe",
fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize")
nvfp4_block_scale_interleave = _lazy_import_wrapper(
"flashinfer", "nvfp4_block_scale_interleave")
trtllm_fp4_block_scale_moe = _lazy_import_wrapper(
"flashinfer", "trtllm_fp4_block_scale_moe")
# Special case for autotune since it returns a context manager
autotune = _lazy_import_wrapper(
@ -112,6 +114,7 @@ def has_flashinfer_cutlass_fused_moe() -> bool:
("flashinfer.fused_moe", "cutlass_fused_moe"),
("flashinfer", "fp4_quantize"),
("flashinfer", "nvfp4_block_scale_interleave"),
("flashinfer.fused_moe", "trtllm_fp4_block_scale_moe"),
]
for module_name, attr_name in required_functions:
@ -188,6 +191,7 @@ __all__ = [
"flashinfer_cutlass_fused_moe",
"fp4_quantize",
"nvfp4_block_scale_interleave",
"trtllm_fp4_block_scale_moe",
"autotune",
"has_flashinfer_moe",
"has_flashinfer_cutlass_fused_moe",