diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu index 84492553c02f2..4a8a5ed02d6ce 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu @@ -9,10 +9,6 @@ void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales) { - TORCH_CHECK( - a.size(0) % 4 == 0, - "Input tensor must have a number of rows that is a multiple of 4. ", - "but got: ", a.size(0), " rows."); if (out.dtype() == torch::kBFloat16) { cutlass_gemm_blockwise_sm100_fp8_dispatch( out, a, b, a_scales, b_scales); diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh index ef324364c6d5e..c841125dbb734 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh @@ -1,5 +1,6 @@ #pragma once +#include "cuda_utils.h" #include "cutlass/cutlass.h" #include "cutlass/numeric_types.h" @@ -22,49 +23,49 @@ namespace vllm { using namespace cute; -template +// clang-format off +template struct cutlass_3x_gemm_fp8_blockwise { + static constexpr bool swap_ab = swap_ab_; using ElementAB = cutlass::float_e4m3_t; using ElementA = ElementAB; using LayoutA = cutlass::layout::RowMajor; + using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose::type; static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; using ElementB = ElementAB; using LayoutB = cutlass::layout::ColumnMajor; + using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose::type; static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; - using ElementC = void; using ElementD = OutType; using LayoutD = cutlass::layout::RowMajor; + using LayoutD_Transpose = typename cutlass::layout::LayoutTranspose::type; static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + using ElementC = void; // TODO: support bias using LayoutC = LayoutD; + using LayoutC_Transpose = LayoutD_Transpose; static constexpr int AlignmentC = AlignmentD; using ElementAccumulator = float; using ElementCompute = float; using ElementBlockScale = float; - // MMA and Cluster Tile Shapes - // Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster - // Shape %2 == 0 using MmaTileShape_MNK = Shape<_128,_128,_128>; - static constexpr int ScaleMsPerTile = size<0>(ScalesPerTile{}); - static constexpr int ScaleGranularityM = - size<0>(MmaTileShape{}) / ScaleMsPerTile; - static constexpr int ScaleGranularityN = - size<1>(MmaTileShape{}) / size<1>(ScalesPerTile{}); - static constexpr int ScaleGranularityK = - size<2>(MmaTileShape{}) / size<2>(ScalesPerTile{}); + using ScaleConfig = conditional_t, + cutlass::detail::Sm100BlockwiseScaleConfig< + ScaleGranularityM, ScaleGranularityN, ScaleGranularityK, + cute::UMMA::Major::MN, cute::UMMA::Major::K>>; - // Shape of the threadblocks in a cluster - using ClusterShape_MNK = ClusterShape; - - using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig< - ScaleGranularityM, ScaleGranularityN, ScaleGranularityK, - cute::UMMA::Major::MN, cute::UMMA::Major::K>; + // layout_SFA and layout_SFB cannot be swapped since they are deduced. using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); @@ -73,7 +74,6 @@ struct cutlass_3x_gemm_fp8_blockwise { static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; using ElementScalar = float; - // clang-format off using DefaultOperation = cutlass::epilogue::fusion::LinearCombination; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< ArchTag, @@ -84,33 +84,47 @@ struct cutlass_3x_gemm_fp8_blockwise { ElementAccumulator, ElementCompute, ElementC, - LayoutC, + conditional_t, AlignmentC, ElementD, - LayoutD, + conditional_t, AlignmentD, EpilogueScheduler, DefaultOperation >::CollectiveOp; using StageCountType = cutlass::gemm::collective::StageCountAuto; - using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - ArchTag, - OperatorClass, - ElementA, - cute::tuple, - AlignmentA, - ElementB, - cute::tuple, - AlignmentB, - ElementAccumulator, - MmaTileShape, - ClusterShape, - + using CollectiveMainloop = conditional_t, + AlignmentB, + ElementA, + cute::tuple, + AlignmentA, + ElementAccumulator, + MmaTileShape, + ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, - MainloopScheduler - >::CollectiveOp; - // clang-format on + MainloopScheduler + >::CollectiveOp, + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + ElementA, + cute::tuple, + AlignmentA, + ElementB, + cute::tuple, + AlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopScheduler + >::CollectiveOp>; using KernelType = enable_sm100_only, CollectiveMainloop, CollectiveEpilogue>>; @@ -123,6 +137,7 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales) { + static constexpr bool swap_ab = Gemm::swap_ab; using GemmKernel = typename Gemm::GemmKernel; using StrideA = typename Gemm::GemmKernel::StrideA; using StrideB = typename Gemm::GemmKernel::StrideB; @@ -136,7 +151,6 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, using ElementD = typename Gemm::ElementD; int32_t m = a.size(0), n = b.size(1), k = a.size(1); - auto prob_shape = cute::make_shape(m, n, k, 1); StrideA a_stride; StrideB b_stride; @@ -146,11 +160,13 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, b_stride = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1)); c_stride = - cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1)); + cutlass::make_cute_packed_stride(StrideC{}, swap_ab ? cute::make_shape(n, m, 1) : cute::make_shape(m, n, 1)); - LayoutSFA layout_SFA = + LayoutSFA layout_SFA = swap_ab ? + ScaleConfig::tile_atom_to_shape_SFA(make_shape(n, m, k, 1)) : ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1)); - LayoutSFB layout_SFB = + LayoutSFB layout_SFB = swap_ab ? + ScaleConfig::tile_atom_to_shape_SFB(make_shape(n, m, k, 1)) : ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1)); auto a_ptr = static_cast(a.data_ptr()); @@ -158,9 +174,22 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, auto a_scales_ptr = static_cast(a_scales.data_ptr()); auto b_scales_ptr = static_cast(b_scales.data_ptr()); - typename GemmKernel::MainloopArguments mainloop_args{ - a_ptr, a_stride, b_ptr, b_stride, - a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB}; + auto mainloop_args = [&](){ + // layout_SFA and layout_SFB cannot be swapped since they are deduced. + if (swap_ab) { + return typename GemmKernel::MainloopArguments{ + b_ptr, b_stride, a_ptr, a_stride, + b_scales_ptr, layout_SFA, a_scales_ptr, layout_SFB + }; + } + else { + return typename GemmKernel::MainloopArguments{ + a_ptr, a_stride, b_ptr, b_stride, + a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB + }; + } + }(); + auto prob_shape = swap_ab ? cute::make_shape(n, m, k, 1) : cute::make_shape(m, n, k, 1); auto c_ptr = static_cast(out.data_ptr()); typename GemmKernel::EpilogueArguments epilogue_args{ @@ -175,29 +204,74 @@ void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales) { - auto m = a.size(0); - auto k = a.size(1); - auto n = b.size(1); - int sms; + int32_t m = a.size(0), n = b.size(1), k = a.size(1), sms; cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device()); - auto should_use_2sm = [&sms](int m, int n, int tile1SM = 128) { - return std::ceil(static_cast(m) / tile1SM) * - std::ceil(static_cast(n) / tile1SM) >= - sms; - }; - bool use_2sm = should_use_2sm(m, n); - if (use_2sm) { - cutlass_gemm_caller_blockwise, Shape<_256, _1, _1>, - Shape<_2, _2, _1>, cutlass::epilogue::TmaWarpSpecialized2Sm, - cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>( - out, a, b, a_scales, b_scales); + constexpr int TILE_K = 128; + // TODO: better heuristics + bool swap_ab = (m < 16) || (m % 4 != 0); + bool use_tma_epilogue = (m * n) % 4 == 0; + if (!swap_ab) { + constexpr int TILE_N = 128; + int tile_m = 256; + if (cuda_utils::ceil_div(n, TILE_N) * cuda_utils::ceil_div(m, 64) <= sms) { + tile_m = 64; + } + else if (cuda_utils::ceil_div(n, TILE_N) * cuda_utils::ceil_div(m, 128) <= sms) { + tile_m = 128; + } + if (tile_m == 64) { + if (use_tma_epilogue) { + cutlass_gemm_caller_blockwise, Int>, + Shape<_1, _1, _1>, cutlass::epilogue::TmaWarpSpecialized1Sm, + cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>( + out, a, b, a_scales, b_scales); + } else { + cutlass_gemm_caller_blockwise, Int>, + Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm, + cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>( + out, a, b, a_scales, b_scales); + } + } else if (tile_m == 128) { + if (use_tma_epilogue) { + cutlass_gemm_caller_blockwise, Int>, + Shape<_1, _1, _1>, cutlass::epilogue::TmaWarpSpecialized1Sm, + cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>( + out, a, b, a_scales, b_scales); + } else { + cutlass_gemm_caller_blockwise, Int>, + Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm, + cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>( + out, a, b, a_scales, b_scales); + } + } else { // tile_m == 256 + if (use_tma_epilogue) { + cutlass_gemm_caller_blockwise, Int>, + Shape<_2, _1, _1>, cutlass::epilogue::TmaWarpSpecialized2Sm, + cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>( + out, a, b, a_scales, b_scales); + } else { + cutlass_gemm_caller_blockwise, Int>, + Shape<_2, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized2Sm, + cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>( + out, a, b, a_scales, b_scales); + } + } } else { + // TODO: Test more tile N configs + constexpr int TILE_M = 128; + constexpr int TILE_N = 16; + // TMA epilogue isn't compatible with Swap A/B cutlass_gemm_caller_blockwise, Shape<_128, _1, _1>, - Shape<_1, _1, _1>, cutlass::epilogue::TmaWarpSpecialized1Sm, - cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>( + OutType, TILE_M, 1, TILE_K, Shape, Int, Int>, + Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm, + cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100, true>>( out, a, b, a_scales, b_scales); } } diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 1ebd2a8985824..270979c8e932e 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -136,24 +136,10 @@ def apply_w8a8_block_fp8_linear( use_cutlass, use_aiter_and_is_supported) if use_cutlass: - rows, cols = input_2d.shape - # Blackwell GPUs (SM100) require row dimensions to be multiple of 4 for - # optimal tensor core usage. Can be removed when targeting platforms - # without this constraint. - should_pad = current_platform.has_device_capability( - 100) and rows % 4 != 0 - if should_pad: - input_2d = torch.nn.functional.pad(input_2d, - (0, 0, 0, 4 - (rows % 4)), - value=0).contiguous() - q_input, x_scale = per_token_group_quant_fp8( input_2d, block_size[1], column_major_scales=use_cutlass) - output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale, block_size, input.dtype) - if should_pad: - output = output[:rows, :] else: q_input, x_scale = per_token_group_quant_fp8(