[ROCm] Add skinny gemm bias support for dtypes fp16,bf16,fp8 (#24988)

Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
Signed-off-by: Hashem Hashemi <159079214+amd-hhashemi@users.noreply.github.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Hashem Hashemi 2025-09-23 11:31:45 -07:00 committed by yewentao256
parent 65c4513ad8
commit 9689be1e8e
7 changed files with 231 additions and 77 deletions

View File

@ -5,11 +5,14 @@
torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b,
const int64_t rows_per_block);
torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
const c10::optional<at::Tensor>& in_bias,
const int64_t CuCount);
void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
at::Tensor& scale_a, at::Tensor& scale_b, const int64_t CuCount);
void wvSplitKQ(const at::Tensor& in_a, const at::Tensor& in_b,
const c10::optional<at::Tensor>& in_bias, at::Tensor& out_c,
const at::Tensor& scale_a, const at::Tensor& scale_b,
const int64_t CuCount);
void paged_attention(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,

View File

@ -292,8 +292,9 @@ torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b,
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N>
__global__ void __launch_bounds__(WvPrGrp* THRDS)
wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B,
const scalar_t* __restrict__ A, scalar_t* C,
wvSplitK_hf_sml_(const int K, const int M, const int Bx, const int By,
const scalar_t* B, const scalar_t* __restrict__ A,
const scalar_t* __restrict__ BIAS, scalar_t* C,
const int _WvPrGrp, const int CuCount) {
constexpr int max_lds_len = LDS_SIZE / 2;
#if defined(__HIP__MI3XX__)
@ -484,7 +485,14 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
if (threadIdx.x == 63) {
for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) {
// if (commitColumn[i]) C[m + i + n * M] = __float2half(sum[n][i]);
if constexpr (std::is_same_v<scalar_t, half>) {
if (BIAS)
sum[n][i] += __half2float(BIAS[(m + i) % Bx + (n % By) * M]);
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
if (BIAS)
sum[n][i] +=
__bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]);
}
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
}
}
@ -529,7 +537,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
if (threadIdx.x == 63) {
for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) {
// if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[n][i]);
if (BIAS)
sum4[n][i][0] +=
__bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]);
C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]);
}
}
@ -541,8 +551,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N>
__global__ void wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B,
const scalar_t* __restrict__ A, scalar_t* C,
__global__ void wvSplitK_hf_sml_(const int K, const int M, const int Bx,
const int By, const scalar_t* B,
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ BIAS, scalar_t* C,
const int _WvPrGrp, const int CuCount) {
UNREACHABLE_CODE
}
@ -553,8 +565,9 @@ __global__ void wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B,
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N>
__global__ void __launch_bounds__(WvPrGrp* THRDS)
wvSplitK_hf_(const int K, const int M, const scalar_t* B,
const scalar_t* __restrict__ A, scalar_t* C,
wvSplitK_hf_(const int K, const int M, const int Bx, const int By,
const scalar_t* B, const scalar_t* __restrict__ A,
const scalar_t* __restrict__ BIAS, scalar_t* C,
const int _WvPrGrp, const int CuCount) {
constexpr int max_lds_len = LDS_SIZE / 2;
#if defined(__HIP__MI3XX__)
@ -772,8 +785,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
if (threadIdx.x == 63) {
for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) {
if (commitColumn[i])
if (commitColumn[i]) {
if constexpr (std::is_same_v<scalar_t, half>) {
if (BIAS)
sum[n][i] += __half2float(BIAS[(m + i) % Bx + (n % By) * M]);
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
if (BIAS)
sum[n][i] +=
__bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]);
}
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
}
}
}
}
@ -818,8 +840,12 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
if (threadIdx.x == 63) {
for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) {
// if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[n][i]);
C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]);
if (commitColumn[i]) {
if (BIAS)
sum4[n][i][0] +=
__bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]);
C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]);
}
}
}
}
@ -842,8 +868,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N>
__global__ void wvSplitK_hf_(const int K, const int M, const scalar_t* B,
const scalar_t* __restrict__ A, scalar_t* C,
__global__ void wvSplitK_hf_(const int K, const int M, const int Bx,
const int By, const scalar_t* B,
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ BIAS, scalar_t* C,
const int _WvPrGrp, const int CuCount) {
UNREACHABLE_CODE
}
@ -854,8 +882,9 @@ __global__ void wvSplitK_hf_(const int K, const int M, const scalar_t* B,
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N>
__global__ void __launch_bounds__(WvPrGrp* THRDS)
wvSplitK_hf_big_(const int K, const int M, const scalar_t* B,
const scalar_t* __restrict__ A, scalar_t* C,
wvSplitK_hf_big_(const int K, const int M, const int Bx, const int By,
const scalar_t* B, const scalar_t* __restrict__ A,
const scalar_t* __restrict__ BIAS, scalar_t* C,
const int _WvPrGrp, const int CuCount) {
constexpr int max_lds_len = LDS_SIZE / 2;
#if defined(__HIP__MI3XX__)
@ -1124,8 +1153,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
if (threadIdx.x == 63) {
for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) {
if (commitColumn[i])
if (commitColumn[i]) {
if constexpr (std::is_same_v<scalar_t, half>) {
if (BIAS)
sum[n][i] += __half2float(BIAS[(m + i) % Bx + (n % By) * M]);
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
if (BIAS)
sum[n][i] +=
__bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]);
}
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
}
}
}
}
@ -1166,8 +1204,12 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
if (threadIdx.x == 63) {
for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) {
// if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[n][i]);
C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]);
if (commitColumn[i]) {
if (BIAS)
sum4[n][i][0] +=
__bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]);
C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]);
}
}
}
}
@ -1190,8 +1232,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N>
__global__ void wvSplitK_hf_big_(const int K, const int M, const scalar_t* B,
const scalar_t* __restrict__ A, scalar_t* C,
__global__ void wvSplitK_hf_big_(const int K, const int M, const int Bx,
const int By, const scalar_t* B,
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ BIAS, scalar_t* C,
const int _WvPrGrp, const int CuCount) {
UNREACHABLE_CODE
}
@ -1226,11 +1270,20 @@ int mindiv(int N, int div1, int div2) {
return rtn;
}
torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
const c10::optional<at::Tensor>& in_bias,
const int64_t CuCount) {
auto M_in = in_a.size(0);
auto K_in = in_a.size(1);
auto N_in = in_b.size(0);
auto Bx_in =
(in_bias.has_value() && in_bias->numel() > 0)
? (in_bias->sizes().size() == 2) ? in_bias->size(1) : in_bias->size(0)
: 1;
auto By_in = (in_bias.has_value() && in_bias->numel() > 0 &&
in_bias->sizes().size() == 2)
? in_bias->size(0)
: 1;
TORCH_CHECK(in_a.dtype() == in_b.dtype());
TORCH_CHECK(K_in % 8 == 0, "k % 8 == 0");
@ -1254,18 +1307,18 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \
wvSplitK_hf_sml_<fptype, 64, _YTILEs, _WvPrGrp, 8, _UNRLs, _N> \
<<<grid, block, 0, stream>>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \
CuCount); \
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
biasf4, c, __wvPrGrp, CuCount); \
} else if (K_in * N_in <= max_lds_len * 1.2) { \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \
wvSplitK_hf_<fptype, 64, _YTILEm, _WvPrGrp, 8, _UNRLm, _N> \
<<<grid, block, 0, stream>>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \
CuCount); \
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
biasf4, c, __wvPrGrp, CuCount); \
} else { \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEb, _WvPrGrp); \
wvSplitK_hf_big_<fptype, 64, _YTILEb, _WvPrGrp, 8, _UNRLb, _N> \
<<<grid, block, 0, stream>>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \
CuCount); \
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
biasf4, c, __wvPrGrp, CuCount); \
} \
}
@ -1273,6 +1326,10 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
using fptype = typename scalar<scalar_t>::type;
fptype* af4 = reinterpret_cast<fptype*>(in_a.data_ptr());
const fptype* bf4 = reinterpret_cast<const fptype*>(in_b.data_ptr());
const fptype* biasf4 =
(in_bias.has_value() && in_bias->numel() > 0)
? reinterpret_cast<const fptype*>(in_bias->data_ptr())
: nullptr;
fptype* c = reinterpret_cast<fptype*>(out_c.data_ptr());
switch (N_in) {
case 1:
@ -1300,8 +1357,9 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
int A_CHUNK, int UNRL, int N>
__global__ void __launch_bounds__(WvPrGrp* THRDS)
wvSplitKQ_hf_sml_(const int K, const int Kp, const int M, const fp8_t* B,
const fp8_t* __restrict__ A, scalar_t* C,
wvSplitKQ_hf_sml_(const int K, const int Kp, const int M, const int Bx,
const int By, const fp8_t* B, const fp8_t* __restrict__ A,
const scalar_t* __restrict__ BIAS, scalar_t* C,
const float* __restrict__ s_A,
const float* __restrict__ s_B, const int _WvPrGrp,
const int CuCount) {
@ -1453,7 +1511,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
if (threadIdx.x == 0) {
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
C[m + y + n * M] = __float2s<scalar_t>(sum[n][y][0] * sA * sB);
if (y + m >= M) break; // To avoid mem access fault.
sum[n][y][0] *= sA * sB;
if constexpr (std::is_same_v<scalar_t, half>) {
if (BIAS)
sum[n][y][0] += __half2float(BIAS[(m + y) % Bx + (n % By) * M]);
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
if (BIAS)
sum[n][y][0] +=
__bfloat162float(BIAS[(m + y) % Bx + (n % By) * M]);
}
C[m + y + n * M] = __float2s<scalar_t>(sum[n][y][0]); // * sA * sB);
}
}
}
@ -1465,7 +1533,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
int A_CHUNK, int UNRL, int N>
__global__ void wvSplitKQ_hf_sml_(const int K, const int Kp, const int M,
const fp8_t* B, const fp8_t* __restrict__ A,
const int Bx, const int By, const fp8_t* B,
const fp8_t* __restrict__ A,
const scalar_t* __restrict__ BIAS,
scalar_t* C, const float* __restrict__ s_A,
const float* __restrict__ s_B,
const int _WvPrGrp, const int CuCount) {
@ -1477,8 +1547,9 @@ __global__ void wvSplitKQ_hf_sml_(const int K, const int Kp, const int M,
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
int A_CHUNK, int UNRL, int N>
__global__ void __launch_bounds__(WvPrGrp* THRDS)
wvSplitKQ_hf_(const int K, const int Kp, const int M, const fp8_t* B,
const fp8_t* __restrict__ A, scalar_t* C,
wvSplitKQ_hf_(const int K, const int Kp, const int M, const int Bx,
const int By, const fp8_t* B, const fp8_t* __restrict__ A,
const scalar_t* __restrict__ BIAS, scalar_t* C,
const float* __restrict__ s_A, const float* __restrict__ s_B,
const int _WvPrGrp, const int CuCount) {
constexpr int max_lds_len = LDS_SIZE;
@ -1626,7 +1697,16 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
if (y + m >= M) break; // To avoid mem access fault.
C[m + y + n * M] = __float2s<scalar_t>(sum[n][y][0] * sA * sB);
sum[n][y][0] *= sA * sB;
if constexpr (std::is_same_v<scalar_t, half>) {
if (BIAS)
sum[n][y][0] += __half2float(BIAS[(m + y) % Bx + (n % By) * M]);
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
if (BIAS)
sum[n][y][0] +=
__bfloat162float(BIAS[(m + y) % Bx + (n % By) * M]);
}
C[m + y + n * M] = __float2s<scalar_t>(sum[n][y][0]);
}
}
}
@ -1638,16 +1718,19 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
int A_CHUNK, int UNRL, int N>
__global__ void wvSplitKQ_hf_(const int K, const int Kp, const int M,
const fp8_t* B, const fp8_t* __restrict__ A,
scalar_t* C, const float* __restrict__ s_A,
const int Bx, const int By, const fp8_t* B,
const fp8_t* __restrict__ A,
const scalar_t* __restrict__ BIAS, scalar_t* C,
const float* __restrict__ s_A,
const float* __restrict__ s_B, const int _WvPrGrp,
const int CuCount) {
UNREACHABLE_CODE
}
#endif // defined(__HIP__MI3XX__) TODO: Add NAVI support
void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
at::Tensor& scale_a, at::Tensor& scale_b,
void wvSplitKQ(const at::Tensor& in_a, const at::Tensor& in_b,
const c10::optional<at::Tensor>& in_bias, at::Tensor& out_c,
const at::Tensor& scale_a, const at::Tensor& scale_b,
const int64_t CuCount) {
static c10::ScalarType kFp8Type = is_fp8_ocp()
? c10::ScalarType::Float8_e4m3fn
@ -1656,6 +1739,15 @@ void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
auto K_in = in_a.size(1);
auto N_in = in_b.size(0);
auto Kp_in = in_a.stride(0);
auto Bx_in =
(in_bias.has_value() && in_bias->numel() > 0)
? (in_bias->sizes().size() == 2) ? in_bias->size(1) : in_bias->size(0)
: 1;
auto By_in = (in_bias.has_value() && in_bias->numel() > 0 &&
in_bias->sizes().size() == 2)
? in_bias->size(0)
: 1;
TORCH_CHECK(K_in % 16 == 0, "k % 16 == 0");
TORCH_CHECK(in_a.dtype() == in_b.dtype() && in_a.dtype() == kFp8Type);
TORCH_CHECK(out_c.dtype() == torch::kFloat16 ||
@ -1673,13 +1765,15 @@ void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \
wvSplitKQ_hf_sml_<fptype, fp8_t, 64, _YTILEs, _WvPrGrp, 16, _UNRLs, _N> \
<<<grid, block, 0, stream>>>(K_in, Kp_in, M_in, a_ptr, b_ptr, c_ptr, \
s_a, s_b, __wvPrGrp, CuCount); \
<<<grid, block, 0, stream>>>(K_in, Kp_in, M_in, Bx_in, By_in, a_ptr, \
b_ptr, bias_ptr, c_ptr, s_a, s_b, \
__wvPrGrp, CuCount); \
} else { \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \
wvSplitKQ_hf_<fptype, fp8_t, 64, _YTILEm, _WvPrGrp, 16, _UNRLm, _N> \
<<<grid, block, 0, stream>>>(K_in, Kp_in, M_in, a_ptr, b_ptr, c_ptr, \
s_a, s_b, __wvPrGrp, CuCount); \
<<<grid, block, 0, stream>>>(K_in, Kp_in, M_in, Bx_in, By_in, a_ptr, \
b_ptr, bias_ptr, c_ptr, s_a, s_b, \
__wvPrGrp, CuCount); \
} \
}
@ -1691,6 +1785,9 @@ void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
VLLM_DISPATCH_FP8_TYPES(in_a.scalar_type(), "wvSplitKQ", [&] {
auto a_ptr = in_a.data_ptr<fp8_t>();
auto b_ptr = in_b.data_ptr<fp8_t>();
auto bias_ptr = (in_bias.has_value() && in_bias->numel() > 0)
? reinterpret_cast<fptype*>(in_bias->data_ptr())
: nullptr;
switch (N_in) {
case 1:
WVSPLITKQ(16, 2, 2, 2, 2, 2, 2, 1)

View File

@ -22,13 +22,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
// Custom gemm op for skinny matrix-matrix multiplication
rocm_ops.def(
"wvSplitK(Tensor in_a, Tensor in_b, int CuCount) -> "
"wvSplitK(Tensor in_a, Tensor in_b, Tensor? in_bias, int CuCount) -> "
"Tensor");
rocm_ops.impl("wvSplitK", torch::kCUDA, &wvSplitK);
// wvSplitK for fp8
rocm_ops.def(
"wvSplitKQ(Tensor in_a, Tensor in_b, Tensor! out_c, Tensor scale_a, "
"wvSplitKQ(Tensor in_a, Tensor in_b, Tensor? in_bias, Tensor! out_c, "
"Tensor scale_a, "
" Tensor scale_b, int CuCount) -> ()");
rocm_ops.impl("wvSplitKQ", torch::kCUDA, &wvSplitKQ);

View File

@ -1,12 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
import pytest
import torch
import vllm._custom_ops as ops
from tests.kernels.quant_utils import ref_dynamic_per_tensor_fp8_quant
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
rocm_per_tensor_w8a8_scaled_mm_impl)
from vllm.platforms import current_platform
DTYPES = [torch.bfloat16, torch.float16]
@ -49,6 +49,7 @@ NKM_FACTORS_WVSPLITK_FP8 = [
(2, 512, 512),
(3, 2048, 2048),
(4, 4096, 4096),
(4, 16400, 2048),
# Extended FP8 dimensions not covered by WVSPLITK
(1, 14336, 1024),
(2, 24576, 2048),
@ -67,6 +68,9 @@ SEEDS = [0]
@torch.inference_mode()
def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed):
torch.manual_seed(seed)
#TODO: Zero-centering the inputs causes errors for LLMM1!
# Without that the numbers quickly saturate, and may
# be giving false matches.
A = torch.rand(n, k, dtype=dtype, device="cuda")
B = torch.rand(m, k, dtype=dtype, device="cuda")
@ -85,11 +89,51 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
torch.manual_seed(seed)
cu_count = current_platform.get_cu_count()
A = torch.rand(n, k, dtype=dtype, device="cuda")
B = torch.rand(m, k, dtype=dtype, device="cuda")
A = torch.rand(n, k, dtype=dtype, device="cuda") - .5
B = torch.rand(m, k, dtype=dtype, device="cuda") - .5
ref_out = torch.matmul(A, B.t())
out = ops.wvSplitK(B, A, cu_count)
ref_out = torch.nn.functional.linear(A, B)
out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count)
assert torch.allclose(out, ref_out, rtol=0.01)
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(not current_platform.is_rocm(),
reason="only test for rocm")
def test_rocm_wvsplitk_bias1D_kernel(n, k, m, dtype, seed):
torch.manual_seed(seed)
cu_count = current_platform.get_cu_count()
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
A = (torch.rand(n, k, dtype=dtype, device="cuda") - .5) * xavier
B = (torch.rand(m, k, dtype=dtype, device="cuda") - .5) * xavier
BIAS = torch.rand(m, dtype=dtype, device="cuda") - .5
ref_out = torch.nn.functional.linear(A, B, BIAS)
out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS)
assert torch.allclose(out, ref_out, rtol=0.01)
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(not current_platform.is_rocm(),
reason="only test for rocm")
def test_rocm_wvsplitk_bias2D_kernel(n, k, m, dtype, seed):
torch.manual_seed(seed)
cu_count = current_platform.get_cu_count()
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
A = (torch.rand(n, k, dtype=dtype, device="cuda") - .5) * xavier
B = (torch.rand(m, k, dtype=dtype, device="cuda") - .5) * xavier
BIAS = torch.rand(n, m, dtype=dtype, device="cuda") - .5
ref_out = torch.nn.functional.linear(A, B, BIAS)
out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS)
assert torch.allclose(out, ref_out, rtol=0.01)
@ -103,8 +147,8 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed):
torch.manual_seed(seed)
A = torch.rand(n, k, device="cuda")
B = torch.rand(m, k, device="cuda")
A = torch.rand(n, k, device="cuda") - 0.5
B = torch.rand(m, k, device="cuda") - 0.5
A, scale_a = ref_dynamic_per_tensor_fp8_quant(A)
B, scale_b = ref_dynamic_per_tensor_fp8_quant(B)
@ -123,27 +167,27 @@ def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed):
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK_FP8)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.skipif(
not (current_platform.is_rocm() and current_platform.supports_fp8()),
reason="only test for rocm fp8")
def test_rocm_per_tensor_w8a8_scaled_mm_impl(n, k, m, dtype, seed, use_bias):
def test_rocm_wvsplitk_fp8_bias1D_kernel(n, k, m, dtype, seed):
torch.manual_seed(seed)
A = torch.rand(n, k, device="cuda")
B = torch.rand(m, k, device="cuda")
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
A = (torch.rand(n, k, device="cuda") - .5) * xavier
B = (torch.rand(m, k, device="cuda") - .5) * xavier
BIAS = torch.rand(m, dtype=dtype, device="cuda") - .5
A, scale_a = ref_dynamic_per_tensor_fp8_quant(A)
B, scale_b = ref_dynamic_per_tensor_fp8_quant(B)
bias = torch.rand(1, m, dtype=dtype, device="cuda") if use_bias else None
output = rocm_per_tensor_w8a8_scaled_mm_impl(A, B.t(), dtype, scale_a,
scale_b, bias)
ref_out = torch._scaled_mm(A,
B.t(),
out_dtype=dtype,
scale_a=scale_a,
scale_b=scale_b,
bias=bias)
assert torch.allclose(output, ref_out, rtol=0.01)
bias=BIAS)
out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b,
current_platform.get_cu_count(), BIAS)
assert torch.allclose(out, ref_out, rtol=0.01)

