[Perf] Further tunings for SM100 FP8 CUTLASS kernel (#19566)

This commit is contained in:
Ilya Markov 2025-06-15 02:25:10 +02:00 committed by GitHub
parent 08500011d3
commit e13945f9dd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -15,11 +15,25 @@ using c3x::cutlass_gemm_caller;
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm100_fp8_config_default {
// M in (128, inf)
// M in (256, inf)
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
using TileShape = Shape<_256, _128, _64>;
using TileShape = Shape<_256, _128, _128>;
using ClusterShape = Shape<_2, _2, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm100_fp8_config_M256 {
// M in (128, 256]
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
using TileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_2, _2, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
@ -33,8 +47,8 @@ struct sm100_fp8_config_M128 {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
using TileShape = Shape<_128, _128, _64>;
using ClusterShape = Shape<_2, _2, _1>;
using TileShape = Shape<_128, _128, _256>;
using ClusterShape = Shape<_2, _4, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
@ -72,6 +86,8 @@ inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out,
typename sm100_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM128 =
typename sm100_fp8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM256 =
typename sm100_fp8_config_M256<InType, OutType, Epilogue>::Cutlass3xGemm;
uint32_t const m = a.size(0);
uint32_t const mp2 =
@ -85,8 +101,12 @@ inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out,
// m in (64, 128]
return cutlass_gemm_caller<Cutlass3xGemmM128>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 256) {
// m in (128, 256]
return cutlass_gemm_caller<Cutlass3xGemmM256>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
// m in (128, inf)
// m in (256, inf)
return cutlass_gemm_caller<Cutlass3xGemmDefault>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}