[Refactor] Optimize FP8 MOE Backend Choice and Log (#26044)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye 2025-10-03 17:23:42 -04:00 committed by GitHub
parent 0879736aab
commit c1ffcb55da
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import Enum
from typing import TYPE_CHECKING, Any, Callable, Optional, Union from typing import TYPE_CHECKING, Any, Callable, Optional, Union
import torch import torch
@ -68,6 +69,65 @@ ACTIVATION_SCHEMES = ["static", "dynamic"]
logger = init_logger(__name__) logger = init_logger(__name__)
class Fp8MoeBackend(Enum):
NONE = 0
FLASHINFER_TRTLLM = 1
FLASHINFER_CUTLASS = 2
DEEPGEMM = 3
CUTLASS_BLOCK_SCALED_GROUPED_GEMM = 4
MARLIN = 5
TRITON = 6
def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend:
"""
Select the primary FP8 MoE backend
Note: Shape-specific fallbacks may still occur at runtime.
"""
# prefer FlashInfer backends when available and enabled on supported GPUs
if (current_platform.is_cuda()
and current_platform.is_device_capability(100)
and envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe()):
backend = get_flashinfer_moe_backend()
if backend == FlashinferMoeBackend.TENSORRT_LLM:
logger.info_once(
"Using FlashInfer FP8 MoE TRTLLM backend for SM100")
return Fp8MoeBackend.FLASHINFER_TRTLLM
else:
logger.info_once(
"Using FlashInfer FP8 MoE CUTLASS backend for SM100")
return Fp8MoeBackend.FLASHINFER_CUTLASS
# weight-only path for older GPUs without native FP8
use_marlin = (not current_platform.has_device_capability(89)
or envs.VLLM_TEST_FORCE_FP8_MARLIN)
if current_platform.is_rocm():
use_marlin = False
if use_marlin:
logger.info_once("Using Marlin backend for FP8 MoE")
return Fp8MoeBackend.MARLIN
# deepGEMM on supported platforms with block-quantized weights
if envs.VLLM_USE_DEEP_GEMM and block_quant:
if not has_deep_gemm():
logger.warning_once(
"DeepGEMM backend requested but not available.")
elif is_deep_gemm_supported():
logger.info_once("Using DeepGEMM backend for FP8 MoE")
return Fp8MoeBackend.DEEPGEMM
# CUTLASS BlockScaled GroupedGemm on SM100 with block-quantized weights
if (current_platform.is_cuda()
and current_platform.is_device_capability(100) and block_quant):
logger.info_once(
"Using Cutlass BlockScaled GroupedGemm backend for FP8 MoE")
return Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM
# default to Triton
logger.info_once("Using Triton backend for FP8 MoE")
return Fp8MoeBackend.TRITON
class Fp8Config(QuantizationConfig): class Fp8Config(QuantizationConfig):
"""Config class for FP8.""" """Config class for FP8."""
@ -453,54 +513,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self.fused_experts: Optional[ self.fused_experts: Optional[
mk.FusedMoEModularKernel] = None # type: ignore mk.FusedMoEModularKernel] = None # type: ignore
# For GPUs that lack FP8 hardware support, we can leverage the Marlin self.fp8_backend = get_fp8_moe_backend(self.block_quant)
# kernel for fast weight-only FP8 quantization
self.use_marlin = (not current_platform.has_device_capability(89)
or envs.VLLM_TEST_FORCE_FP8_MARLIN)
# Disable marlin for rocm
if current_platform.is_rocm():
self.use_marlin = False
# First check for Flashinfer MOE on Blackwell GPUs self.use_marlin = (self.fp8_backend == Fp8MoeBackend.MARLIN)
self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None
if (current_platform.is_cuda() if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
and current_platform.is_device_capability(100) self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
and envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe()): elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
self.flashinfer_moe_backend = get_flashinfer_moe_backend() self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
logger.info_once(
f"Detected Blackwell GPUs, using FlashInfer "
f"{self.flashinfer_moe_backend.value} kernels for FP8 MOE.")
# Check for DeepGemm support. self.allow_deep_gemm = (self.fp8_backend == Fp8MoeBackend.DEEPGEMM)
self.allow_deep_gemm = False self.allow_cutlass_block_scaled_grouped_gemm = (
if envs.VLLM_USE_DEEP_GEMM: self.fp8_backend == Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM
if not has_deep_gemm(): )
logger.warning_once("Failed to import DeepGemm kernels.")
elif not self.block_quant:
logger.warning_once("Model is not block quantized. Not using"
" DeepGemm kernels")
elif self.flashinfer_moe_backend:
logger.info_once("DeepGemm disabled: FlashInfer MOE is"
" enabled.")
elif (is_deep_gemm_supported()):
logger.debug_once(
"DeepGemm kernels available for Fp8MoEMethod.")
self.allow_deep_gemm = True
else:
logger.warning_once(
"DeepGemm not supported on the current platform.")
# Check for CutlassBlockScaledGroupedGemm support.
self.allow_cutlass_block_scaled_grouped_gemm = False
if not self.block_quant:
logger.debug_once("Model is not block quantized. Not using "
"CutlassBlockScaledGroupedGemm kernels")
elif (current_platform.is_cuda()
and current_platform.is_device_capability(100)
and not self.flashinfer_moe_backend):
logger.debug_once(
"CutlassBlockScaledGroupedGemm available for Fp8MoEMethod.")
self.allow_cutlass_block_scaled_grouped_gemm = True
def create_weights(self, layer: Module, num_experts: int, hidden_size: int, def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
intermediate_size_per_partition: int, intermediate_size_per_partition: int,