mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-03 16:17:53 +08:00
[Perf] SM100 - add swap AB optimization to CUTLASS FP8 GEMM (#27284)
Signed-off-by: Faqin Zhong <faqin.zhong@gmail.com> Co-authored-by: Faqin Zhong <zhofaqin@amazon.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
parent
5a0a6dfd55
commit
97e3dda84b
@ -1,6 +1,5 @@
|
|||||||
#include "scaled_mm_kernels.hpp"
|
#include "scaled_mm_kernels.hpp"
|
||||||
#include "scaled_mm_sm100_fp8_dispatch.cuh"
|
#include "scaled_mm_sm100_fp8_dispatch.cuh"
|
||||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
@ -13,11 +12,11 @@ void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
if (bias) {
|
if (bias) {
|
||||||
TORCH_CHECK(bias->dtype() == out.dtype(),
|
TORCH_CHECK(bias->dtype() == out.dtype(),
|
||||||
"currently bias dtype must match output dtype ", out.dtype());
|
"currently bias dtype must match output dtype ", out.dtype());
|
||||||
return cutlass_scaled_mm_sm100_fp8_epilogue<c3x::ScaledEpilogueBias>(
|
return cutlass_scaled_mm_sm100_fp8_epilogue<true>(out, a, b, a_scales,
|
||||||
out, a, b, a_scales, b_scales, *bias);
|
b_scales, *bias);
|
||||||
} else {
|
} else {
|
||||||
return cutlass_scaled_mm_sm100_fp8_epilogue<c3x::ScaledEpilogue>(
|
return cutlass_scaled_mm_sm100_fp8_epilogue<false>(out, a, b, a_scales,
|
||||||
out, a, b, a_scales, b_scales);
|
b_scales);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
#include "scaled_mm.cuh"
|
#include "scaled_mm.cuh"
|
||||||
#include "cutlass_gemm_caller.cuh"
|
#include "cutlass_gemm_caller.cuh"
|
||||||
|
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This file defines Gemm kernel configurations for SM100 (fp8) based on the
|
* This file defines Gemm kernel configurations for SM100 (fp8) based on the
|
||||||
@ -12,8 +13,88 @@ namespace vllm {
|
|||||||
|
|
||||||
using c3x::cutlass_gemm_caller;
|
using c3x::cutlass_gemm_caller;
|
||||||
|
|
||||||
template <typename InType, typename OutType,
|
template <typename ElementAB_, typename ElementD_,
|
||||||
template <typename, typename, typename> typename Epilogue>
|
template <typename, typename, typename> typename Epilogue_,
|
||||||
|
typename TileShape, typename ClusterShape, typename KernelSchedule,
|
||||||
|
typename EpilogueSchedule, bool swap_ab_ = false>
|
||||||
|
struct cutlass_3x_gemm_sm100_fp8 {
|
||||||
|
using ElementAB = ElementAB_;
|
||||||
|
using ElementC = ElementD_;
|
||||||
|
using ElementD = ElementD_;
|
||||||
|
using ElementAcc =
|
||||||
|
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
|
||||||
|
float>::type;
|
||||||
|
|
||||||
|
using Epilogue = Epilogue_<ElementAcc, ElementD, TileShape>;
|
||||||
|
|
||||||
|
using EVTCompute = typename Epilogue::EVTCompute;
|
||||||
|
|
||||||
|
static constexpr int AlignmentAB =
|
||||||
|
128 / cutlass::sizeof_bits<ElementAB>::value;
|
||||||
|
static constexpr int AlignmentCD =
|
||||||
|
128 / cutlass::sizeof_bits<ElementD>::value;
|
||||||
|
|
||||||
|
// Compile-time swap_ab flag
|
||||||
|
static constexpr bool swap_ab = swap_ab_;
|
||||||
|
|
||||||
|
// -----------------------------------------------------------
|
||||||
|
// Layout definitions
|
||||||
|
// -----------------------------------------------------------
|
||||||
|
using LayoutA = cutlass::layout::RowMajor;
|
||||||
|
using LayoutA_T = typename cutlass::layout::LayoutTranspose<LayoutA>::type;
|
||||||
|
|
||||||
|
using LayoutB = cutlass::layout::ColumnMajor;
|
||||||
|
using LayoutB_T = typename cutlass::layout::LayoutTranspose<LayoutB>::type;
|
||||||
|
|
||||||
|
using LayoutD = cutlass::layout::RowMajor;
|
||||||
|
using LayoutD_Transpose =
|
||||||
|
typename cutlass::layout::LayoutTranspose<LayoutD>::type;
|
||||||
|
|
||||||
|
using LayoutC = LayoutD;
|
||||||
|
using LayoutC_Transpose = LayoutD_Transpose;
|
||||||
|
|
||||||
|
// -----------------------------------------------------------
|
||||||
|
// Collective epilogue (conditionally swap operands and layouts)
|
||||||
|
// -----------------------------------------------------------
|
||||||
|
using CollectiveEpilogue =
|
||||||
|
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||||
|
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, TileShape,
|
||||||
|
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
|
||||||
|
ElementAcc, float, ElementC,
|
||||||
|
conditional_t<swap_ab, LayoutC_Transpose, LayoutC>, AlignmentCD,
|
||||||
|
ElementD, conditional_t<swap_ab, LayoutD_Transpose, LayoutD>,
|
||||||
|
AlignmentCD, EpilogueSchedule, EVTCompute>::CollectiveOp;
|
||||||
|
|
||||||
|
static constexpr size_t CEStorageSize =
|
||||||
|
sizeof(typename CollectiveEpilogue::SharedStorage);
|
||||||
|
|
||||||
|
using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout<
|
||||||
|
static_cast<int>(CEStorageSize)>;
|
||||||
|
|
||||||
|
// -----------------------------------------------------------
|
||||||
|
// Collective mainloop (conditionally swap operands and layouts)
|
||||||
|
// -----------------------------------------------------------
|
||||||
|
using CollectiveMainloop = conditional_t<
|
||||||
|
swap_ab,
|
||||||
|
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||||
|
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, ElementAB,
|
||||||
|
LayoutB_T, AlignmentAB, // Swapped B (as A)
|
||||||
|
ElementAB, LayoutA_T, AlignmentAB, // Swapped A (as B)
|
||||||
|
ElementAcc, TileShape, ClusterShape, Stages,
|
||||||
|
KernelSchedule>::CollectiveOp,
|
||||||
|
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||||
|
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, ElementAB,
|
||||||
|
LayoutA, AlignmentAB, ElementAB, LayoutB, AlignmentAB, ElementAcc,
|
||||||
|
TileShape, ClusterShape, Stages, KernelSchedule>::CollectiveOp>;
|
||||||
|
|
||||||
|
// -----------------------------------------------------------
|
||||||
|
// Kernel definition
|
||||||
|
// -----------------------------------------------------------
|
||||||
|
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||||
|
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename InType, typename OutType, bool EnableBias>
|
||||||
struct sm100_fp8_config_default {
|
struct sm100_fp8_config_default {
|
||||||
// M in (256, inf)
|
// M in (256, inf)
|
||||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||||
@ -22,12 +103,16 @@ struct sm100_fp8_config_default {
|
|||||||
using TileShape = Shape<_256, _128, _128>;
|
using TileShape = Shape<_256, _128, _128>;
|
||||||
using ClusterShape = Shape<_2, _2, _1>;
|
using ClusterShape = Shape<_2, _2, _1>;
|
||||||
using Cutlass3xGemm =
|
using Cutlass3xGemm =
|
||||||
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
|
conditional_t<EnableBias,
|
||||||
KernelSchedule, EpilogueSchedule>;
|
cutlass_3x_gemm_sm100_fp8<
|
||||||
|
InType, OutType, c3x::ScaledEpilogueBias, TileShape,
|
||||||
|
ClusterShape, KernelSchedule, EpilogueSchedule>,
|
||||||
|
cutlass_3x_gemm_sm100_fp8<
|
||||||
|
InType, OutType, c3x::ScaledEpilogue, TileShape,
|
||||||
|
ClusterShape, KernelSchedule, EpilogueSchedule>>;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename InType, typename OutType,
|
template <typename InType, typename OutType, bool EnableBias>
|
||||||
template <typename, typename, typename> typename Epilogue>
|
|
||||||
struct sm100_fp8_config_M256 {
|
struct sm100_fp8_config_M256 {
|
||||||
// M in (64, 256]
|
// M in (64, 256]
|
||||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||||
@ -36,44 +121,127 @@ struct sm100_fp8_config_M256 {
|
|||||||
using TileShape = Shape<_128, _128, _128>;
|
using TileShape = Shape<_128, _128, _128>;
|
||||||
using ClusterShape = Shape<_2, _1, _1>;
|
using ClusterShape = Shape<_2, _1, _1>;
|
||||||
using Cutlass3xGemm =
|
using Cutlass3xGemm =
|
||||||
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
|
conditional_t<EnableBias,
|
||||||
KernelSchedule, EpilogueSchedule>;
|
cutlass_3x_gemm_sm100_fp8<
|
||||||
|
InType, OutType, c3x::ScaledEpilogueBias, TileShape,
|
||||||
|
ClusterShape, KernelSchedule, EpilogueSchedule>,
|
||||||
|
cutlass_3x_gemm_sm100_fp8<
|
||||||
|
InType, OutType, c3x::ScaledEpilogue, TileShape,
|
||||||
|
ClusterShape, KernelSchedule, EpilogueSchedule>>;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename InType, typename OutType,
|
template <typename InType, typename OutType, bool EnableBias>
|
||||||
template <typename, typename, typename> typename Epilogue>
|
struct sm100_fp8_config_M64_swap_ab {
|
||||||
|
// This config is for M in (16, 64] and K >= 4096
|
||||||
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||||
|
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
|
||||||
|
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
||||||
|
using TileShape = Shape<_128, _64, _256>;
|
||||||
|
using ClusterShape = Shape<_4, _1, _1>;
|
||||||
|
|
||||||
|
// Use ScaledEpilogueColumnBias instead of ScaledEpilogueBias when doing swap
|
||||||
|
// AB
|
||||||
|
using Cutlass3xGemm = conditional_t<
|
||||||
|
EnableBias,
|
||||||
|
cutlass_3x_gemm_sm100_fp8<InType, OutType, c3x::ScaledEpilogueColumnBias,
|
||||||
|
TileShape, ClusterShape, KernelSchedule,
|
||||||
|
EpilogueSchedule, true>,
|
||||||
|
cutlass_3x_gemm_sm100_fp8<InType, OutType, c3x::ScaledEpilogue, TileShape,
|
||||||
|
ClusterShape, KernelSchedule, EpilogueSchedule,
|
||||||
|
true>>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename InType, typename OutType, bool EnableBias>
|
||||||
struct sm100_fp8_config_M64 {
|
struct sm100_fp8_config_M64 {
|
||||||
// M in (16, 64]
|
// This config is for M = 64 and K < 4096 (do not enable swap AB in such case)
|
||||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||||
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
|
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
|
||||||
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
||||||
using TileShape = Shape<_64, _64, _128>;
|
using TileShape = Shape<_64, _64, _128>;
|
||||||
using ClusterShape = Shape<_1, _1, _1>;
|
using ClusterShape = Shape<_1, _1, _1>;
|
||||||
|
|
||||||
using Cutlass3xGemm =
|
using Cutlass3xGemm =
|
||||||
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
|
conditional_t<EnableBias,
|
||||||
KernelSchedule, EpilogueSchedule>;
|
cutlass_3x_gemm_sm100_fp8<
|
||||||
|
InType, OutType, c3x::ScaledEpilogueBias, TileShape,
|
||||||
|
ClusterShape, KernelSchedule, EpilogueSchedule>,
|
||||||
|
cutlass_3x_gemm_sm100_fp8<
|
||||||
|
InType, OutType, c3x::ScaledEpilogue, TileShape,
|
||||||
|
ClusterShape, KernelSchedule, EpilogueSchedule>>;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename InType, typename OutType,
|
template <typename InType, typename OutType, bool EnableBias>
|
||||||
template <typename, typename, typename> typename Epilogue>
|
struct sm100_fp8_config_M16_swap_ab {
|
||||||
struct sm100_fp8_config_M16 {
|
|
||||||
// M in [1, 16]
|
// M in [1, 16]
|
||||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||||
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
|
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
|
||||||
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
||||||
using TileShape = Shape<_64, _64, _128>;
|
using TileShape = Shape<_128, _32, _128>;
|
||||||
using ClusterShape = Shape<_1, _4, _1>;
|
using ClusterShape = Shape<_4, _1, _1>;
|
||||||
using Cutlass3xGemm =
|
|
||||||
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
|
// Use ScaledEpilogueColumnBias instead of ScaledEpilogueBias when doing swap
|
||||||
KernelSchedule, EpilogueSchedule>;
|
// AB
|
||||||
|
using Cutlass3xGemm = conditional_t<
|
||||||
|
EnableBias,
|
||||||
|
cutlass_3x_gemm_sm100_fp8<InType, OutType, c3x::ScaledEpilogueColumnBias,
|
||||||
|
TileShape, ClusterShape, KernelSchedule,
|
||||||
|
EpilogueSchedule, true>,
|
||||||
|
cutlass_3x_gemm_sm100_fp8<InType, OutType, c3x::ScaledEpilogue, TileShape,
|
||||||
|
ClusterShape, KernelSchedule, EpilogueSchedule,
|
||||||
|
true>>;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename InType, typename OutType,
|
template <typename Gemm, typename... EpilogueArgs>
|
||||||
template <typename, typename, typename> typename Epilogue,
|
void cutlass_gemm_caller_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b,
|
||||||
|
EpilogueArgs&&... epilogue_params) {
|
||||||
|
static constexpr bool swap_ab = Gemm::swap_ab;
|
||||||
|
using ElementAB = typename Gemm::ElementAB;
|
||||||
|
using ElementD = typename Gemm::ElementD;
|
||||||
|
using GemmKernel = typename Gemm::GemmKernel;
|
||||||
|
|
||||||
|
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||||
|
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||||
|
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||||
|
|
||||||
|
int32_t m = a.size(0), n = b.size(1), k = a.size(1);
|
||||||
|
auto prob_shape =
|
||||||
|
swap_ab ? cute::make_shape(n, m, k, 1) : cute::make_shape(m, n, k, 1);
|
||||||
|
|
||||||
|
StrideA a_stride =
|
||||||
|
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1));
|
||||||
|
StrideB b_stride =
|
||||||
|
cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1));
|
||||||
|
StrideC c_stride = cutlass::make_cute_packed_stride(
|
||||||
|
StrideC{},
|
||||||
|
swap_ab ? cute::make_shape(n, m, 1) : cute::make_shape(m, n, 1));
|
||||||
|
|
||||||
|
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
|
||||||
|
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
|
||||||
|
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
||||||
|
|
||||||
|
typename GemmKernel::MainloopArguments mainloop_args =
|
||||||
|
swap_ab ? typename GemmKernel::MainloopArguments{b_ptr, b_stride, a_ptr,
|
||||||
|
a_stride}
|
||||||
|
: typename GemmKernel::MainloopArguments{a_ptr, a_stride, b_ptr,
|
||||||
|
b_stride};
|
||||||
|
|
||||||
|
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||||
|
Gemm::Epilogue::prepare_args(
|
||||||
|
std::forward<EpilogueArgs>(epilogue_params)...),
|
||||||
|
c_ptr, c_stride, c_ptr, c_stride};
|
||||||
|
|
||||||
|
c3x::cutlass_gemm_caller<GemmKernel>(a.device(), prob_shape, mainloop_args,
|
||||||
|
epilogue_args);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename InType, typename OutType, bool EnableBias,
|
||||||
typename... EpilogueArgs>
|
typename... EpilogueArgs>
|
||||||
inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out,
|
inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out,
|
||||||
torch::Tensor const& a,
|
torch::Tensor const& a,
|
||||||
torch::Tensor const& b,
|
torch::Tensor const& b,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales,
|
||||||
EpilogueArgs&&... args) {
|
EpilogueArgs&&... args) {
|
||||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||||
@ -81,55 +249,69 @@ inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out,
|
|||||||
|
|
||||||
using Cutlass3xGemmDefault =
|
using Cutlass3xGemmDefault =
|
||||||
typename sm100_fp8_config_default<InType, OutType,
|
typename sm100_fp8_config_default<InType, OutType,
|
||||||
Epilogue>::Cutlass3xGemm;
|
EnableBias>::Cutlass3xGemm;
|
||||||
using Cutlass3xGemmM16 =
|
using Cutlass3xGemmM16SwapAB =
|
||||||
typename sm100_fp8_config_M16<InType, OutType, Epilogue>::Cutlass3xGemm;
|
typename sm100_fp8_config_M16_swap_ab<InType, OutType,
|
||||||
|
EnableBias>::Cutlass3xGemm;
|
||||||
|
using Cutlass3xGemmM64SwapAB =
|
||||||
|
typename sm100_fp8_config_M64_swap_ab<InType, OutType,
|
||||||
|
EnableBias>::Cutlass3xGemm;
|
||||||
using Cutlass3xGemmM64 =
|
using Cutlass3xGemmM64 =
|
||||||
typename sm100_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
|
typename sm100_fp8_config_M64<InType, OutType, EnableBias>::Cutlass3xGemm;
|
||||||
|
|
||||||
using Cutlass3xGemmM256 =
|
using Cutlass3xGemmM256 =
|
||||||
typename sm100_fp8_config_M256<InType, OutType, Epilogue>::Cutlass3xGemm;
|
typename sm100_fp8_config_M256<InType, OutType,
|
||||||
|
EnableBias>::Cutlass3xGemm;
|
||||||
|
|
||||||
uint32_t const m = a.size(0);
|
uint32_t const m = a.size(0);
|
||||||
uint32_t const mp2 =
|
uint32_t const k = a.size(1);
|
||||||
std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
|
|
||||||
|
|
||||||
if (mp2 <= 16) {
|
if (m <= 16) {
|
||||||
// m in [1, 16]
|
// m in [1, 16]
|
||||||
return cutlass_gemm_caller<Cutlass3xGemmM16>(
|
return cutlass_gemm_caller_sm100_fp8<Cutlass3xGemmM16SwapAB>(
|
||||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
out, a, b, b_scales, a_scales, std::forward<EpilogueArgs>(args)...);
|
||||||
} else if (mp2 <= 64) {
|
} else if (m <= 64) {
|
||||||
// m in (16, 64]
|
// m in (16, 64]
|
||||||
return cutlass_gemm_caller<Cutlass3xGemmM64>(
|
if (m == 64 && k < 4096) {
|
||||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
// do not enable swap AB
|
||||||
} else if (mp2 <= 256) {
|
return cutlass_gemm_caller_sm100_fp8<Cutlass3xGemmM64>(
|
||||||
|
out, a, b, a_scales, b_scales, std::forward<EpilogueArgs>(args)...);
|
||||||
|
}
|
||||||
|
return cutlass_gemm_caller_sm100_fp8<Cutlass3xGemmM64SwapAB>(
|
||||||
|
out, a, b, b_scales, a_scales, std::forward<EpilogueArgs>(args)...);
|
||||||
|
|
||||||
|
} else if (m <= 256) {
|
||||||
// m in (64, 256]
|
// m in (64, 256]
|
||||||
return cutlass_gemm_caller<Cutlass3xGemmM256>(
|
return cutlass_gemm_caller_sm100_fp8<Cutlass3xGemmM256>(
|
||||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
out, a, b, a_scales, b_scales, std::forward<EpilogueArgs>(args)...);
|
||||||
} else {
|
} else {
|
||||||
// m in (256, inf)
|
// m in (256, inf)
|
||||||
return cutlass_gemm_caller<Cutlass3xGemmDefault>(
|
return cutlass_gemm_caller_sm100_fp8<Cutlass3xGemmDefault>(
|
||||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
out, a, b, a_scales, b_scales, std::forward<EpilogueArgs>(args)...);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <template <typename, typename, typename> typename Epilogue,
|
template <bool EnableBias, typename... EpilogueArgs>
|
||||||
typename... EpilogueArgs>
|
|
||||||
void cutlass_scaled_mm_sm100_fp8_epilogue(torch::Tensor& out,
|
void cutlass_scaled_mm_sm100_fp8_epilogue(torch::Tensor& out,
|
||||||
torch::Tensor const& a,
|
torch::Tensor const& a,
|
||||||
torch::Tensor const& b,
|
torch::Tensor const& b,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales,
|
||||||
EpilogueArgs&&... epilogue_args) {
|
EpilogueArgs&&... epilogue_args) {
|
||||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||||
|
|
||||||
if (out.dtype() == torch::kBFloat16) {
|
if (out.dtype() == torch::kBFloat16) {
|
||||||
return cutlass_gemm_sm100_fp8_dispatch<cutlass::float_e4m3_t,
|
return cutlass_gemm_sm100_fp8_dispatch<cutlass::float_e4m3_t,
|
||||||
cutlass::bfloat16_t, Epilogue>(
|
cutlass::bfloat16_t, EnableBias>(
|
||||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
out, a, b, a_scales, b_scales,
|
||||||
|
std::forward<EpilogueArgs>(epilogue_args)...);
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||||
return cutlass_gemm_sm100_fp8_dispatch<cutlass::float_e4m3_t,
|
return cutlass_gemm_sm100_fp8_dispatch<cutlass::float_e4m3_t,
|
||||||
cutlass::half_t, Epilogue>(
|
cutlass::half_t, EnableBias>(
|
||||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
out, a, b, a_scales, b_scales,
|
||||||
|
std::forward<EpilogueArgs>(epilogue_args)...);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user