View File

@ -1447,17 +1447,24 @@ def LLMM1(a: torch.Tensor, b: torch.Tensor,
return torch.ops._rocm_C.LLMM1(a, b, rows_per_block)
def wvSplitK(a: torch.Tensor, b: torch.Tensor, cu_count: int) -> torch.Tensor:
return torch.ops._rocm_C.wvSplitK(a, b, cu_count)
def wvSplitK(a: torch.Tensor,
b: torch.Tensor,
cu_count: int,
bias: torch.Tensor = None) -> torch.Tensor:
return torch.ops._rocm_C.wvSplitK(a, b, bias, cu_count)
def wvSplitKQ(a: torch.Tensor, b: torch.Tensor, out_dtype: torch.dtype,
scale_a: torch.Tensor, scale_b: torch.Tensor,
cu_count: int) -> torch.Tensor:
def wvSplitKQ(a: torch.Tensor,
b: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
cu_count: int,
bias: torch.Tensor = None) -> torch.Tensor:
out = torch.empty((b.shape[0], a.shape[0]),
dtype=out_dtype,
device=b.device)
torch.ops._rocm_C.wvSplitKQ(a, b, out, scale_a, scale_b, cu_count)
torch.ops._rocm_C.wvSplitKQ(a, b, bias, out, scale_a, scale_b, cu_count)
return out

View File

@ -178,10 +178,12 @@ def rocm_per_tensor_w8a8_scaled_mm_impl(qinput: torch.Tensor,
scale_b: torch.Tensor,
bias: torch.Tensor) -> torch.Tensor:
from vllm.platforms.rocm import on_mi3xx
if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi3xx(
) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0 and bias is None:
if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi3xx() and \
qinput.shape[0] == 1 and \
qinput.shape[1] % 16 == 0 and \
((bias is None) or (bias.dtype == out_dtype)) :
output = ops.wvSplitKQ(weight.t(), qinput, out_dtype, scale_a, scale_b,
current_platform.get_cu_count())
current_platform.get_cu_count(), bias)
else:
output = torch._scaled_mm(qinput,
weight,

View File

@ -100,7 +100,7 @@ def rocm_unquantized_gemm_impl(
k = weight.shape[1]
use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and on_gfx9() and \
x.dtype in [torch.float16, torch.bfloat16] \
and k % 8 == 0 and bias is None)
and k % 8 == 0)
if use_skinny is not True:
return torch.nn.functional.linear(x, weight, bias)
@ -111,9 +111,9 @@ def rocm_unquantized_gemm_impl(
cu_count = current_platform.get_cu_count()
if m > 8 and 0 < n <= 4:
out = ops.wvSplitK(weight, x_view, cu_count)
out = ops.wvSplitK(weight, x_view, cu_count, bias)
return out.view(*x.shape[:-1], weight.shape[0])
elif m % 4 == 0 and n == 1 and k <= 8192:
elif m % 4 == 0 and n == 1 and k <= 8192 and bias is None:
out = ops.LLMM1(weight, x_view, 4)
return out.view(*x.shape[:-1], weight.shape[0])
return torch.nn.functional.linear(x, weight, bias)