mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 10:16:20 +08:00
[Bugfix] Fix incorrect dispatch for CutlassBlockScaledGroupedGemm and DeepGEMM (#20933)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
ba8c300018
commit
bcdfb2a330
@ -488,11 +488,16 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
logger.warning_once("Failed to import DeepGemm kernels.")
|
logger.warning_once("Failed to import DeepGemm kernels.")
|
||||||
elif not self.block_quant:
|
elif not self.block_quant:
|
||||||
logger.warning_once("Model is not block quantized. Not using "
|
logger.warning_once("Model is not block quantized. Not using "
|
||||||
" DeepGemm kernels")
|
"DeepGemm kernels")
|
||||||
elif (current_platform.is_cuda()
|
elif (current_platform.is_cuda()
|
||||||
and current_platform.has_device_capability(90)):
|
and current_platform.is_device_capability(90)):
|
||||||
logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.")
|
logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.")
|
||||||
self.allow_deep_gemm = True
|
self.allow_deep_gemm = True
|
||||||
|
elif (current_platform.is_cuda()
|
||||||
|
and is_blackwell_deep_gemm_used()):
|
||||||
|
logger.info_once("Using DeepGemm SM100 kernels for "
|
||||||
|
"Fp8MoEMethod.")
|
||||||
|
self.allow_deep_gemm = True
|
||||||
else:
|
else:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"DeepGemm not supported on the current platform.")
|
"DeepGemm not supported on the current platform.")
|
||||||
@ -500,10 +505,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
# Check for CutlassBlockScaledGroupedGemm support.
|
# Check for CutlassBlockScaledGroupedGemm support.
|
||||||
self.allow_cutlass_block_scaled_grouped_gemm = False
|
self.allow_cutlass_block_scaled_grouped_gemm = False
|
||||||
if not self.block_quant:
|
if not self.block_quant:
|
||||||
logger.warning_once("Model is not block quantized. Not using "
|
logger.debug_once("Model is not block quantized. Not using "
|
||||||
"CutlassBlockScaledGroupedGemm kernels")
|
"CutlassBlockScaledGroupedGemm kernels")
|
||||||
elif (current_platform.is_cuda()
|
elif (current_platform.is_cuda()
|
||||||
and current_platform.has_device_capability(100)):
|
and current_platform.is_device_capability(100)):
|
||||||
logger.info_once(
|
logger.info_once(
|
||||||
"Using CutlassBlockScaledGroupedGemm kernels for Fp8MoEMethod."
|
"Using CutlassBlockScaledGroupedGemm kernels for Fp8MoEMethod."
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user