From 306d60401dbd066f64298e02ca73d4f2075d7bf6 Mon Sep 17 00:00:00 2001 From: Charlie Fu Date: Sat, 31 May 2025 09:40:05 -0500 Subject: [PATCH] [ROCm][Kernel] Add gfx950 support for skinny gemms (#18010) Signed-off-by: charlifu --- csrc/rocm/skinny_gemms.cu | 113 +++++++++++------- tests/kernels/quant_utils.py | 14 ++- .../layers/quantization/utils/w8a8_utils.py | 4 +- vllm/model_executor/layers/utils.py | 4 +- vllm/platforms/rocm.py | 10 +- 5 files changed, 91 insertions(+), 54 deletions(-) diff --git a/csrc/rocm/skinny_gemms.cu b/csrc/rocm/skinny_gemms.cu index b3717892db78..e31aa0162628 100644 --- a/csrc/rocm/skinny_gemms.cu +++ b/csrc/rocm/skinny_gemms.cu @@ -13,14 +13,34 @@ #include "dispatch_utils.h" #include "quantization/fp8/common.cuh" -#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx942__)) - #define __HIP__MI300_MI250__ +#if defined(__HIPCC__) && \ + (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) + #define __HIP__GFX9__ #endif -#if defined(__HIPCC__) && defined(__gfx942__) - #define __HIP__MI300__ +#if defined(__HIPCC__) && (defined(__gfx942__) || defined(__gfx950__)) + #define __HIP__MI3XX__ #endif +#if defined(__gfx950__) + #define LDS_SIZE 160 * 1024 +#else + #define LDS_SIZE 64 * 1024 +#endif + +int get_lds_size() { + static bool is_cached = false; + static int result; + if (is_cached == false) { + auto dprops = at::cuda::getCurrentDeviceProperties(); + std::string device_arch = dprops->gcnArchName; + size_t substring = device_arch.find("gfx95"); + result = (substring == std::string::npos ? 64 * 1024 : 160 * 1024); + is_cached = true; + } + return result; +} + #if defined(NDEBUG) #undef NDEBUG #include @@ -267,7 +287,7 @@ torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b, V0 += (s.x + s.y); \ } -#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support +#if defined(__HIP__GFX9__) // TODO: Add NAVI support // This version targets cases where A[] fits LDS capacity template @@ -275,7 +295,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B, const scalar_t* __restrict__ A, scalar_t* C, const int _WvPrGrp, const int CuCount) { - #if defined(__HIP__MI300__) + constexpr int max_lds_len = LDS_SIZE / 2; + #if defined(__HIP__MI3XX__) constexpr bool use_mfma = (std::is_same_v); #else constexpr bool use_mfma = false; @@ -295,13 +316,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) }; //---------------------------------------------------- - // Reserving 64 KB of LDS to have 1 WG / CU + // Reserving 64/160 KB of LDS to have 1 WG / CU // Goal is to bring the activation matrix A to the LDS // and use it across the lifetime of the work group // TODO: When activation matrix is larger than 64 KB // then this is not goint to work! //---------------------------------------------------- - __shared__ scalar_t s[1024 * 32]; + __shared__ scalar_t s[max_lds_len]; //---------------------------------------------------- // Fetch the activation matrix to LDS @@ -312,11 +333,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // - Then the WG will move to another 8 K elements // TODO: Logic below will only work when K is multiple of 8 //---------------------------------------------------- - for (uint32_t k = 0; k < min(K * N, 32 * 1024); + for (uint32_t k = 0; k < min(K * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) { uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); - if (k_in >= min(K * N, 32 * 1024)) break; + if (k_in >= min(K * N, max_lds_len)) break; *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); } @@ -517,7 +538,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) m += CuCount * _WvPrGrp * YTILE; } } -#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support +#else // !defined(__HIP__GFX9__) TODO: Add NAVI support template __global__ void wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B, @@ -525,9 +546,9 @@ __global__ void wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B, const int _WvPrGrp, const int CuCount) { UNREACHABLE_CODE } -#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support +#endif // defined(__HIP__GFX9__) TODO: Add NAVI support -#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support +#if defined(__HIP__GFX9__) // TODO: Add NAVI support // This version targets cases where A[] marginally exceeds LDS capacity template @@ -535,7 +556,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) wvSplitK_hf_(const int K, const int M, const scalar_t* B, const scalar_t* __restrict__ A, scalar_t* C, const int _WvPrGrp, const int CuCount) { - #if defined(__HIP__MI300__) + constexpr int max_lds_len = LDS_SIZE / 2; + #if defined(__HIP__MI3XX__) constexpr bool use_mfma = (std::is_same_v); #else constexpr bool use_mfma = false; @@ -561,7 +583,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // TODO: When activation matrix is larger than 64 KB // then this is not goint to work! //---------------------------------------------------- - __shared__ scalar_t s[1024 * 32]; + __shared__ scalar_t s[max_lds_len]; //---------------------------------------------------- // Computation of columns that need to be committed to memory! @@ -598,11 +620,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // - Then the WG will move to another 8 K elements // TODO: Logic below will only work when K is multiple of 8 //---------------------------------------------------- - for (uint32_t k = 0; k < min(K * N, 32 * 1024); + for (uint32_t k = 0; k < min(K * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) { uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); - if (k_in >= min(K * N, 32 * 1024)) break; + if (k_in >= min(K * N, max_lds_len)) break; *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); } @@ -686,7 +708,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // Fetch A activation matrix in interleaved fashion from LDS or memory for (int n = 0; n < N; n++) { - if (k_ + K * n < 32 * 1024) + if (k_ + K * n < max_lds_len) bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n]))); else bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n]))); @@ -817,7 +839,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } } -#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support +#else // !defined(__HIP__GFX9__) TODO: Add NAVI support template __global__ void wvSplitK_hf_(const int K, const int M, const scalar_t* B, @@ -825,9 +847,9 @@ __global__ void wvSplitK_hf_(const int K, const int M, const scalar_t* B, const int _WvPrGrp, const int CuCount) { UNREACHABLE_CODE } -#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support +#endif // defined(__HIP__GFX9__) TODO: Add NAVI support -#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support +#if defined(__HIP__GFX9__) // TODO: Add NAVI support // This version targets big A[] cases, where it is much larger than LDS capacity template @@ -835,7 +857,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) wvSplitK_hf_big_(const int K, const int M, const scalar_t* B, const scalar_t* __restrict__ A, scalar_t* C, const int _WvPrGrp, const int CuCount) { - #if defined(__HIP__MI300__) + constexpr int max_lds_len = LDS_SIZE / 2; + #if defined(__HIP__MI3XX__) constexpr bool use_mfma = (std::is_same_v); #else constexpr bool use_mfma = false; @@ -855,13 +878,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) }; //---------------------------------------------------- - // Reserving 64 KB of LDS to have 1 WG / CU + // Reserving 64/160 KB of LDS to have 1 WG / CU // Goal is to bring the activation matrix A to the LDS // and use it across the lifetime of the work group // TODO: When activation matrix is larger than 64 KB // then this is not goint to work! //---------------------------------------------------- - __shared__ scalar_t s[1024 * 32]; + __shared__ scalar_t s[max_lds_len]; //---------------------------------------------------- // Computation of columns that need to be committed to memory! @@ -902,11 +925,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) //---------------------------------------------------- #define PCML #ifndef PCML - for (uint32_t k = 0; k < min(K * N, 32 * 1024); + for (uint32_t k = 0; k < min(K * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) { uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); - if (k_in >= min(K * N, 32 * 1024)) break; + if (k_in >= min(K * N, max_lds_len)) break; *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); } @@ -916,7 +939,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) #define TUC (THRDS * UNRL * A_CHUNK) uint32_t kBase = 0; // find biggest k size that fits in LDS - uint32_t kFit = (32 * 1024) / N; + uint32_t kFit = (max_lds_len) / N; // kFit = (kFit%TWC==0) ? kFit : (kFit-kFit%TWC+TWC); //round up to multiple // of TUC kFit = (kFit % TUC == 0) @@ -1164,7 +1187,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } } } -#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support +#else // !defined(__HIP__GFX9__) TODO: Add NAVI support template __global__ void wvSplitK_hf_big_(const int K, const int M, const scalar_t* B, @@ -1172,7 +1195,7 @@ __global__ void wvSplitK_hf_big_(const int K, const int M, const scalar_t* B, const int _WvPrGrp, const int CuCount) { UNREACHABLE_CODE } -#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support +#endif // defined(__HIP__GFX9__) TODO: Add NAVI support int mindiv(int N, int div1, int div2) { int nPrRnd = div1 * div2; @@ -1222,17 +1245,18 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b, const at::cuda::OptionalCUDAGuard device_guard(device_of(in_a)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const int max_lds_len = get_lds_size() / 2; #define WVSPLITK(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \ _N) \ { \ dim3 block(64, _WvPrGrp); \ - if ((K_in * N_in <= 32 * 1024) && (M_in % _YTILEs == 0)) { \ + if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \ int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \ wvSplitK_hf_sml_ \ <<>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \ CuCount); \ - } else if (K_in * N_in <= 32 * 1024 * 1.2) { \ + } else if (K_in * N_in <= max_lds_len * 1.2) { \ int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \ wvSplitK_hf_ \ <<>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \ @@ -1272,7 +1296,7 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b, return out_c; } -#if defined(__HIP__MI300__) // TODO: Add NAVI support +#if defined(__HIP__MI3XX__) // TODO: Add NAVI support template __global__ void __launch_bounds__(WvPrGrp* THRDS) @@ -1281,6 +1305,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) const float* __restrict__ s_A, const float* __restrict__ s_B, const int _WvPrGrp, const int CuCount) { + constexpr int max_lds_len = LDS_SIZE; using scalar8 = __attribute__((__vector_size__((A_CHUNK / 4) * sizeof(float)))) float; using intx2 = __attribute__((__vector_size__(2 * sizeof(int)))) int; @@ -1296,10 +1321,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) scalar8 h8; }; - __shared__ fp8_t s[1024 * 64]; + __shared__ fp8_t s[max_lds_len]; for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK; - k < min(K * N, 64 * 1024); k += THRDS * WvPrGrp * A_CHUNK) { + k < min(K * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) { *((bigType*)(&s[k])) = *((bigType*)(&A[k])); } __syncthreads(); @@ -1436,7 +1461,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) m += CuCount * _WvPrGrp * YTILE; } } -#else // !defined(__HIP__MI300__) TODO: Add NAVI support +#else // !defined(__HIP__MI3XX__) TODO: Add NAVI support template __global__ void wvSplitKQ_hf_sml_(const int K, const int Kp, const int M, @@ -1446,9 +1471,9 @@ __global__ void wvSplitKQ_hf_sml_(const int K, const int Kp, const int M, const int _WvPrGrp, const int CuCount) { UNREACHABLE_CODE } -#endif // defined(__HIP__MI300__) TODO: Add NAVI support +#endif // defined(__HIP__MI3XX__) TODO: Add NAVI support -#if defined(__HIP__MI300__) // TODO: Add NAVI support +#if defined(__HIP__MI3XX__) // TODO: Add NAVI support template __global__ void __launch_bounds__(WvPrGrp* THRDS) @@ -1456,6 +1481,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) const fp8_t* __restrict__ A, scalar_t* C, const float* __restrict__ s_A, const float* __restrict__ s_B, const int _WvPrGrp, const int CuCount) { + constexpr int max_lds_len = LDS_SIZE; using scalar8 = __attribute__((__vector_size__((A_CHUNK / 4) * sizeof(float)))) float; using intx2 = __attribute__((__vector_size__(2 * sizeof(int)))) int; @@ -1471,10 +1497,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) scalar8 h8; }; - __shared__ fp8_t s[1024 * 64]; + __shared__ fp8_t s[max_lds_len]; for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK; - k < min(K * N, 64 * 1024); k += THRDS * WvPrGrp * A_CHUNK) { + k < min(K * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) { *((bigType*)(&s[k])) = *((bigType*)(&A[k])); } __syncthreads(); @@ -1517,7 +1543,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) uint32_t k_ = k + threadIdx.x * A_CHUNK; if (k_ >= K) break; for (int n = 0; n < N; n++) { - if (k_ + K * n < 64 * 1024) + if (k_ + K * n < max_lds_len) bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n]))); else bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n]))); @@ -1608,7 +1634,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) m += CuCount * _WvPrGrp * YTILE; } } -#else // !defined(__HIP__MI300__) TODO: Add NAVI support +#else // !defined(__HIP__MI3XX__) TODO: Add NAVI support template __global__ void wvSplitKQ_hf_(const int K, const int Kp, const int M, @@ -1618,7 +1644,7 @@ __global__ void wvSplitKQ_hf_(const int K, const int Kp, const int M, const int CuCount) { UNREACHABLE_CODE } -#endif // defined(__HIP__MI300__) TODO: Add NAVI support +#endif // defined(__HIP__MI3XX__) TODO: Add NAVI support void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, at::Tensor& scale_a, at::Tensor& scale_b, @@ -1638,12 +1664,13 @@ void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, dim3 grid(CuCount); const at::cuda::OptionalCUDAGuard device_guard(device_of(in_a)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const int max_lds_len = get_lds_size(); #define WVSPLITKQ(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \ _N) \ { \ dim3 block(64, _WvPrGrp); \ - if ((K_in * N_in <= 64 * 1024) && (M_in % _YTILEs == 0)) { \ + if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \ int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \ wvSplitKQ_hf_sml_ \ <<>>(K_in, Kp_in, M_in, a_ptr, b_ptr, c_ptr, \ diff --git a/tests/kernels/quant_utils.py b/tests/kernels/quant_utils.py index 764924f26783..892309a017e4 100644 --- a/tests/kernels/quant_utils.py +++ b/tests/kernels/quant_utils.py @@ -8,7 +8,7 @@ from vllm.platforms import current_platform # Using the default value (240.0) from pytorch will cause accuracy # issue on dynamic quantization models. Here use 224.0 for rocm. -ROCM_FP8_MAX = 224.0 +ROCM_FP8FNUZ_MAX = 224.0 FP8_DTYPE = current_platform.fp8_dtype() @@ -26,9 +26,11 @@ def ref_dynamic_per_token_quant(x: torch.tensor, qtype_traits = torch.iinfo(quant_dtype) if quant_dtype == torch.int8 \ else torch.finfo(quant_dtype) - qtype_traits_max = ROCM_FP8_MAX if current_platform.is_rocm() \ + qtype_traits_max = ROCM_FP8FNUZ_MAX if current_platform.is_rocm() \ + and current_platform.is_fp8_fnuz() \ else qtype_traits.max - qtype_traits_min = -ROCM_FP8_MAX if current_platform.is_rocm() \ + qtype_traits_min = -ROCM_FP8FNUZ_MAX if current_platform.is_rocm() \ + and current_platform.is_fp8_fnuz() \ else qtype_traits.min qtype_max = as_float32_tensor(qtype_traits_max) s_1 = as_float32_tensor(1.0) @@ -70,9 +72,11 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \ -> tuple[torch.tensor, torch.tensor]: fp8_traits = torch.finfo(FP8_DTYPE) - fp8_traits_max = ROCM_FP8_MAX if current_platform.is_rocm() \ + fp8_traits_max = ROCM_FP8FNUZ_MAX if current_platform.is_rocm() \ + and current_platform.is_fp8_fnuz() \ else fp8_traits.max - fp8_traits_min = -ROCM_FP8_MAX if current_platform.is_rocm() \ + fp8_traits_min = -ROCM_FP8FNUZ_MAX if current_platform.is_rocm() \ + and current_platform.is_fp8_fnuz() \ else fp8_traits.min fp8_max = as_float32_tensor(fp8_traits_max) one = as_float32_tensor(1.0) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 4b041cff2ecc..eed8998fe3da 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -155,8 +155,8 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor, input_2d: torch.Tensor, output_shape: list) -> torch.Tensor: - from vllm.platforms.rocm import on_mi250_mi300 - if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi250_mi300( + from vllm.platforms.rocm import on_mi3xx + if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi3xx( ) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0: output = ops.wvSplitKQ(weight.t(), qinput, out_dtype, scale_a, scale_b, current_platform.get_cu_count()) diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index 18783d0d7785..001e6aaf0cc7 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -70,9 +70,9 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, def rocm_unquantized_gemm(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None): - from vllm.platforms.rocm import on_mi250_mi300 + from vllm.platforms.rocm import on_gfx9 k = weight.shape[1] - use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi250_mi300() and \ + use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and on_gfx9() and \ x.dtype in [torch.float16, torch.bfloat16] \ and k % 8 == 0 and bias is None) diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 06ee8614d1f0..ef1c632a5398 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -105,9 +105,15 @@ def on_gfx1x() -> bool: @cache -def on_mi250_mi300() -> bool: +def on_mi3xx() -> bool: GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName - return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942"]) + return any(arch in GPU_ARCH for arch in ["gfx942", "gfx950"]) + + +@cache +def on_gfx9() -> bool: + GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName + return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"]) @cache