// // Based off of: // https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu // #include #include #include #include "cutlass_extensions/torch_utils.hpp" #include "core/registration.h" #include "cutlass/cutlass.h" #include #include "cute/tensor.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" #include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/gemm/device/gemm_universal_adapter.h" #include "cutlass/gemm/kernel/gemm_universal.hpp" #include "cutlass/util/packed_stride.hpp" #include "cutlass/util/mixed_dtype_utils.hpp" #include "cutlass_extensions/common.hpp" #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" #include namespace vllm::cutlass_w4a8 { using namespace cute; // ------------------------------------------------------------------------------------- // Static configuration shared across all instantiations // ------------------------------------------------------------------------------------- using MmaType = cutlass::float_e4m3_t; // A/scale element type using QuantType = cutlass::int4b_t; // B element type (packed int4) static int constexpr TileShapeK = 128 * 8 / sizeof_bits::value; static int constexpr ScalePackSize = 8; // pack 8 scale elements together static int constexpr PackFactor = 8; // 8 4-bit packed into int32 // A matrix configuration using ElementA = MmaType; // Element type for A matrix operand using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose::type; constexpr int AlignmentA = 128 / cutlass::sizeof_bits< ElementA>::value; // Memory access granularity/alignment of A // matrix in units of elements (up to 16 bytes) using StrideA = cutlass::detail::TagToStrideA_t; // B matrix configuration using ElementB = QuantType; // Element type for B matrix operand using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose::type; constexpr int AlignmentB = 128 / cutlass::sizeof_bits< ElementB>::value; // Memory access granularity/alignment of B // matrix in units of elements (up to 16 bytes) using StrideB = cutlass::detail::TagToStrideB_t; // Define the CuTe layout for reordered quantized tensor B // LayoutAtomQuant places values that will be read by the same thread in // contiguous locations in global memory. It specifies the reordering within a // single warp's fragment using LayoutAtomQuant = decltype(cutlass::compute_memory_reordering_atom()); using LayoutB_Reordered = decltype(cute::tile_to_shape( LayoutAtomQuant{}, Layout, StrideB>{})); // Group-wise scales using ElementScale = MmaType; using LayoutScale = cutlass::layout::RowMajor; // Per-tok, per-chan scales using ElementSChannel = float; // C/D matrix configuration using ElementC = cutlass::bfloat16_t; // Element type for C and D matrix operands using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands constexpr int AlignmentC = 128 / cutlass::sizeof_bits< ElementC>::value; // Memory access granularity/alignment of C // matrix in units of elements (up to 16 bytes) using ElementD = ElementC; using LayoutD = LayoutC; constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Core kernel configurations using ElementAccumulator = float; // Element type for internal accumulation using ElementCompute = float; // Element type for epilogue computation using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that // supports the intended feature using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperative; // Kernel to launch // based on the default // setting in the // Collective Builder using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; // ---------------------------------------------------------------------------- // Kernel template — Tile/Cluster shapes // ---------------------------------------------------------------------------- template struct W4A8GemmKernel { using TileShape = decltype(cute::append(TileShape_MN{}, cute::Int{})); using ClusterShape = ClusterShape_MNK; // Epilogue per-tok, per-chan scales using ChTokScalesEpilogue = typename vllm::c3x::ScaledEpilogue; using EVTCompute = typename ChTokScalesEpilogue::EVTCompute; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType, ElementAccumulator, ElementSChannel, // Transpose layout of D here since we use explicit swap + transpose // the void type for C tells the builder to allocate 0 smem for the C // matrix. We can enable this if beta == 0 by changing ElementC to // void below. ElementC, typename cutlass::layout::LayoutTranspose::type, AlignmentC, ElementD, typename cutlass::layout::LayoutTranspose::type, AlignmentD, EpilogueSchedule, // This is the only epi supporting the required // swap + transpose. EVTCompute>::CollectiveOp; // The Scale information must get paired with the operand that will be scaled. // In this example, B is scaled so we make a tuple of B's information and the // scale information. using CollectiveMainloopShuffled = typename cutlass::gemm::collective::CollectiveBuilder< ArchTag, OperatorClass, cute::tuple>, LayoutB_Reordered, AlignmentB, ElementA, LayoutA_Transpose, AlignmentA, ElementAccumulator, TileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout( sizeof(typename CollectiveEpilogue::SharedStorage))>, KernelSchedule>::CollectiveOp; using GemmKernelShuffled = cutlass::gemm::kernel::GemmUniversal< Shape, // Indicates ProblemShape CollectiveMainloopShuffled, CollectiveEpilogue>; using GemmShuffled = cutlass::gemm::device::GemmUniversalAdapter; using StrideC = typename GemmKernelShuffled::StrideC; using StrideD = typename GemmKernelShuffled::StrideD; using StrideS = typename CollectiveMainloopShuffled::StrideScale; static torch::Tensor mm(torch::Tensor const& A, torch::Tensor const& B, // already packed torch::Tensor const& group_scales, // already packed int64_t group_size, torch::Tensor const& channel_scales, torch::Tensor const& token_scales, std::optional const& maybe_out_type) { // TODO: param validation int m = A.size(0); int k = A.size(1); int n = B.size(1); // safely cast group_size to int TORCH_CHECK(group_size > 0 && group_size <= std::numeric_limits::max(), "group_size out of supported range for int: ", group_size); int const group_size_int = static_cast(group_size); // Allocate output const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); auto device = A.device(); auto stream = at::cuda::getCurrentCUDAStream(device.index()); torch::Tensor D = torch::empty({m, n}, torch::TensorOptions() .dtype(equivalent_scalar_type_v) .device(device)); // prepare arg pointers auto A_ptr = static_cast(A.const_data_ptr()); auto B_ptr = static_cast(B.const_data_ptr()); auto D_ptr = static_cast(D.data_ptr()); // can we avoid hardcode the 8 here auto S_ptr = static_cast const*>( group_scales.const_data_ptr()); // runtime layout for B auto shape_B = cute::make_shape(n, k, 1); LayoutB_Reordered layout_B_reordered = cute::tile_to_shape(LayoutAtomQuant{}, shape_B); // strides int const scale_k = cutlass::ceil_div(k, group_size_int); StrideA stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1)); // Reverse stride here due to swap and transpose StrideD stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(n, m, 1)); StrideS stride_S = cutlass::make_cute_packed_stride( StrideS{}, cute::make_shape(n, scale_k, 1)); // Create a structure of gemm kernel arguments suitable for invoking an // instance of Gemm auto arguments = // args_from_options(options); /// Populates a Gemm::Arguments structure from the given arguments /// Swap the A and B tensors, as well as problem shapes here. using Args = typename GemmShuffled::Arguments; using MainloopArguments = typename GemmKernelShuffled::MainloopArguments; using EpilogueArguments = typename GemmKernelShuffled::EpilogueArguments; MainloopArguments mainloop_arguments{ B_ptr, layout_B_reordered, A_ptr, stride_A, S_ptr, stride_S, group_size_int}; EpilogueArguments epilogue_arguments{ ChTokScalesEpilogue::prepare_args(channel_scales, token_scales), nullptr, {}, // no C D_ptr, stride_D}; Args arguments{cutlass::gemm::GemmUniversalMode::kGemm, {n, m, k, 1}, // shape mainloop_arguments, epilogue_arguments}; // Workspace size_t workspace_size = GemmShuffled::get_workspace_size(arguments); torch::Tensor workspace = torch::empty(workspace_size, torch::TensorOptions().dtype(torch::kU8).device(device)); // Run GEMM GemmShuffled gemm; CUTLASS_CHECK(gemm.can_implement(arguments)); CUTLASS_CHECK(gemm.initialize(arguments, workspace.data_ptr(), stream)); CUTLASS_CHECK(gemm.run(stream)); return D; } }; // ---------------------------------------------------------------------------- // Kernel instantiations and dispatch logic // ---------------------------------------------------------------------------- using Kernel_256x128_1x1x1 = W4A8GemmKernel, Shape<_1, _1, _1>>; using Kernel_256x64_1x1x1 = W4A8GemmKernel, Shape<_1, _1, _1>>; using Kernel_256x32_1x1x1 = W4A8GemmKernel, Shape<_1, _1, _1>>; using Kernel_256x16_1x1x1 = W4A8GemmKernel, Shape<_1, _1, _1>>; using Kernel_128x256_2x1x1 = W4A8GemmKernel, Shape<_2, _1, _1>>; using Kernel_128x256_1x1x1 = W4A8GemmKernel, Shape<_1, _1, _1>>; using Kernel_128x128_1x1x1 = W4A8GemmKernel, Shape<_1, _1, _1>>; using Kernel_128x64_1x1x1 = W4A8GemmKernel, Shape<_1, _1, _1>>; using Kernel_128x32_1x1x1 = W4A8GemmKernel, Shape<_1, _1, _1>>; using Kernel_128x16_1x1x1 = W4A8GemmKernel, Shape<_1, _1, _1>>; torch::Tensor mm_dispatch(torch::Tensor const& A, torch::Tensor const& B, // already packed torch::Tensor const& group_scales, // already packed int64_t group_size, torch::Tensor const& channel_scales, torch::Tensor const& token_scales, std::optional const& maybe_out_type, const std::string& schedule) { if (schedule == "256x128_1x1x1") { return Kernel_256x128_1x1x1::mm(A, B, group_scales, group_size, channel_scales, token_scales, maybe_out_type); } else if (schedule == "256x64_1x1x1") { return Kernel_256x64_1x1x1::mm(A, B, group_scales, group_size, channel_scales, token_scales, maybe_out_type); } else if (schedule == "256x32_1x1x1") { return Kernel_256x32_1x1x1::mm(A, B, group_scales, group_size, channel_scales, token_scales, maybe_out_type); } else if (schedule == "256x16_1x1x1") { return Kernel_256x16_1x1x1::mm(A, B, group_scales, group_size, channel_scales, token_scales, maybe_out_type); } else if (schedule == "128x256_2x1x1") { return Kernel_128x256_2x1x1::mm(A, B, group_scales, group_size, channel_scales, token_scales, maybe_out_type); } else if (schedule == "128x256_1x1x1") { return Kernel_128x256_1x1x1::mm(A, B, group_scales, group_size, channel_scales, token_scales, maybe_out_type); } else if (schedule == "128x128_1x1x1") { return Kernel_128x128_1x1x1::mm(A, B, group_scales, group_size, channel_scales, token_scales, maybe_out_type); } else if (schedule == "128x64_1x1x1") { return Kernel_128x64_1x1x1::mm(A, B, group_scales, group_size, channel_scales, token_scales, maybe_out_type); } else if (schedule == "128x32_1x1x1") { return Kernel_128x32_1x1x1::mm(A, B, group_scales, group_size, channel_scales, token_scales, maybe_out_type); } else if (schedule == "128x16_1x1x1") { return Kernel_128x16_1x1x1::mm(A, B, group_scales, group_size, channel_scales, token_scales, maybe_out_type); } TORCH_CHECK(false, "Unknown W4A8 schedule: ", schedule); return {}; } torch::Tensor mm(torch::Tensor const& A, torch::Tensor const& B, // already packed torch::Tensor const& group_scales, // already packed int64_t group_size, torch::Tensor const& channel_scales, torch::Tensor const& token_scales, std::optional const& maybe_out_type, std::optional maybe_schedule) { // requested a specific schedule if (maybe_schedule) { return mm_dispatch(A, B, group_scales, group_size, channel_scales, token_scales, maybe_out_type, *maybe_schedule); } std::string schedule; int M = A.size(0); int K = A.size(1); int N = B.size(1); // heuristic if (M <= 16) { schedule = (K == 16384 && N == 18432) ? "256x16_1x1x1" : "128x16_1x1x1"; } else if (M <= 32) { schedule = (K == 16384 && N == 18432) ? "256x32_1x1x1" : "128x32_1x1x1"; } else if (M <= 64) { if (K == 16384 && N == 18432) schedule = "256x64_1x1x1"; else if (N <= 8192 && K <= 8192) schedule = "128x32_1x1x1"; else schedule = "128x64_1x1x1"; } else if (M <= 128) { if (K == 16384 && N == 18432) schedule = "256x128_1x1x1"; else if (N <= 8192) schedule = "128x64_1x1x1"; else schedule = "128x128_1x1x1"; } else if (M <= 256) { if (N <= 4096) schedule = "128x64_1x1x1"; else if (N <= 8192) schedule = "128x128_1x1x1"; else schedule = "128x256_1x1x1"; } else if (M <= 512 && N <= 4096) { schedule = "128x128_1x1x1"; } else if (M <= 1024) { schedule = "128x256_1x1x1"; } else { schedule = "128x256_2x1x1"; } return mm_dispatch(A, B, group_scales, group_size, channel_scales, token_scales, maybe_out_type, schedule); } // ---------------------------------------------------------------------------- // Pre-processing utils // ---------------------------------------------------------------------------- torch::Tensor pack_scale_fp8(torch::Tensor const& scales) { TORCH_CHECK(scales.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(scales.is_contiguous()); TORCH_CHECK(scales.is_cuda()); auto packed_scales = torch::empty( {scales.numel() * ScalePackSize}, torch::TensorOptions().dtype(scales.dtype()).device(scales.device())); auto scales_ptr = static_cast(scales.const_data_ptr()); auto packed_scales_ptr = static_cast*>( packed_scales.data_ptr()); cutlass::pack_scale_fp8(scales_ptr, packed_scales_ptr, scales.numel()); return packed_scales; } /* GPU-accelerated implementation of cutlass::unified_encode_int4b. Constructs a lookup table in constant memory to map 8 bits (two 4-bit values) at a time. Assumes memory is contiguous and pointers are 16-byte aligned. */ __constant__ uint8_t kNibbleLUT[256]; __global__ void unified_encode_int4b_device(const uint8_t* in, uint8_t* out, size_t nbytes) { constexpr size_t V = sizeof(uint4); // 16 bytes const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; const size_t nthreads = size_t(gridDim.x) * blockDim.x; const size_t nvec = nbytes / V; // 1-D grid-stride loop over 16-byte chunks for (size_t vec = tid; vec < nvec; vec += nthreads) { uint4 v = reinterpret_cast(in)[vec]; uint8_t* b = reinterpret_cast(&v); #pragma unroll for (int i = 0; i < int(V); ++i) b[i] = kNibbleLUT[b[i]]; reinterpret_cast(out)[vec] = v; } } static bool upload_lut() { std::array lut{}; auto map_nib = [](uint8_t v) -> uint8_t { // 1..7 -> (8 - v); keep 0 and 8..15 return (v == 0 || (v & 0x8)) ? v : uint8_t(8 - v); }; for (int b = 0; b < 256; ++b) { uint8_t lo = b & 0xF; uint8_t hi = (b >> 4) & 0xF; lut[b] = uint8_t((map_nib(hi) << 4) | map_nib(lo)); } cudaError_t e = cudaMemcpyToSymbol(kNibbleLUT, lut.data(), lut.size(), /*offset=*/0, cudaMemcpyHostToDevice); return (e == cudaSuccess); } static bool unified_encode_int4b(cutlass::int4b_t const* in, cutlass::int4b_t* out, size_t num_int4_elems) { // Build/upload LUT if (!upload_lut()) return false; static_assert(sizeof(typename cutlass::int4b_t::Storage) == 1, "int4 storage must be 1 byte"); const size_t nbytes = num_int4_elems >> 1; auto* in_bytes = reinterpret_cast(in); auto* out_bytes = reinterpret_cast(out); // kernel launch params constexpr int block = 256; const size_t nvec = nbytes / sizeof(uint4); // # of 16B vectors int grid = int((nvec + block - 1) / block); if (grid == 0) grid = 1; // ensure we still cover the tail in the kernel unified_encode_int4b_device<<>>(in_bytes, out_bytes, nbytes); cudaError_t err = cudaGetLastError(); return (err == cudaSuccess); } torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) { TORCH_CHECK(B.dtype() == torch::kInt32); TORCH_CHECK(B.dim() == 2); torch::Tensor B_packed = torch::empty_like(B); int k = B.size(0) * PackFactor; // logical k int n = B.size(1); TORCH_CHECK((n * k) % 32 == 0, "need multiples of 32 int4s for 16B chunks"); auto B_ptr = static_cast(B.const_data_ptr()); auto B_packed_ptr = static_cast(B_packed.data_ptr()); auto shape_B = cute::make_shape(n, k, 1); auto layout_B = make_layout(shape_B, LayoutRight{}); // row major LayoutB_Reordered layout_B_reordered = cute::tile_to_shape(LayoutAtomQuant{}, shape_B); bool ok = vllm::cutlass_w4a8::unified_encode_int4b(B_ptr, B_packed_ptr, n * k); TORCH_CHECK(ok, "unified_encode_int4b failed"); cutlass::reorder_tensor(B_packed_ptr, layout_B, layout_B_reordered); return B_packed; } TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { m.impl("cutlass_w4a8_mm", &mm); m.impl("cutlass_pack_scale_fp8", &pack_scale_fp8); m.impl("cutlass_encode_and_reorder_int4b", &encode_and_reorder_int4b); } } // namespace vllm::cutlass_w4a8