mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 12:16:13 +08:00
[Perf] Further tunings for SM100 FP8 CUTLASS kernel (#19566)
This commit is contained in:
parent
08500011d3
commit
e13945f9dd
@ -15,11 +15,25 @@ using c3x::cutlass_gemm_caller;
|
|||||||
template <typename InType, typename OutType,
|
template <typename InType, typename OutType,
|
||||||
template <typename, typename, typename> typename Epilogue>
|
template <typename, typename, typename> typename Epilogue>
|
||||||
struct sm100_fp8_config_default {
|
struct sm100_fp8_config_default {
|
||||||
// M in (128, inf)
|
// M in (256, inf)
|
||||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||||
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
|
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
|
||||||
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
||||||
using TileShape = Shape<_256, _128, _64>;
|
using TileShape = Shape<_256, _128, _128>;
|
||||||
|
using ClusterShape = Shape<_2, _2, _1>;
|
||||||
|
using Cutlass3xGemm =
|
||||||
|
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||||
|
KernelSchedule, EpilogueSchedule>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename InType, typename OutType,
|
||||||
|
template <typename, typename, typename> typename Epilogue>
|
||||||
|
struct sm100_fp8_config_M256 {
|
||||||
|
// M in (128, 256]
|
||||||
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||||
|
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
|
||||||
|
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
||||||
|
using TileShape = Shape<_128, _128, _128>;
|
||||||
using ClusterShape = Shape<_2, _2, _1>;
|
using ClusterShape = Shape<_2, _2, _1>;
|
||||||
using Cutlass3xGemm =
|
using Cutlass3xGemm =
|
||||||
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
|
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||||
@ -33,8 +47,8 @@ struct sm100_fp8_config_M128 {
|
|||||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||||
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
|
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
|
||||||
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
||||||
using TileShape = Shape<_128, _128, _64>;
|
using TileShape = Shape<_128, _128, _256>;
|
||||||
using ClusterShape = Shape<_2, _2, _1>;
|
using ClusterShape = Shape<_2, _4, _1>;
|
||||||
using Cutlass3xGemm =
|
using Cutlass3xGemm =
|
||||||
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
|
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||||
KernelSchedule, EpilogueSchedule>;
|
KernelSchedule, EpilogueSchedule>;
|
||||||
@ -72,6 +86,8 @@ inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out,
|
|||||||
typename sm100_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
|
typename sm100_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||||
using Cutlass3xGemmM128 =
|
using Cutlass3xGemmM128 =
|
||||||
typename sm100_fp8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
|
typename sm100_fp8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||||
|
using Cutlass3xGemmM256 =
|
||||||
|
typename sm100_fp8_config_M256<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||||
|
|
||||||
uint32_t const m = a.size(0);
|
uint32_t const m = a.size(0);
|
||||||
uint32_t const mp2 =
|
uint32_t const mp2 =
|
||||||
@ -85,8 +101,12 @@ inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out,
|
|||||||
// m in (64, 128]
|
// m in (64, 128]
|
||||||
return cutlass_gemm_caller<Cutlass3xGemmM128>(
|
return cutlass_gemm_caller<Cutlass3xGemmM128>(
|
||||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||||
|
} else if (mp2 <= 256) {
|
||||||
|
// m in (128, 256]
|
||||||
|
return cutlass_gemm_caller<Cutlass3xGemmM256>(
|
||||||
|
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||||
} else {
|
} else {
|
||||||
// m in (128, inf)
|
// m in (256, inf)
|
||||||
return cutlass_gemm_caller<Cutlass3xGemmDefault>(
|
return cutlass_gemm_caller<Cutlass3xGemmDefault>(
|
||||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user