diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 17b41e8a1c23..09d8890888fa 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -17,9 +17,14 @@ from vllm.model_executor.layers.fused_moe import ( FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase, FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize, FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa + FlashInferCutlassMoEPrepareAndFinalize) from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP) from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( + build_flashinfer_fp4_cutlass_moe_kernel, + flashinfer_fp4_cutlass_moe_forward, reorder_w1w3_to_w3w1) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( check_moe_marlin_supports_layer, marlin_make_workspace_new, marlin_moe_permute_scales) @@ -28,7 +33,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( prepare_moe_fp8_layer_for_marlin) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - cutlass_fp4_supported, swizzle_blockscale) + swizzle_blockscale) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize) from vllm.model_executor.utils import set_weight_attrs @@ -96,8 +101,14 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): def __init__(self): - self.use_marlin = not cutlass_fp4_supported() + from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501 + detect_nvfp4_moe_support) + _nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__) + self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported + self.allow_flashinfer_cutlass = _nvfp4.allow_flashinfer_cutlass + self.use_marlin = _nvfp4.use_marlin self.group_size = 16 + self.fused_experts = None # type: ignore[assignment] def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, @@ -200,6 +211,14 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): layer.w2_weight = torch.nn.Parameter(layer.w2_weight_packed.data, requires_grad=False) + # reorder GEMM1 weights and block scales for FlashInfer CUTLASS kernel. + if self.allow_flashinfer_cutlass: + w, s = reorder_w1w3_to_w3w1(layer.w13_weight.data, + layer.w13_weight_scale.data, + dim=-2) + layer.w13_weight = torch.nn.Parameter(w, requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter(s, requires_grad=False) + if not torch.allclose(layer.w13_weight_global_scale[:, 0], layer.w13_weight_global_scale[:, 1]): logger.warning_once( @@ -246,6 +265,21 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): layer.w2_input_scale_quant = torch.nn.Parameter( (layer.w2_input_global_scale), requires_grad=False) + def maybe_swap_experts_impl(self, moe_parallel_config): + if not self.allow_flashinfer_cutlass: + return + self.fused_experts = build_flashinfer_fp4_cutlass_moe_kernel( + moe_parallel_config) + + def select_gemm_impl(self, prepare_finalize, moe): + """Return the appropriate GEMM experts implementation.""" + assert moe is not None and prepare_finalize is not None + from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( # noqa: E501 + select_nvfp4_gemm_impl) + + return select_nvfp4_gemm_impl(self.allow_flashinfer_cutlass, moe, + logger) + def apply( self, layer: torch.nn.Module, @@ -303,10 +337,23 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): global_num_experts=global_num_experts, expert_map=expert_map) + # FlashInfer fused experts path + if self.fused_experts is not None: + return flashinfer_fp4_cutlass_moe_forward( + self.fused_experts, + layer, + x, + topk_weights, + topk_ids, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + assert expert_map is None, ("Expert Parallelism / expert_map " "is currently not supported for " "CompressedTensorsW4A4MoeMethod.") - from vllm.model_executor.layers.fused_moe.cutlass_moe import ( cutlass_moe_fp4) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index b8ffcf90c022..0334a2824512 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -10,11 +10,8 @@ from torch.nn.parameter import Parameter import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant -from vllm.distributed import get_ep_group from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig -from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 - FlashInferCutlassMoEPrepareAndFinalize) from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, @@ -23,6 +20,9 @@ from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( + build_flashinfer_fp4_cutlass_moe_kernel, + flashinfer_fp4_cutlass_moe_forward, reorder_w1w3_to_w3w1) from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( apply_flashinfer_per_tensor_scale_fp8, rotate_flashinfer_fp8_moe_weights, swap_w13_to_w31) @@ -35,7 +35,6 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( Fp8LinearOp, requantize_with_max_scale) from vllm.model_executor.parameter import (ModelWeightParameter, PerTensorScaleParameter) -from vllm.platforms import current_platform from vllm.scalar_type import scalar_types from vllm.utils.flashinfer import has_flashinfer_moe @@ -869,28 +868,12 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): def __init__(self, quant_config: ModelOptNvFp4Config): self.quant_config = quant_config - self.cutlass_nvfp4_supported = cutlass_fp4_supported() - self.use_marlin = False - self.allow_flashinfer_cutlass = False - - if envs.VLLM_USE_FLASHINFER_MOE_FP4: - if self.cutlass_nvfp4_supported and current_platform.is_cuda() \ - and current_platform.is_device_capability(100): - logger.info_once( - "Using FlashInfer kernels for ModelOptNvFp4FusedMoE.") - self.allow_flashinfer_cutlass = True - else: - logger.warning_once( - "Flashinfer CUTLASS Fused MoE not supported " - "or found on the current platform.") - - if not self.cutlass_nvfp4_supported: - if is_fp4_marlin_supported(): - self.use_marlin = True - else: - raise ValueError("Current platform does not support NVFP4" - " quantization. Please use Blackwell and" - " above.") + from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501 + detect_nvfp4_moe_support) + _nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__) + self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported + self.allow_flashinfer_cutlass = _nvfp4.allow_flashinfer_cutlass + self.use_marlin = _nvfp4.use_marlin self.fused_experts = None # type: ignore @@ -900,29 +883,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ): if not self.allow_flashinfer_cutlass: return - - logger.debug_once("FlashInferExperts") - # default to TP/EP case only - - experts_kwargs: dict[str, Any] = { - "use_nvfp4_w4a4": True, - "use_dp": moe_parallel_config.dp_size > 1, - "ep_rank": moe_parallel_config.ep_rank, - "ep_size": moe_parallel_config.ep_size, - "tp_rank": moe_parallel_config.tp_rank, - "tp_size": moe_parallel_config.tp_size, - } - - from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501 - FlashInferExperts) - experts = FlashInferExperts(**experts_kwargs) - self.fused_experts = mk.FusedMoEModularKernel( - FlashInferCutlassMoEPrepareAndFinalize( - quant_dtype=torch.uint8, - #meaning 2x e2m1 packed in one, kernel requirement - ), - experts, - ) + self.fused_experts = build_flashinfer_fp4_cutlass_moe_kernel( + moe_parallel_config) # This method update self.fused_experts # only prepare_finalize is not None call select_gemm_impl @@ -931,32 +893,12 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): def select_gemm_impl(self, prepare_finalize, moe) -> mk.FusedMoEPermuteExpertsUnpermute: - assert moe is not None - assert prepare_finalize is not None - experts = None - all2all_manager = get_ep_group().device_communicator.all2all_manager - assert all2all_manager is not None - if self.allow_flashinfer_cutlass: - from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501 - FlashInferExperts) - logger.debug_once("Using FlashInferExperts") - experts = FlashInferExperts( - use_nvfp4_w4a4=True, - use_dp=moe.moe_parallel_config.dp_size > 1, - ep_rank=moe.moe_parallel_config.ep_rank, - ep_size=moe.moe_parallel_config.ep_size, - tp_rank=moe.moe_parallel_config.tp_rank, - tp_size=moe.moe_parallel_config.tp_size, - ) - else: - assert moe.dp_size > 1 - logger.debug_once("Using CutlassExpertsFp4") - # Currently CutlassExpertsFp4 doesn't support DP - raise ValueError("CutlassExpertsFp4 doesn't support DP. " - "Use flashinfer CUTLASS FusedMoE backend instead " - "(set VLLM_USE_FLASHINFER_MOE_FP4=1)") + assert moe is not None and prepare_finalize is not None + from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( # noqa: E501 + select_nvfp4_gemm_impl) - return experts + return select_nvfp4_gemm_impl(self.allow_flashinfer_cutlass, moe, + logger) def uses_weight_scale_2_pattern(self) -> bool: """ @@ -1062,18 +1004,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): gemm1_weight_scale = layer.w13_weight_scale.data if self.allow_flashinfer_cutlass: - dim = -2 - size = gemm1_weight.size(dim) - assert size % 2 == 0, f"Expected even size in dim {dim}, got {size}" - half = size // 2 - - # Reorder weight - w1, w3 = gemm1_weight.split(half, dim=dim) - gemm1_weight = torch.cat([w3, w1], dim=dim).contiguous() - - # Reorder scale - s1, s3 = gemm1_weight_scale.split(half, dim=dim) - gemm1_weight_scale = torch.cat([s3, s1], dim=dim).contiguous() + gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1( + gemm1_weight, gemm1_weight_scale, dim=-2) layer.w13_weight = Parameter(gemm1_weight, requires_grad=False) layer.w13_weight_scale = Parameter(gemm1_weight_scale, @@ -1217,49 +1149,15 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): expert_map=expert_map, apply_router_weight_on_input=apply_router_weight_on_input) else: - # TP or DP case - from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501 - is_valid_flashinfer_cutlass_fused_moe) - assert is_valid_flashinfer_cutlass_fused_moe( - x, layer.w13_weight, layer.w2_weight), ( - "Flashinfer CUTLASS Fused MoE not applicable!") - - a1_gscale = layer.w13_input_scale_quant - a2_gscale = layer.w2_input_scale_quant - extra_expert_args = { - 'g1_alphas': layer.g1_alphas, - 'g2_alphas': layer.g2_alphas, - 'out_dtype': x.dtype, - # Avoid confusion with a1_scale and a2_scale - # where are batch size related. - 'a1_gscale': a1_gscale, - 'a2_gscale': a2_gscale, - } - extra_prepare_args = { - 'use_dp': layer.dp_size > 1, - 'local_tokens': x.shape[0], - 'a1_gscale': a1_gscale, - } - extra_finalize_args = { - 'use_dp': layer.dp_size > 1, - 'local_tokens': x.shape[0], - } - - out = self.fused_experts( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=False, # TODO(shuw): fix later, now output is high prec + out = flashinfer_fp4_cutlass_moe_forward( + self.fused_experts, + layer, + x, + topk_weights, + topk_ids, activation=activation, global_num_experts=global_num_experts, expert_map=expert_map, - w1_scale=layer.w13_blockscale_swizzled, - w2_scale=layer.w2_blockscale_swizzled, apply_router_weight_on_input=apply_router_weight_on_input, - extra_expert_args=extra_expert_args, - extra_prepare_args=extra_prepare_args, - extra_finalize_args=extra_finalize_args, ) return out diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py new file mode 100644 index 000000000000..4c617e226041 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py @@ -0,0 +1,154 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Utility helpers for NVFP4 + FlashInfer fused-MoE path""" +from __future__ import annotations + +from typing import Optional + +import torch + +import vllm.envs as envs +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig +from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( + FlashInferExperts, is_valid_flashinfer_cutlass_fused_moe) +from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 + FlashInferCutlassMoEPrepareAndFinalize) +from vllm.platforms import current_platform + +logger = init_logger(__name__) + +__all__ = [ + "is_flashinfer_fp4_cutlass_moe_available", + "reorder_w1w3_to_w3w1", + "build_flashinfer_fp4_cutlass_moe_kernel", + "flashinfer_fp4_cutlass_moe_forward", +] + + +def is_flashinfer_fp4_cutlass_moe_available() -> bool: + """Return ``True`` when FlashInfer CUTLASS NV-FP4 kernels can be used.""" + return (envs.VLLM_USE_FLASHINFER_MOE_FP4 and current_platform.is_cuda() + and current_platform.is_device_capability(100)) + + +def reorder_w1w3_to_w3w1(weight: torch.Tensor, + scale: torch.Tensor, + dim: int = -2) -> tuple[torch.Tensor, torch.Tensor]: + """Re-order the concatenated `[w1, w3]` tensors to `[w3, w1]`""" + size = weight.size(dim) + assert size % 2 == 0, f"Expected even size in dim {dim}, got {size}" + half = size // 2 + + w1, w3 = weight.split(half, dim=dim) + s1, s3 = scale.split(half, dim=dim) + + return (torch.cat([w3, w1], + dim=dim).contiguous(), torch.cat([s3, s1], + dim=dim).contiguous()) + + +def build_flashinfer_fp4_cutlass_moe_kernel( + moe_parallel_config: FusedMoEParallelConfig, ) -> mk.FusedMoEModularKernel: + """Create *and return* a FlashInfer CUTLASS fused-MoE modular kernel""" + experts = FlashInferExperts( + use_nvfp4_w4a4=True, + use_dp=moe_parallel_config.dp_size > 1, + ep_rank=moe_parallel_config.ep_rank, + ep_size=moe_parallel_config.ep_size, + tp_rank=moe_parallel_config.tp_rank, + tp_size=moe_parallel_config.tp_size, + ) + logger.debug_once("FlashInferExperts (util)") + return mk.FusedMoEModularKernel( + FlashInferCutlassMoEPrepareAndFinalize(quant_dtype=torch.uint8), + experts, + ) + + +def flashinfer_fp4_cutlass_moe_forward( + fused_experts: mk.FusedMoEModularKernel, + layer: torch.nn.Module, + x: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, +) -> torch.Tensor: + """Common forward wrapper for FlashInfer NV-FP4 fused-MoE""" + + assert is_valid_flashinfer_cutlass_fused_moe( + x, layer.w13_weight, + layer.w2_weight), ("FlashInfer CUTLASS fused-MoE not applicable!") + + a1_gscale = layer.w13_input_scale_quant + a2_gscale = layer.w2_input_scale_quant + + extra_expert_args = { + "g1_alphas": layer.g1_alphas, + "g2_alphas": layer.g2_alphas, + # Avoid confusion with a1_scale and a2_scale + # where are batch size related. + "a1_gscale": a1_gscale, + "a2_gscale": a2_gscale, + "out_dtype": x.dtype, + } + extra_prepare_args = { + "use_dp": layer.dp_size > 1, + "local_tokens": x.shape[0], + "a1_gscale": a1_gscale, + } + extra_finalize_args = { + "use_dp": layer.dp_size > 1, + "local_tokens": x.shape[0], + } + + return fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=False, # TODO(shuw): fix later, now output is high prec + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=layer.w13_blockscale_swizzled, + w2_scale=layer.w2_blockscale_swizzled, + apply_router_weight_on_input=apply_router_weight_on_input, + extra_expert_args=extra_expert_args, + extra_prepare_args=extra_prepare_args, + extra_finalize_args=extra_finalize_args, + ) + + +def select_nvfp4_gemm_impl( + allow_flashinfer_cutlass: bool, + moe, # FusedMoEConfig + logger): + """Return a GEMM *experts* implementation for NV-FP4 fused-MoE layers""" + + # lazy import + from vllm.distributed import get_ep_group + + all2all_manager = get_ep_group().device_communicator.all2all_manager + assert all2all_manager is not None + + if allow_flashinfer_cutlass: + logger.debug_once("Using FlashInferExperts") + return FlashInferExperts( + use_nvfp4_w4a4=True, + use_dp=moe.moe_parallel_config.dp_size > 1, + ep_rank=moe.moe_parallel_config.ep_rank, + ep_size=moe.moe_parallel_config.ep_size, + tp_rank=moe.moe_parallel_config.tp_rank, + tp_size=moe.moe_parallel_config.tp_size, + ) + + # native cutlass experts currently don't support DP; TP case won't call this + raise ValueError( + "CutlassExpertsFp4 doesn't support DP. Use flashinfer CUTLASS " + "Fused MoE backend instead (set VLLM_USE_FLASHINFER_MOE_FP4=1)") diff --git a/vllm/model_executor/layers/quantization/utils/nvfp4_moe_support.py b/vllm/model_executor/layers/quantization/utils/nvfp4_moe_support.py new file mode 100644 index 000000000000..23a749467f19 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/nvfp4_moe_support.py @@ -0,0 +1,59 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass + +import vllm.envs as envs +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( + is_flashinfer_fp4_cutlass_moe_available) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + is_fp4_marlin_supported) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + cutlass_fp4_supported) + +__all__ = ["detect_nvfp4_moe_support", "NvFp4Support"] + +_logger = init_logger(__name__) + + +@dataclass(frozen=True) +class NvFp4Support: + """Result container for NV-FP4 capability probing.""" + + cutlass_supported: bool + allow_flashinfer_cutlass: bool + use_marlin: bool + + +def detect_nvfp4_moe_support(class_name: str = "") -> NvFp4Support: + """Detect platform support for NV-FP4 fused-MoE path""" + cutlass_supported = cutlass_fp4_supported() + + allow_flashinfer = (cutlass_supported + and is_flashinfer_fp4_cutlass_moe_available()) + + if allow_flashinfer: + _logger.info_once("Using FlashInfer kernels for %s.", class_name + or "NVFP4 path") + else: + if envs.VLLM_USE_FLASHINFER_MOE_FP4: + _logger.warning_once( + "FlashInfer kernels unavailable for %s on current platform.", + class_name or "NVFP4 path", + ) + + use_marlin = False + if not cutlass_supported: + if is_fp4_marlin_supported(): + use_marlin = True + _logger.info_once("Falling back to Marlin FP4 MoE kernel.") + else: + raise ValueError( + "Current platform does not support NVFP4 quantization. " + "Please use Blackwell GPUs or enable FlashInfer.") + + return NvFp4Support( + cutlass_supported=cutlass_supported, + allow_flashinfer_cutlass=allow_flashinfer, + use_marlin=use_marlin, + )