From e13945f9dd60260ae8ed42d8d316d48401fe850b Mon Sep 17 00:00:00 2001 From: Ilya Markov Date: Sun, 15 Jun 2025 02:25:10 +0200 Subject: [PATCH] [Perf] Further tunings for SM100 FP8 CUTLASS kernel (#19566) --- .../c3x/scaled_mm_sm100_fp8_dispatch.cuh | 30 +++++++++++++++---- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh index 6da2da634075..1549ed96aa2b 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh @@ -15,11 +15,25 @@ using c3x::cutlass_gemm_caller; template typename Epilogue> struct sm100_fp8_config_default { - // M in (128, inf) + // M in (256, inf) static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; - using TileShape = Shape<_256, _128, _64>; + using TileShape = Shape<_256, _128, _128>; + using ClusterShape = Shape<_2, _2, _1>; + using Cutlass3xGemm = + cutlass_3x_gemm_sm100; +}; + +template typename Epilogue> +struct sm100_fp8_config_M256 { + // M in (128, 256] + static_assert(std::is_same()); + using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; + using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; + using TileShape = Shape<_128, _128, _128>; using ClusterShape = Shape<_2, _2, _1>; using Cutlass3xGemm = cutlass_3x_gemm_sm100()); using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; - using TileShape = Shape<_128, _128, _64>; - using ClusterShape = Shape<_2, _2, _1>; + using TileShape = Shape<_128, _128, _256>; + using ClusterShape = Shape<_2, _4, _1>; using Cutlass3xGemm = cutlass_3x_gemm_sm100; @@ -72,6 +86,8 @@ inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out, typename sm100_fp8_config_M64::Cutlass3xGemm; using Cutlass3xGemmM128 = typename sm100_fp8_config_M128::Cutlass3xGemm; + using Cutlass3xGemmM256 = + typename sm100_fp8_config_M256::Cutlass3xGemm; uint32_t const m = a.size(0); uint32_t const mp2 = @@ -85,8 +101,12 @@ inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out, // m in (64, 128] return cutlass_gemm_caller( out, a, b, std::forward(args)...); + } else if (mp2 <= 256) { + // m in (128, 256] + return cutlass_gemm_caller( + out, a, b, std::forward(args)...); } else { - // m in (128, inf) + // m in (256, inf) return cutlass_gemm_caller( out, a, b, std::forward(args)...); }