Add cutlass support for blackwell fp8 blockwise gemm (#14383)

Signed-off-by: Shu Wang <shuw@nvidia.com>
This commit is contained in:
Shu Wang 2025-05-08 17:09:55 -05:00 committed by GitHub
parent 4f605a6de5
commit 376786fac1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 332 additions and 64 deletions

View File

@ -418,6 +418,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
set(SRCS set(SRCS
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu" "csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu"
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu" "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu"
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu"
) )
set_gencode_flags_for_srcs( set_gencode_flags_for_srcs(
SRCS "${SRCS}" SRCS "${SRCS}"

View File

@ -59,3 +59,13 @@ struct enable_sm90_only : Kernel {
#endif #endif
} }
}; };
template <typename Kernel>
struct enable_sm100_only : Kernel {
template <typename... Args>
CUTLASS_DEVICE void operator()(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 1000
Kernel::operator()(std::forward<Args>(args)...);
#endif
}
};

View File

@ -0,0 +1,27 @@
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_blockwise_sm100_fp8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace vllm {
void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
TORCH_CHECK(
a.size(0) % 4 == 0,
"Input tensor must have a number of rows that is a multiple of 4. ",
"but got: ", a.size(0), " rows.");
if (out.dtype() == torch::kBFloat16) {
cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::bfloat16_t>(
out, a, b, a_scales, b_scales);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::half_t>(
out, a, b, a_scales, b_scales);
}
}
} // namespace vllm

View File

