mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:25:00 +08:00
[Refactor] Optimize FP8 MOE Backend Choice and Log (#26044)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
0879736aab
commit
c1ffcb55da
@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -68,6 +69,65 @@ ACTIVATION_SCHEMES = ["static", "dynamic"]
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Fp8MoeBackend(Enum):
|
||||||
|
NONE = 0
|
||||||
|
FLASHINFER_TRTLLM = 1
|
||||||
|
FLASHINFER_CUTLASS = 2
|
||||||
|
DEEPGEMM = 3
|
||||||
|
CUTLASS_BLOCK_SCALED_GROUPED_GEMM = 4
|
||||||
|
MARLIN = 5
|
||||||
|
TRITON = 6
|
||||||
|
|
||||||
|
|
||||||
|
def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend:
|
||||||
|
"""
|
||||||
|
Select the primary FP8 MoE backend
|
||||||
|
Note: Shape-specific fallbacks may still occur at runtime.
|
||||||
|
"""
|
||||||
|
# prefer FlashInfer backends when available and enabled on supported GPUs
|
||||||
|
if (current_platform.is_cuda()
|
||||||
|
and current_platform.is_device_capability(100)
|
||||||
|
and envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe()):
|
||||||
|
backend = get_flashinfer_moe_backend()
|
||||||
|
if backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||||
|
logger.info_once(
|
||||||
|
"Using FlashInfer FP8 MoE TRTLLM backend for SM100")
|
||||||
|
return Fp8MoeBackend.FLASHINFER_TRTLLM
|
||||||
|
else:
|
||||||
|
logger.info_once(
|
||||||
|
"Using FlashInfer FP8 MoE CUTLASS backend for SM100")
|
||||||
|
return Fp8MoeBackend.FLASHINFER_CUTLASS
|
||||||
|
|
||||||
|
# weight-only path for older GPUs without native FP8
|
||||||
|
use_marlin = (not current_platform.has_device_capability(89)
|
||||||
|
or envs.VLLM_TEST_FORCE_FP8_MARLIN)
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
use_marlin = False
|
||||||
|
if use_marlin:
|
||||||
|
logger.info_once("Using Marlin backend for FP8 MoE")
|
||||||
|
return Fp8MoeBackend.MARLIN
|
||||||
|
|
||||||
|
# deepGEMM on supported platforms with block-quantized weights
|
||||||
|
if envs.VLLM_USE_DEEP_GEMM and block_quant:
|
||||||
|
if not has_deep_gemm():
|
||||||
|
logger.warning_once(
|
||||||
|
"DeepGEMM backend requested but not available.")
|
||||||
|
elif is_deep_gemm_supported():
|
||||||
|
logger.info_once("Using DeepGEMM backend for FP8 MoE")
|
||||||
|
return Fp8MoeBackend.DEEPGEMM
|
||||||
|
|
||||||
|
# CUTLASS BlockScaled GroupedGemm on SM100 with block-quantized weights
|
||||||
|
if (current_platform.is_cuda()
|
||||||
|
and current_platform.is_device_capability(100) and block_quant):
|
||||||
|
logger.info_once(
|
||||||
|
"Using Cutlass BlockScaled GroupedGemm backend for FP8 MoE")
|
||||||
|
return Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM
|
||||||
|
|
||||||
|
# default to Triton
|
||||||
|
logger.info_once("Using Triton backend for FP8 MoE")
|
||||||
|
return Fp8MoeBackend.TRITON
|
||||||
|
|
||||||
|
|
||||||
class Fp8Config(QuantizationConfig):
|
class Fp8Config(QuantizationConfig):
|
||||||
"""Config class for FP8."""
|
"""Config class for FP8."""
|
||||||
|
|
||||||
@ -453,54 +513,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
self.fused_experts: Optional[
|
self.fused_experts: Optional[
|
||||||
mk.FusedMoEModularKernel] = None # type: ignore
|
mk.FusedMoEModularKernel] = None # type: ignore
|
||||||
|
|
||||||
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
self.fp8_backend = get_fp8_moe_backend(self.block_quant)
|
||||||
# kernel for fast weight-only FP8 quantization
|
|
||||||
self.use_marlin = (not current_platform.has_device_capability(89)
|
|
||||||
or envs.VLLM_TEST_FORCE_FP8_MARLIN)
|
|
||||||
# Disable marlin for rocm
|
|
||||||
if current_platform.is_rocm():
|
|
||||||
self.use_marlin = False
|
|
||||||
|
|
||||||
# First check for Flashinfer MOE on Blackwell GPUs
|
self.use_marlin = (self.fp8_backend == Fp8MoeBackend.MARLIN)
|
||||||
self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None
|
self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None
|
||||||
if (current_platform.is_cuda()
|
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||||
and current_platform.is_device_capability(100)
|
self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
|
||||||
and envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe()):
|
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
|
||||||
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
|
self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
|
||||||
logger.info_once(
|
|
||||||
f"Detected Blackwell GPUs, using FlashInfer "
|
|
||||||
f"{self.flashinfer_moe_backend.value} kernels for FP8 MOE.")
|
|
||||||
|
|
||||||
# Check for DeepGemm support.
|
self.allow_deep_gemm = (self.fp8_backend == Fp8MoeBackend.DEEPGEMM)
|
||||||
self.allow_deep_gemm = False
|
self.allow_cutlass_block_scaled_grouped_gemm = (
|
||||||
if envs.VLLM_USE_DEEP_GEMM:
|
self.fp8_backend == Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM
|
||||||
if not has_deep_gemm():
|
)
|
||||||
logger.warning_once("Failed to import DeepGemm kernels.")
|
|
||||||
elif not self.block_quant:
|
|
||||||
logger.warning_once("Model is not block quantized. Not using"
|
|
||||||
" DeepGemm kernels")
|
|
||||||
elif self.flashinfer_moe_backend:
|
|
||||||
logger.info_once("DeepGemm disabled: FlashInfer MOE is"
|
|
||||||
" enabled.")
|
|
||||||
elif (is_deep_gemm_supported()):
|
|
||||||
logger.debug_once(
|
|
||||||
"DeepGemm kernels available for Fp8MoEMethod.")
|
|
||||||
self.allow_deep_gemm = True
|
|
||||||
else:
|
|
||||||
logger.warning_once(
|
|
||||||
"DeepGemm not supported on the current platform.")
|
|
||||||
|
|
||||||
# Check for CutlassBlockScaledGroupedGemm support.
|
|
||||||
self.allow_cutlass_block_scaled_grouped_gemm = False
|
|
||||||
if not self.block_quant:
|
|
||||||
logger.debug_once("Model is not block quantized. Not using "
|
|
||||||
"CutlassBlockScaledGroupedGemm kernels")
|
|
||||||
elif (current_platform.is_cuda()
|
|
||||||
and current_platform.is_device_capability(100)
|
|
||||||
and not self.flashinfer_moe_backend):
|
|
||||||
logger.debug_once(
|
|
||||||
"CutlassBlockScaledGroupedGemm available for Fp8MoEMethod.")
|
|
||||||
self.allow_cutlass_block_scaled_grouped_gemm = True
|
|
||||||
|
|
||||||
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
|
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
|
||||||
intermediate_size_per_partition: int,
|
intermediate_size_per_partition: int,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user