diff --git a/csrc/rocm/skinny_gemms.cu b/csrc/rocm/skinny_gemms.cu index 2ef579a1b7537..8ebe55cef391d 100644 --- a/csrc/rocm/skinny_gemms.cu +++ b/csrc/rocm/skinny_gemms.cu @@ -1241,33 +1241,16 @@ __global__ void wvSplitK_hf_big_(const int K, const int M, const int Bx, } #endif // defined(__HIP__GFX9__) TODO: Add NAVI support +// Find the min val of div2 that doesn't increase N/(div1*div2) int mindiv(int N, int div1, int div2) { int nPrRnd = div1 * div2; - int rnds0 = N / nPrRnd; - nPrRnd -= div1 * 3; - int rnds3 = N / nPrRnd; - nPrRnd -= div1; - int rnds4 = N / nPrRnd; - nPrRnd -= div1; - int rnds5 = N / nPrRnd; - nPrRnd -= div1; - int rnds6 = N / nPrRnd; - nPrRnd -= div1; - int rnds7 = N / nPrRnd; - nPrRnd -= div1; - int rnds8 = N / nPrRnd; - nPrRnd -= div1; - int rnds9 = N / nPrRnd; - nPrRnd -= div1; - int rtn = div2; - if (rnds0 == rnds3) rtn = div2 - 3; - if (rnds0 == rnds4) rtn = div2 - 4; - if (rnds0 == rnds5) rtn = div2 - 5; - if (rnds0 == rnds6) rtn = div2 - 6; - if (rnds0 == rnds7) rtn = div2 - 7; - if (rnds0 == rnds8) rtn = div2 - 8; - if (rnds0 == rnds9) rtn = div2 - 9; - return rtn; + int rnds[13]; + for (int i = 0; i < 13; i++) { + rnds[i] = (N + nPrRnd - 1) / nPrRnd; + nPrRnd -= div1; + } + for (int i = 12; i >= 0; i--) + if (rnds[0] == rnds[i]) return (div2 - i); } torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b, @@ -1300,26 +1283,37 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b, const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const int max_lds_len = get_lds_size() / 2; -#define WVSPLITK(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \ - _N) \ - { \ - dim3 block(64, _WvPrGrp); \ - if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \ - int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \ - wvSplitK_hf_sml_ \ - <<>>(K_in, M_in, Bx_in, By_in, af4, bf4, \ - biasf4, c, __wvPrGrp, CuCount); \ - } else if (K_in * N_in <= max_lds_len * 1.2) { \ - int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \ - wvSplitK_hf_ \ - <<>>(K_in, M_in, Bx_in, By_in, af4, bf4, \ - biasf4, c, __wvPrGrp, CuCount); \ - } else { \ - int __wvPrGrp = mindiv(M_in, CuCount * _YTILEb, _WvPrGrp); \ - wvSplitK_hf_big_ \ - <<>>(K_in, M_in, Bx_in, By_in, af4, bf4, \ - biasf4, c, __wvPrGrp, CuCount); \ - } \ +#define WVSPLITK(_YTILE, _UNRL, _N) \ + { \ + dim3 block(64, 16); \ + int __wvPrGrp = mindiv(M_in, CuCount * _YTILE, 16); \ + if ((K_in * N_in <= max_lds_len) && (M_in % _YTILE == 0)) \ + wvSplitK_hf_sml_ \ + <<>>(K_in, M_in, Bx_in, By_in, af4, bf4, \ + biasf4, c, __wvPrGrp, CuCount); \ + else if (K_in * N_in <= max_lds_len * 1.2) \ + wvSplitK_hf_ \ + <<>>(K_in, M_in, Bx_in, By_in, af4, bf4, \ + biasf4, c, __wvPrGrp, CuCount); \ + else \ + wvSplitK_hf_big_ \ + <<>>(K_in, M_in, Bx_in, By_in, af4, bf4, \ + biasf4, c, __wvPrGrp, CuCount); \ + } + +#define WVSPLIT_TILE(_sYT, __N) \ + { \ + bool fit_lds = (K_in * N_in <= max_lds_len); \ + if (_sYT <= 1) \ + WVSPLITK(1, 4, __N) \ + else if ((__N == 1) || (!fit_lds) || (_sYT <= 4 * 2)) \ + WVSPLITK(2, 2, __N) \ + else if (_sYT <= 4 * 3) \ + WVSPLITK(3, 2, __N) \ + else if (__N == 4) \ + WVSPLITK(4, 1, __N) \ + else \ + WVSPLITK(4, 2, __N) \ } AT_DISPATCH_REDUCED_FLOATING_TYPES(in_b.scalar_type(), "wvSplitK", [&] { @@ -1331,18 +1325,23 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b, ? reinterpret_cast(in_bias->data_ptr()) : nullptr; fptype* c = reinterpret_cast(out_c.data_ptr()); + + // first shoot for biggest tile-size that keeps all simd busy, + // then cut the active waves to balance their distribution... + int sYT = (M_in + CuCount * 4 - 1) / (CuCount * 4); + switch (N_in) { case 1: - WVSPLITK(16, 2, 2, 2, 2, 2, 2, 1) + WVSPLIT_TILE(sYT, 1) break; case 2: - WVSPLITK(16, 2, 2, 2, 2, 2, 2, 2) + WVSPLIT_TILE(sYT, 2) break; case 3: - WVSPLITK(16, 4, 7, 7, 1, 1, 1, 3) + WVSPLIT_TILE(sYT, 3) break; case 4: - WVSPLITK(16, 4, 7, 7, 1, 1, 1, 4) + WVSPLIT_TILE(sYT, 4) break; default: throw std::runtime_error(