[Feature] Add Flashinfer MoE Support for Compressed Tensor NVFP4 (#21639)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye 2025-07-31 18:26:11 -04:00 committed by GitHub
parent 6e672daf62
commit c3e0e9337e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 287 additions and 129 deletions

View File

@ -17,9 +17,14 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase,
FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize,
FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa
FlashInferCutlassMoEPrepareAndFinalize)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP)
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
build_flashinfer_fp4_cutlass_moe_kernel,
flashinfer_fp4_cutlass_moe_forward, reorder_w1w3_to_w3w1)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_moe_marlin_supports_layer, marlin_make_workspace_new,
marlin_moe_permute_scales)
@ -28,7 +33,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
prepare_moe_fp8_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
cutlass_fp4_supported, swizzle_blockscale)
swizzle_blockscale)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
from vllm.model_executor.utils import set_weight_attrs
@ -96,8 +101,14 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
def __init__(self):
self.use_marlin = not cutlass_fp4_supported()
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501
detect_nvfp4_moe_support)
_nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
self.allow_flashinfer_cutlass = _nvfp4.allow_flashinfer_cutlass
self.use_marlin = _nvfp4.use_marlin
self.group_size = 16
self.fused_experts = None # type: ignore[assignment]
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
@ -200,6 +211,14 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
layer.w2_weight = torch.nn.Parameter(layer.w2_weight_packed.data,
requires_grad=False)
# reorder GEMM1 weights and block scales for FlashInfer CUTLASS kernel.
if self.allow_flashinfer_cutlass:
w, s = reorder_w1w3_to_w3w1(layer.w13_weight.data,
layer.w13_weight_scale.data,
dim=-2)
layer.w13_weight = torch.nn.Parameter(w, requires_grad=False)
layer.w13_weight_scale = torch.nn.Parameter(s, requires_grad=False)
if not torch.allclose(layer.w13_weight_global_scale[:, 0],
layer.w13_weight_global_scale[:, 1]):
logger.warning_once(
@ -246,6 +265,21 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
layer.w2_input_scale_quant = torch.nn.Parameter(
(layer.w2_input_global_scale), requires_grad=False)
def maybe_swap_experts_impl(self, moe_parallel_config):
if not self.allow_flashinfer_cutlass:
return
self.fused_experts = build_flashinfer_fp4_cutlass_moe_kernel(
moe_parallel_config)
def select_gemm_impl(self, prepare_finalize, moe):
"""Return the appropriate GEMM experts implementation."""
assert moe is not None and prepare_finalize is not None
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( # noqa: E501
select_nvfp4_gemm_impl)
return select_nvfp4_gemm_impl(self.allow_flashinfer_cutlass, moe,
logger)
def apply(
self,
layer: torch.nn.Module,
@ -303,10 +337,23 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
global_num_experts=global_num_experts,
expert_map=expert_map)
# FlashInfer fused experts path
if self.fused_experts is not None:
return flashinfer_fp4_cutlass_moe_forward(
self.fused_experts,
layer,
x,
topk_weights,
topk_ids,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
assert expert_map is None, ("Expert Parallelism / expert_map "
"is currently not supported for "
"CompressedTensorsW4A4MoeMethod.")
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
cutlass_moe_fp4)

View File

@ -10,11 +10,8 @@ from torch.nn.parameter import Parameter
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.distributed import get_ep_group
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
FlashInferCutlassMoEPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
@ -23,6 +20,9 @@ from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
build_flashinfer_fp4_cutlass_moe_kernel,
flashinfer_fp4_cutlass_moe_forward, reorder_w1w3_to_w3w1)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_flashinfer_per_tensor_scale_fp8, rotate_flashinfer_fp8_moe_weights,
swap_w13_to_w31)
@ -35,7 +35,6 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp, requantize_with_max_scale)
from vllm.model_executor.parameter import (ModelWeightParameter,
PerTensorScaleParameter)
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.utils.flashinfer import has_flashinfer_moe
@ -869,28 +868,12 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
def __init__(self, quant_config: ModelOptNvFp4Config):
self.quant_config = quant_config
self.cutlass_nvfp4_supported = cutlass_fp4_supported()
self.use_marlin = False
self.allow_flashinfer_cutlass = False
if envs.VLLM_USE_FLASHINFER_MOE_FP4:
if self.cutlass_nvfp4_supported and current_platform.is_cuda() \
and current_platform.is_device_capability(100):
logger.info_once(
"Using FlashInfer kernels for ModelOptNvFp4FusedMoE.")
self.allow_flashinfer_cutlass = True
else:
logger.warning_once(
"Flashinfer CUTLASS Fused MoE not supported "
"or found on the current platform.")
if not self.cutlass_nvfp4_supported:
if is_fp4_marlin_supported():
self.use_marlin = True
else:
raise ValueError("Current platform does not support NVFP4"
" quantization. Please use Blackwell and"
" above.")
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501
detect_nvfp4_moe_support)
_nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
self.allow_flashinfer_cutlass = _nvfp4.allow_flashinfer_cutlass
self.use_marlin = _nvfp4.use_marlin
self.fused_experts = None # type: ignore
@ -900,29 +883,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
):
if not self.allow_flashinfer_cutlass:
return
logger.debug_once("FlashInferExperts")
# default to TP/EP case only
experts_kwargs: dict[str, Any] = {
"use_nvfp4_w4a4": True,
"use_dp": moe_parallel_config.dp_size > 1,
"ep_rank": moe_parallel_config.ep_rank,
"ep_size": moe_parallel_config.ep_size,
"tp_rank": moe_parallel_config.tp_rank,
"tp_size": moe_parallel_config.tp_size,
}
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
FlashInferExperts)
experts = FlashInferExperts(**experts_kwargs)
self.fused_experts = mk.FusedMoEModularKernel(
FlashInferCutlassMoEPrepareAndFinalize(
quant_dtype=torch.uint8,
#meaning 2x e2m1 packed in one, kernel requirement
),
experts,
)
self.fused_experts = build_flashinfer_fp4_cutlass_moe_kernel(
moe_parallel_config)
# This method update self.fused_experts
# only prepare_finalize is not None call select_gemm_impl
@ -931,32 +893,12 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
def select_gemm_impl(self, prepare_finalize,
moe) -> mk.FusedMoEPermuteExpertsUnpermute:
assert moe is not None
assert prepare_finalize is not None
experts = None
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
if self.allow_flashinfer_cutlass:
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
FlashInferExperts)
logger.debug_once("Using FlashInferExperts")
experts = FlashInferExperts(
use_nvfp4_w4a4=True,
use_dp=moe.moe_parallel_config.dp_size > 1,
ep_rank=moe.moe_parallel_config.ep_rank,
ep_size=moe.moe_parallel_config.ep_size,
tp_rank=moe.moe_parallel_config.tp_rank,
tp_size=moe.moe_parallel_config.tp_size,
)
else:
assert moe.dp_size > 1
logger.debug_once("Using CutlassExpertsFp4")
# Currently CutlassExpertsFp4 doesn't support DP
raise ValueError("CutlassExpertsFp4 doesn't support DP. "
"Use flashinfer CUTLASS FusedMoE backend instead "
"(set VLLM_USE_FLASHINFER_MOE_FP4=1)")
assert moe is not None and prepare_finalize is not None
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( # noqa: E501
select_nvfp4_gemm_impl)
return experts
return select_nvfp4_gemm_impl(self.allow_flashinfer_cutlass, moe,
logger)
def uses_weight_scale_2_pattern(self) -> bool:
"""
@ -1062,18 +1004,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
gemm1_weight_scale = layer.w13_weight_scale.data
if self.allow_flashinfer_cutlass:
dim = -2
size = gemm1_weight.size(dim)
assert size % 2 == 0, f"Expected even size in dim {dim}, got {size}"
half = size // 2
# Reorder weight
w1, w3 = gemm1_weight.split(half, dim=dim)
gemm1_weight = torch.cat([w3, w1], dim=dim).contiguous()
# Reorder scale
s1, s3 = gemm1_weight_scale.split(half, dim=dim)
gemm1_weight_scale = torch.cat([s3, s1], dim=dim).contiguous()
gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1(
gemm1_weight, gemm1_weight_scale, dim=-2)
layer.w13_weight = Parameter(gemm1_weight, requires_grad=False)
layer.w13_weight_scale = Parameter(gemm1_weight_scale,
@ -1217,49 +1149,15 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input)
else:
# TP or DP case
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
is_valid_flashinfer_cutlass_fused_moe)
assert is_valid_flashinfer_cutlass_fused_moe(
x, layer.w13_weight, layer.w2_weight), (
"Flashinfer CUTLASS Fused MoE not applicable!")
a1_gscale = layer.w13_input_scale_quant
a2_gscale = layer.w2_input_scale_quant
extra_expert_args = {
'g1_alphas': layer.g1_alphas,
'g2_alphas': layer.g2_alphas,
'out_dtype': x.dtype,
# Avoid confusion with a1_scale and a2_scale
# where are batch size related.
'a1_gscale': a1_gscale,
'a2_gscale': a2_gscale,
}
extra_prepare_args = {
'use_dp': layer.dp_size > 1,
'local_tokens': x.shape[0],
'a1_gscale': a1_gscale,
}
extra_finalize_args = {
'use_dp': layer.dp_size > 1,
'local_tokens': x.shape[0],
}
out = self.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False, # TODO(shuw): fix later, now output is high prec
out = flashinfer_fp4_cutlass_moe_forward(
self.fused_experts,
layer,
x,
topk_weights,
topk_ids,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_blockscale_swizzled,
w2_scale=layer.w2_blockscale_swizzled,
apply_router_weight_on_input=apply_router_weight_on_input,
extra_expert_args=extra_expert_args,
extra_prepare_args=extra_prepare_args,
extra_finalize_args=extra_finalize_args,
)
return out

View File

@ -0,0 +1,154 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Utility helpers for NVFP4 + FlashInfer fused-MoE path"""
from __future__ import annotations
from typing import Optional
import torch
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts, is_valid_flashinfer_cutlass_fused_moe)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
FlashInferCutlassMoEPrepareAndFinalize)
from vllm.platforms import current_platform
logger = init_logger(__name__)
__all__ = [
"is_flashinfer_fp4_cutlass_moe_available",
"reorder_w1w3_to_w3w1",
"build_flashinfer_fp4_cutlass_moe_kernel",
"flashinfer_fp4_cutlass_moe_forward",
]
def is_flashinfer_fp4_cutlass_moe_available() -> bool:
"""Return ``True`` when FlashInfer CUTLASS NV-FP4 kernels can be used."""
return (envs.VLLM_USE_FLASHINFER_MOE_FP4 and current_platform.is_cuda()
and current_platform.is_device_capability(100))
def reorder_w1w3_to_w3w1(weight: torch.Tensor,
scale: torch.Tensor,
dim: int = -2) -> tuple[torch.Tensor, torch.Tensor]:
"""Re-order the concatenated `[w1, w3]` tensors to `[w3, w1]`"""
size = weight.size(dim)
assert size % 2 == 0, f"Expected even size in dim {dim}, got {size}"
half = size // 2
w1, w3 = weight.split(half, dim=dim)
s1, s3 = scale.split(half, dim=dim)
return (torch.cat([w3, w1],
dim=dim).contiguous(), torch.cat([s3, s1],
dim=dim).contiguous())
def build_flashinfer_fp4_cutlass_moe_kernel(
moe_parallel_config: FusedMoEParallelConfig, ) -> mk.FusedMoEModularKernel:
"""Create *and return* a FlashInfer CUTLASS fused-MoE modular kernel"""
experts = FlashInferExperts(
use_nvfp4_w4a4=True,
use_dp=moe_parallel_config.dp_size > 1,
ep_rank=moe_parallel_config.ep_rank,
ep_size=moe_parallel_config.ep_size,
tp_rank=moe_parallel_config.tp_rank,
tp_size=moe_parallel_config.tp_size,
)
logger.debug_once("FlashInferExperts (util)")
return mk.FusedMoEModularKernel(
FlashInferCutlassMoEPrepareAndFinalize(quant_dtype=torch.uint8),
experts,
)
def flashinfer_fp4_cutlass_moe_forward(
fused_experts: mk.FusedMoEModularKernel,
layer: torch.nn.Module,
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
) -> torch.Tensor:
"""Common forward wrapper for FlashInfer NV-FP4 fused-MoE"""
assert is_valid_flashinfer_cutlass_fused_moe(
x, layer.w13_weight,
layer.w2_weight), ("FlashInfer CUTLASS fused-MoE not applicable!")
a1_gscale = layer.w13_input_scale_quant
a2_gscale = layer.w2_input_scale_quant
extra_expert_args = {
"g1_alphas": layer.g1_alphas,
"g2_alphas": layer.g2_alphas,
# Avoid confusion with a1_scale and a2_scale
# where are batch size related.
"a1_gscale": a1_gscale,
"a2_gscale": a2_gscale,
"out_dtype": x.dtype,
}
extra_prepare_args = {
"use_dp": layer.dp_size > 1,
"local_tokens": x.shape[0],
"a1_gscale": a1_gscale,
}
extra_finalize_args = {
"use_dp": layer.dp_size > 1,
"local_tokens": x.shape[0],
}
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False, # TODO(shuw): fix later, now output is high prec
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_blockscale_swizzled,
w2_scale=layer.w2_blockscale_swizzled,
apply_router_weight_on_input=apply_router_weight_on_input,
extra_expert_args=extra_expert_args,
extra_prepare_args=extra_prepare_args,
extra_finalize_args=extra_finalize_args,
)
def select_nvfp4_gemm_impl(
allow_flashinfer_cutlass: bool,
moe, # FusedMoEConfig
logger):
"""Return a GEMM *experts* implementation for NV-FP4 fused-MoE layers"""
# lazy import
from vllm.distributed import get_ep_group
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
if allow_flashinfer_cutlass:
logger.debug_once("Using FlashInferExperts")
return FlashInferExperts(
use_nvfp4_w4a4=True,
use_dp=moe.moe_parallel_config.dp_size > 1,
ep_rank=moe.moe_parallel_config.ep_rank,
ep_size=moe.moe_parallel_config.ep_size,
tp_rank=moe.moe_parallel_config.tp_rank,
tp_size=moe.moe_parallel_config.tp_size,
)
# native cutlass experts currently don't support DP; TP case won't call this
raise ValueError(
"CutlassExpertsFp4 doesn't support DP. Use flashinfer CUTLASS "
"Fused MoE backend instead (set VLLM_USE_FLASHINFER_MOE_FP4=1)")

