[Bugfix] Fix incorrect dispatch for CutlassBlockScaledGroupedGemm and DeepGEMM (#20933)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin 2025-07-15 10:42:17 +09:00 committed by GitHub
parent ba8c300018
commit bcdfb2a330
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -488,11 +488,16 @@ class Fp8MoEMethod(FusedMoEMethodBase):
logger.warning_once("Failed to import DeepGemm kernels.") logger.warning_once("Failed to import DeepGemm kernels.")
elif not self.block_quant: elif not self.block_quant:
logger.warning_once("Model is not block quantized. Not using " logger.warning_once("Model is not block quantized. Not using "
" DeepGemm kernels") "DeepGemm kernels")
elif (current_platform.is_cuda() elif (current_platform.is_cuda()
and current_platform.has_device_capability(90)): and current_platform.is_device_capability(90)):
logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.") logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.")
self.allow_deep_gemm = True self.allow_deep_gemm = True
elif (current_platform.is_cuda()
and is_blackwell_deep_gemm_used()):
logger.info_once("Using DeepGemm SM100 kernels for "
"Fp8MoEMethod.")
self.allow_deep_gemm = True
else: else:
logger.warning_once( logger.warning_once(
"DeepGemm not supported on the current platform.") "DeepGemm not supported on the current platform.")
@ -500,10 +505,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# Check for CutlassBlockScaledGroupedGemm support. # Check for CutlassBlockScaledGroupedGemm support.
self.allow_cutlass_block_scaled_grouped_gemm = False self.allow_cutlass_block_scaled_grouped_gemm = False
if not self.block_quant: if not self.block_quant:
logger.warning_once("Model is not block quantized. Not using " logger.debug_once("Model is not block quantized. Not using "
"CutlassBlockScaledGroupedGemm kernels") "CutlassBlockScaledGroupedGemm kernels")
elif (current_platform.is_cuda() elif (current_platform.is_cuda()
and current_platform.has_device_capability(100)): and current_platform.is_device_capability(100)):
logger.info_once( logger.info_once(
"Using CutlassBlockScaledGroupedGemm kernels for Fp8MoEMethod." "Using CutlassBlockScaledGroupedGemm kernels for Fp8MoEMethod."
) )