From 789562c28c143201a1d2ca35f7adcdf54ef832e5 Mon Sep 17 00:00:00 2001 From: "Roberto L. Castro" <38211239+LopezCastroRoberto@users.noreply.github.com> Date: Sun, 3 Aug 2025 09:54:22 +0200 Subject: [PATCH] Support CUTLASS NVFP4 (w4a4) for Blackwell Geforce GPUs (SM120) (#21309) Signed-off-by: LopezCastroRoberto --- CMakeLists.txt | 21 +- .../fp4/nvfp4_blockwise_moe_kernel.cu | 6 +- csrc/quantization/fp4/nvfp4_quant_entry.cu | 14 +- csrc/quantization/fp4/nvfp4_quant_kernels.cu | 2 +- .../quantization/fp4/nvfp4_scaled_mm_entry.cu | 14 +- .../fp4/nvfp4_scaled_mm_sm120_kernels.cu | 285 ++++++++++++++++++ 6 files changed, 329 insertions(+), 13 deletions(-) create mode 100644 csrc/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu diff --git a/CMakeLists.txt b/CMakeLists.txt index ea56b8451f228..e2cc0ccdef515 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -529,6 +529,25 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() endif() + # The nvfp4_scaled_mm_sm120 kernels for Geforce Blackwell SM120 require + # CUDA 12.8 or later + cuda_archs_loose_intersection(FP4_ARCHS "12.0;12.0a" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS) + set(SRCS + "csrc/quantization/fp4/nvfp4_quant_kernels.cu" + "csrc/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${FP4_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM120=1") + message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}") + else() + message(STATUS "Not building NVFP4 as no compatible archs were found.") + # clear FP4_ARCHS + set(FP4_ARCHS) + endif() + # FP4 Archs and flags cuda_archs_loose_intersection(FP4_ARCHS "10.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS) @@ -541,7 +560,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") SRCS "${SRCS}" CUDA_ARCHS "${FP4_ARCHS}") list(APPEND VLLM_EXT_SRC "${SRCS}") - list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4=1") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM100=1") list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1") message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}") else() diff --git a/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu b/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu index a21ee55b65862..03db5cc196d59 100644 --- a/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu +++ b/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu @@ -335,7 +335,7 @@ void run_fp4_blockwise_scaled_group_mm( TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM"); } -#if defined ENABLE_NVFP4 && ENABLE_NVFP4 +#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100 constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte; constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn; #endif @@ -356,7 +356,7 @@ void cutlass_fp4_group_mm( const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales, const torch::Tensor& alphas, const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets) { -#if defined ENABLE_NVFP4 && ENABLE_NVFP4 +#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100 // Input validation CHECK_INPUT(a, FLOAT4_E2M1X2, "a"); CHECK_INPUT(b, FLOAT4_E2M1X2, "b"); @@ -398,7 +398,7 @@ void cutlass_fp4_group_mm( TORCH_CHECK_NOT_IMPLEMENTED( false, "No compiled cutlass_fp4_group_mm kernel, vLLM must " - "be compiled with ENABLE_NVFP4 for SM100+ and CUDA " + "be compiled with ENABLE_NVFP4_SM100 for SM100+ and CUDA " "12.8 or above."); #endif } diff --git a/csrc/quantization/fp4/nvfp4_quant_entry.cu b/csrc/quantization/fp4/nvfp4_quant_entry.cu index badbb7e310df0..1b61bd4519fc3 100644 --- a/csrc/quantization/fp4/nvfp4_quant_entry.cu +++ b/csrc/quantization/fp4/nvfp4_quant_entry.cu @@ -16,14 +16,15 @@ #include -#if defined ENABLE_NVFP4 && ENABLE_NVFP4 -void scaled_fp4_quant_sm100a(torch::Tensor const& output, +#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ + (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) +void scaled_fp4_quant_sm1xxa(torch::Tensor const& output, torch::Tensor const& input, torch::Tensor const& output_sf, torch::Tensor const& input_sf); #endif -#if defined ENABLE_NVFP4 && ENABLE_NVFP4 +#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100 void scaled_fp4_experts_quant_sm100a( torch::Tensor& output, torch::Tensor& output_scale, torch::Tensor const& input, torch::Tensor const& input_global_scale, @@ -33,8 +34,9 @@ void scaled_fp4_experts_quant_sm100a( void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf, torch::Tensor const& input_sf) { -#if defined ENABLE_NVFP4 && ENABLE_NVFP4 - return scaled_fp4_quant_sm100a(output, input, output_sf, input_sf); +#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ + (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) + return scaled_fp4_quant_sm1xxa(output, input, output_sf, input_sf); #endif TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 quantization kernel"); } @@ -44,7 +46,7 @@ void scaled_fp4_experts_quant( torch::Tensor const& input, torch::Tensor const& input_global_scale, torch::Tensor const& input_offset_by_experts, torch::Tensor const& output_scale_offset_by_experts) { -#if defined ENABLE_NVFP4 && ENABLE_NVFP4 +#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100 return scaled_fp4_experts_quant_sm100a( output, output_scale, input, input_global_scale, input_offset_by_experts, output_scale_offset_by_experts); diff --git a/csrc/quantization/fp4/nvfp4_quant_kernels.cu b/csrc/quantization/fp4/nvfp4_quant_kernels.cu index d32911357a953..4e080de151648 100644 --- a/csrc/quantization/fp4/nvfp4_quant_kernels.cu +++ b/csrc/quantization/fp4/nvfp4_quant_kernels.cu @@ -332,7 +332,7 @@ template void invokeFP4Quantization(int m, int n, __nv_bfloat16 const* input, int multiProcessorCount, cudaStream_t stream); -void scaled_fp4_quant_sm100a(torch::Tensor const& output, +void scaled_fp4_quant_sm1xxa(torch::Tensor const& output, torch::Tensor const& input, torch::Tensor const& output_sf, torch::Tensor const& input_sf) { diff --git a/csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu b/csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu index 61b75e92dfaa0..9cba2828aac2e 100644 --- a/csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu +++ b/csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu @@ -16,7 +16,7 @@ #include -#if defined ENABLE_NVFP4 && ENABLE_NVFP4 +#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100 void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A, torch::Tensor const& B, torch::Tensor const& A_sf, @@ -24,12 +24,22 @@ void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A, torch::Tensor const& alpha); #endif +#if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120 +void cutlass_scaled_fp4_mm_sm120a(torch::Tensor& D, torch::Tensor const& A, + torch::Tensor const& B, + torch::Tensor const& A_sf, + torch::Tensor const& B_sf, + torch::Tensor const& alpha); +#endif + void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A, torch::Tensor const& B, torch::Tensor const& A_sf, torch::Tensor const& B_sf, torch::Tensor const& alpha) { -#if defined ENABLE_NVFP4 && ENABLE_NVFP4 +#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100 return cutlass_scaled_fp4_mm_sm100a(D, A, B, A_sf, B_sf, alpha); +#elif defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120 + return cutlass_scaled_fp4_mm_sm120a(D, A, B, A_sf, B_sf, alpha); #endif TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 mm kernel, vLLM should " diff --git a/csrc/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu b/csrc/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu new file mode 100644 index 0000000000000..89de23b76e65d --- /dev/null +++ b/csrc/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu @@ -0,0 +1,285 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include + +#include "cutlass_extensions/common.hpp" + +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include "cutlass/util/packed_stride.hpp" + +#include "core/math.hpp" + +using namespace cute; + +#define CHECK_TYPE(x, st, m) \ + TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m) +#define CHECK_TH_CUDA(x, m) \ + TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x, m) \ + TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous") +#define CHECK_INPUT(x, st, m) \ + CHECK_TH_CUDA(x, m); \ + CHECK_CONTIGUOUS(x, m); \ + CHECK_TYPE(x, st, m) + +constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte; +constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn; + +struct sm120_fp4_config_M256 { + using ClusterShape = Shape<_1, _1, _1>; + using MmaTileShape = Shape<_128, _128, _128>; + using PerSmTileShape_MNK = Shape<_128, _128, _128>; +}; + +struct sm120_fp4_config_default { + using ClusterShape = Shape<_1, _1, _1>; + using MmaTileShape = Shape<_256, _128, _128>; + using PerSmTileShape_MNK = Shape<_256, _128, _128>; +}; + +template +struct Fp4GemmSm120 { + using ElementA = cutlass::nv_float4_t; + using LayoutATag = cutlass::layout::RowMajor; + static constexpr int AlignmentA = 32; + + using ElementB = cutlass::nv_float4_t; + using LayoutBTag = cutlass::layout::ColumnMajor; + static constexpr int AlignmentB = 32; + + using ElementD = OutType; + using ElementC = OutType; + using LayoutCTag = cutlass::layout::RowMajor; + using LayoutDTag = cutlass::layout::RowMajor; + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + + using ElementAccumulator = float; + using ArchTag = cutlass::arch::Sm120; + using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; + + using MmaTileShape = typename Config::MmaTileShape; + using ClusterShape = typename Config::ClusterShape; + using PerSmTileShape_MNK = typename Config::PerSmTileShape_MNK; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, PerSmTileShape_MNK, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, + ElementAccumulator, ElementC, LayoutCTag, AlignmentC, ElementD, + LayoutDTag, AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, ElementA, LayoutATag, AlignmentA, ElementB, + LayoutBTag, AlignmentB, ElementAccumulator, MmaTileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, CollectiveMainloop, CollectiveEpilogue, void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +}; + +template +typename Gemm::Arguments args_from_options(at::Tensor& D, at::Tensor const& A, + at::Tensor const& B, + at::Tensor const& A_sf, + at::Tensor const& B_sf, + torch::Tensor const& alpha, int M, + int N, int K) { + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementD = typename Gemm::ElementD; + using ElementSFA = cutlass::float_ue4m3_t; + using ElementSFB = cutlass::float_ue4m3_t; + using ElementCompute = float; + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + using Sm1xxBlkScaledConfig = + typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + + auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1}); + auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1}); + auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1}); + + auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA( + cute::make_shape(M, N, K, 1)); + auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB( + cute::make_shape(M, N, K, 1)); + + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K, 1}, + {static_cast(A.data_ptr()), stride_A, + static_cast(B.data_ptr()), stride_B, + static_cast(A_sf.data_ptr()), layout_SFA, + static_cast(B_sf.data_ptr()), layout_SFB}, + {{}, + static_cast(D.data_ptr()), + stride_D, + static_cast(D.data_ptr()), + stride_D}}; + auto& fusion_args = arguments.epilogue.thread; + fusion_args.alpha_ptr = static_cast(alpha.data_ptr()); + + return arguments; +} + +template +void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B, + at::Tensor const& A_sf, at::Tensor const& B_sf, + torch::Tensor const& alpha, int M, int N, int K, + cudaStream_t stream) { + Gemm gemm; + + auto arguments = args_from_options(D, A, B, A_sf, B_sf, alpha, M, N, K); + + size_t workspace_size = Gemm::get_workspace_size(arguments); + auto const workspace_options = + torch::TensorOptions().dtype(torch::kUInt8).device(A.device()); + auto workspace = torch::empty(workspace_size, workspace_options); + + CUTLASS_CHECK(gemm.can_implement(arguments)); + + CUTLASS_CHECK(gemm.initialize(arguments, workspace.data_ptr(), stream)); + + CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream)); +} + +void cutlass_fp4_bf16_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A, + torch::Tensor const& B, + torch::Tensor const& A_sf, + torch::Tensor const& B_sf, + torch::Tensor const& alpha, int m, int n, + int k, cudaStream_t stream) { + uint32_t const mp2 = std::max(static_cast(16), next_pow_2(m)); + if (mp2 <= 256) { + runGemm::Gemm>( + D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + } else { + runGemm::Gemm>( + D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + } +} + +void cutlass_fp4_f16_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A, + torch::Tensor const& B, + torch::Tensor const& A_sf, + torch::Tensor const& B_sf, + torch::Tensor const& alpha, int m, int n, + int k, cudaStream_t stream) { + uint32_t const mp2 = std::max(static_cast(16), next_pow_2(m)); + if (mp2 <= 256) { + runGemm::Gemm>( + D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + } else { + runGemm::Gemm>( + D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + } +} + +void cutlass_scaled_fp4_mm_sm120a(torch::Tensor& D, torch::Tensor const& A, + torch::Tensor const& B, + torch::Tensor const& A_sf, + torch::Tensor const& B_sf, + torch::Tensor const& alpha) { +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) + CHECK_INPUT(A, FLOAT4_E2M1X2, "a"); + CHECK_INPUT(B, FLOAT4_E2M1X2, "b"); + + CHECK_INPUT(A_sf, SF_DTYPE, "scale_a"); + CHECK_INPUT(B_sf, SF_DTYPE, "scale_b"); + + CHECK_INPUT(alpha, at::ScalarType::Float, "alpha"); + + TORCH_CHECK(A.dim() == 2, "a must be a matrix"); + TORCH_CHECK(B.dim() == 2, "b must be a matrix"); + TORCH_CHECK(A.sizes()[1] == B.sizes()[1], + "a and b shapes cannot be multiplied (", A.sizes()[0], "x", + A.sizes()[1], " and ", B.sizes()[0], "x", B.sizes()[1], ")"); + + auto const m = A.sizes()[0]; + auto const n = B.sizes()[0]; + auto const k = A.sizes()[1] * 2; + + constexpr int alignment = 32; + TORCH_CHECK(k % alignment == 0, "Expected k to be divisible by ", alignment, + ", but got a shape: (", A.sizes()[0], "x", A.sizes()[1], + "), k: ", k, "."); + TORCH_CHECK(n % alignment == 0, "Expected n to be divisible by ", alignment, + ", but got b shape: (", B.sizes()[0], "x", B.sizes()[1], ")."); + + auto round_up = [](int x, int y) { return (x + y - 1) / y * y; }; + int rounded_m = round_up(m, 128); + int rounded_n = round_up(n, 128); + // Since k is divisible by 32 (alignment), k / 16 is guaranteed to be an + // integer. + int rounded_k = round_up(k / 16, 4); + + TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix"); + TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix"); + TORCH_CHECK(A_sf.sizes()[1] == B_sf.sizes()[1], + "scale_a and scale_b shapes cannot be multiplied (", + A_sf.sizes()[0], "x", A_sf.sizes()[1], " and ", B_sf.sizes()[0], + "x", B_sf.sizes()[1], ")"); + TORCH_CHECK(A_sf.sizes()[0] == rounded_m && A_sf.sizes()[1] == rounded_k, + "scale_a must be padded and swizzled to a shape (", rounded_m, + "x", rounded_k, "), but got a shape (", A_sf.sizes()[0], "x", + A_sf.sizes()[1], ")"); + TORCH_CHECK(B_sf.sizes()[0] == rounded_n && B_sf.sizes()[1] == rounded_k, + "scale_b must be padded and swizzled to a shape (", rounded_n, + "x", rounded_k, "), but got a shape (", B_sf.sizes()[0], "x", + B_sf.sizes()[1], ")"); + + auto out_dtype = D.dtype(); + const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device()); + + if (out_dtype == at::ScalarType::BFloat16) { + return cutlass_fp4_bf16_gemm_dispatch(D, A, B, A_sf, B_sf, alpha, m, n, k, + stream); + } else if (out_dtype == at::ScalarType::Half) { + return cutlass_fp4_f16_gemm_dispatch(D, A, B, A_sf, B_sf, alpha, m, n, k, + stream); + } else { + TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm sm120 (", + out_dtype, ")"); + } +#else + TORCH_CHECK(false, + "Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to " + "a CUTLASS 3.8 source directory to enable support."); +#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) +} \ No newline at end of file