[Bug] DeepGemm: Fix TypeError: per_block_cast_to_fp8() missing 1 required positional argument: 'use_ue8m0' for SM100 (#21187)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye 2025-07-19 02:25:22 -04:00 committed by GitHub
parent 468e2400fe
commit 37bd8d6e4c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -99,7 +99,7 @@ def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
def per_block_cast_to_fp8(x, *args, **kwargs): def per_block_cast_to_fp8(x, *args, **kwargs):
if _per_block_cast_impl is not None and is_blackwell_deep_gemm_used(): if _per_block_cast_impl is not None and is_blackwell_deep_gemm_used():
return _per_block_cast_impl(x) return _per_block_cast_impl(x, use_ue8m0=True)
# TODO: refactor the `per_block_cast_to_fp8` from tests to vllm utils # TODO: refactor the `per_block_cast_to_fp8` from tests to vllm utils
from tests.kernels.quant_utils import per_block_cast_to_fp8 as _pbcf from tests.kernels.quant_utils import per_block_cast_to_fp8 as _pbcf
return _pbcf(x, *args, **kwargs) return _pbcf(x, *args, **kwargs)