From cdd7025961cf79480f885804c21e7d60866fb33f Mon Sep 17 00:00:00 2001 From: czhu-cohere Date: Fri, 14 Nov 2025 12:59:11 -0500 Subject: [PATCH] [kernel] Improve FP8 PTPC on Hopper for larger shapes (#28692) Signed-off-by: czhu-cohere --- .../c3x/scaled_mm_sm90_fp8_dispatch.cuh | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8_dispatch.cuh b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8_dispatch.cuh index 4ff3e65f2b2e1..b8433214be1ba 100644 --- a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8_dispatch.cuh +++ b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8_dispatch.cuh @@ -116,6 +116,26 @@ struct sm90_fp8_config_default { ClusterShape, KernelSchedule, EpilogueSchedule>>; }; +template +struct sm90_fp8_config_M8192_K6144 { + // M >= 8192, K >= 6144 + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum; + using EpilogueSchedule = + typename cutlass::epilogue::TmaWarpSpecializedCooperative; + using TileShape = Shape<_256, _128, _128>; + using ClusterShape = Shape<_2, _1, _1>; + + using Cutlass3xGemm = conditional_t< + EnableBias, + cutlass_3x_gemm_sm90_fp8, + cutlass_3x_gemm_sm90_fp8>; +}; + template struct sm90_fp8_config_M128 { // M in (64, 128] @@ -273,6 +293,9 @@ inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, using Cutlass3xGemmDefault = typename sm90_fp8_config_default::Cutlass3xGemm; + using Cutlass3xGemmM8192_K6144 = + typename sm90_fp8_config_M8192_K6144::Cutlass3xGemm; using Cutlass3xGemmM128 = typename sm90_fp8_config_M128::Cutlass3xGemm; @@ -291,6 +314,7 @@ inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, uint32_t const m = a.size(0); uint32_t const n = b.size(1); + uint32_t const k = a.size(1); if (m <= 16) { // m in [1, 16] @@ -312,6 +336,9 @@ inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, // m in (64, 128] return cutlass_gemm_caller_sm90_fp8( out, a, b, a_scales, b_scales, std::forward(args)...); + } else if (m >= 8192 && k >= 6144) { + return cutlass_gemm_caller_sm90_fp8( + out, a, b, a_scales, b_scales, std::forward(args)...); } else { // m in (128, inf) return cutlass_gemm_caller_sm90_fp8(