diff --git a/CMakeLists.txt b/CMakeLists.txt index 4b3bfe0af7f57..cad9f44286536 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -418,6 +418,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu" "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu" + "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu" ) set_gencode_flags_for_srcs( SRCS "${SRCS}" diff --git a/csrc/cutlass_extensions/common.hpp b/csrc/cutlass_extensions/common.hpp index dbe0e30f5cbfe..0877da52435eb 100644 --- a/csrc/cutlass_extensions/common.hpp +++ b/csrc/cutlass_extensions/common.hpp @@ -59,3 +59,13 @@ struct enable_sm90_only : Kernel { #endif } }; + +template +struct enable_sm100_only : Kernel { + template + CUTLASS_DEVICE void operator()(Args&&... args) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 1000 + Kernel::operator()(std::forward(args)...); +#endif + } +}; diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu new file mode 100644 index 0000000000000..84492553c02f2 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu @@ -0,0 +1,27 @@ +#include "scaled_mm_kernels.hpp" +#include "scaled_mm_blockwise_sm100_fp8_dispatch.cuh" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" + +namespace vllm { + +void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + TORCH_CHECK( + a.size(0) % 4 == 0, + "Input tensor must have a number of rows that is a multiple of 4. ", + "but got: ", a.size(0), " rows."); + if (out.dtype() == torch::kBFloat16) { + cutlass_gemm_blockwise_sm100_fp8_dispatch( + out, a, b, a_scales, b_scales); + + } else { + TORCH_CHECK(out.dtype() == torch::kFloat16); + cutlass_gemm_blockwise_sm100_fp8_dispatch( + out, a, b, a_scales, b_scales); + } +} + +} // namespace vllm diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh new file mode 100644 index 0000000000000..ef324364c6d5e --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh @@ -0,0 +1,205 @@ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass_extensions/gemm/dispatch_policy.hpp" +#include "cutlass_extensions/gemm/collective/collective_builder.hpp" + +#include "cutlass_gemm_caller.cuh" + +namespace vllm { + +using namespace cute; + +template +struct cutlass_3x_gemm_fp8_blockwise { + using ElementAB = cutlass::float_e4m3_t; + + using ElementA = ElementAB; + using LayoutA = cutlass::layout::RowMajor; + static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + + using ElementB = ElementAB; + using LayoutB = cutlass::layout::ColumnMajor; + static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + + using ElementC = void; + using ElementD = OutType; + using LayoutD = cutlass::layout::RowMajor; + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + + using LayoutC = LayoutD; + static constexpr int AlignmentC = AlignmentD; + + using ElementAccumulator = float; + using ElementCompute = float; + using ElementBlockScale = float; + + // MMA and Cluster Tile Shapes + // Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster + // Shape %2 == 0 using MmaTileShape_MNK = Shape<_128,_128,_128>; + static constexpr int ScaleMsPerTile = size<0>(ScalesPerTile{}); + static constexpr int ScaleGranularityM = + size<0>(MmaTileShape{}) / ScaleMsPerTile; + static constexpr int ScaleGranularityN = + size<1>(MmaTileShape{}) / size<1>(ScalesPerTile{}); + static constexpr int ScaleGranularityK = + size<2>(MmaTileShape{}) / size<2>(ScalesPerTile{}); + + // Shape of the threadblocks in a cluster + using ClusterShape_MNK = ClusterShape; + + using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig< + ScaleGranularityM, ScaleGranularityN, ScaleGranularityK, + cute::UMMA::Major::MN, cute::UMMA::Major::K>; + using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); + using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); + + using ArchTag = cutlass::arch::Sm100; + using OperatorClass = cutlass::arch::OpClassTensorOp; + + static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + using ElementScalar = float; + // clang-format off + using DefaultOperation = cutlass::epilogue::fusion::LinearCombination; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + MmaTileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementCompute, + ElementC, + LayoutC, + AlignmentC, + ElementD, + LayoutD, + AlignmentD, + EpilogueScheduler, + DefaultOperation + >::CollectiveOp; + + using StageCountType = cutlass::gemm::collective::StageCountAuto; + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + ElementA, + cute::tuple, + AlignmentA, + ElementB, + cute::tuple, + AlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopScheduler + >::CollectiveOp; + // clang-format on + + using KernelType = enable_sm100_only, CollectiveMainloop, CollectiveEpilogue>>; + + struct GemmKernel : public KernelType {}; +}; + +template +void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + using GemmKernel = typename Gemm::GemmKernel; + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideD = typename Gemm::GemmKernel::StrideD; + using StrideC = typename Gemm::GemmKernel::StrideC; + using LayoutSFA = typename Gemm::LayoutSFA; + using LayoutSFB = typename Gemm::LayoutSFB; + using ScaleConfig = typename Gemm::ScaleConfig; + + using ElementAB = typename Gemm::ElementAB; + using ElementD = typename Gemm::ElementD; + + int32_t m = a.size(0), n = b.size(1), k = a.size(1); + auto prob_shape = cute::make_shape(m, n, k, 1); + + StrideA a_stride; + StrideB b_stride; + StrideC c_stride; + a_stride = + cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1)); + b_stride = + cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1)); + c_stride = + cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1)); + + LayoutSFA layout_SFA = + ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1)); + LayoutSFB layout_SFB = + ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1)); + + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); + auto a_scales_ptr = static_cast(a_scales.data_ptr()); + auto b_scales_ptr = static_cast(b_scales.data_ptr()); + + typename GemmKernel::MainloopArguments mainloop_args{ + a_ptr, a_stride, b_ptr, b_stride, + a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB}; + + auto c_ptr = static_cast(out.data_ptr()); + typename GemmKernel::EpilogueArguments epilogue_args{ + {}, c_ptr, c_stride, c_ptr, c_stride}; + c3x::cutlass_gemm_caller(a.device(), prob_shape, mainloop_args, + epilogue_args); +} + +template +void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + auto m = a.size(0); + auto k = a.size(1); + auto n = b.size(1); + int sms; + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device()); + + auto should_use_2sm = [&sms](int m, int n, int tile1SM = 128) { + return std::ceil(static_cast(m) / tile1SM) * + std::ceil(static_cast(n) / tile1SM) >= + sms; + }; + bool use_2sm = should_use_2sm(m, n); + if (use_2sm) { + cutlass_gemm_caller_blockwise, Shape<_256, _1, _1>, + Shape<_2, _2, _1>, cutlass::epilogue::TmaWarpSpecialized2Sm, + cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>( + out, a, b, a_scales, b_scales); + } else { + cutlass_gemm_caller_blockwise, Shape<_128, _1, _1>, + Shape<_1, _1, _1>, cutlass::epilogue::TmaWarpSpecialized1Sm, + cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>( + out, a, b, a_scales, b_scales); + } +} + +} // namespace vllm diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp new file mode 100644 index 0000000000000..b589a479081ea --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp @@ -0,0 +1,57 @@ +#include +#include "cuda_utils.h" + +template +void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a, + torch::Tensor const& b, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + std::optional const& bias, + Fp8Func fp8_func, Int8Func int8_func, + BlockwiseFunc blockwise_func) { + TORCH_CHECK(a_scales.dtype() == torch::kFloat32); + TORCH_CHECK(b_scales.dtype() == torch::kFloat32); + + int M = a.size(0), N = b.size(1), K = a.size(1); + + if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) && + (b_scales.numel() == 1 || b_scales.numel() == b.size(1))) { + // Standard per-tensor/per-token/per-channel scaling + TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); + if (a.dtype() == torch::kFloat8_e4m3fn) { + fp8_func(c, a, b, a_scales, b_scales, bias); + } else { + TORCH_CHECK(a.dtype() == torch::kInt8); + if constexpr (!std::is_same_v) { + int8_func(c, a, b, a_scales, b_scales, bias); + } else { + TORCH_CHECK(false, "Int8 not supported for this architecture"); + } + } + } else { + using GroupShape = std::array; + auto make_group_shape = [](torch::Tensor const& x, + torch::Tensor const& s) -> GroupShape { + TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D"); + return {cuda_utils::ceil_div(x.size(0), s.size(0)), + cuda_utils::ceil_div(x.size(1), s.size(1))}; + }; + + GroupShape a_scale_group_shape = make_group_shape(a, a_scales); + GroupShape b_scale_group_shape = make_group_shape(b, b_scales); + + // 1x128 per-token group scales for activations + // 128x128 blockwise scales for weights + TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} && + b_scale_group_shape == GroupShape{128, 128} && + a.dtype() == torch::kFloat8_e4m3fn && + b.dtype() == torch::kFloat8_e4m3fn), + "cutlass_scaled_mm only supports datatype float8_e4m3fn.\n" + "a_scale_group_shape must be [1, 128]. Got: [", + a_scale_group_shape[0], ", ", a_scale_group_shape[1], + "]\n" + "b_scale_group_shape must be [128, 128]. Got: [", + b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]"); + TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm"); + blockwise_func(c, a, b, a_scales, b_scales); + } +} diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp index 85272804774db..c1242fdb39da9 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp @@ -36,4 +36,9 @@ void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b_scales, std::optional const& bias); +void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales); } // namespace vllm diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu index 459eb1bb76eb0..0cbd5305e3c25 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu @@ -1,8 +1,6 @@ -#include +#include "c3x/scaled_mm_helper.hpp" #include "c3x/scaled_mm_kernels.hpp" -#include "cuda_utils.h" - /* This file defines quantized GEMM operations using the CUTLASS 3.x API, for NVIDIA GPUs with sm100 (Blackwell). @@ -15,20 +13,10 @@ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& a_scales, torch::Tensor const& b_scales, std::optional const& bias) { - TORCH_CHECK(a_scales.dtype() == torch::kFloat32); - TORCH_CHECK(b_scales.dtype() == torch::kFloat32); - - int M = a.size(0), N = b.size(1), K = a.size(1); - TORCH_CHECK( - (a_scales.numel() == 1 || a_scales.numel() == a.size(0)) && - (b_scales.numel() == 1 || b_scales.numel() == b.size(1)), - "Currently, block scaled fp8 gemm is not implemented for Blackwell"); - - // Standard per-tensor/per-token/per-channel scaling - TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); - TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn, - "Currently, only fp8 gemm is implemented for Blackwell"); - vllm::cutlass_scaled_mm_sm100_fp8(c, a, b, a_scales, b_scales, bias); + dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias, + vllm::cutlass_scaled_mm_sm100_fp8, + nullptr, // int8 not supported on SM100 + vllm::cutlass_scaled_mm_blockwise_sm100_fp8); } #endif diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu index bcb91040d5e2e..211302171f074 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu @@ -1,8 +1,6 @@ -#include +#include "c3x/scaled_mm_helper.hpp" #include "c3x/scaled_mm_kernels.hpp" -#include "cuda_utils.h" - /* This file defines quantized GEMM operations using the CUTLASS 3.x API, for NVIDIA GPUs with sm90a (Hopper). @@ -15,49 +13,10 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& a_scales, torch::Tensor const& b_scales, std::optional const& bias) { - TORCH_CHECK(a_scales.dtype() == torch::kFloat32); - TORCH_CHECK(b_scales.dtype() == torch::kFloat32); - - int M = a.size(0), N = b.size(1), K = a.size(1); - - if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) && - (b_scales.numel() == 1 || b_scales.numel() == b.size(1))) { - // Standard per-tensor/per-token/per-channel scaling - TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); - if (a.dtype() == torch::kFloat8_e4m3fn) { - vllm::cutlass_scaled_mm_sm90_fp8(c, a, b, a_scales, b_scales, bias); - } else { - TORCH_CHECK(a.dtype() == torch::kInt8); - vllm::cutlass_scaled_mm_sm90_int8(c, a, b, a_scales, b_scales, bias); - } - } else { - using GroupShape = std::array; - auto make_group_shape = [](torch::Tensor const& x, - torch::Tensor const& s) -> GroupShape { - TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D"); - return {cuda_utils::ceil_div(x.size(0), s.size(0)), - cuda_utils::ceil_div(x.size(1), s.size(1))}; - }; - - GroupShape a_scale_group_shape = make_group_shape(a, a_scales); - GroupShape b_scale_group_shape = make_group_shape(b, b_scales); - - // 1x128 per-token group scales for activations - // 128x128 blockwise scales for weights - TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} && - b_scale_group_shape == GroupShape{128, 128} && - a.dtype() == torch::kFloat8_e4m3fn && - b.dtype() == torch::kFloat8_e4m3fn), - "cutlass_scaled_mm only supports datatype float8_e4m3fn.\n" - "a_scale_group_shape must be [1, 128]. Got: [", - a_scale_group_shape[0], ", ", a_scale_group_shape[1], - "]\n" - "b_scale_group_shape must be [128, 128]. Got: [", - b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]"); - TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm"); - - vllm::cutlass_scaled_mm_blockwise_sm90_fp8(c, a, b, a_scales, b_scales); - } + dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias, + vllm::cutlass_scaled_mm_sm90_fp8, + vllm::cutlass_scaled_mm_sm90_int8, + vllm::cutlass_scaled_mm_blockwise_sm90_fp8); } void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a, diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index 54b63894e4cbc..ddcc48cccab12 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -110,6 +110,8 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) { #if defined CUDA_VERSION if (cuda_device_capability >= 90 && cuda_device_capability < 100) { return CUDA_VERSION >= 12000; + } else if (cuda_device_capability >= 100) { + return CUDA_VERSION >= 12080; } #endif diff --git a/tests/kernels/quantization/test_cutlass_scaled_mm.py b/tests/kernels/quantization/test_cutlass_scaled_mm.py index 8084d9bf2c2da..633addd421f43 100644 --- a/tests/kernels/quantization/test_cutlass_scaled_mm.py +++ b/tests/kernels/quantization/test_cutlass_scaled_mm.py @@ -95,7 +95,7 @@ def cutlass_fp8_gemm_helper(m: int, out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) - torch.testing.assert_close(out, baseline, rtol=1e-2, atol=5e-2) + torch.testing.assert_close(out, baseline, rtol=1e-2, atol=1.5e-1) opcheck(torch.ops._C.cutlass_scaled_mm, (out, a, b, scale_a, scale_b, bias)) @@ -161,6 +161,8 @@ def test_cutlass_fp8_blockwise_scale_gemm(m: int, n: int, k: int, return if m % a_scale_group_shape[0] != 0 or k % a_scale_group_shape[1] != 0: return + if m % 4 != 0 and current_platform.has_device_capability(100): + return cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape, use_bias) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 064cbb8cf52d7..3bb42e737f10a 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -57,6 +57,16 @@ def apply_w8a8_block_fp8_linear( or br not in (1, weight.shape[0])): shape_supported_by_cutlass = False if cutlass_block_fp8_supported and shape_supported_by_cutlass: + rows, cols = input_2d.shape + # Blackwell GPUs (SM100) require row dimensions to be multiple of 4 for + # optimal tensor core usage. Can be removed when targeting platforms + # without this constraint. + should_pad = current_platform.has_device_capability( + 100) and rows % 4 != 0 + if should_pad: + input_2d = torch.nn.functional.pad(input_2d, + (0, 0, 0, 4 - (rows % 4)), + value=0).contiguous() q_input, x_scale = per_token_group_quant_fp8(input_2d, block_size[1], column_major_scales=True) @@ -65,6 +75,8 @@ def apply_w8a8_block_fp8_linear( out_dtype=input.dtype, scale_a=x_scale, scale_b=weight_scale.T) + if should_pad: + output = output[:rows, :] else: q_input, x_scale = per_token_group_quant_fp8(input_2d, block_size[1],