[kernel] Improve FP8 PTPC on Hopper for larger shapes (#28692)

Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
This commit is contained in:
czhu-cohere 2025-11-14 12:59:11 -05:00 committed by GitHub
parent 085424808e
commit cdd7025961
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -116,6 +116,26 @@ struct sm90_fp8_config_default {
ClusterShape, KernelSchedule, EpilogueSchedule>>;
};
template <typename InType, typename OutType, bool EnableBias>
struct sm90_fp8_config_M8192_K6144 {
// M >= 8192, K >= 6144
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule =
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum;
using EpilogueSchedule =
typename cutlass::epilogue::TmaWarpSpecializedCooperative;
using TileShape = Shape<_256, _128, _128>;
using ClusterShape = Shape<_2, _1, _1>;
using Cutlass3xGemm = conditional_t<
EnableBias,
cutlass_3x_gemm_sm90_fp8<InType, OutType, c3x::ScaledEpilogueBias,
TileShape, ClusterShape, KernelSchedule,
EpilogueSchedule>,
cutlass_3x_gemm_sm90_fp8<InType, OutType, c3x::ScaledEpilogue, TileShape,
ClusterShape, KernelSchedule, EpilogueSchedule>>;
};
template <typename InType, typename OutType, bool EnableBias>
struct sm90_fp8_config_M128 {
// M in (64, 128]
@ -273,6 +293,9 @@ inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out,
using Cutlass3xGemmDefault =
typename sm90_fp8_config_default<InType, OutType,
EnableBias>::Cutlass3xGemm;
using Cutlass3xGemmM8192_K6144 =
typename sm90_fp8_config_M8192_K6144<InType, OutType,
EnableBias>::Cutlass3xGemm;
using Cutlass3xGemmM128 =
typename sm90_fp8_config_M128<InType, OutType, EnableBias>::Cutlass3xGemm;
@ -291,6 +314,7 @@ inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out,
uint32_t const m = a.size(0);
uint32_t const n = b.size(1);
uint32_t const k = a.size(1);
if (m <= 16) {
// m in [1, 16]
@ -312,6 +336,9 @@ inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out,
// m in (64, 128]
return cutlass_gemm_caller_sm90_fp8<Cutlass3xGemmM128>(
out, a, b, a_scales, b_scales, std::forward<EpilogueArgs>(args)...);
} else if (m >= 8192 && k >= 6144) {
return cutlass_gemm_caller_sm90_fp8<Cutlass3xGemmM8192_K6144>(
out, a, b, a_scales, b_scales, std::forward<EpilogueArgs>(args)...);
} else {
// m in (128, inf)
return cutlass_gemm_caller_sm90_fp8<Cutlass3xGemmDefault>(