diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index fb448de3c2341..a0808cb603d05 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -434,14 +434,9 @@ class Fp8MoEMethod(FusedMoEMethodBase): self.weight_block_size = self.quant_config.weight_block_size self.block_quant = self.weight_block_size is not None - self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None self.fused_experts: Optional[ mk.FusedMoEModularKernel] = None # type: ignore - if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe(): - self.flashinfer_moe_backend = get_flashinfer_moe_backend() - logger.info_once( - f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels" - ) + # For GPUs that lack FP8 hardware support, we can leverage the Marlin # kernel for fast weight-only FP8 quantization self.use_marlin = (not current_platform.has_device_capability(89) @@ -450,14 +445,27 @@ class Fp8MoEMethod(FusedMoEMethodBase): if current_platform.is_rocm(): self.use_marlin = False + # First check for Flashinfer MOE on Blackwell GPUs + self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None + if (current_platform.is_cuda() + and current_platform.is_device_capability(100) + and envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe()): + self.flashinfer_moe_backend = get_flashinfer_moe_backend() + logger.info_once( + f"Detected Blackwell GPUs, using FlashInfer " + f"{self.flashinfer_moe_backend.value} kernels for FP8 MOE.") + # Check for DeepGemm support. self.allow_deep_gemm = False if envs.VLLM_USE_DEEP_GEMM: if not has_deep_gemm(): logger.warning_once("Failed to import DeepGemm kernels.") elif not self.block_quant: - logger.warning_once("Model is not block quantized. Not using " - "DeepGemm kernels") + logger.warning_once("Model is not block quantized. Not using" + " DeepGemm kernels") + elif self.flashinfer_moe_backend: + logger.info_once("DeepGemm disabled: FlashInfer MOE is" + " enabled.") elif (is_deep_gemm_supported()): logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.") self.allow_deep_gemm = True @@ -471,15 +479,12 @@ class Fp8MoEMethod(FusedMoEMethodBase): logger.debug_once("Model is not block quantized. Not using " "CutlassBlockScaledGroupedGemm kernels") elif (current_platform.is_cuda() - and current_platform.is_device_capability(100)): + and current_platform.is_device_capability(100) + and not self.flashinfer_moe_backend): logger.info_once( - "Using CutlassBlockScaledGroupedGemm kernels for Fp8MoEMethod." - ) + "Using CutlassBlockScaledGroupedGemm kernels for Fp8 MOE " + "on SM100.") self.allow_cutlass_block_scaled_grouped_gemm = True - else: - logger.warning_once( - "CutlassBlockScaledGroupedGemm not supported on the current " - "platform.") def create_weights(self, layer: Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, @@ -934,7 +939,9 @@ class Fp8MoEMethod(FusedMoEMethodBase): import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401 assert (renormalize and use_grouped_topk and custom_routing_function is None) - result = torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8( + e_score_correction_bias = (e_score_correction_bias.to( + x.dtype) if e_score_correction_bias is not None else None) + return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8( routing_logits=router_logits.to(torch.float32), routing_bias=e_score_correction_bias, x=x, diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index 0e3bdaec829e5..4f05f0bc35cc1 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -27,7 +27,8 @@ def is_deep_gemm_supported() -> bool: is_supported_arch = current_platform.is_cuda() and ( current_platform.is_device_capability(90) or current_platform.is_device_capability(100)) - return envs.VLLM_USE_DEEP_GEMM and has_deep_gemm() and is_supported_arch + return (envs.VLLM_USE_DEEP_GEMM and has_deep_gemm() and is_supported_arch + and not envs.VLLM_USE_FLASHINFER_MOE_FP8) @functools.cache @@ -46,6 +47,10 @@ def is_deep_gemm_e8m0_used() -> bool: logger.info_once("DeepGEMM E8M0 disabled: _fp8_gemm_nt_impl not found") return False + if envs.VLLM_USE_FLASHINFER_MOE_FP8: + logger.info_once("DeepGEMM E8M0 disabled: FlashInfer MOE is enabled.") + return False + if current_platform.is_device_capability(100) and \ envs.VLLM_USE_DEEP_GEMM_E8M0: logger.info_once("DeepGEMM E8M0 enabled on Blackwell GPU.")