mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-04 23:29:10 +08:00
Improve wvsplitK tile and balance heristics. (#29937)
Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
This commit is contained in:
parent
3c680f4a17
commit
2e7054da06
@ -1241,33 +1241,16 @@ __global__ void wvSplitK_hf_big_(const int K, const int M, const int Bx,
|
|||||||
}
|
}
|
||||||
#endif // defined(__HIP__GFX9__) TODO: Add NAVI support
|
#endif // defined(__HIP__GFX9__) TODO: Add NAVI support
|
||||||
|
|
||||||
|
// Find the min val of div2 that doesn't increase N/(div1*div2)
|
||||||
int mindiv(int N, int div1, int div2) {
|
int mindiv(int N, int div1, int div2) {
|
||||||
int nPrRnd = div1 * div2;
|
int nPrRnd = div1 * div2;
|
||||||
int rnds0 = N / nPrRnd;
|
int rnds[13];
|
||||||
nPrRnd -= div1 * 3;
|
for (int i = 0; i < 13; i++) {
|
||||||
int rnds3 = N / nPrRnd;
|
rnds[i] = (N + nPrRnd - 1) / nPrRnd;
|
||||||
nPrRnd -= div1;
|
nPrRnd -= div1;
|
||||||
int rnds4 = N / nPrRnd;
|
}
|
||||||
nPrRnd -= div1;
|
for (int i = 12; i >= 0; i--)
|
||||||
int rnds5 = N / nPrRnd;
|
if (rnds[0] == rnds[i]) return (div2 - i);
|
||||||
nPrRnd -= div1;
|
|
||||||
int rnds6 = N / nPrRnd;
|
|
||||||
nPrRnd -= div1;
|
|
||||||
int rnds7 = N / nPrRnd;
|
|
||||||
nPrRnd -= div1;
|
|
||||||
int rnds8 = N / nPrRnd;
|
|
||||||
nPrRnd -= div1;
|
|
||||||
int rnds9 = N / nPrRnd;
|
|
||||||
nPrRnd -= div1;
|
|
||||||
int rtn = div2;
|
|
||||||
if (rnds0 == rnds3) rtn = div2 - 3;
|
|
||||||
if (rnds0 == rnds4) rtn = div2 - 4;
|
|
||||||
if (rnds0 == rnds5) rtn = div2 - 5;
|
|
||||||
if (rnds0 == rnds6) rtn = div2 - 6;
|
|
||||||
if (rnds0 == rnds7) rtn = div2 - 7;
|
|
||||||
if (rnds0 == rnds8) rtn = div2 - 8;
|
|
||||||
if (rnds0 == rnds9) rtn = div2 - 9;
|
|
||||||
return rtn;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
|
torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
|
||||||
@ -1300,26 +1283,37 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
|
|||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
const int max_lds_len = get_lds_size() / 2;
|
const int max_lds_len = get_lds_size() / 2;
|
||||||
|
|
||||||
#define WVSPLITK(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \
|
#define WVSPLITK(_YTILE, _UNRL, _N) \
|
||||||
_N) \
|
{ \
|
||||||
{ \
|
dim3 block(64, 16); \
|
||||||
dim3 block(64, _WvPrGrp); \
|
int __wvPrGrp = mindiv(M_in, CuCount * _YTILE, 16); \
|
||||||
if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \
|
if ((K_in * N_in <= max_lds_len) && (M_in % _YTILE == 0)) \
|
||||||
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \
|
wvSplitK_hf_sml_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
|
||||||
wvSplitK_hf_sml_<fptype, 64, _YTILEs, _WvPrGrp, 8, _UNRLs, _N> \
|
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
|
||||||
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
|
biasf4, c, __wvPrGrp, CuCount); \
|
||||||
biasf4, c, __wvPrGrp, CuCount); \
|
else if (K_in * N_in <= max_lds_len * 1.2) \
|
||||||
} else if (K_in * N_in <= max_lds_len * 1.2) { \
|
wvSplitK_hf_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
|
||||||
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \
|
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
|
||||||
wvSplitK_hf_<fptype, 64, _YTILEm, _WvPrGrp, 8, _UNRLm, _N> \
|
biasf4, c, __wvPrGrp, CuCount); \
|
||||||
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
|
else \
|
||||||
biasf4, c, __wvPrGrp, CuCount); \
|
wvSplitK_hf_big_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
|
||||||
} else { \
|
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
|
||||||
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEb, _WvPrGrp); \
|
biasf4, c, __wvPrGrp, CuCount); \
|
||||||
wvSplitK_hf_big_<fptype, 64, _YTILEb, _WvPrGrp, 8, _UNRLb, _N> \
|
}
|
||||||
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
|
|
||||||
biasf4, c, __wvPrGrp, CuCount); \
|
#define WVSPLIT_TILE(_sYT, __N) \
|
||||||
} \
|
{ \
|
||||||
|
bool fit_lds = (K_in * N_in <= max_lds_len); \
|
||||||
|
if (_sYT <= 1) \
|
||||||
|
WVSPLITK(1, 4, __N) \
|
||||||
|
else if ((__N == 1) || (!fit_lds) || (_sYT <= 4 * 2)) \
|
||||||
|
WVSPLITK(2, 2, __N) \
|
||||||
|
else if (_sYT <= 4 * 3) \
|
||||||
|
WVSPLITK(3, 2, __N) \
|
||||||
|
else if (__N == 4) \
|
||||||
|
WVSPLITK(4, 1, __N) \
|
||||||
|
else \
|
||||||
|
WVSPLITK(4, 2, __N) \
|
||||||
}
|
}
|
||||||
|
|
||||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(in_b.scalar_type(), "wvSplitK", [&] {
|
AT_DISPATCH_REDUCED_FLOATING_TYPES(in_b.scalar_type(), "wvSplitK", [&] {
|
||||||
@ -1331,18 +1325,23 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
|
|||||||
? reinterpret_cast<const fptype*>(in_bias->data_ptr())
|
? reinterpret_cast<const fptype*>(in_bias->data_ptr())
|
||||||
: nullptr;
|
: nullptr;
|
||||||
fptype* c = reinterpret_cast<fptype*>(out_c.data_ptr());
|
fptype* c = reinterpret_cast<fptype*>(out_c.data_ptr());
|
||||||
|
|
||||||
|
// first shoot for biggest tile-size that keeps all simd busy,
|
||||||
|
// then cut the active waves to balance their distribution...
|
||||||
|
int sYT = (M_in + CuCount * 4 - 1) / (CuCount * 4);
|
||||||
|
|
||||||
switch (N_in) {
|
switch (N_in) {
|
||||||
case 1:
|
case 1:
|
||||||
WVSPLITK(16, 2, 2, 2, 2, 2, 2, 1)
|
WVSPLIT_TILE(sYT, 1)
|
||||||
break;
|
break;
|
||||||
case 2:
|
case 2:
|
||||||
WVSPLITK(16, 2, 2, 2, 2, 2, 2, 2)
|
WVSPLIT_TILE(sYT, 2)
|
||||||
break;
|
break;
|
||||||
case 3:
|
case 3:
|
||||||
WVSPLITK(16, 4, 7, 7, 1, 1, 1, 3)
|
WVSPLIT_TILE(sYT, 3)
|
||||||
break;
|
break;
|
||||||
case 4:
|
case 4:
|
||||||
WVSPLITK(16, 4, 7, 7, 1, 1, 1, 4)
|
WVSPLIT_TILE(sYT, 4)
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user