diff --git a/CMakeLists.txt b/CMakeLists.txt index e2cc0ccdef51..093330caa4f9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -427,6 +427,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu" "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8.cu" + "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8.cu" ) set_gencode_flags_for_srcs( SRCS "${SRCS}" diff --git a/csrc/cutlass_extensions/common.hpp b/csrc/cutlass_extensions/common.hpp index 195872e8edd3..f2c1dcf69f69 100644 --- a/csrc/cutlass_extensions/common.hpp +++ b/csrc/cutlass_extensions/common.hpp @@ -60,3 +60,13 @@ struct enable_sm100_only : Kernel { #endif } }; + +template +struct enable_sm120_only : Kernel { + template + CUTLASS_DEVICE void operator()(Args&&... args) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 1200 + Kernel::operator()(std::forward(args)...); +#endif + } +}; diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8.cu b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8.cu new file mode 100644 index 000000000000..5515374a5759 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8.cu @@ -0,0 +1,23 @@ +#include "scaled_mm_kernels.hpp" +#include "scaled_mm_blockwise_sm120_fp8_dispatch.cuh" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" + +namespace vllm { + +void cutlass_scaled_mm_blockwise_sm120_fp8(torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + if (out.dtype() == torch::kBFloat16) { + cutlass_gemm_blockwise_sm120_fp8_dispatch( + out, a, b, a_scales, b_scales); + + } else { + TORCH_CHECK(out.dtype() == torch::kFloat16); + cutlass_gemm_blockwise_sm120_fp8_dispatch( + out, a, b, a_scales, b_scales); + } +} + +} // namespace vllm diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh new file mode 100644 index 000000000000..d50a83ae1cd4 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh @@ -0,0 +1,183 @@ +#pragma once + +#include "cuda_utils.h" +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass_extensions/gemm/dispatch_policy.hpp" +#include "cutlass_extensions/gemm/collective/collective_builder.hpp" + +#include "cutlass_gemm_caller.cuh" + +namespace vllm { + +using namespace cute; + +// clang-format off +template +struct cutlass_3x_gemm_fp8_blockwise { + using ElementAB = cutlass::float_e4m3_t; + + using ElementA = ElementAB; + using LayoutA = cutlass::layout::RowMajor; + using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose::type; + static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + + using ElementB = ElementAB; + // ColumnMajor is used for B to match the CUTLASS convention. + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose::type; + static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + + using ElementD = OutType; + using LayoutD = cutlass::layout::RowMajor; + using LayoutD_Transpose = typename cutlass::layout::LayoutTranspose::type; + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + + using ElementC = void; // TODO: support bias + using LayoutC = LayoutD; + using LayoutC_Transpose = LayoutD_Transpose; + static constexpr int AlignmentC = AlignmentD; + + using ElementAccumulator = float; + using ElementCompute = float; + using ElementBlockScale = float; + + using ScaleConfig = cutlass::detail::Sm120BlockwiseScaleConfig< + ScaleGranularityM, ScaleGranularityN, ScaleGranularityK, + cute::UMMA::Major::MN, cute::UMMA::Major::K>; + + // layout_SFA and layout_SFB cannot be swapped since they are deduced. + using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); + using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); + + using ArchTag = cutlass::arch::Sm120; + using OperatorClass = cutlass::arch::OpClassTensorOp; + + static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + using ElementScalar = float; + using DefaultOperation = cutlass::epilogue::fusion::LinearCombination; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + MmaTileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementCompute, + ElementC, + LayoutC, + AlignmentC, + ElementD, + LayoutD, + AlignmentD, + EpilogueScheduler, + DefaultOperation + >::CollectiveOp; + + using StageCountType = cutlass::gemm::collective::StageCountAuto; + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + ElementA, + cute::tuple, + AlignmentA, + ElementB, + cute::tuple, + AlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopScheduler + >::CollectiveOp; + + using KernelType = enable_sm120_only, CollectiveMainloop, CollectiveEpilogue>>; + + struct GemmKernel : public KernelType {}; +}; + +template +void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + using GemmKernel = typename Gemm::GemmKernel; + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideD = typename Gemm::GemmKernel::StrideD; + using StrideC = typename Gemm::GemmKernel::StrideC; + using LayoutSFA = typename Gemm::LayoutSFA; + using LayoutSFB = typename Gemm::LayoutSFB; + using ScaleConfig = typename Gemm::ScaleConfig; + + using ElementAB = typename Gemm::ElementAB; + using ElementD = typename Gemm::ElementD; + + int32_t m = a.size(0), n = b.size(1), k = a.size(1); + + StrideA a_stride; + StrideB b_stride; + StrideC c_stride; + a_stride = + cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1)); + b_stride = + cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1)); + c_stride = + cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1)); + + LayoutSFA layout_SFA = + ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1)); + LayoutSFB layout_SFB = + ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1)); + + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); + auto a_scales_ptr = static_cast(a_scales.data_ptr()); + auto b_scales_ptr = static_cast(b_scales.data_ptr()); + + auto mainloop_args = [&](){ + return typename GemmKernel::MainloopArguments{ + a_ptr, a_stride, b_ptr, b_stride, + a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB + }; + }(); + auto prob_shape = cute::make_shape(m, n, k, 1); + + auto c_ptr = static_cast(out.data_ptr()); + typename GemmKernel::EpilogueArguments epilogue_args{ + {}, c_ptr, c_stride, c_ptr, c_stride}; + c3x::cutlass_gemm_caller(a.device(), prob_shape, mainloop_args, + epilogue_args); +} + +template +void cutlass_gemm_blockwise_sm120_fp8_dispatch(torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + // TODO: better heuristics + cutlass_gemm_caller_blockwise, + Shape<_1, _1, _1>, cutlass::epilogue::collective::EpilogueScheduleAuto, + cutlass::gemm::collective::KernelScheduleAuto>>( + out, a, b, a_scales, b_scales); +} + +} // namespace vllm diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp index e049a5f2d2c9..9ceb3a3ece5d 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp @@ -47,4 +47,10 @@ void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales); + +void cutlass_scaled_mm_blockwise_sm120_fp8(torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales); } // namespace vllm diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu index 0c47ab82991d..dc87c5c35cb8 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu @@ -1,11 +1,9 @@ -#include +#include "c3x/scaled_mm_helper.hpp" #include "c3x/scaled_mm_kernels.hpp" -#include "cuda_utils.h" - /* This file defines quantized GEMM operations using the CUTLASS 3.x API, for - NVIDIA GPUs with sm120 (Blackwell Geforce). + NVIDIA GPUs with sm120 (Blackwell). */ #if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120 @@ -15,20 +13,10 @@ void cutlass_scaled_mm_sm120(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& a_scales, torch::Tensor const& b_scales, std::optional const& bias) { - TORCH_CHECK(a_scales.dtype() == torch::kFloat32); - TORCH_CHECK(b_scales.dtype() == torch::kFloat32); - - int M = a.size(0), N = b.size(1), K = a.size(1); - TORCH_CHECK( - (a_scales.numel() == 1 || a_scales.numel() == a.size(0)) && - (b_scales.numel() == 1 || b_scales.numel() == b.size(1)), - "Currently, block scaled fp8 gemm is not implemented for Blackwell"); - - // Standard per-tensor/per-token/per-channel scaling - TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); - TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn, - "Currently, only fp8 gemm is implemented for Blackwell"); - vllm::cutlass_scaled_mm_sm120_fp8(c, a, b, a_scales, b_scales, bias); + dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias, + vllm::cutlass_scaled_mm_sm120_fp8, + nullptr, // int8 not supported on SM120 + vllm::cutlass_scaled_mm_blockwise_sm120_fp8); } #endif