mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:34:57 +08:00
[Refactor] Optimize FP8 MOE Backend Choice and Log (#26044)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
0879736aab
commit
c1ffcb55da
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
@ -68,6 +69,65 @@ ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Fp8MoeBackend(Enum):
|
||||
NONE = 0
|
||||
FLASHINFER_TRTLLM = 1
|
||||
FLASHINFER_CUTLASS = 2
|
||||
DEEPGEMM = 3
|
||||
CUTLASS_BLOCK_SCALED_GROUPED_GEMM = 4
|
||||
MARLIN = 5
|
||||
TRITON = 6
|
||||
|
||||
|
||||
def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend:
|
||||
"""
|
||||
Select the primary FP8 MoE backend
|
||||
Note: Shape-specific fallbacks may still occur at runtime.
|
||||
"""
|
||||
# prefer FlashInfer backends when available and enabled on supported GPUs
|
||||
if (current_platform.is_cuda()
|
||||
and current_platform.is_device_capability(100)
|
||||
and envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe()):
|
||||
backend = get_flashinfer_moe_backend()
|
||||
if backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||
logger.info_once(
|
||||
"Using FlashInfer FP8 MoE TRTLLM backend for SM100")
|
||||
return Fp8MoeBackend.FLASHINFER_TRTLLM
|
||||
else:
|
||||
logger.info_once(
|
||||
"Using FlashInfer FP8 MoE CUTLASS backend for SM100")
|
||||
return Fp8MoeBackend.FLASHINFER_CUTLASS
|
||||
|
||||
# weight-only path for older GPUs without native FP8
|
||||
use_marlin = (not current_platform.has_device_capability(89)
|
||||
or envs.VLLM_TEST_FORCE_FP8_MARLIN)
|
||||
if current_platform.is_rocm():
|
||||
use_marlin = False
|
||||
if use_marlin:
|
||||
logger.info_once("Using Marlin backend for FP8 MoE")
|
||||
return Fp8MoeBackend.MARLIN
|
||||
|
||||
# deepGEMM on supported platforms with block-quantized weights
|
||||
if envs.VLLM_USE_DEEP_GEMM and block_quant:
|
||||
if not has_deep_gemm():
|
||||
logger.warning_once(
|
||||
"DeepGEMM backend requested but not available.")
|
||||
elif is_deep_gemm_supported():
|
||||
logger.info_once("Using DeepGEMM backend for FP8 MoE")
|
||||
return Fp8MoeBackend.DEEPGEMM
|
||||
|
||||
# CUTLASS BlockScaled GroupedGemm on SM100 with block-quantized weights
|
||||
if (current_platform.is_cuda()
|
||||
and current_platform.is_device_capability(100) and block_quant):
|
||||
logger.info_once(
|
||||
"Using Cutlass BlockScaled GroupedGemm backend for FP8 MoE")
|
||||
return Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM
|
||||
|
||||
# default to Triton
|
||||
logger.info_once("Using Triton backend for FP8 MoE")
|
||||
return Fp8MoeBackend.TRITON
|
||||
|
||||
|
||||
class Fp8Config(QuantizationConfig):
|
||||
"""Config class for FP8."""
|
||||
|
||||
@ -453,54 +513,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
self.fused_experts: Optional[
|
||||
mk.FusedMoEModularKernel] = None # type: ignore
|
||||
|
||||
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
||||
# kernel for fast weight-only FP8 quantization
|
||||
self.use_marlin = (not current_platform.has_device_capability(89)
|
||||
or envs.VLLM_TEST_FORCE_FP8_MARLIN)
|
||||
# Disable marlin for rocm
|
||||
if current_platform.is_rocm():
|
||||
self.use_marlin = False
|
||||
self.fp8_backend = get_fp8_moe_backend(self.block_quant)
|
||||
|
||||
# First check for Flashinfer MOE on Blackwell GPUs
|
||||
self.use_marlin = (self.fp8_backend == Fp8MoeBackend.MARLIN)
|
||||
self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None
|
||||
if (current_platform.is_cuda()
|
||||
and current_platform.is_device_capability(100)
|
||||
and envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe()):
|
||||
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
|
||||
logger.info_once(
|
||||
f"Detected Blackwell GPUs, using FlashInfer "
|
||||
f"{self.flashinfer_moe_backend.value} kernels for FP8 MOE.")
|
||||
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||
self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
|
||||
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
|
||||
self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
|
||||
|
||||
# Check for DeepGemm support.
|
||||
self.allow_deep_gemm = False
|
||||
if envs.VLLM_USE_DEEP_GEMM:
|
||||
if not has_deep_gemm():
|
||||
logger.warning_once("Failed to import DeepGemm kernels.")
|
||||
elif not self.block_quant:
|
||||
logger.warning_once("Model is not block quantized. Not using"
|
||||
" DeepGemm kernels")
|
||||
elif self.flashinfer_moe_backend:
|
||||
logger.info_once("DeepGemm disabled: FlashInfer MOE is"
|
||||
" enabled.")
|
||||
elif (is_deep_gemm_supported()):
|
||||
logger.debug_once(
|
||||
"DeepGemm kernels available for Fp8MoEMethod.")
|
||||
self.allow_deep_gemm = True
|
||||
else:
|
||||
logger.warning_once(
|
||||
"DeepGemm not supported on the current platform.")
|
||||
|
||||
# Check for CutlassBlockScaledGroupedGemm support.
|
||||
self.allow_cutlass_block_scaled_grouped_gemm = False
|
||||
if not self.block_quant:
|
||||
logger.debug_once("Model is not block quantized. Not using "
|
||||
"CutlassBlockScaledGroupedGemm kernels")
|
||||
elif (current_platform.is_cuda()
|
||||
and current_platform.is_device_capability(100)
|
||||
and not self.flashinfer_moe_backend):
|
||||
logger.debug_once(
|
||||
"CutlassBlockScaledGroupedGemm available for Fp8MoEMethod.")
|
||||
self.allow_cutlass_block_scaled_grouped_gemm = True
|
||||
self.allow_deep_gemm = (self.fp8_backend == Fp8MoeBackend.DEEPGEMM)
|
||||
self.allow_cutlass_block_scaled_grouped_gemm = (
|
||||
self.fp8_backend == Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM
|
||||
)
|
||||
|
||||
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user