mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 23:25:34 +08:00
[Feature] Add Flashinfer MoE Support for Compressed Tensor NVFP4 (#21639)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
6e672daf62
commit
c3e0e9337e
@ -17,9 +17,14 @@ from vllm.model_executor.layers.fused_moe import (
|
||||
FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase,
|
||||
FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize,
|
||||
FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa
|
||||
FlashInferCutlassMoEPrepareAndFinalize)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
|
||||
WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP)
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
||||
build_flashinfer_fp4_cutlass_moe_kernel,
|
||||
flashinfer_fp4_cutlass_moe_forward, reorder_w1w3_to_w3w1)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
check_moe_marlin_supports_layer, marlin_make_workspace_new,
|
||||
marlin_moe_permute_scales)
|
||||
@ -28,7 +33,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
prepare_moe_fp8_layer_for_marlin)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
cutlass_fp4_supported, swizzle_blockscale)
|
||||
swizzle_blockscale)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
@ -96,8 +101,14 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
||||
class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
def __init__(self):
|
||||
self.use_marlin = not cutlass_fp4_supported()
|
||||
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501
|
||||
detect_nvfp4_moe_support)
|
||||
_nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
|
||||
self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
|
||||
self.allow_flashinfer_cutlass = _nvfp4.allow_flashinfer_cutlass
|
||||
self.use_marlin = _nvfp4.use_marlin
|
||||
self.group_size = 16
|
||||
self.fused_experts = None # type: ignore[assignment]
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
@ -200,6 +211,14 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||
layer.w2_weight = torch.nn.Parameter(layer.w2_weight_packed.data,
|
||||
requires_grad=False)
|
||||
|
||||
# reorder GEMM1 weights and block scales for FlashInfer CUTLASS kernel.
|
||||
if self.allow_flashinfer_cutlass:
|
||||
w, s = reorder_w1w3_to_w3w1(layer.w13_weight.data,
|
||||
layer.w13_weight_scale.data,
|
||||
dim=-2)
|
||||
layer.w13_weight = torch.nn.Parameter(w, requires_grad=False)
|
||||
layer.w13_weight_scale = torch.nn.Parameter(s, requires_grad=False)
|
||||
|
||||
if not torch.allclose(layer.w13_weight_global_scale[:, 0],
|
||||
layer.w13_weight_global_scale[:, 1]):
|
||||
logger.warning_once(
|
||||
@ -246,6 +265,21 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||
layer.w2_input_scale_quant = torch.nn.Parameter(
|
||||
(layer.w2_input_global_scale), requires_grad=False)
|
||||
|
||||
def maybe_swap_experts_impl(self, moe_parallel_config):
|
||||
if not self.allow_flashinfer_cutlass:
|
||||
return
|
||||
self.fused_experts = build_flashinfer_fp4_cutlass_moe_kernel(
|
||||
moe_parallel_config)
|
||||
|
||||
def select_gemm_impl(self, prepare_finalize, moe):
|
||||
"""Return the appropriate GEMM experts implementation."""
|
||||
assert moe is not None and prepare_finalize is not None
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( # noqa: E501
|
||||
select_nvfp4_gemm_impl)
|
||||
|
||||
return select_nvfp4_gemm_impl(self.allow_flashinfer_cutlass, moe,
|
||||
logger)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@ -303,10 +337,23 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map)
|
||||
|
||||
# FlashInfer fused experts path
|
||||
if self.fused_experts is not None:
|
||||
return flashinfer_fp4_cutlass_moe_forward(
|
||||
self.fused_experts,
|
||||
layer,
|
||||
x,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
assert expert_map is None, ("Expert Parallelism / expert_map "
|
||||
"is currently not supported for "
|
||||
"CompressedTensorsW4A4MoeMethod.")
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||
cutlass_moe_fp4)
|
||||
|
||||
|
||||
@ -10,11 +10,8 @@ from torch.nn.parameter import Parameter
|
||||
import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||
from vllm.distributed import get_ep_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
|
||||
FlashInferCutlassMoEPrepareAndFinalize)
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
@ -23,6 +20,9 @@ from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
||||
build_flashinfer_fp4_cutlass_moe_kernel,
|
||||
flashinfer_fp4_cutlass_moe_forward, reorder_w1w3_to_w3w1)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
apply_flashinfer_per_tensor_scale_fp8, rotate_flashinfer_fp8_moe_weights,
|
||||
swap_w13_to_w31)
|
||||
@ -35,7 +35,6 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp, requantize_with_max_scale)
|
||||
from vllm.model_executor.parameter import (ModelWeightParameter,
|
||||
PerTensorScaleParameter)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.utils.flashinfer import has_flashinfer_moe
|
||||
|
||||
@ -869,28 +868,12 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
|
||||
def __init__(self, quant_config: ModelOptNvFp4Config):
|
||||
self.quant_config = quant_config
|
||||
self.cutlass_nvfp4_supported = cutlass_fp4_supported()
|
||||
self.use_marlin = False
|
||||
self.allow_flashinfer_cutlass = False
|
||||
|
||||
if envs.VLLM_USE_FLASHINFER_MOE_FP4:
|
||||
if self.cutlass_nvfp4_supported and current_platform.is_cuda() \
|
||||
and current_platform.is_device_capability(100):
|
||||
logger.info_once(
|
||||
"Using FlashInfer kernels for ModelOptNvFp4FusedMoE.")
|
||||
self.allow_flashinfer_cutlass = True
|
||||
else:
|
||||
logger.warning_once(
|
||||
"Flashinfer CUTLASS Fused MoE not supported "
|
||||
"or found on the current platform.")
|
||||
|
||||
if not self.cutlass_nvfp4_supported:
|
||||
if is_fp4_marlin_supported():
|
||||
self.use_marlin = True
|
||||
else:
|
||||
raise ValueError("Current platform does not support NVFP4"
|
||||
" quantization. Please use Blackwell and"
|
||||
" above.")
|
||||
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501
|
||||
detect_nvfp4_moe_support)
|
||||
_nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
|
||||
self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
|
||||
self.allow_flashinfer_cutlass = _nvfp4.allow_flashinfer_cutlass
|
||||
self.use_marlin = _nvfp4.use_marlin
|
||||
|
||||
self.fused_experts = None # type: ignore
|
||||
|
||||
@ -900,29 +883,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
):
|
||||
if not self.allow_flashinfer_cutlass:
|
||||
return
|
||||
|
||||
logger.debug_once("FlashInferExperts")
|
||||
# default to TP/EP case only
|
||||
|
||||
experts_kwargs: dict[str, Any] = {
|
||||
"use_nvfp4_w4a4": True,
|
||||
"use_dp": moe_parallel_config.dp_size > 1,
|
||||
"ep_rank": moe_parallel_config.ep_rank,
|
||||
"ep_size": moe_parallel_config.ep_size,
|
||||
"tp_rank": moe_parallel_config.tp_rank,
|
||||
"tp_size": moe_parallel_config.tp_size,
|
||||
}
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
|
||||
FlashInferExperts)
|
||||
experts = FlashInferExperts(**experts_kwargs)
|
||||
self.fused_experts = mk.FusedMoEModularKernel(
|
||||
FlashInferCutlassMoEPrepareAndFinalize(
|
||||
quant_dtype=torch.uint8,
|
||||
#meaning 2x e2m1 packed in one, kernel requirement
|
||||
),
|
||||
experts,
|
||||
)
|
||||
self.fused_experts = build_flashinfer_fp4_cutlass_moe_kernel(
|
||||
moe_parallel_config)
|
||||
|
||||
# This method update self.fused_experts
|
||||
# only prepare_finalize is not None call select_gemm_impl
|
||||
@ -931,32 +893,12 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
def select_gemm_impl(self, prepare_finalize,
|
||||
moe) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
|
||||
assert moe is not None
|
||||
assert prepare_finalize is not None
|
||||
experts = None
|
||||
all2all_manager = get_ep_group().device_communicator.all2all_manager
|
||||
assert all2all_manager is not None
|
||||
if self.allow_flashinfer_cutlass:
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
|
||||
FlashInferExperts)
|
||||
logger.debug_once("Using FlashInferExperts")
|
||||
experts = FlashInferExperts(
|
||||
use_nvfp4_w4a4=True,
|
||||
use_dp=moe.moe_parallel_config.dp_size > 1,
|
||||
ep_rank=moe.moe_parallel_config.ep_rank,
|
||||
ep_size=moe.moe_parallel_config.ep_size,
|
||||
tp_rank=moe.moe_parallel_config.tp_rank,
|
||||
tp_size=moe.moe_parallel_config.tp_size,
|
||||
)
|
||||
else:
|
||||
assert moe.dp_size > 1
|
||||
logger.debug_once("Using CutlassExpertsFp4")
|
||||
# Currently CutlassExpertsFp4 doesn't support DP
|
||||
raise ValueError("CutlassExpertsFp4 doesn't support DP. "
|
||||
"Use flashinfer CUTLASS FusedMoE backend instead "
|
||||
"(set VLLM_USE_FLASHINFER_MOE_FP4=1)")
|
||||
assert moe is not None and prepare_finalize is not None
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( # noqa: E501
|
||||
select_nvfp4_gemm_impl)
|
||||
|
||||
return experts
|
||||
return select_nvfp4_gemm_impl(self.allow_flashinfer_cutlass, moe,
|
||||
logger)
|
||||
|
||||
def uses_weight_scale_2_pattern(self) -> bool:
|
||||
"""
|
||||
@ -1062,18 +1004,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
gemm1_weight_scale = layer.w13_weight_scale.data
|
||||
|
||||
if self.allow_flashinfer_cutlass:
|
||||
dim = -2
|
||||
size = gemm1_weight.size(dim)
|
||||
assert size % 2 == 0, f"Expected even size in dim {dim}, got {size}"
|
||||
half = size // 2
|
||||
|
||||
# Reorder weight
|
||||
w1, w3 = gemm1_weight.split(half, dim=dim)
|
||||
gemm1_weight = torch.cat([w3, w1], dim=dim).contiguous()
|
||||
|
||||
# Reorder scale
|
||||
s1, s3 = gemm1_weight_scale.split(half, dim=dim)
|
||||
gemm1_weight_scale = torch.cat([s3, s1], dim=dim).contiguous()
|
||||
gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1(
|
||||
gemm1_weight, gemm1_weight_scale, dim=-2)
|
||||
|
||||
layer.w13_weight = Parameter(gemm1_weight, requires_grad=False)
|
||||
layer.w13_weight_scale = Parameter(gemm1_weight_scale,
|
||||
@ -1217,49 +1149,15 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
else:
|
||||
# TP or DP case
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
|
||||
is_valid_flashinfer_cutlass_fused_moe)
|
||||
assert is_valid_flashinfer_cutlass_fused_moe(
|
||||
x, layer.w13_weight, layer.w2_weight), (
|
||||
"Flashinfer CUTLASS Fused MoE not applicable!")
|
||||
|
||||
a1_gscale = layer.w13_input_scale_quant
|
||||
a2_gscale = layer.w2_input_scale_quant
|
||||
extra_expert_args = {
|
||||
'g1_alphas': layer.g1_alphas,
|
||||
'g2_alphas': layer.g2_alphas,
|
||||
'out_dtype': x.dtype,
|
||||
# Avoid confusion with a1_scale and a2_scale
|
||||
# where are batch size related.
|
||||
'a1_gscale': a1_gscale,
|
||||
'a2_gscale': a2_gscale,
|
||||
}
|
||||
extra_prepare_args = {
|
||||
'use_dp': layer.dp_size > 1,
|
||||
'local_tokens': x.shape[0],
|
||||
'a1_gscale': a1_gscale,
|
||||
}
|
||||
extra_finalize_args = {
|
||||
'use_dp': layer.dp_size > 1,
|
||||
'local_tokens': x.shape[0],
|
||||
}
|
||||
|
||||
out = self.fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False, # TODO(shuw): fix later, now output is high prec
|
||||
out = flashinfer_fp4_cutlass_moe_forward(
|
||||
self.fused_experts,
|
||||
layer,
|
||||
x,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=layer.w13_blockscale_swizzled,
|
||||
w2_scale=layer.w2_blockscale_swizzled,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
extra_expert_args=extra_expert_args,
|
||||
extra_prepare_args=extra_prepare_args,
|
||||
extra_finalize_args=extra_finalize_args,
|
||||
)
|
||||
return out
|
||||
|
||||
@ -0,0 +1,154 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Utility helpers for NVFP4 + FlashInfer fused-MoE path"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
FlashInferExperts, is_valid_flashinfer_cutlass_fused_moe)
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
|
||||
FlashInferCutlassMoEPrepareAndFinalize)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
__all__ = [
|
||||
"is_flashinfer_fp4_cutlass_moe_available",
|
||||
"reorder_w1w3_to_w3w1",
|
||||
"build_flashinfer_fp4_cutlass_moe_kernel",
|
||||
"flashinfer_fp4_cutlass_moe_forward",
|
||||
]
|
||||
|
||||
|
||||
def is_flashinfer_fp4_cutlass_moe_available() -> bool:
|
||||
"""Return ``True`` when FlashInfer CUTLASS NV-FP4 kernels can be used."""
|
||||
return (envs.VLLM_USE_FLASHINFER_MOE_FP4 and current_platform.is_cuda()
|
||||
and current_platform.is_device_capability(100))
|
||||
|
||||
|
||||
def reorder_w1w3_to_w3w1(weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
dim: int = -2) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Re-order the concatenated `[w1, w3]` tensors to `[w3, w1]`"""
|
||||
size = weight.size(dim)
|
||||
assert size % 2 == 0, f"Expected even size in dim {dim}, got {size}"
|
||||
half = size // 2
|
||||
|
||||
w1, w3 = weight.split(half, dim=dim)
|
||||
s1, s3 = scale.split(half, dim=dim)
|
||||
|
||||
return (torch.cat([w3, w1],
|
||||
dim=dim).contiguous(), torch.cat([s3, s1],
|
||||
dim=dim).contiguous())
|
||||
|
||||
|
||||
def build_flashinfer_fp4_cutlass_moe_kernel(
|
||||
moe_parallel_config: FusedMoEParallelConfig, ) -> mk.FusedMoEModularKernel:
|
||||
"""Create *and return* a FlashInfer CUTLASS fused-MoE modular kernel"""
|
||||
experts = FlashInferExperts(
|
||||
use_nvfp4_w4a4=True,
|
||||
use_dp=moe_parallel_config.dp_size > 1,
|
||||
ep_rank=moe_parallel_config.ep_rank,
|
||||
ep_size=moe_parallel_config.ep_size,
|
||||
tp_rank=moe_parallel_config.tp_rank,
|
||||
tp_size=moe_parallel_config.tp_size,
|
||||
)
|
||||
logger.debug_once("FlashInferExperts (util)")
|
||||
return mk.FusedMoEModularKernel(
|
||||
FlashInferCutlassMoEPrepareAndFinalize(quant_dtype=torch.uint8),
|
||||
experts,
|
||||
)
|
||||
|
||||
|
||||
def flashinfer_fp4_cutlass_moe_forward(
|
||||
fused_experts: mk.FusedMoEModularKernel,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> torch.Tensor:
|
||||
"""Common forward wrapper for FlashInfer NV-FP4 fused-MoE"""
|
||||
|
||||
assert is_valid_flashinfer_cutlass_fused_moe(
|
||||
x, layer.w13_weight,
|
||||
layer.w2_weight), ("FlashInfer CUTLASS fused-MoE not applicable!")
|
||||
|
||||
a1_gscale = layer.w13_input_scale_quant
|
||||
a2_gscale = layer.w2_input_scale_quant
|
||||
|
||||
extra_expert_args = {
|
||||
"g1_alphas": layer.g1_alphas,
|
||||
"g2_alphas": layer.g2_alphas,
|
||||
# Avoid confusion with a1_scale and a2_scale
|
||||
# where are batch size related.
|
||||
"a1_gscale": a1_gscale,
|
||||
"a2_gscale": a2_gscale,
|
||||
"out_dtype": x.dtype,
|
||||
}
|
||||
extra_prepare_args = {
|
||||
"use_dp": layer.dp_size > 1,
|
||||
"local_tokens": x.shape[0],
|
||||
"a1_gscale": a1_gscale,
|
||||
}
|
||||
extra_finalize_args = {
|
||||
"use_dp": layer.dp_size > 1,
|
||||
"local_tokens": x.shape[0],
|
||||
}
|
||||
|
||||
return fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False, # TODO(shuw): fix later, now output is high prec
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=layer.w13_blockscale_swizzled,
|
||||
w2_scale=layer.w2_blockscale_swizzled,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
extra_expert_args=extra_expert_args,
|
||||
extra_prepare_args=extra_prepare_args,
|
||||
extra_finalize_args=extra_finalize_args,
|
||||
)
|
||||
|
||||
|
||||
def select_nvfp4_gemm_impl(
|
||||
allow_flashinfer_cutlass: bool,
|
||||
moe, # FusedMoEConfig
|
||||
logger):
|
||||
"""Return a GEMM *experts* implementation for NV-FP4 fused-MoE layers"""
|
||||
|
||||
# lazy import
|
||||
from vllm.distributed import get_ep_group
|
||||
|
||||
all2all_manager = get_ep_group().device_communicator.all2all_manager
|
||||
assert all2all_manager is not None
|
||||
|
||||
if allow_flashinfer_cutlass:
|
||||
logger.debug_once("Using FlashInferExperts")
|
||||
return FlashInferExperts(
|
||||
use_nvfp4_w4a4=True,
|
||||
use_dp=moe.moe_parallel_config.dp_size > 1,
|
||||
ep_rank=moe.moe_parallel_config.ep_rank,
|
||||
ep_size=moe.moe_parallel_config.ep_size,
|
||||
tp_rank=moe.moe_parallel_config.tp_rank,
|
||||
tp_size=moe.moe_parallel_config.tp_size,
|
||||
)
|
||||
|
||||
# native cutlass experts currently don't support DP; TP case won't call this
|
||||
raise ValueError(
|
||||
"CutlassExpertsFp4 doesn't support DP. Use flashinfer CUTLASS "
|
||||
"Fused MoE backend instead (set VLLM_USE_FLASHINFER_MOE_FP4=1)")
|
||||
@ -0,0 +1,59 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
||||
is_flashinfer_fp4_cutlass_moe_available)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
is_fp4_marlin_supported)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
cutlass_fp4_supported)
|
||||
|
||||
__all__ = ["detect_nvfp4_moe_support", "NvFp4Support"]
|
||||
|
||||
_logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NvFp4Support:
|
||||
"""Result container for NV-FP4 capability probing."""
|
||||
|
||||
cutlass_supported: bool
|
||||
allow_flashinfer_cutlass: bool
|
||||
use_marlin: bool
|
||||
|
||||
|
||||
def detect_nvfp4_moe_support(class_name: str = "") -> NvFp4Support:
|
||||
"""Detect platform support for NV-FP4 fused-MoE path"""
|
||||
cutlass_supported = cutlass_fp4_supported()
|
||||
|
||||
allow_flashinfer = (cutlass_supported
|
||||
and is_flashinfer_fp4_cutlass_moe_available())
|
||||
|
||||
if allow_flashinfer:
|
||||
_logger.info_once("Using FlashInfer kernels for %s.", class_name
|
||||
or "NVFP4 path")
|
||||
else:
|
||||
if envs.VLLM_USE_FLASHINFER_MOE_FP4:
|
||||
_logger.warning_once(
|
||||
"FlashInfer kernels unavailable for %s on current platform.",
|
||||
class_name or "NVFP4 path",
|
||||
)
|
||||
|
||||
use_marlin = False
|
||||
if not cutlass_supported:
|
||||
if is_fp4_marlin_supported():
|
||||
use_marlin = True
|
||||
_logger.info_once("Falling back to Marlin FP4 MoE kernel.")
|
||||
else:
|
||||
raise ValueError(
|
||||
"Current platform does not support NVFP4 quantization. "
|
||||
"Please use Blackwell GPUs or enable FlashInfer.")
|
||||
|
||||
return NvFp4Support(
|
||||
cutlass_supported=cutlass_supported,
|
||||
allow_flashinfer_cutlass=allow_flashinfer,
|
||||
use_marlin=use_marlin,
|
||||
)
|
||||
Loading…
x
Reference in New Issue
Block a user