mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:15:01 +08:00
[BugFix] [Kernel] Add Cutlass2x fallback kernels (#5744)
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
parent
832ea88fcb
commit
6c916ac8a8
@ -17,3 +17,11 @@ inline uint32_t next_pow_2(uint32_t const num) {
|
||||
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
|
||||
}
|
||||
|
||||
inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
|
||||
int max_shared_mem_per_block_opt_in = 0;
|
||||
cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in,
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin,
|
||||
device);
|
||||
return max_shared_mem_per_block_opt_in;
|
||||
}
|
||||
|
||||
|
||||
@ -250,12 +250,39 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
template <typename Gemm, typename FallbackGemm, typename... EpilogueArgs>
|
||||
void fallback_cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... args) {
|
||||
// In some cases, the GPU isn't able to accommodate the
|
||||
// shared memory requirements of the Gemm. In such cases, use
|
||||
// the FallbackGemm instead.
|
||||
static const int max_shared_mem_per_block_opt_in =
|
||||
get_cuda_max_shared_memory_per_block_opt_in(0);
|
||||
|
||||
size_t const gemm_shared_mem_size =
|
||||
sizeof(typename Gemm::KernelType::SharedStorage);
|
||||
size_t const fallback_gemm_shared_mem_size =
|
||||
sizeof(typename FallbackGemm::KernelType::SharedStorage);
|
||||
|
||||
if (gemm_shared_mem_size <= max_shared_mem_per_block_opt_in) {
|
||||
return cutlass_gemm_caller<Gemm>(out, a, b,
|
||||
std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
TORCH_CHECK(fallback_gemm_shared_mem_size <=
|
||||
max_shared_mem_per_block_opt_in);
|
||||
return cutlass_gemm_caller<FallbackGemm>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue>
|
||||
struct sm80_config_default {
|
||||
// This config is used in 2 cases,
|
||||
// - M in (128, inf)
|
||||
// - M in (64, 128] and N >= 8192
|
||||
// Shared Memory required by this Gemm - 81920 bytes
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
@ -271,6 +298,7 @@ struct sm80_config_M64 {
|
||||
// This config is used in 2 cases,
|
||||
// - M in (32, 64]
|
||||
// - M in (64, 128] and N < 8192
|
||||
// Shared Memory required by this Gemm - 122880 bytes
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
@ -284,6 +312,7 @@ template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue>
|
||||
struct sm80_config_M32 {
|
||||
// M in (16, 32]
|
||||
// Shared Memory required by this Gemm - 61440 bytes
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
@ -297,6 +326,7 @@ template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue>
|
||||
struct sm80_config_M16 {
|
||||
// M in [1, 16]
|
||||
// Shared Memory required by this Gemm - 51200 bytes
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using TileShape = typename cutlass::gemm::GemmShape<16, 64, 128>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
|
||||
@ -331,35 +361,45 @@ void cutlass_gemm_sm80_dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
using Cutlass2xGemmM16 =
|
||||
typename sm80_config_M16<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||
|
||||
// Due to shared memory requirements, some Gemms may fail to run on some
|
||||
// GPUs. As the name indicates, the Fallback Gemm is used as an alternative
|
||||
// in such cases.
|
||||
// sm80_config_M16 has the least shared-memory requirement. However,
|
||||
// based on some profiling, we select sm80_config_M32 as a better alternative
|
||||
// performance wise.
|
||||
using FallbackGemm =
|
||||
typename sm80_config_M32<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||
|
||||
uint32_t const m = a.size(0);
|
||||
uint32_t const mp2 =
|
||||
std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
|
||||
if (mp2 <= 16) {
|
||||
// M in [1, 16]
|
||||
return cutlass_gemm_caller<Cutlass2xGemmM16>(
|
||||
return fallback_cutlass_gemm_caller<Cutlass2xGemmM16, FallbackGemm>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 32) {
|
||||
// M in (16, 32]
|
||||
return cutlass_gemm_caller<Cutlass2xGemmM32>(
|
||||
return fallback_cutlass_gemm_caller<Cutlass2xGemmM32, FallbackGemm>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 64) {
|
||||
// M in (32, 64]
|
||||
return cutlass_gemm_caller<Cutlass2xGemmM64>(
|
||||
return fallback_cutlass_gemm_caller<Cutlass2xGemmM64, FallbackGemm>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 128) {
|
||||
// M in (64, 128]
|
||||
uint32_t const n = out.size(1);
|
||||
bool const small_n = n < 8192;
|
||||
if (small_n) {
|
||||
return cutlass_gemm_caller<Cutlass2xGemmM128SmallN>(
|
||||
return fallback_cutlass_gemm_caller<Cutlass2xGemmM128SmallN,
|
||||
FallbackGemm>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
return cutlass_gemm_caller<Cutlass2xGemmM128BigN>(
|
||||
return fallback_cutlass_gemm_caller<Cutlass2xGemmM128BigN, FallbackGemm>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
} else {
|
||||
// M in (128, inf)
|
||||
return cutlass_gemm_caller<Cutlass2xGemmDefault>(
|
||||
return fallback_cutlass_gemm_caller<Cutlass2xGemmDefault, FallbackGemm>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user