View File

@ -0,0 +1,59 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
is_flashinfer_fp4_cutlass_moe_available)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
is_fp4_marlin_supported)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
cutlass_fp4_supported)
__all__ = ["detect_nvfp4_moe_support", "NvFp4Support"]
_logger = init_logger(__name__)
@dataclass(frozen=True)
class NvFp4Support:
"""Result container for NV-FP4 capability probing."""
cutlass_supported: bool
allow_flashinfer_cutlass: bool
use_marlin: bool
def detect_nvfp4_moe_support(class_name: str = "") -> NvFp4Support:
"""Detect platform support for NV-FP4 fused-MoE path"""
cutlass_supported = cutlass_fp4_supported()
allow_flashinfer = (cutlass_supported
and is_flashinfer_fp4_cutlass_moe_available())
if allow_flashinfer:
_logger.info_once("Using FlashInfer kernels for %s.", class_name
or "NVFP4 path")
else:
if envs.VLLM_USE_FLASHINFER_MOE_FP4:
_logger.warning_once(
"FlashInfer kernels unavailable for %s on current platform.",
class_name or "NVFP4 path",
)
use_marlin = False
if not cutlass_supported:
if is_fp4_marlin_supported():
use_marlin = True
_logger.info_once("Falling back to Marlin FP4 MoE kernel.")
else:
raise ValueError(
"Current platform does not support NVFP4 quantization. "
"Please use Blackwell GPUs or enable FlashInfer.")
return NvFp4Support(
cutlass_supported=cutlass_supported,
allow_flashinfer_cutlass=allow_flashinfer,
use_marlin=use_marlin,
)