From c1ffcb55da6a63b9db6d1fa984b93e1b9d5dbcc6 Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Fri, 3 Oct 2025 17:23:42 -0400 Subject: [PATCH] [Refactor] Optimize FP8 MOE Backend Choice and Log (#26044) Signed-off-by: yewentao256 --- .../model_executor/layers/quantization/fp8.py | 115 +++++++++++------- 1 file changed, 70 insertions(+), 45 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 9b7b3f18baa7..dbcf4b2fbee5 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from enum import Enum from typing import TYPE_CHECKING, Any, Callable, Optional, Union import torch @@ -68,6 +69,65 @@ ACTIVATION_SCHEMES = ["static", "dynamic"] logger = init_logger(__name__) +class Fp8MoeBackend(Enum): + NONE = 0 + FLASHINFER_TRTLLM = 1 + FLASHINFER_CUTLASS = 2 + DEEPGEMM = 3 + CUTLASS_BLOCK_SCALED_GROUPED_GEMM = 4 + MARLIN = 5 + TRITON = 6 + + +def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend: + """ + Select the primary FP8 MoE backend + Note: Shape-specific fallbacks may still occur at runtime. + """ + # prefer FlashInfer backends when available and enabled on supported GPUs + if (current_platform.is_cuda() + and current_platform.is_device_capability(100) + and envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe()): + backend = get_flashinfer_moe_backend() + if backend == FlashinferMoeBackend.TENSORRT_LLM: + logger.info_once( + "Using FlashInfer FP8 MoE TRTLLM backend for SM100") + return Fp8MoeBackend.FLASHINFER_TRTLLM + else: + logger.info_once( + "Using FlashInfer FP8 MoE CUTLASS backend for SM100") + return Fp8MoeBackend.FLASHINFER_CUTLASS + + # weight-only path for older GPUs without native FP8 + use_marlin = (not current_platform.has_device_capability(89) + or envs.VLLM_TEST_FORCE_FP8_MARLIN) + if current_platform.is_rocm(): + use_marlin = False + if use_marlin: + logger.info_once("Using Marlin backend for FP8 MoE") + return Fp8MoeBackend.MARLIN + + # deepGEMM on supported platforms with block-quantized weights + if envs.VLLM_USE_DEEP_GEMM and block_quant: + if not has_deep_gemm(): + logger.warning_once( + "DeepGEMM backend requested but not available.") + elif is_deep_gemm_supported(): + logger.info_once("Using DeepGEMM backend for FP8 MoE") + return Fp8MoeBackend.DEEPGEMM + + # CUTLASS BlockScaled GroupedGemm on SM100 with block-quantized weights + if (current_platform.is_cuda() + and current_platform.is_device_capability(100) and block_quant): + logger.info_once( + "Using Cutlass BlockScaled GroupedGemm backend for FP8 MoE") + return Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM + + # default to Triton + logger.info_once("Using Triton backend for FP8 MoE") + return Fp8MoeBackend.TRITON + + class Fp8Config(QuantizationConfig): """Config class for FP8.""" @@ -453,54 +513,19 @@ class Fp8MoEMethod(FusedMoEMethodBase): self.fused_experts: Optional[ mk.FusedMoEModularKernel] = None # type: ignore - # For GPUs that lack FP8 hardware support, we can leverage the Marlin - # kernel for fast weight-only FP8 quantization - self.use_marlin = (not current_platform.has_device_capability(89) - or envs.VLLM_TEST_FORCE_FP8_MARLIN) - # Disable marlin for rocm - if current_platform.is_rocm(): - self.use_marlin = False + self.fp8_backend = get_fp8_moe_backend(self.block_quant) - # First check for Flashinfer MOE on Blackwell GPUs + self.use_marlin = (self.fp8_backend == Fp8MoeBackend.MARLIN) self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None - if (current_platform.is_cuda() - and current_platform.is_device_capability(100) - and envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe()): - self.flashinfer_moe_backend = get_flashinfer_moe_backend() - logger.info_once( - f"Detected Blackwell GPUs, using FlashInfer " - f"{self.flashinfer_moe_backend.value} kernels for FP8 MOE.") + if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: + self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM + elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS: + self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS - # Check for DeepGemm support. - self.allow_deep_gemm = False - if envs.VLLM_USE_DEEP_GEMM: - if not has_deep_gemm(): - logger.warning_once("Failed to import DeepGemm kernels.") - elif not self.block_quant: - logger.warning_once("Model is not block quantized. Not using" - " DeepGemm kernels") - elif self.flashinfer_moe_backend: - logger.info_once("DeepGemm disabled: FlashInfer MOE is" - " enabled.") - elif (is_deep_gemm_supported()): - logger.debug_once( - "DeepGemm kernels available for Fp8MoEMethod.") - self.allow_deep_gemm = True - else: - logger.warning_once( - "DeepGemm not supported on the current platform.") - - # Check for CutlassBlockScaledGroupedGemm support. - self.allow_cutlass_block_scaled_grouped_gemm = False - if not self.block_quant: - logger.debug_once("Model is not block quantized. Not using " - "CutlassBlockScaledGroupedGemm kernels") - elif (current_platform.is_cuda() - and current_platform.is_device_capability(100) - and not self.flashinfer_moe_backend): - logger.debug_once( - "CutlassBlockScaledGroupedGemm available for Fp8MoEMethod.") - self.allow_cutlass_block_scaled_grouped_gemm = True + self.allow_deep_gemm = (self.fp8_backend == Fp8MoeBackend.DEEPGEMM) + self.allow_cutlass_block_scaled_grouped_gemm = ( + self.fp8_backend == Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM + ) def create_weights(self, layer: Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int,