From c12df53b60d246f5cffd5f011b3141b9e7b5307b Mon Sep 17 00:00:00 2001 From: TherLF <54900723+Ther-LF@users.noreply.github.com> Date: Mon, 28 Apr 2025 10:38:42 +0800 Subject: [PATCH] =?UTF-8?q?[Bugfix]=20Fix=20cutlass=20dispatch=20for=20fp8?= =?UTF-8?q?/int8=20to=20properly=20invoke=20M<=3D16=20c=E2=80=A6=20(#16751?= =?UTF-8?q?)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ther-LF <2639852836@qq.com> --- .../cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh | 2 +- .../cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh index 4e82c99c3af31..6082937e7e1f9 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh @@ -336,7 +336,7 @@ inline void cutlass_gemm_sm89_fp8_dispatch(torch::Tensor& out, uint32_t const m = a.size(0); uint32_t const mp2 = - std::max(static_cast(32), next_pow_2(m)); // next power of 2 + std::max(static_cast(16), next_pow_2(m)); // next power of 2 if (mp2 <= 16) { // M in [1, 16] diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh index 95723b31ca3ce..87be125b2eb3c 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh @@ -321,7 +321,7 @@ inline void cutlass_gemm_sm89_int8_dispatch(torch::Tensor& out, uint32_t const m = a.size(0); uint32_t const mp2 = - std::max(static_cast(32), next_pow_2(m)); // next power of 2 + std::max(static_cast(16), next_pow_2(m)); // next power of 2 if (mp2 <= 16) { // M in [1, 16]