mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 07:14:58 +08:00
[ROCm][Kernel] Add gfx950 support for skinny gemms (#18010)
Signed-off-by: charlifu <charlifu@amd.com>
This commit is contained in:
parent
f2c3f66d59
commit
306d60401d
@ -13,14 +13,34 @@
|
||||
#include "dispatch_utils.h"
|
||||
#include "quantization/fp8/common.cuh"
|
||||
|
||||
#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx942__))
|
||||
#define __HIP__MI300_MI250__
|
||||
#if defined(__HIPCC__) && \
|
||||
(defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__))
|
||||
#define __HIP__GFX9__
|
||||
#endif
|
||||
|
||||
#if defined(__HIPCC__) && defined(__gfx942__)
|
||||
#define __HIP__MI300__
|
||||
#if defined(__HIPCC__) && (defined(__gfx942__) || defined(__gfx950__))
|
||||
#define __HIP__MI3XX__
|
||||
#endif
|
||||
|
||||
#if defined(__gfx950__)
|
||||
#define LDS_SIZE 160 * 1024
|
||||
#else
|
||||
#define LDS_SIZE 64 * 1024
|
||||
#endif
|
||||
|
||||
int get_lds_size() {
|
||||
static bool is_cached = false;
|
||||
static int result;
|
||||
if (is_cached == false) {
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
std::string device_arch = dprops->gcnArchName;
|
||||
size_t substring = device_arch.find("gfx95");
|
||||
result = (substring == std::string::npos ? 64 * 1024 : 160 * 1024);
|
||||
is_cached = true;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
#if defined(NDEBUG)
|
||||
#undef NDEBUG
|
||||
#include <assert.h>
|
||||
@ -267,7 +287,7 @@ torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b,
|
||||
V0 += (s.x + s.y); \
|
||||
}
|
||||
|
||||
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
|
||||
#if defined(__HIP__GFX9__) // TODO: Add NAVI support
|
||||
// This version targets cases where A[] fits LDS capacity
|
||||
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
||||
int UNRL, int N>
|
||||
@ -275,7 +295,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B,
|
||||
const scalar_t* __restrict__ A, scalar_t* C,
|
||||
const int _WvPrGrp, const int CuCount) {
|
||||
#if defined(__HIP__MI300__)
|
||||
constexpr int max_lds_len = LDS_SIZE / 2;
|
||||
#if defined(__HIP__MI3XX__)
|
||||
constexpr bool use_mfma = (std::is_same_v<scalar_t, __hip_bfloat16>);
|
||||
#else
|
||||
constexpr bool use_mfma = false;
|
||||
@ -295,13 +316,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
};
|
||||
|
||||
//----------------------------------------------------
|
||||
// Reserving 64 KB of LDS to have 1 WG / CU
|
||||
// Reserving 64/160 KB of LDS to have 1 WG / CU
|
||||
// Goal is to bring the activation matrix A to the LDS
|
||||
// and use it across the lifetime of the work group
|
||||
// TODO: When activation matrix is larger than 64 KB
|
||||
// then this is not goint to work!
|
||||
//----------------------------------------------------
|
||||
__shared__ scalar_t s[1024 * 32];
|
||||
__shared__ scalar_t s[max_lds_len];
|
||||
|
||||
//----------------------------------------------------
|
||||
// Fetch the activation matrix to LDS
|
||||
@ -312,11 +333,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
// - Then the WG will move to another 8 K elements
|
||||
// TODO: Logic below will only work when K is multiple of 8
|
||||
//----------------------------------------------------
|
||||
for (uint32_t k = 0; k < min(K * N, 32 * 1024);
|
||||
for (uint32_t k = 0; k < min(K * N, max_lds_len);
|
||||
k += THRDS * WvPrGrp * A_CHUNK) {
|
||||
uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK);
|
||||
|
||||
if (k_in >= min(K * N, 32 * 1024)) break;
|
||||
if (k_in >= min(K * N, max_lds_len)) break;
|
||||
|
||||
*((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in]));
|
||||
}
|
||||
@ -517,7 +538,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
m += CuCount * _WvPrGrp * YTILE;
|
||||
}
|
||||
}
|
||||
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
|
||||
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
|
||||
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
||||
int UNRL, int N>
|
||||
__global__ void wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B,
|
||||
@ -525,9 +546,9 @@ __global__ void wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B,
|
||||
const int _WvPrGrp, const int CuCount) {
|
||||
UNREACHABLE_CODE
|
||||
}
|
||||
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
|
||||
#endif // defined(__HIP__GFX9__) TODO: Add NAVI support
|
||||
|
||||
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
|
||||
#if defined(__HIP__GFX9__) // TODO: Add NAVI support
|
||||
// This version targets cases where A[] marginally exceeds LDS capacity
|
||||
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
||||
int UNRL, int N>
|
||||
@ -535,7 +556,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
wvSplitK_hf_(const int K, const int M, const scalar_t* B,
|
||||
const scalar_t* __restrict__ A, scalar_t* C,
|
||||
const int _WvPrGrp, const int CuCount) {
|
||||
#if defined(__HIP__MI300__)
|
||||
constexpr int max_lds_len = LDS_SIZE / 2;
|
||||
#if defined(__HIP__MI3XX__)
|
||||
constexpr bool use_mfma = (std::is_same_v<scalar_t, __hip_bfloat16>);
|
||||
#else
|
||||
constexpr bool use_mfma = false;
|
||||
@ -561,7 +583,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
// TODO: When activation matrix is larger than 64 KB
|
||||
// then this is not goint to work!
|
||||
//----------------------------------------------------
|
||||
__shared__ scalar_t s[1024 * 32];
|
||||
__shared__ scalar_t s[max_lds_len];
|
||||
|
||||
//----------------------------------------------------
|
||||
// Computation of columns that need to be committed to memory!
|
||||
@ -598,11 +620,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
// - Then the WG will move to another 8 K elements
|
||||
// TODO: Logic below will only work when K is multiple of 8
|
||||
//----------------------------------------------------
|
||||
for (uint32_t k = 0; k < min(K * N, 32 * 1024);
|
||||
for (uint32_t k = 0; k < min(K * N, max_lds_len);
|
||||
k += THRDS * WvPrGrp * A_CHUNK) {
|
||||
uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK);
|
||||
|
||||
if (k_in >= min(K * N, 32 * 1024)) break;
|
||||
if (k_in >= min(K * N, max_lds_len)) break;
|
||||
|
||||
*((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in]));
|
||||
}
|
||||
@ -686,7 +708,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
// Fetch A activation matrix in interleaved fashion from LDS or memory
|
||||
|
||||
for (int n = 0; n < N; n++) {
|
||||
if (k_ + K * n < 32 * 1024)
|
||||
if (k_ + K * n < max_lds_len)
|
||||
bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n])));
|
||||
else
|
||||
bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n])));
|
||||
@ -817,7 +839,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
}
|
||||
}
|
||||
|
||||
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
|
||||
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
|
||||
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
||||
int UNRL, int N>
|
||||
__global__ void wvSplitK_hf_(const int K, const int M, const scalar_t* B,
|
||||
@ -825,9 +847,9 @@ __global__ void wvSplitK_hf_(const int K, const int M, const scalar_t* B,
|
||||
const int _WvPrGrp, const int CuCount) {
|
||||
UNREACHABLE_CODE
|
||||
}
|
||||
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
|
||||
#endif // defined(__HIP__GFX9__) TODO: Add NAVI support
|
||||
|
||||
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
|
||||
#if defined(__HIP__GFX9__) // TODO: Add NAVI support
|
||||
// This version targets big A[] cases, where it is much larger than LDS capacity
|
||||
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
||||
int UNRL, int N>
|
||||
@ -835,7 +857,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
wvSplitK_hf_big_(const int K, const int M, const scalar_t* B,
|
||||
const scalar_t* __restrict__ A, scalar_t* C,
|
||||
const int _WvPrGrp, const int CuCount) {
|
||||
#if defined(__HIP__MI300__)
|
||||
constexpr int max_lds_len = LDS_SIZE / 2;
|
||||
#if defined(__HIP__MI3XX__)
|
||||
constexpr bool use_mfma = (std::is_same_v<scalar_t, __hip_bfloat16>);
|
||||
#else
|
||||
constexpr bool use_mfma = false;
|
||||
@ -855,13 +878,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
};
|
||||
|
||||
//----------------------------------------------------
|
||||
// Reserving 64 KB of LDS to have 1 WG / CU
|
||||
// Reserving 64/160 KB of LDS to have 1 WG / CU
|
||||
// Goal is to bring the activation matrix A to the LDS
|
||||
// and use it across the lifetime of the work group
|
||||
// TODO: When activation matrix is larger than 64 KB
|
||||
// then this is not goint to work!
|
||||
//----------------------------------------------------
|
||||
__shared__ scalar_t s[1024 * 32];
|
||||
__shared__ scalar_t s[max_lds_len];
|
||||
|
||||
//----------------------------------------------------
|
||||
// Computation of columns that need to be committed to memory!
|
||||
@ -902,11 +925,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
//----------------------------------------------------
|
||||
#define PCML
|
||||
#ifndef PCML
|
||||
for (uint32_t k = 0; k < min(K * N, 32 * 1024);
|
||||
for (uint32_t k = 0; k < min(K * N, max_lds_len);
|
||||
k += THRDS * WvPrGrp * A_CHUNK) {
|
||||
uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK);
|
||||
|
||||
if (k_in >= min(K * N, 32 * 1024)) break;
|
||||
if (k_in >= min(K * N, max_lds_len)) break;
|
||||
|
||||
*((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in]));
|
||||
}
|
||||
@ -916,7 +939,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
#define TUC (THRDS * UNRL * A_CHUNK)
|
||||
uint32_t kBase = 0;
|
||||
// find biggest k size that fits in LDS
|
||||
uint32_t kFit = (32 * 1024) / N;
|
||||
uint32_t kFit = (max_lds_len) / N;
|
||||
// kFit = (kFit%TWC==0) ? kFit : (kFit-kFit%TWC+TWC); //round up to multiple
|
||||
// of TUC
|
||||
kFit = (kFit % TUC == 0)
|
||||
@ -1164,7 +1187,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
}
|
||||
}
|
||||
}
|
||||
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
|
||||
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
|
||||
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
||||
int UNRL, int N>
|
||||
__global__ void wvSplitK_hf_big_(const int K, const int M, const scalar_t* B,
|
||||
@ -1172,7 +1195,7 @@ __global__ void wvSplitK_hf_big_(const int K, const int M, const scalar_t* B,
|
||||
const int _WvPrGrp, const int CuCount) {
|
||||
UNREACHABLE_CODE
|
||||
}
|
||||
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
|
||||
#endif // defined(__HIP__GFX9__) TODO: Add NAVI support
|
||||
|
||||
int mindiv(int N, int div1, int div2) {
|
||||
int nPrRnd = div1 * div2;
|
||||
@ -1222,17 +1245,18 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(in_a));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
const int max_lds_len = get_lds_size() / 2;
|
||||
|
||||
#define WVSPLITK(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \
|
||||
_N) \
|
||||
{ \
|
||||
dim3 block(64, _WvPrGrp); \
|
||||
if ((K_in * N_in <= 32 * 1024) && (M_in % _YTILEs == 0)) { \
|
||||
if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \
|
||||
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \
|
||||
wvSplitK_hf_sml_<fptype, 64, _YTILEs, _WvPrGrp, 8, _UNRLs, _N> \
|
||||
<<<grid, block, 0, stream>>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \
|
||||
CuCount); \
|
||||
} else if (K_in * N_in <= 32 * 1024 * 1.2) { \
|
||||
} else if (K_in * N_in <= max_lds_len * 1.2) { \
|
||||
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \
|
||||
wvSplitK_hf_<fptype, 64, _YTILEm, _WvPrGrp, 8, _UNRLm, _N> \
|
||||
<<<grid, block, 0, stream>>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \
|
||||
@ -1272,7 +1296,7 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
|
||||
return out_c;
|
||||
}
|
||||
|
||||
#if defined(__HIP__MI300__) // TODO: Add NAVI support
|
||||
#if defined(__HIP__MI3XX__) // TODO: Add NAVI support
|
||||
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
|
||||
int A_CHUNK, int UNRL, int N>
|
||||
__global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
@ -1281,6 +1305,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
const float* __restrict__ s_A,
|
||||
const float* __restrict__ s_B, const int _WvPrGrp,
|
||||
const int CuCount) {
|
||||
constexpr int max_lds_len = LDS_SIZE;
|
||||
using scalar8 =
|
||||
__attribute__((__vector_size__((A_CHUNK / 4) * sizeof(float)))) float;
|
||||
using intx2 = __attribute__((__vector_size__(2 * sizeof(int)))) int;
|
||||
@ -1296,10 +1321,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
scalar8 h8;
|
||||
};
|
||||
|
||||
__shared__ fp8_t s[1024 * 64];
|
||||
__shared__ fp8_t s[max_lds_len];
|
||||
|
||||
for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK;
|
||||
k < min(K * N, 64 * 1024); k += THRDS * WvPrGrp * A_CHUNK) {
|
||||
k < min(K * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) {
|
||||
*((bigType*)(&s[k])) = *((bigType*)(&A[k]));
|
||||
}
|
||||
__syncthreads();
|
||||
@ -1436,7 +1461,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
m += CuCount * _WvPrGrp * YTILE;
|
||||
}
|
||||
}
|
||||
#else // !defined(__HIP__MI300__) TODO: Add NAVI support
|
||||
#else // !defined(__HIP__MI3XX__) TODO: Add NAVI support
|
||||
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
|
||||
int A_CHUNK, int UNRL, int N>
|
||||
__global__ void wvSplitKQ_hf_sml_(const int K, const int Kp, const int M,
|
||||
@ -1446,9 +1471,9 @@ __global__ void wvSplitKQ_hf_sml_(const int K, const int Kp, const int M,
|
||||
const int _WvPrGrp, const int CuCount) {
|
||||
UNREACHABLE_CODE
|
||||
}
|
||||
#endif // defined(__HIP__MI300__) TODO: Add NAVI support
|
||||
#endif // defined(__HIP__MI3XX__) TODO: Add NAVI support
|
||||
|
||||
#if defined(__HIP__MI300__) // TODO: Add NAVI support
|
||||
#if defined(__HIP__MI3XX__) // TODO: Add NAVI support
|
||||
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
|
||||
int A_CHUNK, int UNRL, int N>
|
||||
__global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
@ -1456,6 +1481,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
const fp8_t* __restrict__ A, scalar_t* C,
|
||||
const float* __restrict__ s_A, const float* __restrict__ s_B,
|
||||
const int _WvPrGrp, const int CuCount) {
|
||||
constexpr int max_lds_len = LDS_SIZE;
|
||||
using scalar8 =
|
||||
__attribute__((__vector_size__((A_CHUNK / 4) * sizeof(float)))) float;
|
||||
using intx2 = __attribute__((__vector_size__(2 * sizeof(int)))) int;
|
||||
@ -1471,10 +1497,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
scalar8 h8;
|
||||
};
|
||||
|
||||
__shared__ fp8_t s[1024 * 64];
|
||||
__shared__ fp8_t s[max_lds_len];
|
||||
|
||||
for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK;
|
||||
k < min(K * N, 64 * 1024); k += THRDS * WvPrGrp * A_CHUNK) {
|
||||
k < min(K * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) {
|
||||
*((bigType*)(&s[k])) = *((bigType*)(&A[k]));
|
||||
}
|
||||
__syncthreads();
|
||||
@ -1517,7 +1543,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
uint32_t k_ = k + threadIdx.x * A_CHUNK;
|
||||
if (k_ >= K) break;
|
||||
for (int n = 0; n < N; n++) {
|
||||
if (k_ + K * n < 64 * 1024)
|
||||
if (k_ + K * n < max_lds_len)
|
||||
bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n])));
|
||||
else
|
||||
bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n])));
|
||||
@ -1608,7 +1634,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
m += CuCount * _WvPrGrp * YTILE;
|
||||
}
|
||||
}
|
||||
#else // !defined(__HIP__MI300__) TODO: Add NAVI support
|
||||
#else // !defined(__HIP__MI3XX__) TODO: Add NAVI support
|
||||
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
|
||||
int A_CHUNK, int UNRL, int N>
|
||||
__global__ void wvSplitKQ_hf_(const int K, const int Kp, const int M,
|
||||
@ -1618,7 +1644,7 @@ __global__ void wvSplitKQ_hf_(const int K, const int Kp, const int M,
|
||||
const int CuCount) {
|
||||
UNREACHABLE_CODE
|
||||
}
|
||||
#endif // defined(__HIP__MI300__) TODO: Add NAVI support
|
||||
#endif // defined(__HIP__MI3XX__) TODO: Add NAVI support
|
||||
|
||||
void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
|
||||
at::Tensor& scale_a, at::Tensor& scale_b,
|
||||
@ -1638,12 +1664,13 @@ void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
|
||||
dim3 grid(CuCount);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(in_a));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
const int max_lds_len = get_lds_size();
|
||||
|
||||
#define WVSPLITKQ(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \
|
||||
_N) \
|
||||
{ \
|
||||
dim3 block(64, _WvPrGrp); \
|
||||
if ((K_in * N_in <= 64 * 1024) && (M_in % _YTILEs == 0)) { \
|
||||
if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \
|
||||
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \
|
||||
wvSplitKQ_hf_sml_<fptype, fp8_t, 64, _YTILEs, _WvPrGrp, 16, _UNRLs, _N> \
|
||||
<<<grid, block, 0, stream>>>(K_in, Kp_in, M_in, a_ptr, b_ptr, c_ptr, \
|
||||
|
||||
@ -8,7 +8,7 @@ from vllm.platforms import current_platform
|
||||
|
||||
# Using the default value (240.0) from pytorch will cause accuracy
|
||||
# issue on dynamic quantization models. Here use 224.0 for rocm.
|
||||
ROCM_FP8_MAX = 224.0
|
||||
ROCM_FP8FNUZ_MAX = 224.0
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
|
||||
@ -26,9 +26,11 @@ def ref_dynamic_per_token_quant(x: torch.tensor,
|
||||
|
||||
qtype_traits = torch.iinfo(quant_dtype) if quant_dtype == torch.int8 \
|
||||
else torch.finfo(quant_dtype)
|
||||
qtype_traits_max = ROCM_FP8_MAX if current_platform.is_rocm() \
|
||||
qtype_traits_max = ROCM_FP8FNUZ_MAX if current_platform.is_rocm() \
|
||||
and current_platform.is_fp8_fnuz() \
|
||||
else qtype_traits.max
|
||||
qtype_traits_min = -ROCM_FP8_MAX if current_platform.is_rocm() \
|
||||
qtype_traits_min = -ROCM_FP8FNUZ_MAX if current_platform.is_rocm() \
|
||||
and current_platform.is_fp8_fnuz() \
|
||||
else qtype_traits.min
|
||||
qtype_max = as_float32_tensor(qtype_traits_max)
|
||||
s_1 = as_float32_tensor(1.0)
|
||||
@ -70,9 +72,11 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
|
||||
-> tuple[torch.tensor, torch.tensor]:
|
||||
|
||||
fp8_traits = torch.finfo(FP8_DTYPE)
|
||||
fp8_traits_max = ROCM_FP8_MAX if current_platform.is_rocm() \
|
||||
fp8_traits_max = ROCM_FP8FNUZ_MAX if current_platform.is_rocm() \
|
||||
and current_platform.is_fp8_fnuz() \
|
||||
else fp8_traits.max
|
||||
fp8_traits_min = -ROCM_FP8_MAX if current_platform.is_rocm() \
|
||||
fp8_traits_min = -ROCM_FP8FNUZ_MAX if current_platform.is_rocm() \
|
||||
and current_platform.is_fp8_fnuz() \
|
||||
else fp8_traits.min
|
||||
fp8_max = as_float32_tensor(fp8_traits_max)
|
||||
one = as_float32_tensor(1.0)
|
||||
|
||||
@ -155,8 +155,8 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
|
||||
scale_b: torch.Tensor, bias: torch.Tensor,
|
||||
input_2d: torch.Tensor,
|
||||
output_shape: list) -> torch.Tensor:
|
||||
from vllm.platforms.rocm import on_mi250_mi300
|
||||
if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi250_mi300(
|
||||
from vllm.platforms.rocm import on_mi3xx
|
||||
if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi3xx(
|
||||
) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0:
|
||||
output = ops.wvSplitKQ(weight.t(), qinput, out_dtype, scale_a, scale_b,
|
||||
current_platform.get_cu_count())
|
||||
|
||||
@ -70,9 +70,9 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
|
||||
def rocm_unquantized_gemm(x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None):
|
||||
from vllm.platforms.rocm import on_mi250_mi300
|
||||
from vllm.platforms.rocm import on_gfx9
|
||||
k = weight.shape[1]
|
||||
use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi250_mi300() and \
|
||||
use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and on_gfx9() and \
|
||||
x.dtype in [torch.float16, torch.bfloat16] \
|
||||
and k % 8 == 0 and bias is None)
|
||||
|
||||
|
||||
@ -105,9 +105,15 @@ def on_gfx1x() -> bool:
|
||||
|
||||
|
||||
@cache
|
||||
def on_mi250_mi300() -> bool:
|
||||
def on_mi3xx() -> bool:
|
||||
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
|
||||
return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942"])
|
||||
return any(arch in GPU_ARCH for arch in ["gfx942", "gfx950"])
|
||||
|
||||
|
||||
@cache
|
||||
def on_gfx9() -> bool:
|
||||
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
|
||||
return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
|
||||
|
||||
|
||||
@cache
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user