From 9689be1e8eb60d4ccc50ed1f47550b5ffb4ebc1d Mon Sep 17 00:00:00 2001 From: Hashem Hashemi <159079214+amd-hhashemi@users.noreply.github.com> Date: Tue, 23 Sep 2025 11:31:45 -0700 Subject: [PATCH] [ROCm] Add skinny gemm bias support for dtypes fp16,bf16,fp8 (#24988) Signed-off-by: Hashem Hashemi Signed-off-by: Hashem Hashemi <159079214+amd-hhashemi@users.noreply.github.com> Signed-off-by: yewentao256 --- csrc/rocm/ops.h | 9 +- csrc/rocm/skinny_gemms.cu | 181 ++++++++++++++---- csrc/rocm/torch_bindings.cpp | 5 +- .../quantization/test_rocm_skinny_gemms.py | 80 ++++++-- vllm/_custom_ops.py | 19 +- .../layers/quantization/utils/w8a8_utils.py | 8 +- vllm/model_executor/layers/utils.py | 6 +- 7 files changed, 231 insertions(+), 77 deletions(-) diff --git a/csrc/rocm/ops.h b/csrc/rocm/ops.h index b6ee2656746c1..edf7aff1abaac 100644 --- a/csrc/rocm/ops.h +++ b/csrc/rocm/ops.h @@ -5,11 +5,14 @@ torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b, const int64_t rows_per_block); -torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b, +torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b, + const c10::optional& in_bias, const int64_t CuCount); -void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, - at::Tensor& scale_a, at::Tensor& scale_b, const int64_t CuCount); +void wvSplitKQ(const at::Tensor& in_a, const at::Tensor& in_b, + const c10::optional& in_bias, at::Tensor& out_c, + const at::Tensor& scale_a, const at::Tensor& scale_b, + const int64_t CuCount); void paged_attention( torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, diff --git a/csrc/rocm/skinny_gemms.cu b/csrc/rocm/skinny_gemms.cu index b8a1b439758c1..bf2fe169c7114 100644 --- a/csrc/rocm/skinny_gemms.cu +++ b/csrc/rocm/skinny_gemms.cu @@ -292,8 +292,9 @@ torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b, template __global__ void __launch_bounds__(WvPrGrp* THRDS) - wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B, - const scalar_t* __restrict__ A, scalar_t* C, + wvSplitK_hf_sml_(const int K, const int M, const int Bx, const int By, + const scalar_t* B, const scalar_t* __restrict__ A, + const scalar_t* __restrict__ BIAS, scalar_t* C, const int _WvPrGrp, const int CuCount) { constexpr int max_lds_len = LDS_SIZE / 2; #if defined(__HIP__MI3XX__) @@ -484,7 +485,14 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (threadIdx.x == 63) { for (int n = 0; n < N; n++) { for (int i = 0; i < YTILE; i++) { - // if (commitColumn[i]) C[m + i + n * M] = __float2half(sum[n][i]); + if constexpr (std::is_same_v) { + if (BIAS) + sum[n][i] += __half2float(BIAS[(m + i) % Bx + (n % By) * M]); + } else if constexpr (std::is_same_v) { + if (BIAS) + sum[n][i] += + __bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]); + } C[m + i + n * M] = __float2s(sum[n][i]); } } @@ -529,7 +537,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (threadIdx.x == 63) { for (int n = 0; n < N; n++) { for (int i = 0; i < YTILE; i++) { - // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[n][i]); + if (BIAS) + sum4[n][i][0] += + __bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]); C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]); } } @@ -541,8 +551,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) #else // !defined(__HIP__GFX9__) TODO: Add NAVI support template -__global__ void wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B, - const scalar_t* __restrict__ A, scalar_t* C, +__global__ void wvSplitK_hf_sml_(const int K, const int M, const int Bx, + const int By, const scalar_t* B, + const scalar_t* __restrict__ A, + const scalar_t* __restrict__ BIAS, scalar_t* C, const int _WvPrGrp, const int CuCount) { UNREACHABLE_CODE } @@ -553,8 +565,9 @@ __global__ void wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B, template __global__ void __launch_bounds__(WvPrGrp* THRDS) - wvSplitK_hf_(const int K, const int M, const scalar_t* B, - const scalar_t* __restrict__ A, scalar_t* C, + wvSplitK_hf_(const int K, const int M, const int Bx, const int By, + const scalar_t* B, const scalar_t* __restrict__ A, + const scalar_t* __restrict__ BIAS, scalar_t* C, const int _WvPrGrp, const int CuCount) { constexpr int max_lds_len = LDS_SIZE / 2; #if defined(__HIP__MI3XX__) @@ -772,8 +785,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (threadIdx.x == 63) { for (int n = 0; n < N; n++) { for (int i = 0; i < YTILE; i++) { - if (commitColumn[i]) + if (commitColumn[i]) { + if constexpr (std::is_same_v) { + if (BIAS) + sum[n][i] += __half2float(BIAS[(m + i) % Bx + (n % By) * M]); + } else if constexpr (std::is_same_v) { + if (BIAS) + sum[n][i] += + __bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]); + } C[m + i + n * M] = __float2s(sum[n][i]); + } } } } @@ -818,8 +840,12 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (threadIdx.x == 63) { for (int n = 0; n < N; n++) { for (int i = 0; i < YTILE; i++) { - // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[n][i]); - C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]); + if (commitColumn[i]) { + if (BIAS) + sum4[n][i][0] += + __bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]); + C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]); + } } } } @@ -842,8 +868,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) #else // !defined(__HIP__GFX9__) TODO: Add NAVI support template -__global__ void wvSplitK_hf_(const int K, const int M, const scalar_t* B, - const scalar_t* __restrict__ A, scalar_t* C, +__global__ void wvSplitK_hf_(const int K, const int M, const int Bx, + const int By, const scalar_t* B, + const scalar_t* __restrict__ A, + const scalar_t* __restrict__ BIAS, scalar_t* C, const int _WvPrGrp, const int CuCount) { UNREACHABLE_CODE } @@ -854,8 +882,9 @@ __global__ void wvSplitK_hf_(const int K, const int M, const scalar_t* B, template __global__ void __launch_bounds__(WvPrGrp* THRDS) - wvSplitK_hf_big_(const int K, const int M, const scalar_t* B, - const scalar_t* __restrict__ A, scalar_t* C, + wvSplitK_hf_big_(const int K, const int M, const int Bx, const int By, + const scalar_t* B, const scalar_t* __restrict__ A, + const scalar_t* __restrict__ BIAS, scalar_t* C, const int _WvPrGrp, const int CuCount) { constexpr int max_lds_len = LDS_SIZE / 2; #if defined(__HIP__MI3XX__) @@ -1124,8 +1153,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (threadIdx.x == 63) { for (int n = 0; n < N; n++) { for (int i = 0; i < YTILE; i++) { - if (commitColumn[i]) + if (commitColumn[i]) { + if constexpr (std::is_same_v) { + if (BIAS) + sum[n][i] += __half2float(BIAS[(m + i) % Bx + (n % By) * M]); + } else if constexpr (std::is_same_v) { + if (BIAS) + sum[n][i] += + __bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]); + } C[m + i + n * M] = __float2s(sum[n][i]); + } } } } @@ -1166,8 +1204,12 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (threadIdx.x == 63) { for (int n = 0; n < N; n++) { for (int i = 0; i < YTILE; i++) { - // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[n][i]); - C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]); + if (commitColumn[i]) { + if (BIAS) + sum4[n][i][0] += + __bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]); + C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]); + } } } } @@ -1190,8 +1232,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) #else // !defined(__HIP__GFX9__) TODO: Add NAVI support template -__global__ void wvSplitK_hf_big_(const int K, const int M, const scalar_t* B, - const scalar_t* __restrict__ A, scalar_t* C, +__global__ void wvSplitK_hf_big_(const int K, const int M, const int Bx, + const int By, const scalar_t* B, + const scalar_t* __restrict__ A, + const scalar_t* __restrict__ BIAS, scalar_t* C, const int _WvPrGrp, const int CuCount) { UNREACHABLE_CODE } @@ -1226,11 +1270,20 @@ int mindiv(int N, int div1, int div2) { return rtn; } -torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b, +torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b, + const c10::optional& in_bias, const int64_t CuCount) { auto M_in = in_a.size(0); auto K_in = in_a.size(1); auto N_in = in_b.size(0); + auto Bx_in = + (in_bias.has_value() && in_bias->numel() > 0) + ? (in_bias->sizes().size() == 2) ? in_bias->size(1) : in_bias->size(0) + : 1; + auto By_in = (in_bias.has_value() && in_bias->numel() > 0 && + in_bias->sizes().size() == 2) + ? in_bias->size(0) + : 1; TORCH_CHECK(in_a.dtype() == in_b.dtype()); TORCH_CHECK(K_in % 8 == 0, "k % 8 == 0"); @@ -1254,18 +1307,18 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b, if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \ int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \ wvSplitK_hf_sml_ \ - <<>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \ - CuCount); \ + <<>>(K_in, M_in, Bx_in, By_in, af4, bf4, \ + biasf4, c, __wvPrGrp, CuCount); \ } else if (K_in * N_in <= max_lds_len * 1.2) { \ int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \ wvSplitK_hf_ \ - <<>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \ - CuCount); \ + <<>>(K_in, M_in, Bx_in, By_in, af4, bf4, \ + biasf4, c, __wvPrGrp, CuCount); \ } else { \ int __wvPrGrp = mindiv(M_in, CuCount * _YTILEb, _WvPrGrp); \ wvSplitK_hf_big_ \ - <<>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \ - CuCount); \ + <<>>(K_in, M_in, Bx_in, By_in, af4, bf4, \ + biasf4, c, __wvPrGrp, CuCount); \ } \ } @@ -1273,6 +1326,10 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b, using fptype = typename scalar::type; fptype* af4 = reinterpret_cast(in_a.data_ptr()); const fptype* bf4 = reinterpret_cast(in_b.data_ptr()); + const fptype* biasf4 = + (in_bias.has_value() && in_bias->numel() > 0) + ? reinterpret_cast(in_bias->data_ptr()) + : nullptr; fptype* c = reinterpret_cast(out_c.data_ptr()); switch (N_in) { case 1: @@ -1300,8 +1357,9 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b, template __global__ void __launch_bounds__(WvPrGrp* THRDS) - wvSplitKQ_hf_sml_(const int K, const int Kp, const int M, const fp8_t* B, - const fp8_t* __restrict__ A, scalar_t* C, + wvSplitKQ_hf_sml_(const int K, const int Kp, const int M, const int Bx, + const int By, const fp8_t* B, const fp8_t* __restrict__ A, + const scalar_t* __restrict__ BIAS, scalar_t* C, const float* __restrict__ s_A, const float* __restrict__ s_B, const int _WvPrGrp, const int CuCount) { @@ -1453,7 +1511,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (threadIdx.x == 0) { for (int n = 0; n < N; n++) { for (int y = 0; y < YTILE; y++) { - C[m + y + n * M] = __float2s(sum[n][y][0] * sA * sB); + if (y + m >= M) break; // To avoid mem access fault. + sum[n][y][0] *= sA * sB; + if constexpr (std::is_same_v) { + if (BIAS) + sum[n][y][0] += __half2float(BIAS[(m + y) % Bx + (n % By) * M]); + } else if constexpr (std::is_same_v) { + if (BIAS) + sum[n][y][0] += + __bfloat162float(BIAS[(m + y) % Bx + (n % By) * M]); + } + C[m + y + n * M] = __float2s(sum[n][y][0]); // * sA * sB); } } } @@ -1465,7 +1533,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) template __global__ void wvSplitKQ_hf_sml_(const int K, const int Kp, const int M, - const fp8_t* B, const fp8_t* __restrict__ A, + const int Bx, const int By, const fp8_t* B, + const fp8_t* __restrict__ A, + const scalar_t* __restrict__ BIAS, scalar_t* C, const float* __restrict__ s_A, const float* __restrict__ s_B, const int _WvPrGrp, const int CuCount) { @@ -1477,8 +1547,9 @@ __global__ void wvSplitKQ_hf_sml_(const int K, const int Kp, const int M, template __global__ void __launch_bounds__(WvPrGrp* THRDS) - wvSplitKQ_hf_(const int K, const int Kp, const int M, const fp8_t* B, - const fp8_t* __restrict__ A, scalar_t* C, + wvSplitKQ_hf_(const int K, const int Kp, const int M, const int Bx, + const int By, const fp8_t* B, const fp8_t* __restrict__ A, + const scalar_t* __restrict__ BIAS, scalar_t* C, const float* __restrict__ s_A, const float* __restrict__ s_B, const int _WvPrGrp, const int CuCount) { constexpr int max_lds_len = LDS_SIZE; @@ -1626,7 +1697,16 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) for (int n = 0; n < N; n++) { for (int y = 0; y < YTILE; y++) { if (y + m >= M) break; // To avoid mem access fault. - C[m + y + n * M] = __float2s(sum[n][y][0] * sA * sB); + sum[n][y][0] *= sA * sB; + if constexpr (std::is_same_v) { + if (BIAS) + sum[n][y][0] += __half2float(BIAS[(m + y) % Bx + (n % By) * M]); + } else if constexpr (std::is_same_v) { + if (BIAS) + sum[n][y][0] += + __bfloat162float(BIAS[(m + y) % Bx + (n % By) * M]); + } + C[m + y + n * M] = __float2s(sum[n][y][0]); } } } @@ -1638,16 +1718,19 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) template __global__ void wvSplitKQ_hf_(const int K, const int Kp, const int M, - const fp8_t* B, const fp8_t* __restrict__ A, - scalar_t* C, const float* __restrict__ s_A, + const int Bx, const int By, const fp8_t* B, + const fp8_t* __restrict__ A, + const scalar_t* __restrict__ BIAS, scalar_t* C, + const float* __restrict__ s_A, const float* __restrict__ s_B, const int _WvPrGrp, const int CuCount) { UNREACHABLE_CODE } #endif // defined(__HIP__MI3XX__) TODO: Add NAVI support -void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, - at::Tensor& scale_a, at::Tensor& scale_b, +void wvSplitKQ(const at::Tensor& in_a, const at::Tensor& in_b, + const c10::optional& in_bias, at::Tensor& out_c, + const at::Tensor& scale_a, const at::Tensor& scale_b, const int64_t CuCount) { static c10::ScalarType kFp8Type = is_fp8_ocp() ? c10::ScalarType::Float8_e4m3fn @@ -1656,6 +1739,15 @@ void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, auto K_in = in_a.size(1); auto N_in = in_b.size(0); auto Kp_in = in_a.stride(0); + auto Bx_in = + (in_bias.has_value() && in_bias->numel() > 0) + ? (in_bias->sizes().size() == 2) ? in_bias->size(1) : in_bias->size(0) + : 1; + auto By_in = (in_bias.has_value() && in_bias->numel() > 0 && + in_bias->sizes().size() == 2) + ? in_bias->size(0) + : 1; + TORCH_CHECK(K_in % 16 == 0, "k % 16 == 0"); TORCH_CHECK(in_a.dtype() == in_b.dtype() && in_a.dtype() == kFp8Type); TORCH_CHECK(out_c.dtype() == torch::kFloat16 || @@ -1673,13 +1765,15 @@ void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \ int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \ wvSplitKQ_hf_sml_ \ - <<>>(K_in, Kp_in, M_in, a_ptr, b_ptr, c_ptr, \ - s_a, s_b, __wvPrGrp, CuCount); \ + <<>>(K_in, Kp_in, M_in, Bx_in, By_in, a_ptr, \ + b_ptr, bias_ptr, c_ptr, s_a, s_b, \ + __wvPrGrp, CuCount); \ } else { \ int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \ wvSplitKQ_hf_ \ - <<>>(K_in, Kp_in, M_in, a_ptr, b_ptr, c_ptr, \ - s_a, s_b, __wvPrGrp, CuCount); \ + <<>>(K_in, Kp_in, M_in, Bx_in, By_in, a_ptr, \ + b_ptr, bias_ptr, c_ptr, s_a, s_b, \ + __wvPrGrp, CuCount); \ } \ } @@ -1691,6 +1785,9 @@ void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, VLLM_DISPATCH_FP8_TYPES(in_a.scalar_type(), "wvSplitKQ", [&] { auto a_ptr = in_a.data_ptr(); auto b_ptr = in_b.data_ptr(); + auto bias_ptr = (in_bias.has_value() && in_bias->numel() > 0) + ? reinterpret_cast(in_bias->data_ptr()) + : nullptr; switch (N_in) { case 1: WVSPLITKQ(16, 2, 2, 2, 2, 2, 2, 1) diff --git a/csrc/rocm/torch_bindings.cpp b/csrc/rocm/torch_bindings.cpp index c0c4daef64f05..518486b1ca5de 100644 --- a/csrc/rocm/torch_bindings.cpp +++ b/csrc/rocm/torch_bindings.cpp @@ -22,13 +22,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) { // Custom gemm op for skinny matrix-matrix multiplication rocm_ops.def( - "wvSplitK(Tensor in_a, Tensor in_b, int CuCount) -> " + "wvSplitK(Tensor in_a, Tensor in_b, Tensor? in_bias, int CuCount) -> " "Tensor"); rocm_ops.impl("wvSplitK", torch::kCUDA, &wvSplitK); // wvSplitK for fp8 rocm_ops.def( - "wvSplitKQ(Tensor in_a, Tensor in_b, Tensor! out_c, Tensor scale_a, " + "wvSplitKQ(Tensor in_a, Tensor in_b, Tensor? in_bias, Tensor! out_c, " + "Tensor scale_a, " " Tensor scale_b, int CuCount) -> ()"); rocm_ops.impl("wvSplitKQ", torch::kCUDA, &wvSplitKQ); diff --git a/tests/kernels/quantization/test_rocm_skinny_gemms.py b/tests/kernels/quantization/test_rocm_skinny_gemms.py index a9b1c71ef0718..6de5fc9c56010 100644 --- a/tests/kernels/quantization/test_rocm_skinny_gemms.py +++ b/tests/kernels/quantization/test_rocm_skinny_gemms.py @@ -1,12 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math + import pytest import torch import vllm._custom_ops as ops from tests.kernels.quant_utils import ref_dynamic_per_tensor_fp8_quant -from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - rocm_per_tensor_w8a8_scaled_mm_impl) from vllm.platforms import current_platform DTYPES = [torch.bfloat16, torch.float16] @@ -49,6 +49,7 @@ NKM_FACTORS_WVSPLITK_FP8 = [ (2, 512, 512), (3, 2048, 2048), (4, 4096, 4096), + (4, 16400, 2048), # Extended FP8 dimensions not covered by WVSPLITK (1, 14336, 1024), (2, 24576, 2048), @@ -67,6 +68,9 @@ SEEDS = [0] @torch.inference_mode() def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed): torch.manual_seed(seed) + #TODO: Zero-centering the inputs causes errors for LLMM1! + # Without that the numbers quickly saturate, and may + # be giving false matches. A = torch.rand(n, k, dtype=dtype, device="cuda") B = torch.rand(m, k, dtype=dtype, device="cuda") @@ -85,11 +89,51 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed): torch.manual_seed(seed) cu_count = current_platform.get_cu_count() - A = torch.rand(n, k, dtype=dtype, device="cuda") - B = torch.rand(m, k, dtype=dtype, device="cuda") + A = torch.rand(n, k, dtype=dtype, device="cuda") - .5 + B = torch.rand(m, k, dtype=dtype, device="cuda") - .5 - ref_out = torch.matmul(A, B.t()) - out = ops.wvSplitK(B, A, cu_count) + ref_out = torch.nn.functional.linear(A, B) + out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count) + + assert torch.allclose(out, ref_out, rtol=0.01) + + +@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.skipif(not current_platform.is_rocm(), + reason="only test for rocm") +def test_rocm_wvsplitk_bias1D_kernel(n, k, m, dtype, seed): + torch.manual_seed(seed) + cu_count = current_platform.get_cu_count() + + xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas + A = (torch.rand(n, k, dtype=dtype, device="cuda") - .5) * xavier + B = (torch.rand(m, k, dtype=dtype, device="cuda") - .5) * xavier + BIAS = torch.rand(m, dtype=dtype, device="cuda") - .5 + + ref_out = torch.nn.functional.linear(A, B, BIAS) + out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS) + + assert torch.allclose(out, ref_out, rtol=0.01) + + +@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.skipif(not current_platform.is_rocm(), + reason="only test for rocm") +def test_rocm_wvsplitk_bias2D_kernel(n, k, m, dtype, seed): + torch.manual_seed(seed) + cu_count = current_platform.get_cu_count() + + xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas + A = (torch.rand(n, k, dtype=dtype, device="cuda") - .5) * xavier + B = (torch.rand(m, k, dtype=dtype, device="cuda") - .5) * xavier + BIAS = torch.rand(n, m, dtype=dtype, device="cuda") - .5 + + ref_out = torch.nn.functional.linear(A, B, BIAS) + out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS) assert torch.allclose(out, ref_out, rtol=0.01) @@ -103,8 +147,8 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed): def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed): torch.manual_seed(seed) - A = torch.rand(n, k, device="cuda") - B = torch.rand(m, k, device="cuda") + A = torch.rand(n, k, device="cuda") - 0.5 + B = torch.rand(m, k, device="cuda") - 0.5 A, scale_a = ref_dynamic_per_tensor_fp8_quant(A) B, scale_b = ref_dynamic_per_tensor_fp8_quant(B) @@ -123,27 +167,27 @@ def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed): @pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK_FP8) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("use_bias", [True, False]) @pytest.mark.skipif( not (current_platform.is_rocm() and current_platform.supports_fp8()), reason="only test for rocm fp8") -def test_rocm_per_tensor_w8a8_scaled_mm_impl(n, k, m, dtype, seed, use_bias): +def test_rocm_wvsplitk_fp8_bias1D_kernel(n, k, m, dtype, seed): torch.manual_seed(seed) - A = torch.rand(n, k, device="cuda") - B = torch.rand(m, k, device="cuda") + xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas + A = (torch.rand(n, k, device="cuda") - .5) * xavier + B = (torch.rand(m, k, device="cuda") - .5) * xavier + BIAS = torch.rand(m, dtype=dtype, device="cuda") - .5 A, scale_a = ref_dynamic_per_tensor_fp8_quant(A) B, scale_b = ref_dynamic_per_tensor_fp8_quant(B) - bias = torch.rand(1, m, dtype=dtype, device="cuda") if use_bias else None - - output = rocm_per_tensor_w8a8_scaled_mm_impl(A, B.t(), dtype, scale_a, - scale_b, bias) ref_out = torch._scaled_mm(A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b, - bias=bias) - assert torch.allclose(output, ref_out, rtol=0.01) + bias=BIAS) + out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b, + current_platform.get_cu_count(), BIAS) + + assert torch.allclose(out, ref_out, rtol=0.01) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 712295aa92886..a108542e14368 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1447,17 +1447,24 @@ def LLMM1(a: torch.Tensor, b: torch.Tensor, return torch.ops._rocm_C.LLMM1(a, b, rows_per_block) -def wvSplitK(a: torch.Tensor, b: torch.Tensor, cu_count: int) -> torch.Tensor: - return torch.ops._rocm_C.wvSplitK(a, b, cu_count) +def wvSplitK(a: torch.Tensor, + b: torch.Tensor, + cu_count: int, + bias: torch.Tensor = None) -> torch.Tensor: + return torch.ops._rocm_C.wvSplitK(a, b, bias, cu_count) -def wvSplitKQ(a: torch.Tensor, b: torch.Tensor, out_dtype: torch.dtype, - scale_a: torch.Tensor, scale_b: torch.Tensor, - cu_count: int) -> torch.Tensor: +def wvSplitKQ(a: torch.Tensor, + b: torch.Tensor, + out_dtype: torch.dtype, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + cu_count: int, + bias: torch.Tensor = None) -> torch.Tensor: out = torch.empty((b.shape[0], a.shape[0]), dtype=out_dtype, device=b.device) - torch.ops._rocm_C.wvSplitKQ(a, b, out, scale_a, scale_b, cu_count) + torch.ops._rocm_C.wvSplitKQ(a, b, bias, out, scale_a, scale_b, cu_count) return out diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 8cda1789e6c97..6ed482db4700e 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -178,10 +178,12 @@ def rocm_per_tensor_w8a8_scaled_mm_impl(qinput: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor) -> torch.Tensor: from vllm.platforms.rocm import on_mi3xx - if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi3xx( - ) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0 and bias is None: + if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi3xx() and \ + qinput.shape[0] == 1 and \ + qinput.shape[1] % 16 == 0 and \ + ((bias is None) or (bias.dtype == out_dtype)) : output = ops.wvSplitKQ(weight.t(), qinput, out_dtype, scale_a, scale_b, - current_platform.get_cu_count()) + current_platform.get_cu_count(), bias) else: output = torch._scaled_mm(qinput, weight, diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index a1675ffbaa950..d7a65d43c2107 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -100,7 +100,7 @@ def rocm_unquantized_gemm_impl( k = weight.shape[1] use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and on_gfx9() and \ x.dtype in [torch.float16, torch.bfloat16] \ - and k % 8 == 0 and bias is None) + and k % 8 == 0) if use_skinny is not True: return torch.nn.functional.linear(x, weight, bias) @@ -111,9 +111,9 @@ def rocm_unquantized_gemm_impl( cu_count = current_platform.get_cu_count() if m > 8 and 0 < n <= 4: - out = ops.wvSplitK(weight, x_view, cu_count) + out = ops.wvSplitK(weight, x_view, cu_count, bias) return out.view(*x.shape[:-1], weight.shape[0]) - elif m % 4 == 0 and n == 1 and k <= 8192: + elif m % 4 == 0 and n == 1 and k <= 8192 and bias is None: out = ops.LLMM1(weight, x_view, 4) return out.view(*x.shape[:-1], weight.shape[0]) return torch.nn.functional.linear(x, weight, bias)