[ROCm][Kernel] Add gfx950 support for skinny gemms (#18010)

Signed-off-by: charlifu <charlifu@amd.com>
This commit is contained in:
Charlie Fu 2025-05-31 09:40:05 -05:00 committed by GitHub
parent f2c3f66d59
commit 306d60401d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 91 additions and 54 deletions

View File

@ -13,14 +13,34 @@
#include "dispatch_utils.h"
#include "quantization/fp8/common.cuh"
#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx942__))
#define __HIP__MI300_MI250__
#if defined(__HIPCC__) && \
(defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__))
#define __HIP__GFX9__
#endif
#if defined(__HIPCC__) && defined(__gfx942__)
#define __HIP__MI300__
#if defined(__HIPCC__) && (defined(__gfx942__) || defined(__gfx950__))
#define __HIP__MI3XX__
#endif
#if defined(__gfx950__)
#define LDS_SIZE 160 * 1024
#else
#define LDS_SIZE 64 * 1024
#endif
int get_lds_size() {
static bool is_cached = false;
static int result;
if (is_cached == false) {
auto dprops = at::cuda::getCurrentDeviceProperties();
std::string device_arch = dprops->gcnArchName;
size_t substring = device_arch.find("gfx95");
result = (substring == std::string::npos ? 64 * 1024 : 160 * 1024);
is_cached = true;
}
return result;
}
#if defined(NDEBUG)
#undef NDEBUG
#include <assert.h>
@ -267,7 +287,7 @@ torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b,
V0 += (s.x + s.y); \
}
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
#if defined(__HIP__GFX9__) // TODO: Add NAVI support
// This version targets cases where A[] fits LDS capacity
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N>
@ -275,7 +295,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B,
const scalar_t* __restrict__ A, scalar_t* C,
const int _WvPrGrp, const int CuCount) {
#if defined(__HIP__MI300__)
constexpr int max_lds_len = LDS_SIZE / 2;
#if defined(__HIP__MI3XX__)
constexpr bool use_mfma = (std::is_same_v<scalar_t, __hip_bfloat16>);
#else
constexpr bool use_mfma = false;
@ -295,13 +316,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
};
//----------------------------------------------------
// Reserving 64 KB of LDS to have 1 WG / CU
// Reserving 64/160 KB of LDS to have 1 WG / CU
// Goal is to bring the activation matrix A to the LDS
// and use it across the lifetime of the work group
// TODO: When activation matrix is larger than 64 KB
// then this is not goint to work!
//----------------------------------------------------
__shared__ scalar_t s[1024 * 32];
__shared__ scalar_t s[max_lds_len];
//----------------------------------------------------
// Fetch the activation matrix to LDS
@ -312,11 +333,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// - Then the WG will move to another 8 K elements
// TODO: Logic below will only work when K is multiple of 8
//----------------------------------------------------
for (uint32_t k = 0; k < min(K * N, 32 * 1024);
for (uint32_t k = 0; k < min(K * N, max_lds_len);
k += THRDS * WvPrGrp * A_CHUNK) {
uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK);
if (k_in >= min(K * N, 32 * 1024)) break;
if (k_in >= min(K * N, max_lds_len)) break;
*((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in]));
}
@ -517,7 +538,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
m += CuCount * _WvPrGrp * YTILE;
}
}
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N>
__global__ void wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B,
@ -525,9 +546,9 @@ __global__ void wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B,
const int _WvPrGrp, const int CuCount) {
UNREACHABLE_CODE
}
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
#endif // defined(__HIP__GFX9__) TODO: Add NAVI support
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
#if defined(__HIP__GFX9__) // TODO: Add NAVI support
// This version targets cases where A[] marginally exceeds LDS capacity
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N>
@ -535,7 +556,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
wvSplitK_hf_(const int K, const int M, const scalar_t* B,
const scalar_t* __restrict__ A, scalar_t* C,
const int _WvPrGrp, const int CuCount) {
#if defined(__HIP__MI300__)
constexpr int max_lds_len = LDS_SIZE / 2;
#if defined(__HIP__MI3XX__)
constexpr bool use_mfma = (std::is_same_v<scalar_t, __hip_bfloat16>);
#else
constexpr bool use_mfma = false;
@ -561,7 +583,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// TODO: When activation matrix is larger than 64 KB
// then this is not goint to work!
//----------------------------------------------------
__shared__ scalar_t s[1024 * 32];
__shared__ scalar_t s[max_lds_len];
//----------------------------------------------------
// Computation of columns that need to be committed to memory!
@ -598,11 +620,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// - Then the WG will move to another 8 K elements
// TODO: Logic below will only work when K is multiple of 8
//----------------------------------------------------
for (uint32_t k = 0; k < min(K * N, 32 * 1024);
for (uint32_t k = 0; k < min(K * N, max_lds_len);
k += THRDS * WvPrGrp * A_CHUNK) {
uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK);
if (k_in >= min(K * N, 32 * 1024)) break;
if (k_in >= min(K * N, max_lds_len)) break;
*((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in]));
}
@ -686,7 +708,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// Fetch A activation matrix in interleaved fashion from LDS or memory
for (int n = 0; n < N; n++) {
if (k_ + K * n < 32 * 1024)
if (k_ + K * n < max_lds_len)
bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n])));
else
bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n])));
@ -817,7 +839,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
}
}
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N>
__global__ void wvSplitK_hf_(const int K, const int M, const scalar_t* B,
@ -825,9 +847,9 @@ __global__ void wvSplitK_hf_(const int K, const int M, const scalar_t* B,
const int _WvPrGrp, const int CuCount) {
UNREACHABLE_CODE
}
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
#endif // defined(__HIP__GFX9__) TODO: Add NAVI support
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
#if defined(__HIP__GFX9__) // TODO: Add NAVI support
// This version targets big A[] cases, where it is much larger than LDS capacity
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N>
@ -835,7 +857,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
wvSplitK_hf_big_(const int K, const int M, const scalar_t* B,
const scalar_t* __restrict__ A, scalar_t* C,
const int _WvPrGrp, const int CuCount) {
#if defined(__HIP__MI300__)
constexpr int max_lds_len = LDS_SIZE / 2;
#if defined(__HIP__MI3XX__)
constexpr bool use_mfma = (std::is_same_v<scalar_t, __hip_bfloat16>);
#else
constexpr bool use_mfma = false;
@ -855,13 +878,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
};
//----------------------------------------------------
// Reserving 64 KB of LDS to have 1 WG / CU
// Reserving 64/160 KB of LDS to have 1 WG / CU
// Goal is to bring the activation matrix A to the LDS
// and use it across the lifetime of the work group
// TODO: When activation matrix is larger than 64 KB
// then this is not goint to work!
//----------------------------------------------------
__shared__ scalar_t s[1024 * 32];
__shared__ scalar_t s[max_lds_len];
//----------------------------------------------------
// Computation of columns that need to be committed to memory!
@ -902,11 +925,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
//----------------------------------------------------
#define PCML
#ifndef PCML
for (uint32_t k = 0; k < min(K * N, 32 * 1024);
for (uint32_t k = 0; k < min(K * N, max_lds_len);
k += THRDS * WvPrGrp * A_CHUNK) {
uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK);
if (k_in >= min(K * N, 32 * 1024)) break;
if (k_in >= min(K * N, max_lds_len)) break;
*((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in]));
}
@ -916,7 +939,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#define TUC (THRDS * UNRL * A_CHUNK)
uint32_t kBase = 0;
// find biggest k size that fits in LDS
uint32_t kFit = (32 * 1024) / N;
uint32_t kFit = (max_lds_len) / N;
// kFit = (kFit%TWC==0) ? kFit : (kFit-kFit%TWC+TWC); //round up to multiple
// of TUC
kFit = (kFit % TUC == 0)
@ -1164,7 +1187,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
}
}
}
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N>
__global__ void wvSplitK_hf_big_(const int K, const int M, const scalar_t* B,
@ -1172,7 +1195,7 @@ __global__ void wvSplitK_hf_big_(const int K, const int M, const scalar_t* B,
const int _WvPrGrp, const int CuCount) {
UNREACHABLE_CODE
}
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
#endif // defined(__HIP__GFX9__) TODO: Add NAVI support
int mindiv(int N, int div1, int div2) {
int nPrRnd = div1 * div2;
@ -1222,17 +1245,18 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
const at::cuda::OptionalCUDAGuard device_guard(device_of(in_a));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int max_lds_len = get_lds_size() / 2;
#define WVSPLITK(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \
_N) \
{ \
dim3 block(64, _WvPrGrp); \
if ((K_in * N_in <= 32 * 1024) && (M_in % _YTILEs == 0)) { \
if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \
wvSplitK_hf_sml_<fptype, 64, _YTILEs, _WvPrGrp, 8, _UNRLs, _N> \
<<<grid, block, 0, stream>>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \
CuCount); \
} else if (K_in * N_in <= 32 * 1024 * 1.2) { \
} else if (K_in * N_in <= max_lds_len * 1.2) { \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \
wvSplitK_hf_<fptype, 64, _YTILEm, _WvPrGrp, 8, _UNRLm, _N> \
<<<grid, block, 0, stream>>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \
@ -1272,7 +1296,7 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
return out_c;
}
#if defined(__HIP__MI300__) // TODO: Add NAVI support
#if defined(__HIP__MI3XX__) // TODO: Add NAVI support
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
int A_CHUNK, int UNRL, int N>
__global__ void __launch_bounds__(WvPrGrp* THRDS)
@ -1281,6 +1305,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
const float* __restrict__ s_A,
const float* __restrict__ s_B, const int _WvPrGrp,
const int CuCount) {
constexpr int max_lds_len = LDS_SIZE;
using scalar8 =
__attribute__((__vector_size__((A_CHUNK / 4) * sizeof(float)))) float;
using intx2 = __attribute__((__vector_size__(2 * sizeof(int)))) int;
@ -1296,10 +1321,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
scalar8 h8;
};
__shared__ fp8_t s[1024 * 64];
__shared__ fp8_t s[max_lds_len];
for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK;
k < min(K * N, 64 * 1024); k += THRDS * WvPrGrp * A_CHUNK) {
k < min(K * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) {
*((bigType*)(&s[k])) = *((bigType*)(&A[k]));
}
__syncthreads();
@ -1436,7 +1461,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
m += CuCount * _WvPrGrp * YTILE;
}
}
#else // !defined(__HIP__MI300__) TODO: Add NAVI support
#else // !defined(__HIP__MI3XX__) TODO: Add NAVI support
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
int A_CHUNK, int UNRL, int N>
__global__ void wvSplitKQ_hf_sml_(const int K, const int Kp, const int M,
@ -1446,9 +1471,9 @@ __global__ void wvSplitKQ_hf_sml_(const int K, const int Kp, const int M,
const int _WvPrGrp, const int CuCount) {
UNREACHABLE_CODE
}
#endif // defined(__HIP__MI300__) TODO: Add NAVI support
#endif // defined(__HIP__MI3XX__) TODO: Add NAVI support
#if defined(__HIP__MI300__) // TODO: Add NAVI support
#if defined(__HIP__MI3XX__) // TODO: Add NAVI support
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
int A_CHUNK, int UNRL, int N>
__global__ void __launch_bounds__(WvPrGrp* THRDS)
@ -1456,6 +1481,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
const fp8_t* __restrict__ A, scalar_t* C,
const float* __restrict__ s_A, const float* __restrict__ s_B,
const int _WvPrGrp, const int CuCount) {
constexpr int max_lds_len = LDS_SIZE;
using scalar8 =
__attribute__((__vector_size__((A_CHUNK / 4) * sizeof(float)))) float;
using intx2 = __attribute__((__vector_size__(2 * sizeof(int)))) int;
@ -1471,10 +1497,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
scalar8 h8;
};
__shared__ fp8_t s[1024 * 64];
__shared__ fp8_t s[max_lds_len];
for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK;
k < min(K * N, 64 * 1024); k += THRDS * WvPrGrp * A_CHUNK) {
k < min(K * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) {
*((bigType*)(&s[k])) = *((bigType*)(&A[k]));
}
__syncthreads();
@ -1517,7 +1543,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
for (int n = 0; n < N; n++) {
if (k_ + K * n < 64 * 1024)
if (k_ + K * n < max_lds_len)
bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n])));
else
bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n])));
@ -1608,7 +1634,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
m += CuCount * _WvPrGrp * YTILE;
}
}
#else // !defined(__HIP__MI300__) TODO: Add NAVI support
#else // !defined(__HIP__MI3XX__) TODO: Add NAVI support
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
int A_CHUNK, int UNRL, int N>
__global__ void wvSplitKQ_hf_(const int K, const int Kp, const int M,
@ -1618,7 +1644,7 @@ __global__ void wvSplitKQ_hf_(const int K, const int Kp, const int M,
const int CuCount) {
UNREACHABLE_CODE
}
#endif // defined(__HIP__MI300__) TODO: Add NAVI support
#endif // defined(__HIP__MI3XX__) TODO: Add NAVI support
void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
at::Tensor& scale_a, at::Tensor& scale_b,
@ -1638,12 +1664,13 @@ void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
dim3 grid(CuCount);
const at::cuda::OptionalCUDAGuard device_guard(device_of(in_a));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int max_lds_len = get_lds_size();
#define WVSPLITKQ(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \
_N) \
{ \
dim3 block(64, _WvPrGrp); \
if ((K_in * N_in <= 64 * 1024) && (M_in % _YTILEs == 0)) { \
if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \
wvSplitKQ_hf_sml_<fptype, fp8_t, 64, _YTILEs, _WvPrGrp, 16, _UNRLs, _N> \
<<<grid, block, 0, stream>>>(K_in, Kp_in, M_in, a_ptr, b_ptr, c_ptr, \

View File

@ -8,7 +8,7 @@ from vllm.platforms import current_platform
# Using the default value (240.0) from pytorch will cause accuracy
# issue on dynamic quantization models. Here use 224.0 for rocm.
ROCM_FP8_MAX = 224.0
ROCM_FP8FNUZ_MAX = 224.0
FP8_DTYPE = current_platform.fp8_dtype()
@ -26,9 +26,11 @@ def ref_dynamic_per_token_quant(x: torch.tensor,
qtype_traits = torch.iinfo(quant_dtype) if quant_dtype == torch.int8 \
else torch.finfo(quant_dtype)
qtype_traits_max = ROCM_FP8_MAX if current_platform.is_rocm() \
qtype_traits_max = ROCM_FP8FNUZ_MAX if current_platform.is_rocm() \
and current_platform.is_fp8_fnuz() \
else qtype_traits.max
qtype_traits_min = -ROCM_FP8_MAX if current_platform.is_rocm() \
qtype_traits_min = -ROCM_FP8FNUZ_MAX if current_platform.is_rocm() \
and current_platform.is_fp8_fnuz() \
else qtype_traits.min
qtype_max = as_float32_tensor(qtype_traits_max)
s_1 = as_float32_tensor(1.0)
@ -70,9 +72,11 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
-> tuple[torch.tensor, torch.tensor]:
fp8_traits = torch.finfo(FP8_DTYPE)
fp8_traits_max = ROCM_FP8_MAX if current_platform.is_rocm() \
fp8_traits_max = ROCM_FP8FNUZ_MAX if current_platform.is_rocm() \
and current_platform.is_fp8_fnuz() \
else fp8_traits.max
fp8_traits_min = -ROCM_FP8_MAX if current_platform.is_rocm() \
fp8_traits_min = -ROCM_FP8FNUZ_MAX if current_platform.is_rocm() \
and current_platform.is_fp8_fnuz() \
else fp8_traits.min
fp8_max = as_float32_tensor(fp8_traits_max)
one = as_float32_tensor(1.0)

View File

@ -155,8 +155,8 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
scale_b: torch.Tensor, bias: torch.Tensor,
input_2d: torch.Tensor,
output_shape: list) -> torch.Tensor:
from vllm.platforms.rocm import on_mi250_mi300
if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi250_mi300(
from vllm.platforms.rocm import on_mi3xx
if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi3xx(
) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0:
output = ops.wvSplitKQ(weight.t(), qinput, out_dtype, scale_a, scale_b,
current_platform.get_cu_count())

View File

@ -70,9 +70,9 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
def rocm_unquantized_gemm(x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None):
from vllm.platforms.rocm import on_mi250_mi300
from vllm.platforms.rocm import on_gfx9
k = weight.shape[1]
use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi250_mi300() and \
use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and on_gfx9() and \
x.dtype in [torch.float16, torch.bfloat16] \
and k % 8 == 0 and bias is None)

View File

@ -105,9 +105,15 @@ def on_gfx1x() -> bool:
@cache
def on_mi250_mi300() -> bool:
def on_mi3xx() -> bool:
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942"])
return any(arch in GPU_ARCH for arch in ["gfx942", "gfx950"])
@cache
def on_gfx9() -> bool:
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
@cache