mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-05 09:42:14 +08:00
[Kernel] Basic tuned configs for NVFP4 CUTLASS dense GEMM (#20646)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
53fa457391
commit
d47661f0cd
@ -30,35 +30,40 @@
|
|||||||
|
|
||||||
#include "cutlass/util/packed_stride.hpp"
|
#include "cutlass/util/packed_stride.hpp"
|
||||||
|
|
||||||
|
#include "core/math.hpp"
|
||||||
|
|
||||||
using namespace cute;
|
using namespace cute;
|
||||||
|
|
||||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||||
// Kernel Perf config
|
|
||||||
template <typename T>
|
|
||||||
struct KernelTraits;
|
|
||||||
|
|
||||||
template <>
|
// Configuration for M in (256, inf)
|
||||||
struct KernelTraits<float> {
|
struct sm100_fp4_config_default {
|
||||||
using MmaTileShape = Shape<_128, _128, _256>;
|
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
|
||||||
|
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
||||||
|
using TileShape = Shape<_256, _256, _256>;
|
||||||
|
using ClusterShape = Shape<_2, _1, _1>;
|
||||||
|
using PerSmTileShape_MNK = Shape<_128, _256, _256>;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Configuration for M in (16, 256]
|
||||||
|
struct sm100_fp4_config_M256 {
|
||||||
|
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
|
||||||
|
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
||||||
|
using TileShape = Shape<_256, _128, _256>;
|
||||||
|
using ClusterShape = Shape<_2, _1, _1>;
|
||||||
|
using PerSmTileShape_MNK = Shape<_128, _128, _256>;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Configuration for M in [1, 16]
|
||||||
|
struct sm100_fp4_config_M16 {
|
||||||
|
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
|
||||||
|
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
||||||
|
using TileShape = Shape<_128, _128, _256>;
|
||||||
using ClusterShape = Shape<_1, _1, _1>;
|
using ClusterShape = Shape<_1, _1, _1>;
|
||||||
using PerSmTileShape_MNK = Shape<_128, _128, _256>;
|
using PerSmTileShape_MNK = Shape<_128, _128, _256>;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
template <typename Config, typename OutType>
|
||||||
struct KernelTraits<cutlass::half_t> {
|
|
||||||
using MmaTileShape = Shape<_256, _256, _256>;
|
|
||||||
using ClusterShape = Shape<_4, _4, _1>;
|
|
||||||
using PerSmTileShape_MNK = Shape<_128, _256, _256>;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct KernelTraits<cutlass::bfloat16_t> {
|
|
||||||
using MmaTileShape = Shape<_256, _256, _256>;
|
|
||||||
using ClusterShape = Shape<_4, _4, _1>;
|
|
||||||
using PerSmTileShape_MNK = Shape<_128, _256, _256>;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
struct Fp4GemmSm100 {
|
struct Fp4GemmSm100 {
|
||||||
// A matrix configuration
|
// A matrix configuration
|
||||||
using ElementA = cutlass::nv_float4_t<cutlass::float_e2m1_t>;
|
using ElementA = cutlass::nv_float4_t<cutlass::float_e2m1_t>;
|
||||||
@ -71,21 +76,22 @@ struct Fp4GemmSm100 {
|
|||||||
static constexpr int AlignmentB = 32;
|
static constexpr int AlignmentB = 32;
|
||||||
|
|
||||||
// C/D matrix configuration
|
// C/D matrix configuration
|
||||||
using ElementD = T;
|
using ElementD = OutType;
|
||||||
using ElementC = T;
|
using ElementC = OutType;
|
||||||
using LayoutCTag = cutlass::layout::RowMajor;
|
using LayoutCTag = cutlass::layout::RowMajor;
|
||||||
using LayoutDTag = cutlass::layout::RowMajor;
|
using LayoutDTag = cutlass::layout::RowMajor;
|
||||||
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||||
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
|
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||||
|
|
||||||
// Kernel functional config
|
// Kernel functional config
|
||||||
using ElementAccumulator = float;
|
using ElementAccumulator = float;
|
||||||
using ArchTag = cutlass::arch::Sm100;
|
using ArchTag = cutlass::arch::Sm100;
|
||||||
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp;
|
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp;
|
||||||
|
|
||||||
// Kernel Perf config
|
// Use config's tile shapes
|
||||||
using MmaTileShape = typename KernelTraits<T>::MmaTileShape;
|
using MmaTileShape = typename Config::TileShape;
|
||||||
using ClusterShape = typename KernelTraits<T>::ClusterShape;
|
using ClusterShape = typename Config::ClusterShape;
|
||||||
using PerSmTileShape_MNK = typename KernelTraits<T>::PerSmTileShape_MNK;
|
using PerSmTileShape_MNK = typename Config::PerSmTileShape_MNK;
|
||||||
|
|
||||||
using CollectiveEpilogue =
|
using CollectiveEpilogue =
|
||||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||||
@ -119,22 +125,22 @@ struct Fp4GemmSm100 {
|
|||||||
using LayoutD = decltype(cute::make_layout(make_shape(0, 0, 0), StrideD{}));
|
using LayoutD = decltype(cute::make_layout(make_shape(0, 0, 0), StrideD{}));
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename Config>
|
||||||
typename T::Gemm::Arguments args_from_options(
|
typename Config::Gemm::Arguments args_from_options(
|
||||||
at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
|
at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
|
||||||
at::Tensor const& A_sf, at::Tensor const& B_sf, at::Tensor const& alpha,
|
at::Tensor const& A_sf, at::Tensor const& B_sf, at::Tensor const& alpha,
|
||||||
int64_t M, int64_t N, int64_t K) {
|
int64_t M, int64_t N, int64_t K) {
|
||||||
using ElementA = typename T::Gemm::ElementA;
|
using ElementA = typename Config::Gemm::ElementA;
|
||||||
using ElementB = typename T::Gemm::ElementB;
|
using ElementB = typename Config::Gemm::ElementB;
|
||||||
using ElementSFA = cutlass::float_ue4m3_t;
|
using ElementSFA = cutlass::float_ue4m3_t;
|
||||||
using ElementSFB = cutlass::float_ue4m3_t;
|
using ElementSFB = cutlass::float_ue4m3_t;
|
||||||
using ElementD = typename T::Gemm::ElementD;
|
using ElementD = typename Config::Gemm::ElementD;
|
||||||
using ElementCompute = float;
|
using ElementCompute = float;
|
||||||
using StrideA = typename T::StrideA;
|
using StrideA = typename Config::StrideA;
|
||||||
using StrideB = typename T::StrideB;
|
using StrideB = typename Config::StrideB;
|
||||||
using StrideD = typename T::StrideD;
|
using StrideD = typename Config::StrideD;
|
||||||
using Sm100BlkScaledConfig =
|
using Sm100BlkScaledConfig = typename Config::Gemm::GemmKernel::
|
||||||
typename T::Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
|
CollectiveMainloop::Sm1xxBlkScaledConfig;
|
||||||
|
|
||||||
int m = static_cast<int>(M);
|
int m = static_cast<int>(M);
|
||||||
int n = static_cast<int>(N);
|
int n = static_cast<int>(N);
|
||||||
@ -148,7 +154,7 @@ typename T::Gemm::Arguments args_from_options(
|
|||||||
auto layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(
|
auto layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(
|
||||||
cute::make_shape(m, n, k, 1));
|
cute::make_shape(m, n, k, 1));
|
||||||
|
|
||||||
typename T::Gemm::Arguments arguments{
|
typename Config::Gemm::Arguments arguments{
|
||||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||||
{m, n, k, 1},
|
{m, n, k, 1},
|
||||||
{// Mainloop arguments
|
{// Mainloop arguments
|
||||||
@ -167,17 +173,17 @@ typename T::Gemm::Arguments args_from_options(
|
|||||||
return arguments;
|
return arguments;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename Config>
|
||||||
void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
|
void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
|
||||||
at::Tensor const& A_sf, at::Tensor const& B_sf,
|
at::Tensor const& A_sf, at::Tensor const& B_sf,
|
||||||
at::Tensor const& alpha, int64_t m, int64_t n, int64_t k,
|
at::Tensor const& alpha, int64_t m, int64_t n, int64_t k,
|
||||||
cudaStream_t stream) {
|
cudaStream_t stream) {
|
||||||
typename Fp4GemmSm100<T>::Gemm gemm;
|
typename Config::Gemm gemm;
|
||||||
|
|
||||||
auto arguments =
|
auto arguments =
|
||||||
args_from_options<Fp4GemmSm100<T>>(D, A, B, A_sf, B_sf, alpha, m, n, k);
|
args_from_options<Config>(D, A, B, A_sf, B_sf, alpha, m, n, k);
|
||||||
|
|
||||||
size_t workspace_size = Fp4GemmSm100<T>::Gemm::get_workspace_size(arguments);
|
size_t workspace_size = Config::Gemm::get_workspace_size(arguments);
|
||||||
auto const workspace_options =
|
auto const workspace_options =
|
||||||
torch::TensorOptions().dtype(torch::kUInt8).device(A.device());
|
torch::TensorOptions().dtype(torch::kUInt8).device(A.device());
|
||||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||||
@ -188,12 +194,40 @@ void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
|
|||||||
|
|
||||||
CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream));
|
CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Dispatch function to select appropriate config based on M
|
||||||
|
template <typename OutType>
|
||||||
|
void cutlass_fp4_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
|
||||||
|
torch::Tensor const& B,
|
||||||
|
torch::Tensor const& A_sf,
|
||||||
|
torch::Tensor const& B_sf,
|
||||||
|
torch::Tensor const& alpha, int64_t m, int64_t n,
|
||||||
|
int64_t k, cudaStream_t stream) {
|
||||||
|
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));
|
||||||
|
|
||||||
|
if (mp2 <= 16) {
|
||||||
|
// m in [1, 16]
|
||||||
|
runGemm<Fp4GemmSm100<sm100_fp4_config_M16, OutType>>(
|
||||||
|
D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||||
|
} else if (mp2 <= 256) {
|
||||||
|
// m in (16, 256]
|
||||||
|
runGemm<Fp4GemmSm100<sm100_fp4_config_M256, OutType>>(
|
||||||
|
D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||||
|
} else {
|
||||||
|
// m in (256, inf)
|
||||||
|
runGemm<Fp4GemmSm100<sm100_fp4_config_default, OutType>>(
|
||||||
|
D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#else
|
#else
|
||||||
template <typename T>
|
template <typename OutType>
|
||||||
void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
|
void cutlass_fp4_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
|
||||||
at::Tensor const& A_sf, at::Tensor const& B_sf,
|
torch::Tensor const& B,
|
||||||
at::Tensor const& alpha, int64_t m, int64_t n, int64_t k,
|
torch::Tensor const& A_sf,
|
||||||
cudaStream_t stream) {
|
torch::Tensor const& B_sf,
|
||||||
|
torch::Tensor const& alpha, int64_t m, int64_t n,
|
||||||
|
int64_t k, cudaStream_t stream) {
|
||||||
TORCH_CHECK(false,
|
TORCH_CHECK(false,
|
||||||
"Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to "
|
"Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to "
|
||||||
"a CUTLASS 3.8 source directory to enable support.");
|
"a CUTLASS 3.8 source directory to enable support.");
|
||||||
@ -271,12 +305,13 @@ void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A,
|
|||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device());
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device());
|
||||||
|
|
||||||
if (out_dtype == at::ScalarType::Half) {
|
if (out_dtype == at::ScalarType::Half) {
|
||||||
runGemm<cutlass::half_t>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
cutlass_fp4_gemm_dispatch<cutlass::half_t>(D, A, B, A_sf, B_sf, alpha, m, n,
|
||||||
|
k, stream);
|
||||||
} else if (out_dtype == at::ScalarType::BFloat16) {
|
} else if (out_dtype == at::ScalarType::BFloat16) {
|
||||||
runGemm<cutlass::bfloat16_t>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
cutlass_fp4_gemm_dispatch<cutlass::bfloat16_t>(D, A, B, A_sf, B_sf, alpha,
|
||||||
} else if (out_dtype == at::ScalarType::Float) {
|
m, n, k, stream);
|
||||||
runGemm<float>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm");
|
TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm (", out_dtype,
|
||||||
|
")");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user