mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-12 09:07:03 +08:00
[kernel] Improve FP8 PTPC on Hopper for larger shapes (#28692)
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
This commit is contained in:
parent
085424808e
commit
cdd7025961
@ -116,6 +116,26 @@ struct sm90_fp8_config_default {
|
||||
ClusterShape, KernelSchedule, EpilogueSchedule>>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType, bool EnableBias>
|
||||
struct sm90_fp8_config_M8192_K6144 {
|
||||
// M >= 8192, K >= 6144
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum;
|
||||
using EpilogueSchedule =
|
||||
typename cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using TileShape = Shape<_256, _128, _128>;
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
|
||||
using Cutlass3xGemm = conditional_t<
|
||||
EnableBias,
|
||||
cutlass_3x_gemm_sm90_fp8<InType, OutType, c3x::ScaledEpilogueBias,
|
||||
TileShape, ClusterShape, KernelSchedule,
|
||||
EpilogueSchedule>,
|
||||
cutlass_3x_gemm_sm90_fp8<InType, OutType, c3x::ScaledEpilogue, TileShape,
|
||||
ClusterShape, KernelSchedule, EpilogueSchedule>>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType, bool EnableBias>
|
||||
struct sm90_fp8_config_M128 {
|
||||
// M in (64, 128]
|
||||
@ -273,6 +293,9 @@ inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out,
|
||||
using Cutlass3xGemmDefault =
|
||||
typename sm90_fp8_config_default<InType, OutType,
|
||||
EnableBias>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM8192_K6144 =
|
||||
typename sm90_fp8_config_M8192_K6144<InType, OutType,
|
||||
EnableBias>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM128 =
|
||||
typename sm90_fp8_config_M128<InType, OutType, EnableBias>::Cutlass3xGemm;
|
||||
|
||||
@ -291,6 +314,7 @@ inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out,
|
||||
|
||||
uint32_t const m = a.size(0);
|
||||
uint32_t const n = b.size(1);
|
||||
uint32_t const k = a.size(1);
|
||||
|
||||
if (m <= 16) {
|
||||
// m in [1, 16]
|
||||
@ -312,6 +336,9 @@ inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out,
|
||||
// m in (64, 128]
|
||||
return cutlass_gemm_caller_sm90_fp8<Cutlass3xGemmM128>(
|
||||
out, a, b, a_scales, b_scales, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (m >= 8192 && k >= 6144) {
|
||||
return cutlass_gemm_caller_sm90_fp8<Cutlass3xGemmM8192_K6144>(
|
||||
out, a, b, a_scales, b_scales, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
// m in (128, inf)
|
||||
return cutlass_gemm_caller_sm90_fp8<Cutlass3xGemmDefault>(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user