mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:45:29 +08:00
[Kernel] Update Cutlass int8 kernel configs for SM90 (#5514)
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
parent
1b2eaac316
commit
111af1fa2c
@ -234,15 +234,15 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename InType, typename OutType,
|
template <typename InType, typename OutType,
|
||||||
template <typename, typename, typename> typename Epilogue, int32_t M>
|
template <typename, typename, typename> typename Epilogue>
|
||||||
struct sm90_fp8_config {
|
struct sm90_fp8_config_default {
|
||||||
|
// M in (128, inf)
|
||||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||||
using KernelSchedule =
|
using KernelSchedule =
|
||||||
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
||||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||||
using TileShape = Shape<_128, _128, _128>;
|
using TileShape = Shape<_128, _128, _128>;
|
||||||
using ClusterShape = Shape<_2, _1, _1>;
|
using ClusterShape = Shape<_2, _1, _1>;
|
||||||
|
|
||||||
using Cutlass3xGemm =
|
using Cutlass3xGemm =
|
||||||
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||||
KernelSchedule, EpilogueSchedule>;
|
KernelSchedule, EpilogueSchedule>;
|
||||||
@ -250,14 +250,14 @@ struct sm90_fp8_config {
|
|||||||
|
|
||||||
template <typename InType, typename OutType,
|
template <typename InType, typename OutType,
|
||||||
template <typename, typename, typename> typename Epilogue>
|
template <typename, typename, typename> typename Epilogue>
|
||||||
struct sm90_fp8_config<InType, OutType, Epilogue, 128> {
|
struct sm90_fp8_config_M128 {
|
||||||
|
// M in (64, 128]
|
||||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||||
using KernelSchedule =
|
using KernelSchedule =
|
||||||
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
||||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||||
using TileShape = Shape<_64, _128, _128>;
|
using TileShape = Shape<_64, _128, _128>;
|
||||||
using ClusterShape = Shape<_2, _1, _1>;
|
using ClusterShape = Shape<_2, _1, _1>;
|
||||||
|
|
||||||
using Cutlass3xGemm =
|
using Cutlass3xGemm =
|
||||||
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||||
KernelSchedule, EpilogueSchedule>;
|
KernelSchedule, EpilogueSchedule>;
|
||||||
@ -265,7 +265,8 @@ struct sm90_fp8_config<InType, OutType, Epilogue, 128> {
|
|||||||
|
|
||||||
template <typename InType, typename OutType,
|
template <typename InType, typename OutType,
|
||||||
template <typename, typename, typename> typename Epilogue>
|
template <typename, typename, typename> typename Epilogue>
|
||||||
struct sm90_fp8_config<InType, OutType, Epilogue, 64> {
|
struct sm90_fp8_config_M64 {
|
||||||
|
// M in [1, 64]
|
||||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||||
using KernelSchedule =
|
using KernelSchedule =
|
||||||
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
||||||
@ -278,6 +279,78 @@ struct sm90_fp8_config<InType, OutType, Epilogue, 64> {
|
|||||||
KernelSchedule, EpilogueSchedule>;
|
KernelSchedule, EpilogueSchedule>;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename InType, typename OutType,
|
||||||
|
template <typename, typename, typename> typename Epilogue>
|
||||||
|
struct sm90_int8_config_default {
|
||||||
|
// For M > 128 and any N
|
||||||
|
static_assert(std::is_same<InType, int8_t>());
|
||||||
|
using KernelSchedule =
|
||||||
|
typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
|
||||||
|
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||||
|
using TileShape = Shape<_128, _128, _128>;
|
||||||
|
using ClusterShape = Shape<_2, _1, _1>;
|
||||||
|
using Cutlass3xGemm =
|
||||||
|
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||||
|
KernelSchedule, EpilogueSchedule>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename InType, typename OutType,
|
||||||
|
template <typename, typename, typename> typename Epilogue>
|
||||||
|
struct sm90_int8_config_M128 {
|
||||||
|
// For M in (64, 128] and any N
|
||||||
|
static_assert(std::is_same<InType, int8_t>());
|
||||||
|
using KernelSchedule =
|
||||||
|
typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
|
||||||
|
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||||
|
using TileShape = Shape<_64, _128, _128>;
|
||||||
|
using ClusterShape = Shape<_2, _1, _1>;
|
||||||
|
using Cutlass3xGemm =
|
||||||
|
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||||
|
KernelSchedule, EpilogueSchedule>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename InType, typename OutType,
|
||||||
|
template <typename, typename, typename> typename Epilogue>
|
||||||
|
struct sm90_int8_config_M64 {
|
||||||
|
// For M in (32, 64] and any N
|
||||||
|
static_assert(std::is_same<InType, int8_t>());
|
||||||
|
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
|
||||||
|
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||||
|
using TileShape = Shape<_64, _64, _256>;
|
||||||
|
using ClusterShape = Shape<_1, _1, _1>;
|
||||||
|
using Cutlass3xGemm =
|
||||||
|
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||||
|
KernelSchedule, EpilogueSchedule>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename InType, typename OutType,
|
||||||
|
template <typename, typename, typename> typename Epilogue>
|
||||||
|
struct sm90_int8_config_M32_NBig {
|
||||||
|
// For M in [1, 32] and N >= 8192
|
||||||
|
static_assert(std::is_same<InType, int8_t>());
|
||||||
|
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
|
||||||
|
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||||
|
using TileShape = Shape<_64, _128, _256>;
|
||||||
|
using ClusterShape = Shape<_1, _4, _1>;
|
||||||
|
using Cutlass3xGemm =
|
||||||
|
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||||
|
KernelSchedule, EpilogueSchedule>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename InType, typename OutType,
|
||||||
|
template <typename, typename, typename> typename Epilogue>
|
||||||
|
struct sm90_int8_config_M32_NSmall {
|
||||||
|
// For M in [1, 32] and N < 8192
|
||||||
|
static_assert(std::is_same<InType, int8_t>());
|
||||||
|
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
|
||||||
|
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||||
|
using TileShape = Shape<_64, _64, _256>;
|
||||||
|
using ClusterShape = Shape<_1, _8, _1>;
|
||||||
|
using Cutlass3xGemm =
|
||||||
|
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||||
|
KernelSchedule, EpilogueSchedule>;
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
template <typename InType, typename OutType,
|
template <typename InType, typename OutType,
|
||||||
@ -291,11 +364,12 @@ void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||||
|
|
||||||
using Cutlass3xGemmDefault =
|
using Cutlass3xGemmDefault =
|
||||||
typename sm90_fp8_config<InType, OutType, Epilogue, 0>::Cutlass3xGemm;
|
typename sm90_fp8_config_default<InType, OutType,
|
||||||
|
Epilogue>::Cutlass3xGemm;
|
||||||
using Cutlass3xGemmM64 =
|
using Cutlass3xGemmM64 =
|
||||||
typename sm90_fp8_config<InType, OutType, Epilogue, 64>::Cutlass3xGemm;
|
typename sm90_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||||
using Cutlass3xGemmM128 =
|
using Cutlass3xGemmM128 =
|
||||||
typename sm90_fp8_config<InType, OutType, Epilogue, 128>::Cutlass3xGemm;
|
typename sm90_fp8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||||
|
|
||||||
uint32_t const m = a.size(0);
|
uint32_t const m = a.size(0);
|
||||||
uint32_t const mp2 =
|
uint32_t const mp2 =
|
||||||
@ -316,6 +390,61 @@ void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename InType, typename OutType,
|
||||||
|
template <typename, typename, typename> typename Epilogue,
|
||||||
|
typename... EpilogueArgs>
|
||||||
|
void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b,
|
||||||
|
EpilogueArgs&&... args) {
|
||||||
|
static_assert(std::is_same<InType, int8_t>());
|
||||||
|
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||||
|
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||||
|
|
||||||
|
using Cutlass3xGemmDefault =
|
||||||
|
typename sm90_int8_config_default<InType, OutType,
|
||||||
|
Epilogue>::Cutlass3xGemm;
|
||||||
|
using Cutlass3xGemmM128 =
|
||||||
|
typename sm90_int8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||||
|
using Cutlass3xGemmM64 =
|
||||||
|
typename sm90_int8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||||
|
using Cutlass3xGemmM32NBig =
|
||||||
|
typename sm90_int8_config_M32_NBig<InType, OutType,
|
||||||
|
Epilogue>::Cutlass3xGemm;
|
||||||
|
using Cutlass3xGemmM32NSmall =
|
||||||
|
typename sm90_int8_config_M32_NSmall<InType, OutType,
|
||||||
|
Epilogue>::Cutlass3xGemm;
|
||||||
|
|
||||||
|
uint32_t const n = out.size(1);
|
||||||
|
bool const is_small_n = n < 8192;
|
||||||
|
|
||||||
|
uint32_t const m = a.size(0);
|
||||||
|
uint32_t const mp2 =
|
||||||
|
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
|
||||||
|
|
||||||
|
if (mp2 <= 32) {
|
||||||
|
// m in [1, 32]
|
||||||
|
if (is_small_n) {
|
||||||
|
return cutlass_gemm_caller<Cutlass3xGemmM32NSmall>(
|
||||||
|
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||||
|
} else {
|
||||||
|
return cutlass_gemm_caller<Cutlass3xGemmM32NBig>(
|
||||||
|
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||||
|
}
|
||||||
|
} else if (mp2 <= 64) {
|
||||||
|
// m in (32, 64]
|
||||||
|
return cutlass_gemm_caller<Cutlass3xGemmM64>(
|
||||||
|
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||||
|
} else if (mp2 <= 128) {
|
||||||
|
// m in (64, 128]
|
||||||
|
return cutlass_gemm_caller<Cutlass3xGemmM128>(
|
||||||
|
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||||
|
} else {
|
||||||
|
// m in (128, inf)
|
||||||
|
return cutlass_gemm_caller<Cutlass3xGemmDefault>(
|
||||||
|
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void cutlass_scaled_mm_sm90(torch::Tensor& out, torch::Tensor const& a,
|
void cutlass_scaled_mm_sm90(torch::Tensor& out, torch::Tensor const& a,
|
||||||
torch::Tensor const& b,
|
torch::Tensor const& b,
|
||||||
torch::Tensor const& a_scales,
|
torch::Tensor const& a_scales,
|
||||||
@ -326,22 +455,14 @@ void cutlass_scaled_mm_sm90(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
if (a.dtype() == torch::kInt8) {
|
if (a.dtype() == torch::kInt8) {
|
||||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||||
|
|
||||||
using TileShape = Shape<_128, _128, _128>;
|
|
||||||
using ClusterShape = Shape<_1, _2, _1>;
|
|
||||||
using KernelSchedule =
|
|
||||||
typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
|
|
||||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
|
||||||
|
|
||||||
if (out.dtype() == torch::kBFloat16) {
|
if (out.dtype() == torch::kBFloat16) {
|
||||||
return cutlass_gemm_caller<cutlass_3x_gemm<
|
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
|
||||||
int8_t, cutlass::bfloat16_t, ScaledEpilogue, TileShape, ClusterShape,
|
ScaledEpilogue>(
|
||||||
KernelSchedule, EpilogueSchedule>>(out, a, b, a_scales, b_scales);
|
out, a, b, a_scales, b_scales);
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||||
|
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t,
|
||||||
return cutlass_gemm_caller<
|
ScaledEpilogue>(
|
||||||
cutlass_3x_gemm<int8_t, cutlass::half_t, ScaledEpilogue, TileShape,
|
|
||||||
ClusterShape, KernelSchedule, EpilogueSchedule>>(
|
|
||||||
out, a, b, a_scales, b_scales);
|
out, a, b, a_scales, b_scales);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user