diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu index e092c61abc249..1db6c41bf9535 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu @@ -1,6 +1,5 @@ #include "scaled_mm_kernels.hpp" #include "scaled_mm_sm90_fp8_dispatch.cuh" -#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" namespace vllm { @@ -13,11 +12,11 @@ void cutlass_scaled_mm_sm90_fp8(torch::Tensor& out, torch::Tensor const& a, if (bias) { TORCH_CHECK(bias->dtype() == out.dtype(), "currently bias dtype must match output dtype ", out.dtype()); - return cutlass_scaled_mm_sm90_fp8_epilogue( - out, a, b, a_scales, b_scales, *bias); + return cutlass_scaled_mm_sm90_fp8_epilogue(out, a, b, a_scales, + b_scales, *bias); } else { - return cutlass_scaled_mm_sm90_fp8_epilogue( - out, a, b, a_scales, b_scales); + return cutlass_scaled_mm_sm90_fp8_epilogue(out, a, b, a_scales, + b_scales); } } diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh index 32ea5db3321bc..4ff3e65f2b2e1 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh @@ -2,6 +2,7 @@ #include "scaled_mm.cuh" #include "cutlass_gemm_caller.cuh" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" /** * This file defines Gemm kernel configurations for SM90 (fp8) based on the Gemm @@ -12,8 +13,91 @@ namespace vllm { using c3x::cutlass_gemm_caller; -template typename Epilogue> +template typename Epilogue_, + typename TileShape, typename ClusterShape, typename KernelSchedule, + typename EpilogueSchedule, bool swap_ab_ = false> +struct cutlass_3x_gemm_sm90_fp8 { + using ElementAB = ElementAB_; + using ElementC = ElementD_; + using ElementD = ElementD_; + using ElementAcc = + typename std::conditional, int32_t, + float>::type; + + using Epilogue = Epilogue_; + + using EVTCompute = typename Epilogue::EVTCompute; + + static constexpr int AlignmentAB = + 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentCD = + 128 / cutlass::sizeof_bits::value; + + // Compile-time swap_ab flag + static constexpr bool swap_ab = swap_ab_; + + // ----------------------------------------------------------- + // Layout definitions + // ----------------------------------------------------------- + using LayoutA = cutlass::layout::RowMajor; + using LayoutA_T = typename cutlass::layout::LayoutTranspose::type; + + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutB_T = typename cutlass::layout::LayoutTranspose::type; + + using LayoutD = cutlass::layout::RowMajor; + using LayoutD_Transpose = + typename cutlass::layout::LayoutTranspose::type; + + using LayoutC = LayoutD; + using LayoutC_Transpose = LayoutD_Transpose; + + // ----------------------------------------------------------- + // Collective epilogue (conditionally swap operands and layouts) + // ----------------------------------------------------------- + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, + ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, float, ElementC, + conditional_t, AlignmentCD, + ElementD, conditional_t, + AlignmentCD, EpilogueSchedule, EVTCompute>::CollectiveOp; + + static constexpr size_t CEStorageSize = + sizeof(typename CollectiveEpilogue::SharedStorage); + + using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(CEStorageSize)>; + + // ----------------------------------------------------------- + // Collective mainloop (conditionally swap operands and layouts) + // ----------------------------------------------------------- + using CollectiveMainloop = conditional_t< + swap_ab, + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, ElementAB, + LayoutB_T, AlignmentAB, // Swapped B (as A) + ElementAB, LayoutA_T, AlignmentAB, // Swapped A (as B) + ElementAcc, TileShape, ClusterShape, Stages, + KernelSchedule>::CollectiveOp, + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, ElementAB, + LayoutA, AlignmentAB, ElementAB, LayoutB, AlignmentAB, ElementAcc, + TileShape, ClusterShape, Stages, KernelSchedule>::CollectiveOp>; + + // ----------------------------------------------------------- + // Kernel definition + // ----------------------------------------------------------- + using KernelType = enable_sm90_or_later, CollectiveMainloop, CollectiveEpilogue, + cutlass::gemm::PersistentScheduler>>; + + struct GemmKernel : public KernelType {}; +}; + +template struct sm90_fp8_config_default { // M in (128, inf) static_assert(std::is_same()); @@ -22,13 +106,17 @@ struct sm90_fp8_config_default { using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; using TileShape = Shape<_128, _128, _128>; using ClusterShape = Shape<_2, _1, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; + + using Cutlass3xGemm = conditional_t< + EnableBias, + cutlass_3x_gemm_sm90_fp8, + cutlass_3x_gemm_sm90_fp8>; }; -template typename Epilogue> +template struct sm90_fp8_config_M128 { // M in (64, 128] static_assert(std::is_same()); @@ -37,33 +125,146 @@ struct sm90_fp8_config_M128 { using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; using TileShape = Shape<_64, _128, _128>; using ClusterShape = Shape<_2, _1, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; + using Cutlass3xGemm = conditional_t< + EnableBias, + cutlass_3x_gemm_sm90_fp8, + cutlass_3x_gemm_sm90_fp8>; }; -template typename Epilogue> -struct sm90_fp8_config_M64 { - // M in [1, 64] +template +struct sm90_fp8_config_M64_N1280 { + // M in (16, 64], N in [1 1280] static_assert(std::is_same()); - using KernelSchedule = - cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_64, _64, _128>; - using ClusterShape = Shape<_1, _8, _1>; + using TileShape = Shape<_64, _16, _256>; + using ClusterShape = Shape<_1, _4, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; + // enable swap AB for M < 64 + using Cutlass3xGemm = conditional_t< + EnableBias, + cutlass_3x_gemm_sm90_fp8, + cutlass_3x_gemm_sm90_fp8>; }; -template typename Epilogue, +template +struct sm90_fp8_config_M64_N8192 { + // M in (16, 64], N > 1280 + static_assert(std::is_same()); + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_64, _64, _256>; + using ClusterShape = Shape<_1, _1, _1>; + + // enable swap AB for M < 64 + using Cutlass3xGemm = conditional_t< + EnableBias, + cutlass_3x_gemm_sm90_fp8, + cutlass_3x_gemm_sm90_fp8>; +}; + +template +struct sm90_fp8_config_M16_N1280 { + // M in [1, 16], N in [1, 1280] + static_assert(std::is_same()); + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_64, _16, _256>; + using ClusterShape = Shape<_1, _2, _1>; + + // enable swap AB for M < 64 + using Cutlass3xGemm = conditional_t< + EnableBias, + cutlass_3x_gemm_sm90_fp8, + cutlass_3x_gemm_sm90_fp8>; +}; + +template +struct sm90_fp8_config_M16_N8192 { + // M in [1, 16], N > 1280 + static_assert(std::is_same()); + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_64, _16, _256>; + using ClusterShape = Shape<_1, _1, _1>; + + // enable swap AB for M < 64 + using Cutlass3xGemm = conditional_t< + EnableBias, + cutlass_3x_gemm_sm90_fp8, + cutlass_3x_gemm_sm90_fp8>; +}; + +template +void cutlass_gemm_caller_sm90_fp8(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + EpilogueArgs&&... epilogue_params) { + static constexpr bool swap_ab = Gemm::swap_ab; + using ElementAB = typename Gemm::ElementAB; + using ElementD = typename Gemm::ElementD; + using GemmKernel = typename Gemm::GemmKernel; + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + + int32_t m = a.size(0), n = b.size(1), k = a.size(1); + auto prob_shape = + swap_ab ? cute::make_shape(n, m, k, 1) : cute::make_shape(m, n, k, 1); + + StrideA a_stride = + cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1)); + StrideB b_stride = + cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1)); + StrideC c_stride = cutlass::make_cute_packed_stride( + StrideC{}, + swap_ab ? cute::make_shape(n, m, 1) : cute::make_shape(m, n, 1)); + + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); + auto c_ptr = static_cast(out.data_ptr()); + + typename GemmKernel::MainloopArguments mainloop_args = + swap_ab ? typename GemmKernel::MainloopArguments{b_ptr, b_stride, a_ptr, + a_stride} + : typename GemmKernel::MainloopArguments{a_ptr, a_stride, b_ptr, + b_stride}; + + typename GemmKernel::EpilogueArguments epilogue_args{ + Gemm::Epilogue::prepare_args( + std::forward(epilogue_params)...), + c_ptr, c_stride, c_ptr, c_stride}; + + c3x::cutlass_gemm_caller(a.device(), prob_shape, mainloop_args, + epilogue_args); +} + +template inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, EpilogueArgs&&... args) { static_assert(std::is_same()); TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); @@ -71,50 +272,75 @@ inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, using Cutlass3xGemmDefault = typename sm90_fp8_config_default::Cutlass3xGemm; - using Cutlass3xGemmM64 = - typename sm90_fp8_config_M64::Cutlass3xGemm; + EnableBias>::Cutlass3xGemm; using Cutlass3xGemmM128 = - typename sm90_fp8_config_M128::Cutlass3xGemm; + typename sm90_fp8_config_M128::Cutlass3xGemm; + + using Cutlass3xGemmM64_N1280 = + typename sm90_fp8_config_M64_N1280::Cutlass3xGemm; + using Cutlass3xGemmM64_N8192 = + typename sm90_fp8_config_M64_N8192::Cutlass3xGemm; + using Cutlass3xGemmM16_N1280 = + typename sm90_fp8_config_M16_N1280::Cutlass3xGemm; + using Cutlass3xGemmM16_N8192 = + typename sm90_fp8_config_M16_N8192::Cutlass3xGemm; uint32_t const m = a.size(0); - uint32_t const mp2 = - std::max(static_cast(64), next_pow_2(m)); // next power of 2 + uint32_t const n = b.size(1); - if (mp2 <= 64) { - // m in [1, 64] - return cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } else if (mp2 <= 128) { + if (m <= 16) { + // m in [1, 16] + if (n <= 1280) { + return cutlass_gemm_caller_sm90_fp8( + out, a, b, b_scales, a_scales, std::forward(args)...); + } + return cutlass_gemm_caller_sm90_fp8( + out, a, b, b_scales, a_scales, std::forward(args)...); + } else if (m <= 64) { + // m in (16, 64] + if (n <= 1280) { + return cutlass_gemm_caller_sm90_fp8( + out, a, b, b_scales, a_scales, std::forward(args)...); + } + return cutlass_gemm_caller_sm90_fp8( + out, a, b, b_scales, a_scales, std::forward(args)...); + } else if (m <= 128) { // m in (64, 128] - return cutlass_gemm_caller( - out, a, b, std::forward(args)...); + return cutlass_gemm_caller_sm90_fp8( + out, a, b, a_scales, b_scales, std::forward(args)...); } else { // m in (128, inf) - return cutlass_gemm_caller( - out, a, b, std::forward(args)...); + return cutlass_gemm_caller_sm90_fp8( + out, a, b, a_scales, b_scales, std::forward(args)...); } } -template