Improve wvsplitK tile and balance heristics. (#29937)

Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
This commit is contained in:
Hashem Hashemi 2025-12-09 15:51:32 -08:00 committed by GitHub
parent 3c680f4a17
commit 2e7054da06
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1241,33 +1241,16 @@ __global__ void wvSplitK_hf_big_(const int K, const int M, const int Bx,
}
#endif // defined(__HIP__GFX9__) TODO: Add NAVI support
// Find the min val of div2 that doesn't increase N/(div1*div2)
int mindiv(int N, int div1, int div2) {
int nPrRnd = div1 * div2;
int rnds0 = N / nPrRnd;
nPrRnd -= div1 * 3;
int rnds3 = N / nPrRnd;
nPrRnd -= div1;
int rnds4 = N / nPrRnd;
nPrRnd -= div1;
int rnds5 = N / nPrRnd;
nPrRnd -= div1;
int rnds6 = N / nPrRnd;
nPrRnd -= div1;
int rnds7 = N / nPrRnd;
nPrRnd -= div1;
int rnds8 = N / nPrRnd;
nPrRnd -= div1;
int rnds9 = N / nPrRnd;
nPrRnd -= div1;
int rtn = div2;
if (rnds0 == rnds3) rtn = div2 - 3;
if (rnds0 == rnds4) rtn = div2 - 4;
if (rnds0 == rnds5) rtn = div2 - 5;
if (rnds0 == rnds6) rtn = div2 - 6;
if (rnds0 == rnds7) rtn = div2 - 7;
if (rnds0 == rnds8) rtn = div2 - 8;
if (rnds0 == rnds9) rtn = div2 - 9;
return rtn;
int rnds[13];
for (int i = 0; i < 13; i++) {
rnds[i] = (N + nPrRnd - 1) / nPrRnd;
nPrRnd -= div1;
}
for (int i = 12; i >= 0; i--)
if (rnds[0] == rnds[i]) return (div2 - i);
}
torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
@ -1300,26 +1283,37 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int max_lds_len = get_lds_size() / 2;
#define WVSPLITK(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \
_N) \
{ \
dim3 block(64, _WvPrGrp); \
if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \
wvSplitK_hf_sml_<fptype, 64, _YTILEs, _WvPrGrp, 8, _UNRLs, _N> \
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
biasf4, c, __wvPrGrp, CuCount); \
} else if (K_in * N_in <= max_lds_len * 1.2) { \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \
wvSplitK_hf_<fptype, 64, _YTILEm, _WvPrGrp, 8, _UNRLm, _N> \
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
biasf4, c, __wvPrGrp, CuCount); \
} else { \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEb, _WvPrGrp); \
wvSplitK_hf_big_<fptype, 64, _YTILEb, _WvPrGrp, 8, _UNRLb, _N> \
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
biasf4, c, __wvPrGrp, CuCount); \
} \
#define WVSPLITK(_YTILE, _UNRL, _N) \
{ \
dim3 block(64, 16); \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILE, 16); \
if ((K_in * N_in <= max_lds_len) && (M_in % _YTILE == 0)) \
wvSplitK_hf_sml_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
biasf4, c, __wvPrGrp, CuCount); \
else if (K_in * N_in <= max_lds_len * 1.2) \
wvSplitK_hf_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
biasf4, c, __wvPrGrp, CuCount); \
else \
wvSplitK_hf_big_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
biasf4, c, __wvPrGrp, CuCount); \
}
#define WVSPLIT_TILE(_sYT, __N) \
{ \
bool fit_lds = (K_in * N_in <= max_lds_len); \
if (_sYT <= 1) \
WVSPLITK(1, 4, __N) \
else if ((__N == 1) || (!fit_lds) || (_sYT <= 4 * 2)) \
WVSPLITK(2, 2, __N) \
else if (_sYT <= 4 * 3) \
WVSPLITK(3, 2, __N) \
else if (__N == 4) \
WVSPLITK(4, 1, __N) \
else \
WVSPLITK(4, 2, __N) \
}
AT_DISPATCH_REDUCED_FLOATING_TYPES(in_b.scalar_type(), "wvSplitK", [&] {
@ -1331,18 +1325,23 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
? reinterpret_cast<const fptype*>(in_bias->data_ptr())
: nullptr;
fptype* c = reinterpret_cast<fptype*>(out_c.data_ptr());
// first shoot for biggest tile-size that keeps all simd busy,
// then cut the active waves to balance their distribution...
int sYT = (M_in + CuCount * 4 - 1) / (CuCount * 4);
switch (N_in) {
case 1:
WVSPLITK(16, 2, 2, 2, 2, 2, 2, 1)
WVSPLIT_TILE(sYT, 1)
break;
case 2:
WVSPLITK(16, 2, 2, 2, 2, 2, 2, 2)
WVSPLIT_TILE(sYT, 2)
break;
case 3:
WVSPLITK(16, 4, 7, 7, 1, 1, 1, 3)
WVSPLIT_TILE(sYT, 3)
break;
case 4:
WVSPLITK(16, 4, 7, 7, 1, 1, 1, 4)
WVSPLIT_TILE(sYT, 4)
break;
default:
throw std::runtime_error(