From bcdfb2a3308e14fbf46da6d6d41747f289af9300 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Tue, 15 Jul 2025 10:42:17 +0900 Subject: [PATCH] [Bugfix] Fix incorrect dispatch for CutlassBlockScaledGroupedGemm and DeepGEMM (#20933) Signed-off-by: mgoin --- vllm/model_executor/layers/quantization/fp8.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 59db3e6c4449b..824dfe15ae250 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -488,11 +488,16 @@ class Fp8MoEMethod(FusedMoEMethodBase): logger.warning_once("Failed to import DeepGemm kernels.") elif not self.block_quant: logger.warning_once("Model is not block quantized. Not using " - " DeepGemm kernels") + "DeepGemm kernels") elif (current_platform.is_cuda() - and current_platform.has_device_capability(90)): + and current_platform.is_device_capability(90)): logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.") self.allow_deep_gemm = True + elif (current_platform.is_cuda() + and is_blackwell_deep_gemm_used()): + logger.info_once("Using DeepGemm SM100 kernels for " + "Fp8MoEMethod.") + self.allow_deep_gemm = True else: logger.warning_once( "DeepGemm not supported on the current platform.") @@ -500,10 +505,10 @@ class Fp8MoEMethod(FusedMoEMethodBase): # Check for CutlassBlockScaledGroupedGemm support. self.allow_cutlass_block_scaled_grouped_gemm = False if not self.block_quant: - logger.warning_once("Model is not block quantized. Not using " - "CutlassBlockScaledGroupedGemm kernels") + logger.debug_once("Model is not block quantized. Not using " + "CutlassBlockScaledGroupedGemm kernels") elif (current_platform.is_cuda() - and current_platform.has_device_capability(100)): + and current_platform.is_device_capability(100)): logger.info_once( "Using CutlassBlockScaledGroupedGemm kernels for Fp8MoEMethod." )