mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 22:45:50 +08:00
[Kernel] Add support for block FP8 on SM120 (NVIDIA 5090 and RTX PRO 6000) (#22131)
Signed-off-by: Junhao Li <junhao@ubicloud.com>
This commit is contained in:
parent
b2c8ce57c6
commit
3303f134e0
@ -427,6 +427,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
set(SRCS
|
set(SRCS
|
||||||
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu"
|
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu"
|
||||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8.cu"
|
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8.cu"
|
||||||
|
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8.cu"
|
||||||
)
|
)
|
||||||
set_gencode_flags_for_srcs(
|
set_gencode_flags_for_srcs(
|
||||||
SRCS "${SRCS}"
|
SRCS "${SRCS}"
|
||||||
|
|||||||
@ -60,3 +60,13 @@ struct enable_sm100_only : Kernel {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename Kernel>
|
||||||
|
struct enable_sm120_only : Kernel {
|
||||||
|
template <typename... Args>
|
||||||
|
CUTLASS_DEVICE void operator()(Args&&... args) {
|
||||||
|
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 1200
|
||||||
|
Kernel::operator()(std::forward<Args>(args)...);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|||||||
@ -0,0 +1,23 @@
|
|||||||
|
#include "scaled_mm_kernels.hpp"
|
||||||
|
#include "scaled_mm_blockwise_sm120_fp8_dispatch.cuh"
|
||||||
|
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_blockwise_sm120_fp8(torch::Tensor& out,
|
||||||
|
torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales) {
|
||||||
|
if (out.dtype() == torch::kBFloat16) {
|
||||||
|
cutlass_gemm_blockwise_sm120_fp8_dispatch<cutlass::bfloat16_t>(
|
||||||
|
out, a, b, a_scales, b_scales);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||||
|
cutlass_gemm_blockwise_sm120_fp8_dispatch<cutlass::half_t>(
|
||||||
|
out, a, b, a_scales, b_scales);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace vllm
|
||||||
@ -0,0 +1,183 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "cuda_utils.h"
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "cutlass/numeric_types.h"
|
||||||
|
|
||||||
|
#include "cute/tensor.hpp"
|
||||||
|
#include "cutlass/tensor_ref.h"
|
||||||
|
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||||
|
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||||
|
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||||
|
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||||
|
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
||||||
|
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||||
|
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||||
|
|
||||||
|
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
|
||||||
|
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
|
||||||
|
|
||||||
|
#include "cutlass_gemm_caller.cuh"
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
|
||||||
|
using namespace cute;
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
template <class OutType, int ScaleGranularityM,
|
||||||
|
int ScaleGranularityN, int ScaleGranularityK,
|
||||||
|
class MmaTileShape, class ClusterShape,
|
||||||
|
class EpilogueScheduler, class MainloopScheduler>
|
||||||
|
struct cutlass_3x_gemm_fp8_blockwise {
|
||||||
|
using ElementAB = cutlass::float_e4m3_t;
|
||||||
|
|
||||||
|
using ElementA = ElementAB;
|
||||||
|
using LayoutA = cutlass::layout::RowMajor;
|
||||||
|
using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose<LayoutA>::type;
|
||||||
|
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
|
||||||
|
|
||||||
|
using ElementB = ElementAB;
|
||||||
|
// ColumnMajor is used for B to match the CUTLASS convention.
|
||||||
|
using LayoutB = cutlass::layout::ColumnMajor;
|
||||||
|
using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose<LayoutB>::type;
|
||||||
|
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
|
||||||
|
|
||||||
|
using ElementD = OutType;
|
||||||
|
using LayoutD = cutlass::layout::RowMajor;
|
||||||
|
using LayoutD_Transpose = typename cutlass::layout::LayoutTranspose<LayoutD>::type;
|
||||||
|
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||||
|
|
||||||
|
using ElementC = void; // TODO: support bias
|
||||||
|
using LayoutC = LayoutD;
|
||||||
|
using LayoutC_Transpose = LayoutD_Transpose;
|
||||||
|
static constexpr int AlignmentC = AlignmentD;
|
||||||
|
|
||||||
|
using ElementAccumulator = float;
|
||||||
|
using ElementCompute = float;
|
||||||
|
using ElementBlockScale = float;
|
||||||
|
|
||||||
|
using ScaleConfig = cutlass::detail::Sm120BlockwiseScaleConfig<
|
||||||
|
ScaleGranularityM, ScaleGranularityN, ScaleGranularityK,
|
||||||
|
cute::UMMA::Major::MN, cute::UMMA::Major::K>;
|
||||||
|
|
||||||
|
// layout_SFA and layout_SFB cannot be swapped since they are deduced.
|
||||||
|
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
||||||
|
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
||||||
|
|
||||||
|
using ArchTag = cutlass::arch::Sm120;
|
||||||
|
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||||
|
|
||||||
|
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
|
||||||
|
using ElementScalar = float;
|
||||||
|
using DefaultOperation = cutlass::epilogue::fusion::LinearCombination<ElementD, ElementCompute, ElementC, ElementScalar, RoundStyle>;
|
||||||
|
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||||
|
ArchTag,
|
||||||
|
OperatorClass,
|
||||||
|
MmaTileShape,
|
||||||
|
ClusterShape,
|
||||||
|
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||||
|
ElementAccumulator,
|
||||||
|
ElementCompute,
|
||||||
|
ElementC,
|
||||||
|
LayoutC,
|
||||||
|
AlignmentC,
|
||||||
|
ElementD,
|
||||||
|
LayoutD,
|
||||||
|
AlignmentD,
|
||||||
|
EpilogueScheduler,
|
||||||
|
DefaultOperation
|
||||||
|
>::CollectiveOp;
|
||||||
|
|
||||||
|
using StageCountType = cutlass::gemm::collective::StageCountAuto;
|
||||||
|
using CollectiveMainloop =
|
||||||
|
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||||
|
ArchTag,
|
||||||
|
OperatorClass,
|
||||||
|
ElementA,
|
||||||
|
cute::tuple<LayoutA, LayoutSFA>,
|
||||||
|
AlignmentA,
|
||||||
|
ElementB,
|
||||||
|
cute::tuple<LayoutB, LayoutSFB>,
|
||||||
|
AlignmentB,
|
||||||
|
ElementAccumulator,
|
||||||
|
MmaTileShape,
|
||||||
|
ClusterShape,
|
||||||
|
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||||
|
MainloopScheduler
|
||||||
|
>::CollectiveOp;
|
||||||
|
|
||||||
|
using KernelType = enable_sm120_only<cutlass::gemm::kernel::GemmUniversal<
|
||||||
|
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue>>;
|
||||||
|
|
||||||
|
struct GemmKernel : public KernelType {};
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Gemm>
|
||||||
|
void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales) {
|
||||||
|
using GemmKernel = typename Gemm::GemmKernel;
|
||||||
|
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||||
|
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||||
|
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||||
|
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||||
|
using LayoutSFA = typename Gemm::LayoutSFA;
|
||||||
|
using LayoutSFB = typename Gemm::LayoutSFB;
|
||||||
|
using ScaleConfig = typename Gemm::ScaleConfig;
|
||||||
|
|
||||||
|
using ElementAB = typename Gemm::ElementAB;
|
||||||
|
using ElementD = typename Gemm::ElementD;
|
||||||
|
|
||||||
|
int32_t m = a.size(0), n = b.size(1), k = a.size(1);
|
||||||
|
|
||||||
|
StrideA a_stride;
|
||||||
|
StrideB b_stride;
|
||||||
|
StrideC c_stride;
|
||||||
|
a_stride =
|
||||||
|
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1));
|
||||||
|
b_stride =
|
||||||
|
cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1));
|
||||||
|
c_stride =
|
||||||
|
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1));
|
||||||
|
|
||||||
|
LayoutSFA layout_SFA =
|
||||||
|
ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1));
|
||||||
|
LayoutSFB layout_SFB =
|
||||||
|
ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1));
|
||||||
|
|
||||||
|
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
|
||||||
|
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
|
||||||
|
auto a_scales_ptr = static_cast<float*>(a_scales.data_ptr());
|
||||||
|
auto b_scales_ptr = static_cast<float*>(b_scales.data_ptr());
|
||||||
|
|
||||||
|
auto mainloop_args = [&](){
|
||||||
|
return typename GemmKernel::MainloopArguments{
|
||||||
|
a_ptr, a_stride, b_ptr, b_stride,
|
||||||
|
a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB
|
||||||
|
};
|
||||||
|
}();
|
||||||
|
auto prob_shape = cute::make_shape(m, n, k, 1);
|
||||||
|
|
||||||
|
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
||||||
|
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||||
|
{}, c_ptr, c_stride, c_ptr, c_stride};
|
||||||
|
c3x::cutlass_gemm_caller<GemmKernel>(a.device(), prob_shape, mainloop_args,
|
||||||
|
epilogue_args);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename OutType>
|
||||||
|
void cutlass_gemm_blockwise_sm120_fp8_dispatch(torch::Tensor& out,
|
||||||
|
torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales) {
|
||||||
|
// TODO: better heuristics
|
||||||
|
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
|
||||||
|
OutType, 1, 128, 128, Shape<_128, _128, _128>,
|
||||||
|
Shape<_1, _1, _1>, cutlass::epilogue::collective::EpilogueScheduleAuto,
|
||||||
|
cutlass::gemm::collective::KernelScheduleAuto>>(
|
||||||
|
out, a, b, a_scales, b_scales);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace vllm
|
||||||
@ -47,4 +47,10 @@ void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out,
|
|||||||
torch::Tensor const& b,
|
torch::Tensor const& b,
|
||||||
torch::Tensor const& a_scales,
|
torch::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales);
|
torch::Tensor const& b_scales);
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_blockwise_sm120_fp8(torch::Tensor& out,
|
||||||
|
torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales);
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|||||||
@ -1,11 +1,9 @@
|
|||||||
#include <cudaTypedefs.h>
|
#include "c3x/scaled_mm_helper.hpp"
|
||||||
#include "c3x/scaled_mm_kernels.hpp"
|
#include "c3x/scaled_mm_kernels.hpp"
|
||||||
|
|
||||||
#include "cuda_utils.h"
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
|
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
|
||||||
NVIDIA GPUs with sm120 (Blackwell Geforce).
|
NVIDIA GPUs with sm120 (Blackwell).
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
|
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
|
||||||
@ -15,20 +13,10 @@ void cutlass_scaled_mm_sm120(torch::Tensor& c, torch::Tensor const& a,
|
|||||||
torch::Tensor const& a_scales,
|
torch::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales,
|
torch::Tensor const& b_scales,
|
||||||
std::optional<torch::Tensor> const& bias) {
|
std::optional<torch::Tensor> const& bias) {
|
||||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
|
||||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
vllm::cutlass_scaled_mm_sm120_fp8,
|
||||||
|
nullptr, // int8 not supported on SM120
|
||||||
int M = a.size(0), N = b.size(1), K = a.size(1);
|
vllm::cutlass_scaled_mm_blockwise_sm120_fp8);
|
||||||
TORCH_CHECK(
|
|
||||||
(a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
|
|
||||||
(b_scales.numel() == 1 || b_scales.numel() == b.size(1)),
|
|
||||||
"Currently, block scaled fp8 gemm is not implemented for Blackwell");
|
|
||||||
|
|
||||||
// Standard per-tensor/per-token/per-channel scaling
|
|
||||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
|
||||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn,
|
|
||||||
"Currently, only fp8 gemm is implemented for Blackwell");
|
|
||||||
vllm::cutlass_scaled_mm_sm120_fp8(c, a, b, a_scales, b_scales, bias);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user