@ -0,0 +1,205 @@
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
#include "cutlass_gemm_caller.cuh"
namespace vllm {
using namespace cute;
template <typename OutType, typename MmaTileShape, typename ScalesPerTile,
class ClusterShape, typename EpilogueScheduler,
typename MainloopScheduler>
struct cutlass_3x_gemm_fp8_blockwise {
using ElementAB = cutlass::float_e4m3_t;
using ElementA = ElementAB;
using LayoutA = cutlass::layout::RowMajor;
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
using ElementB = ElementAB;
using LayoutB = cutlass::layout::ColumnMajor;
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
using ElementC = void;
using ElementD = OutType;
using LayoutD = cutlass::layout::RowMajor;
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
using LayoutC = LayoutD;
static constexpr int AlignmentC = AlignmentD;
using ElementAccumulator = float;
using ElementCompute = float;
using ElementBlockScale = float;
// MMA and Cluster Tile Shapes
// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster
// Shape %2 == 0 using MmaTileShape_MNK = Shape<_128,_128,_128>;
static constexpr int ScaleMsPerTile = size<0>(ScalesPerTile{});
static constexpr int ScaleGranularityM =
size<0>(MmaTileShape{}) / ScaleMsPerTile;
static constexpr int ScaleGranularityN =
size<1>(MmaTileShape{}) / size<1>(ScalesPerTile{});
static constexpr int ScaleGranularityK =
size<2>(MmaTileShape{}) / size<2>(ScalesPerTile{});
// Shape of the threadblocks in a cluster
using ClusterShape_MNK = ClusterShape;
using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig<
ScaleGranularityM, ScaleGranularityN, ScaleGranularityK,
cute::UMMA::Major::MN, cute::UMMA::Major::K>;
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
using ArchTag = cutlass::arch::Sm100;
using OperatorClass = cutlass::arch::OpClassTensorOp;
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
using ElementScalar = float;
// clang-format off
using DefaultOperation = cutlass::epilogue::fusion::LinearCombination<ElementD, ElementCompute, ElementC, ElementScalar, RoundStyle>;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
MmaTileShape,
ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator,
ElementCompute,
ElementC,
LayoutC,
AlignmentC,
ElementD,
LayoutD,
AlignmentD,
EpilogueScheduler,
DefaultOperation
>::CollectiveOp;
using StageCountType = cutlass::gemm::collective::StageCountAuto;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
ElementA,
cute::tuple<LayoutA, LayoutSFA>,
AlignmentA,
ElementB,
cute::tuple<LayoutB, LayoutSFB>,
AlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
MainloopScheduler
>::CollectiveOp;
// clang-format on
using KernelType = enable_sm100_only<cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue>>;
struct GemmKernel : public KernelType {};
};
template <typename Gemm>
void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
using GemmKernel = typename Gemm::GemmKernel;
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideD = typename Gemm::GemmKernel::StrideD;
using StrideC = typename Gemm::GemmKernel::StrideC;
using LayoutSFA = typename Gemm::LayoutSFA;
using LayoutSFB = typename Gemm::LayoutSFB;
using ScaleConfig = typename Gemm::ScaleConfig;
using ElementAB = typename Gemm::ElementAB;
using ElementD = typename Gemm::ElementD;
int32_t m = a.size(0), n = b.size(1), k = a.size(1);
auto prob_shape = cute::make_shape(m, n, k, 1);
StrideA a_stride;
StrideB b_stride;
StrideC c_stride;
a_stride =
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1));
b_stride =
cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1));
c_stride =
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1));
LayoutSFA layout_SFA =
ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1));
LayoutSFB layout_SFB =
ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1));
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
auto a_scales_ptr = static_cast<float*>(a_scales.data_ptr());
auto b_scales_ptr = static_cast<float*>(b_scales.data_ptr());
typename GemmKernel::MainloopArguments mainloop_args{
a_ptr, a_stride, b_ptr, b_stride,
a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB};
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
typename GemmKernel::EpilogueArguments epilogue_args{
{}, c_ptr, c_stride, c_ptr, c_stride};
c3x::cutlass_gemm_caller<GemmKernel>(a.device(), prob_shape, mainloop_args,
epilogue_args);
}
template <typename OutType>
void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
auto m = a.size(0);
auto k = a.size(1);
auto n = b.size(1);
int sms;
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device());
auto should_use_2sm = [&sms](int m, int n, int tile1SM = 128) {
return std::ceil(static_cast<float>(m) / tile1SM) *
std::ceil(static_cast<float>(n) / tile1SM) >=
sms;
};
bool use_2sm = should_use_2sm(m, n);
if (use_2sm) {
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
OutType, Shape<_256, _128, _128>, Shape<_256, _1, _1>,
Shape<_2, _2, _1>, cutlass::epilogue::TmaWarpSpecialized2Sm,
cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>(
out, a, b, a_scales, b_scales);
} else {
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
OutType, Shape<_128, _128, _128>, Shape<_128, _1, _1>,
Shape<_1, _1, _1>, cutlass::epilogue::TmaWarpSpecialized1Sm,
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
out, a, b, a_scales, b_scales);
}
}
} // namespace vllm

View File

@ -0,0 +1,57 @@
#include <torch/all.h>
#include "cuda_utils.h"
template <typename Fp8Func, typename Int8Func, typename BlockwiseFunc>
void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias,
Fp8Func fp8_func, Int8Func int8_func,
BlockwiseFunc blockwise_func) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
int M = a.size(0), N = b.size(1), K = a.size(1);
if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
(b_scales.numel() == 1 || b_scales.numel() == b.size(1))) {
// Standard per-tensor/per-token/per-channel scaling
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (a.dtype() == torch::kFloat8_e4m3fn) {
fp8_func(c, a, b, a_scales, b_scales, bias);
} else {
TORCH_CHECK(a.dtype() == torch::kInt8);
if constexpr (!std::is_same_v<Int8Func, std::nullptr_t>) {
int8_func(c, a, b, a_scales, b_scales, bias);
} else {
TORCH_CHECK(false, "Int8 not supported for this architecture");
}
}
} else {
using GroupShape = std::array<int64_t, 2>;
auto make_group_shape = [](torch::Tensor const& x,
torch::Tensor const& s) -> GroupShape {
TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D");
return {cuda_utils::ceil_div(x.size(0), s.size(0)),
cuda_utils::ceil_div(x.size(1), s.size(1))};
};
GroupShape a_scale_group_shape = make_group_shape(a, a_scales);
GroupShape b_scale_group_shape = make_group_shape(b, b_scales);
// 1x128 per-token group scales for activations
// 128x128 blockwise scales for weights
TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} &&
b_scale_group_shape == GroupShape{128, 128} &&
a.dtype() == torch::kFloat8_e4m3fn &&
b.dtype() == torch::kFloat8_e4m3fn),
"cutlass_scaled_mm only supports datatype float8_e4m3fn.\n"
"a_scale_group_shape must be [1, 128]. Got: [",
a_scale_group_shape[0], ", ", a_scale_group_shape[1],
"]\n"
"b_scale_group_shape must be [128, 128]. Got: [",
b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]");
TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");
blockwise_func(c, a, b, a_scales, b_scales);
}
}

