// clang-format will break include orders // clang-format off #include #if defined CUDA_VERSION && CUDA_VERSION >= 12020 #include "sparse_scaled_mm_c3x.cuh" // clang-format on using namespace cute; using namespace vllm; template typename Epilogue, typename... EpilogueArgs> void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& bt_nzs, torch::Tensor const& bt_meta, EpilogueArgs&&... args) { static_assert(std::is_same()); TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(bt_meta.dtype() == torch::kUInt8); TORCH_CHECK(bt_nzs.dtype() == torch::kFloat8_e4m3fn); using Cutlass3xGemmDefault = typename sm90_config_default::Cutlass3xGemm; using Cutlass3xGemmM64 = typename sm90_fp8_config_M64::Cutlass3xGemm; using Cutlass3xGemmM128 = typename sm90_fp8_config_M128::Cutlass3xGemm; using Cutlass3xGemmM256 = typename sm90_fp8_config_M256::Cutlass3xGemm; using Cutlass3xGemmM512 = typename sm90_fp8_config_M512::Cutlass3xGemm; using Cutlass3xGemm1 = typename sm90_fp8_config_1::Cutlass3xGemm; using Cutlass3xGemm2 = typename sm90_fp8_config_2::Cutlass3xGemm; using Cutlass3xGemm3 = typename sm90_fp8_config_3::Cutlass3xGemm; using Cutlass3xGemm4 = typename sm90_fp8_config_4::Cutlass3xGemm; using Cutlass3xGemm5 = typename sm90_fp8_config_5::Cutlass3xGemm; using Cutlass3xGemm6 = typename sm90_fp8_config_6::Cutlass3xGemm; using Cutlass3xGemm7 = typename sm90_fp8_config_7::Cutlass3xGemm; using Cutlass3xGemm8 = typename sm90_fp8_config_8::Cutlass3xGemm; uint32_t const n = bt_nzs.size(0); uint32_t const m = a.size(0); // Batch size uint32_t const mp2 = std::max(static_cast(64), next_pow_2(m)); // next power of 2 if (mp2 <= 64) { if (n == 28672) { return cutlass_sparse_gemm_caller( out, a, bt_nzs, bt_meta, std::forward(args)...); } else if (n == 4096 || n == 6144) { return cutlass_sparse_gemm_caller( out, a, bt_nzs, bt_meta, std::forward(args)...); } } else if (mp2 <= 128) { if (n == 4096) { return cutlass_sparse_gemm_caller( out, a, bt_nzs, bt_meta, std::forward(args)...); } else if (n == 28672) { return cutlass_sparse_gemm_caller( out, a, bt_nzs, bt_meta, std::forward(args)...); } else if (n == 6144) { return cutlass_sparse_gemm_caller( out, a, bt_nzs, bt_meta, std::forward(args)...); } } else if (mp2 <= 256) { if (n == 4096) { return cutlass_sparse_gemm_caller( out, a, bt_nzs, bt_meta, std::forward(args)...); } else if (n == 28672) { return cutlass_sparse_gemm_caller( out, a, bt_nzs, bt_meta, std::forward(args)...); } else if (n == 6144) { return cutlass_sparse_gemm_caller( out, a, bt_nzs, bt_meta, std::forward(args)...); } } else { if (n == 6144 || n == 28672) { return cutlass_sparse_gemm_caller( out, a, bt_nzs, bt_meta, std::forward(args)...); } else if (n == 4096) { return cutlass_sparse_gemm_caller( out, a, bt_nzs, bt_meta, std::forward(args)...); } } // Otherwise the default heuristic if (mp2 <= 64) { // n in [1, 64] return cutlass_sparse_gemm_caller( out, a, bt_nzs, bt_meta, std::forward(args)...); } else if (mp2 <= 128) { // n in (64, 128] return cutlass_sparse_gemm_caller( out, a, bt_nzs, bt_meta, std::forward(args)...); } else if (mp2 <= 256) { // n in (128, 256] return cutlass_sparse_gemm_caller( out, a, bt_nzs, bt_meta, std::forward(args)...); } else { // n in (256, inf) return cutlass_sparse_gemm_caller( out, a, bt_nzs, bt_meta, std::forward(args)...); } } template typename Epilogue, typename... EpilogueArgs> void cutlass_gemm_sm90_fp16_dispatch(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& bt_nzs, torch::Tensor const& bt_meta, EpilogueArgs&&... args) { static_assert(std::is_same()); TORCH_CHECK(a.dtype() == torch::kFloat16); TORCH_CHECK(bt_meta.dtype() == torch::kUInt8); TORCH_CHECK(bt_nzs.dtype() == torch::kFloat16); using Cutlass3xGemmDefault = typename sm90_config_default::Cutlass3xGemm; // m in (128, inf) return cutlass_sparse_gemm_caller( out, a, bt_nzs, bt_meta, std::forward(args)...); } template typename Epilogue, typename... EpilogueArgs> void cutlass_gemm_sm90_bf16_dispatch(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& bt_nzs, torch::Tensor const& bt_meta, EpilogueArgs&&... args) { static_assert(std::is_same()); TORCH_CHECK(a.dtype() == torch::kBFloat16); TORCH_CHECK(bt_meta.dtype() == torch::kUInt8); TORCH_CHECK(bt_nzs.dtype() == torch::kBFloat16); using Cutlass3xGemmDefault = typename sm90_config_default::Cutlass3xGemm; // m in (128, inf) return cutlass_sparse_gemm_caller( out, a, bt_nzs, bt_meta, std::forward(args)...); } template typename Epilogue, typename... EpilogueArgs> void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& bt_nzs, torch::Tensor const& bt_meta, EpilogueArgs&&... args) { static_assert(std::is_same()); TORCH_CHECK(a.dtype() == torch::kInt8); TORCH_CHECK(bt_meta.dtype() == torch::kUInt8); TORCH_CHECK(bt_nzs.dtype() == torch::kInt8); using Cutlass3xGemmDefault = typename sm90_config_default::Cutlass3xGemm; using Cutlass3xGemmM128 = typename sm90_int8_config_M128::Cutlass3xGemm; using Cutlass3xGemmM64 = typename sm90_int8_config_M64::Cutlass3xGemm; using Cutlass3xGemmM32NBig = typename sm90_int8_config_M32_NBig::Cutlass3xGemm; using Cutlass3xGemmM32NSmall = typename sm90_int8_config_M32_NSmall::Cutlass3xGemm; uint32_t const n = out.size(1); bool const is_small_n = n < 8192; uint32_t const m = a.size(0); uint32_t const mp2 = std::max(static_cast(32), next_pow_2(m)); // next power of 2 if (mp2 <= 32) { // m in [1, 32] if (is_small_n) { return cutlass_sparse_gemm_caller( out, a, bt_nzs, bt_meta, std::forward(args)...); } else { return cutlass_sparse_gemm_caller( out, a, bt_nzs, bt_meta, std::forward(args)...); } } else if (mp2 <= 64) { // m in (32, 64] return cutlass_sparse_gemm_caller( out, a, bt_nzs, bt_meta, std::forward(args)...); } else if (mp2 <= 128) { // m in (64, 128] return cutlass_sparse_gemm_caller( out, a, bt_nzs, bt_meta, std::forward(args)...); } else { // m in (128, inf) return cutlass_sparse_gemm_caller( out, a, bt_nzs, bt_meta, std::forward(args)...); } } template