diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index 56326c9315ba..8b5713e02c95 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -99,7 +99,7 @@ def fp8_m_grouped_gemm_nt_masked(*args, **kwargs): def per_block_cast_to_fp8(x, *args, **kwargs): if _per_block_cast_impl is not None and is_blackwell_deep_gemm_used(): - return _per_block_cast_impl(x) + return _per_block_cast_impl(x, use_ue8m0=True) # TODO: refactor the `per_block_cast_to_fp8` from tests to vllm utils from tests.kernels.quant_utils import per_block_cast_to_fp8 as _pbcf return _pbcf(x, *args, **kwargs)