View File

@ -36,4 +36,9 @@ void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b_scales, torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias); std::optional<torch::Tensor> const& bias);
void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
} // namespace vllm } // namespace vllm

View File

@ -1,8 +1,6 @@
#include <cudaTypedefs.h> #include "c3x/scaled_mm_helper.hpp"
#include "c3x/scaled_mm_kernels.hpp" #include "c3x/scaled_mm_kernels.hpp"
#include "cuda_utils.h"
/* /*
This file defines quantized GEMM operations using the CUTLASS 3.x API, for This file defines quantized GEMM operations using the CUTLASS 3.x API, for
NVIDIA GPUs with sm100 (Blackwell). NVIDIA GPUs with sm100 (Blackwell).
@ -15,20 +13,10 @@ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& a_scales, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias) { std::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32); dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
TORCH_CHECK(b_scales.dtype() == torch::kFloat32); vllm::cutlass_scaled_mm_sm100_fp8,
nullptr, // int8 not supported on SM100
int M = a.size(0), N = b.size(1), K = a.size(1); vllm::cutlass_scaled_mm_blockwise_sm100_fp8);
TORCH_CHECK(
(a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
(b_scales.numel() == 1 || b_scales.numel() == b.size(1)),
"Currently, block scaled fp8 gemm is not implemented for Blackwell");
// Standard per-tensor/per-token/per-channel scaling
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn,
"Currently, only fp8 gemm is implemented for Blackwell");
vllm::cutlass_scaled_mm_sm100_fp8(c, a, b, a_scales, b_scales, bias);
} }
#endif #endif

View File

@ -1,8 +1,6 @@
#include <cudaTypedefs.h> #include "c3x/scaled_mm_helper.hpp"
#include "c3x/scaled_mm_kernels.hpp" #include "c3x/scaled_mm_kernels.hpp"
#include "cuda_utils.h"
/* /*
This file defines quantized GEMM operations using the CUTLASS 3.x API, for This file defines quantized GEMM operations using the CUTLASS 3.x API, for
NVIDIA GPUs with sm90a (Hopper). NVIDIA GPUs with sm90a (Hopper).
@ -15,49 +13,10 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& a_scales, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias) { std::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32); dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
TORCH_CHECK(b_scales.dtype() == torch::kFloat32); vllm::cutlass_scaled_mm_sm90_fp8,
vllm::cutlass_scaled_mm_sm90_int8,
int M = a.size(0), N = b.size(1), K = a.size(1); vllm::cutlass_scaled_mm_blockwise_sm90_fp8);
if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
(b_scales.numel() == 1 || b_scales.numel() == b.size(1))) {
// Standard per-tensor/per-token/per-channel scaling
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (a.dtype() == torch::kFloat8_e4m3fn) {
vllm::cutlass_scaled_mm_sm90_fp8(c, a, b, a_scales, b_scales, bias);
} else {
TORCH_CHECK(a.dtype() == torch::kInt8);
vllm::cutlass_scaled_mm_sm90_int8(c, a, b, a_scales, b_scales, bias);
}
} else {
using GroupShape = std::array<int64_t, 2>;
auto make_group_shape = [](torch::Tensor const& x,
torch::Tensor const& s) -> GroupShape {
TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D");
return {cuda_utils::ceil_div(x.size(0), s.size(0)),
cuda_utils::ceil_div(x.size(1), s.size(1))};
};
GroupShape a_scale_group_shape = make_group_shape(a, a_scales);
GroupShape b_scale_group_shape = make_group_shape(b, b_scales);
// 1x128 per-token group scales for activations
// 128x128 blockwise scales for weights
TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} &&
b_scale_group_shape == GroupShape{128, 128} &&
a.dtype() == torch::kFloat8_e4m3fn &&
b.dtype() == torch::kFloat8_e4m3fn),
"cutlass_scaled_mm only supports datatype float8_e4m3fn.\n"
"a_scale_group_shape must be [1, 128]. Got: [",
a_scale_group_shape[0], ", ", a_scale_group_shape[1],
"]\n"
"b_scale_group_shape must be [128, 128]. Got: [",
b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]");
TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");
vllm::cutlass_scaled_mm_blockwise_sm90_fp8(c, a, b, a_scales, b_scales);
}
} }
void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a, void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,

View File

@ -110,6 +110,8 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
#if defined CUDA_VERSION #if defined CUDA_VERSION
if (cuda_device_capability >= 90 && cuda_device_capability < 100) { if (cuda_device_capability >= 90 && cuda_device_capability < 100) {
return CUDA_VERSION >= 12000; return CUDA_VERSION >= 12000;
} else if (cuda_device_capability >= 100) {
return CUDA_VERSION >= 12080;
} }
#endif #endif

View File

@ -95,7 +95,7 @@ def cutlass_fp8_gemm_helper(m: int,
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
torch.testing.assert_close(out, baseline, rtol=1e-2, atol=5e-2) torch.testing.assert_close(out, baseline, rtol=1e-2, atol=1.5e-1)
opcheck(torch.ops._C.cutlass_scaled_mm, opcheck(torch.ops._C.cutlass_scaled_mm,
(out, a, b, scale_a, scale_b, bias)) (out, a, b, scale_a, scale_b, bias))
@ -161,6 +161,8 @@ def test_cutlass_fp8_blockwise_scale_gemm(m: int, n: int, k: int,
return return
if m % a_scale_group_shape[0] != 0 or k % a_scale_group_shape[1] != 0: if m % a_scale_group_shape[0] != 0 or k % a_scale_group_shape[1] != 0:
return return
if m % 4 != 0 and current_platform.has_device_capability(100):
return
cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape, cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape,
use_bias) use_bias)

View File

@ -57,6 +57,16 @@ def apply_w8a8_block_fp8_linear(
or br not in (1, weight.shape[0])): or br not in (1, weight.shape[0])):
shape_supported_by_cutlass = False shape_supported_by_cutlass = False
if cutlass_block_fp8_supported and shape_supported_by_cutlass: if cutlass_block_fp8_supported and shape_supported_by_cutlass:
rows, cols = input_2d.shape
# Blackwell GPUs (SM100) require row dimensions to be multiple of 4 for
# optimal tensor core usage. Can be removed when targeting platforms
# without this constraint.
should_pad = current_platform.has_device_capability(
100) and rows % 4 != 0
if should_pad:
input_2d = torch.nn.functional.pad(input_2d,
(0, 0, 0, 4 - (rows % 4)),
value=0).contiguous()
q_input, x_scale = per_token_group_quant_fp8(input_2d, q_input, x_scale = per_token_group_quant_fp8(input_2d,
block_size[1], block_size[1],
column_major_scales=True) column_major_scales=True)
@ -65,6 +75,8 @@ def apply_w8a8_block_fp8_linear(
out_dtype=input.dtype, out_dtype=input.dtype,
scale_a=x_scale, scale_a=x_scale,
scale_b=weight_scale.T) scale_b=weight_scale.T)
if should_pad:
output = output[:rows, :]
else: else:
q_input, x_scale = per_token_group_quant_fp8(input_2d, q_input, x_scale = per_token_group_quant_fp8(input_2d,
block_size[1], block_size[1],