From d47661f0cd6ce28504a2c03d2d2105521a591f28 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Sat, 12 Jul 2025 01:05:33 +0900 Subject: [PATCH] [Kernel] Basic tuned configs for NVFP4 CUTLASS dense GEMM (#20646) Signed-off-by: mgoin --- .../fp4/nvfp4_scaled_mm_kernels.cu | 139 +++++++++++------- 1 file changed, 87 insertions(+), 52 deletions(-) diff --git a/csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu b/csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu index 7572a7eb3122d..5bc4c38a275ca 100644 --- a/csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu +++ b/csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu @@ -30,35 +30,40 @@ #include "cutlass/util/packed_stride.hpp" +#include "core/math.hpp" + using namespace cute; #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) -// Kernel Perf config -template -struct KernelTraits; -template <> -struct KernelTraits { - using MmaTileShape = Shape<_128, _128, _256>; +// Configuration for M in (256, inf) +struct sm100_fp4_config_default { + using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; + using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; + using TileShape = Shape<_256, _256, _256>; + using ClusterShape = Shape<_2, _1, _1>; + using PerSmTileShape_MNK = Shape<_128, _256, _256>; +}; + +// Configuration for M in (16, 256] +struct sm100_fp4_config_M256 { + using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; + using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; + using TileShape = Shape<_256, _128, _256>; + using ClusterShape = Shape<_2, _1, _1>; + using PerSmTileShape_MNK = Shape<_128, _128, _256>; +}; + +// Configuration for M in [1, 16] +struct sm100_fp4_config_M16 { + using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; + using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; + using TileShape = Shape<_128, _128, _256>; using ClusterShape = Shape<_1, _1, _1>; using PerSmTileShape_MNK = Shape<_128, _128, _256>; }; -template <> -struct KernelTraits { - using MmaTileShape = Shape<_256, _256, _256>; - using ClusterShape = Shape<_4, _4, _1>; - using PerSmTileShape_MNK = Shape<_128, _256, _256>; -}; - -template <> -struct KernelTraits { - using MmaTileShape = Shape<_256, _256, _256>; - using ClusterShape = Shape<_4, _4, _1>; - using PerSmTileShape_MNK = Shape<_128, _256, _256>; -}; - -template +template struct Fp4GemmSm100 { // A matrix configuration using ElementA = cutlass::nv_float4_t; @@ -71,21 +76,22 @@ struct Fp4GemmSm100 { static constexpr int AlignmentB = 32; // C/D matrix configuration - using ElementD = T; - using ElementC = T; + using ElementD = OutType; + using ElementC = OutType; using LayoutCTag = cutlass::layout::RowMajor; using LayoutDTag = cutlass::layout::RowMajor; static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + // Kernel functional config using ElementAccumulator = float; using ArchTag = cutlass::arch::Sm100; using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; - // Kernel Perf config - using MmaTileShape = typename KernelTraits::MmaTileShape; - using ClusterShape = typename KernelTraits::ClusterShape; - using PerSmTileShape_MNK = typename KernelTraits::PerSmTileShape_MNK; + // Use config's tile shapes + using MmaTileShape = typename Config::TileShape; + using ClusterShape = typename Config::ClusterShape; + using PerSmTileShape_MNK = typename Config::PerSmTileShape_MNK; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< @@ -119,22 +125,22 @@ struct Fp4GemmSm100 { using LayoutD = decltype(cute::make_layout(make_shape(0, 0, 0), StrideD{})); }; -template -typename T::Gemm::Arguments args_from_options( +template +typename Config::Gemm::Arguments args_from_options( at::Tensor& D, at::Tensor const& A, at::Tensor const& B, at::Tensor const& A_sf, at::Tensor const& B_sf, at::Tensor const& alpha, int64_t M, int64_t N, int64_t K) { - using ElementA = typename T::Gemm::ElementA; - using ElementB = typename T::Gemm::ElementB; + using ElementA = typename Config::Gemm::ElementA; + using ElementB = typename Config::Gemm::ElementB; using ElementSFA = cutlass::float_ue4m3_t; using ElementSFB = cutlass::float_ue4m3_t; - using ElementD = typename T::Gemm::ElementD; + using ElementD = typename Config::Gemm::ElementD; using ElementCompute = float; - using StrideA = typename T::StrideA; - using StrideB = typename T::StrideB; - using StrideD = typename T::StrideD; - using Sm100BlkScaledConfig = - typename T::Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + using StrideA = typename Config::StrideA; + using StrideB = typename Config::StrideB; + using StrideD = typename Config::StrideD; + using Sm100BlkScaledConfig = typename Config::Gemm::GemmKernel:: + CollectiveMainloop::Sm1xxBlkScaledConfig; int m = static_cast(M); int n = static_cast(N); @@ -148,7 +154,7 @@ typename T::Gemm::Arguments args_from_options( auto layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB( cute::make_shape(m, n, k, 1)); - typename T::Gemm::Arguments arguments{ + typename Config::Gemm::Arguments arguments{ cutlass::gemm::GemmUniversalMode::kGemm, {m, n, k, 1}, {// Mainloop arguments @@ -167,17 +173,17 @@ typename T::Gemm::Arguments args_from_options( return arguments; } -template +template void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B, at::Tensor const& A_sf, at::Tensor const& B_sf, at::Tensor const& alpha, int64_t m, int64_t n, int64_t k, cudaStream_t stream) { - typename Fp4GemmSm100::Gemm gemm; + typename Config::Gemm gemm; auto arguments = - args_from_options>(D, A, B, A_sf, B_sf, alpha, m, n, k); + args_from_options(D, A, B, A_sf, B_sf, alpha, m, n, k); - size_t workspace_size = Fp4GemmSm100::Gemm::get_workspace_size(arguments); + size_t workspace_size = Config::Gemm::get_workspace_size(arguments); auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(A.device()); auto workspace = torch::empty(workspace_size, workspace_options); @@ -188,12 +194,40 @@ void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B, CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream)); } + +// Dispatch function to select appropriate config based on M +template +void cutlass_fp4_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A, + torch::Tensor const& B, + torch::Tensor const& A_sf, + torch::Tensor const& B_sf, + torch::Tensor const& alpha, int64_t m, int64_t n, + int64_t k, cudaStream_t stream) { + uint32_t const mp2 = std::max(static_cast(16), next_pow_2(m)); + + if (mp2 <= 16) { + // m in [1, 16] + runGemm>( + D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + } else if (mp2 <= 256) { + // m in (16, 256] + runGemm>( + D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + } else { + // m in (256, inf) + runGemm>( + D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + } +} + #else -template -void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B, - at::Tensor const& A_sf, at::Tensor const& B_sf, - at::Tensor const& alpha, int64_t m, int64_t n, int64_t k, - cudaStream_t stream) { +template +void cutlass_fp4_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A, + torch::Tensor const& B, + torch::Tensor const& A_sf, + torch::Tensor const& B_sf, + torch::Tensor const& alpha, int64_t m, int64_t n, + int64_t k, cudaStream_t stream) { TORCH_CHECK(false, "Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to " "a CUTLASS 3.8 source directory to enable support."); @@ -271,12 +305,13 @@ void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A, const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device()); if (out_dtype == at::ScalarType::Half) { - runGemm(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + cutlass_fp4_gemm_dispatch(D, A, B, A_sf, B_sf, alpha, m, n, + k, stream); } else if (out_dtype == at::ScalarType::BFloat16) { - runGemm(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); - } else if (out_dtype == at::ScalarType::Float) { - runGemm(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + cutlass_fp4_gemm_dispatch(D, A, B, A_sf, B_sf, alpha, + m, n, k, stream); } else { - TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm"); + TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm (", out_dtype, + ")"); } }