// clang-format will break include orders // clang-format off #include #if defined CUDA_VERSION && CUDA_VERSION >= 12020 #include "sparse_scaled_mm_c3x.cuh" #include "cutlass/numeric_conversion.h" #include "cutlass/transform/device/transform_universal_adapter.hpp" #include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" #include "cutlass/epilogue/collective/default_epilogue.hpp" #include "cutlass/util/host_tensor.h" #include "cutlass/util/packed_stride.hpp" // clang-format on using namespace cute; using namespace vllm; /// Make A structured sparse by replacing elements with 0 and compress it template bool cutlass_sparse_compress(torch::Tensor& a_nzs, torch::Tensor& a_meta, torch::Tensor const& a) { // Checks for conformality TORCH_CHECK(a.dtype() == torch::kInt8 || a.dtype() == torch::kFloat8_e4m3fn || a.dtype() == torch::kFloat16 || a.dtype() == torch::kBFloat16); TORCH_CHECK(a.dim() == 2) // Check for strides and alignment TORCH_CHECK(a.stride(0) % 4 == 0) // Required for semi-structured sparsity TORCH_CHECK(a.stride(1) == 1) int m = a.size(0); int k = a.size(1); // Sparse kernel setup; this kernel is not used for matmul, // but just for setting up the compressor utility // A matrix configuration using ElementA = ElementA_; using LayoutTagA = cutlass::layout::RowMajor; constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // B matrix configuration using ElementB = ElementA; using LayoutTagB = cutlass::layout::ColumnMajor; constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // C/D matrix configuration using ElementC = float; using LayoutTagC = cutlass::layout::ColumnMajor; constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Core kernel configurations using ElementAccumulator = ElementAcc_; using TileShape = Shape<_128, _128, _128>; using TileShapeRef = Shape<_128, _128, _64>; using ClusterShape = Shape<_1, _2, _1>; using KernelSchedule = typename std::conditional< std::is_same_v, cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum, cutlass::gemm::KernelTmaWarpSpecialized>::type; using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; using ProblemShape = Shape; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementAccumulator, ElementC, LayoutTagC, AlignmentC, ElementC, LayoutTagC, AlignmentC, EpilogueSchedule>::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassSparseTensorOp, ElementA, LayoutTagA, AlignmentA, ElementB, LayoutTagB, AlignmentB, ElementAccumulator, TileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout( sizeof(typename CollectiveEpilogue::SharedStorage))>, KernelSchedule>::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; using StrideA = cutlass::gemm::TagToStrideA_t; using StrideE = StrideA; using StrideA = Stride, int64_t>; // The n (=1) dimension does not matter for the compressor typename GemmKernel::ProblemShape prob_shape{m, 1, k, 1}; using LayoutA = typename GemmKernel::CollectiveMainloop::LayoutA; using LayoutE = typename GemmKernel::CollectiveMainloop::LayoutE; using ElementE = typename GemmKernel::CollectiveMainloop::ElementE; using SparseConfig = typename GemmKernel::CollectiveMainloop::SparseConfig; // Offline compressor kernel using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility< ProblemShape, ElementA, LayoutTagA, SparseConfig>; using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor< ProblemShape, ElementA, LayoutTagA, SparseConfig, cutlass::arch::Sm90>; using Compressor = cutlass::transform::device::TransformUniversalAdapter; auto [M, N, K, L] = prob_shape; StrideA stride_A; stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); CompressorUtility compressor_utility(prob_shape, stride_A); int ME = compressor_utility.get_metadata_m_physical(); int KE = compressor_utility.get_metadata_k_physical(); int KC = compressor_utility.get_tensorA_k_physical(); auto a_ptr = static_cast(a.data_ptr()); auto a_nzs_ptr = static_cast(a_nzs.data_ptr()); auto a_meta_ptr = static_cast( a_meta.data_ptr()); cutlass::KernelHardwareInfo hw_info; hw_info.device_id = 0; hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count( hw_info.device_id); typename Compressor::Arguments arguments{ prob_shape, {a_ptr, stride_A, a_nzs_ptr, a_meta_ptr}, {hw_info}}; Compressor compressor_op; size_t workspace_size = Compressor::get_workspace_size(arguments); cutlass::device_memory::allocation workspace(workspace_size); CUTLASS_CHECK(compressor_op.can_implement(arguments)); CUTLASS_CHECK(compressor_op.initialize(arguments, workspace.get())); CUTLASS_CHECK(compressor_op.run()); CUDA_CHECK(cudaDeviceSynchronize()); return true; } bool cutlass_sparse_compress_sm90(torch::Tensor& a_nzs, torch::Tensor& a_meta, torch::Tensor const& a) { if (a.dtype() == torch::kBFloat16) { return cutlass_sparse_compress(a_nzs, a_meta, a); } else if (a.dtype() == torch::kFloat16) { return cutlass_sparse_compress(a_nzs, a_meta, a); } else if (a.dtype() == torch::kFloat8_e4m3fn) { return cutlass_sparse_compress(a_nzs, a_meta, a); } else if (a.dtype() == torch::kInt8) { return cutlass_sparse_compress(a_nzs, a_meta, a); } return false; } #endif