Improve wvsplitK tile and balance heristics. (#29937)

Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
This commit is contained in:
Hashem Hashemi 2025-12-09 15:51:32 -08:00 committed by GitHub
parent 3c680f4a17
commit 2e7054da06
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1241,33 +1241,16 @@ __global__ void wvSplitK_hf_big_(const int K, const int M, const int Bx,
} }
#endif // defined(__HIP__GFX9__) TODO: Add NAVI support #endif // defined(__HIP__GFX9__) TODO: Add NAVI support
// Find the min val of div2 that doesn't increase N/(div1*div2)
int mindiv(int N, int div1, int div2) { int mindiv(int N, int div1, int div2) {
int nPrRnd = div1 * div2; int nPrRnd = div1 * div2;
int rnds0 = N / nPrRnd; int rnds[13];
nPrRnd -= div1 * 3; for (int i = 0; i < 13; i++) {
int rnds3 = N / nPrRnd; rnds[i] = (N + nPrRnd - 1) / nPrRnd;
nPrRnd -= div1; nPrRnd -= div1;
int rnds4 = N / nPrRnd; }
nPrRnd -= div1; for (int i = 12; i >= 0; i--)
int rnds5 = N / nPrRnd; if (rnds[0] == rnds[i]) return (div2 - i);
nPrRnd -= div1;
int rnds6 = N / nPrRnd;
nPrRnd -= div1;
int rnds7 = N / nPrRnd;
nPrRnd -= div1;
int rnds8 = N / nPrRnd;
nPrRnd -= div1;
int rnds9 = N / nPrRnd;
nPrRnd -= div1;
int rtn = div2;
if (rnds0 == rnds3) rtn = div2 - 3;
if (rnds0 == rnds4) rtn = div2 - 4;
if (rnds0 == rnds5) rtn = div2 - 5;
if (rnds0 == rnds6) rtn = div2 - 6;
if (rnds0 == rnds7) rtn = div2 - 7;
if (rnds0 == rnds8) rtn = div2 - 8;
if (rnds0 == rnds9) rtn = div2 - 9;
return rtn;
} }
torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b, torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
@ -1300,26 +1283,37 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int max_lds_len = get_lds_size() / 2; const int max_lds_len = get_lds_size() / 2;
#define WVSPLITK(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \ #define WVSPLITK(_YTILE, _UNRL, _N) \
_N) \ { \
{ \ dim3 block(64, 16); \
dim3 block(64, _WvPrGrp); \ int __wvPrGrp = mindiv(M_in, CuCount * _YTILE, 16); \
if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \ if ((K_in * N_in <= max_lds_len) && (M_in % _YTILE == 0)) \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \ wvSplitK_hf_sml_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
wvSplitK_hf_sml_<fptype, 64, _YTILEs, _WvPrGrp, 8, _UNRLs, _N> \ <<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \ biasf4, c, __wvPrGrp, CuCount); \
biasf4, c, __wvPrGrp, CuCount); \ else if (K_in * N_in <= max_lds_len * 1.2) \
} else if (K_in * N_in <= max_lds_len * 1.2) { \ wvSplitK_hf_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \ <<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
wvSplitK_hf_<fptype, 64, _YTILEm, _WvPrGrp, 8, _UNRLm, _N> \ biasf4, c, __wvPrGrp, CuCount); \
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \ else \
biasf4, c, __wvPrGrp, CuCount); \ wvSplitK_hf_big_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
} else { \ <<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEb, _WvPrGrp); \ biasf4, c, __wvPrGrp, CuCount); \
wvSplitK_hf_big_<fptype, 64, _YTILEb, _WvPrGrp, 8, _UNRLb, _N> \ }
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
biasf4, c, __wvPrGrp, CuCount); \ #define WVSPLIT_TILE(_sYT, __N) \
} \ { \
bool fit_lds = (K_in * N_in <= max_lds_len); \
if (_sYT <= 1) \
WVSPLITK(1, 4, __N) \
else if ((__N == 1) || (!fit_lds) || (_sYT <= 4 * 2)) \
WVSPLITK(2, 2, __N) \
else if (_sYT <= 4 * 3) \
WVSPLITK(3, 2, __N) \
else if (__N == 4) \
WVSPLITK(4, 1, __N) \
else \
WVSPLITK(4, 2, __N) \
} }
AT_DISPATCH_REDUCED_FLOATING_TYPES(in_b.scalar_type(), "wvSplitK", [&] { AT_DISPATCH_REDUCED_FLOATING_TYPES(in_b.scalar_type(), "wvSplitK", [&] {
@ -1331,18 +1325,23 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
? reinterpret_cast<const fptype*>(in_bias->data_ptr()) ? reinterpret_cast<const fptype*>(in_bias->data_ptr())
: nullptr; : nullptr;
fptype* c = reinterpret_cast<fptype*>(out_c.data_ptr()); fptype* c = reinterpret_cast<fptype*>(out_c.data_ptr());
// first shoot for biggest tile-size that keeps all simd busy,
// then cut the active waves to balance their distribution...
int sYT = (M_in + CuCount * 4 - 1) / (CuCount * 4);
switch (N_in) { switch (N_in) {
case 1: case 1:
WVSPLITK(16, 2, 2, 2, 2, 2, 2, 1) WVSPLIT_TILE(sYT, 1)
break; break;
case 2: case 2:
WVSPLITK(16, 2, 2, 2, 2, 2, 2, 2) WVSPLIT_TILE(sYT, 2)
break; break;
case 3: case 3:
WVSPLITK(16, 4, 7, 7, 1, 1, 1, 3) WVSPLIT_TILE(sYT, 3)
break; break;
case 4: case 4:
WVSPLITK(16, 4, 7, 7, 1, 1, 1, 4) WVSPLIT_TILE(sYT, 4)
break; break;
default: default:
throw std::runtime_error( throw std::runtime_error(