[ROCM] Add gfx950 to the custom attention archs (#16034)

Signed-off-by: jpvillam <Juan.Villamizar@amd.com>
Signed-off-by: seungrokjung <seungrok.jung@amd.com>
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Co-authored-by: seungrokjung <seungrok.jung@amd.com>
Co-authored-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
Juan Villamizar 2025-05-01 13:18:28 -05:00 committed by GitHub
parent 9b1769dd9a
commit 811a6c0972
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 12 additions and 8 deletions

View File

@ -25,8 +25,9 @@
#include "../attention/dtype_fp8.cuh"
#include "../quantization/fp8/amd/quant_utils.cuh"
#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx942__))
#define __HIP__MI300_MI250__
#if defined(__HIPCC__) && \
(defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__))
#define __HIP__GFX9__
#endif
#if defined(NDEBUG)
@ -42,7 +43,7 @@
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
#if defined(__HIP__GFX9__) // TODO: Add NAVI support
#define GCN_MFMA_INSTR1 __builtin_amdgcn_mfma_f32_16x16x4f32
#define GCN_MFMA_INSTR __builtin_amdgcn_mfma_f32_4x4x4f16
@ -1479,7 +1480,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
}
}
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
// clang-format off
template <typename scalar_t, typename cache_t,
@ -1552,7 +1553,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
}
// clang-format on
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
#endif // defined(__HIP__GFX9__) TODO: Add NAVI support
#define LAUNCH_CUSTOM_ATTENTION_MFMA16(GQA_RATIO) \
paged_attention_ll4mi_QKV_mfma16_kernel<T, KVT, KV_DTYPE, OUTT, BLOCK_SIZE, \

View File

@ -106,11 +106,14 @@ def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
max_seq_len: int,
sliding_window: int) -> bool:
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
# rocm custom page attention not support on gfx1*
# custom paged attn always supported on V0. On V1, requires sliding window
# disabled due to observed numerical discrepancy.
return (on_mi250_mi300() and (not envs.VLLM_USE_V1 or sliding_window == 0
or sliding_window == (-1, -1))
return (ON_GFX9 and (not envs.VLLM_USE_V1 or sliding_window == 0
or sliding_window == (-1, -1))
and (qtype == torch.half or qtype == torch.bfloat16)
and (head_size == 64 or head_size == 128)
and (block_size == 16 or block_size == 32)
@ -337,7 +340,7 @@ class RocmPlatform(Platform):
def use_custom_allreduce(cls) -> bool:
# We only enable custom allreduce for MI300 series
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
supported_archs = ['gfx94']
supported_archs = ['gfx94', 'gfx95']
return any(gfx in gcn_arch for gfx in supported_archs)
@classmethod