mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:35:01 +08:00
[ROCM] Add gfx950 to the custom attention archs (#16034)
Signed-off-by: jpvillam <Juan.Villamizar@amd.com> Signed-off-by: seungrokjung <seungrok.jung@amd.com> Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Co-authored-by: seungrokjung <seungrok.jung@amd.com> Co-authored-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
parent
9b1769dd9a
commit
811a6c0972
@ -25,8 +25,9 @@
|
||||
#include "../attention/dtype_fp8.cuh"
|
||||
#include "../quantization/fp8/amd/quant_utils.cuh"
|
||||
|
||||
#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx942__))
|
||||
#define __HIP__MI300_MI250__
|
||||
#if defined(__HIPCC__) && \
|
||||
(defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__))
|
||||
#define __HIP__GFX9__
|
||||
#endif
|
||||
|
||||
#if defined(NDEBUG)
|
||||
@ -42,7 +43,7 @@
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
|
||||
|
||||
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
|
||||
#if defined(__HIP__GFX9__) // TODO: Add NAVI support
|
||||
|
||||
#define GCN_MFMA_INSTR1 __builtin_amdgcn_mfma_f32_16x16x4f32
|
||||
#define GCN_MFMA_INSTR __builtin_amdgcn_mfma_f32_4x4x4f16
|
||||
@ -1479,7 +1480,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
|
||||
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
|
||||
|
||||
// clang-format off
|
||||
template <typename scalar_t, typename cache_t,
|
||||
@ -1552,7 +1553,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
||||
}
|
||||
// clang-format on
|
||||
|
||||
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
|
||||
#endif // defined(__HIP__GFX9__) TODO: Add NAVI support
|
||||
|
||||
#define LAUNCH_CUSTOM_ATTENTION_MFMA16(GQA_RATIO) \
|
||||
paged_attention_ll4mi_QKV_mfma16_kernel<T, KVT, KV_DTYPE, OUTT, BLOCK_SIZE, \
|
||||
|
||||
@ -106,11 +106,14 @@ def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
|
||||
max_seq_len: int,
|
||||
sliding_window: int) -> bool:
|
||||
|
||||
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
|
||||
ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
|
||||
|
||||
# rocm custom page attention not support on gfx1*
|
||||
# custom paged attn always supported on V0. On V1, requires sliding window
|
||||
# disabled due to observed numerical discrepancy.
|
||||
return (on_mi250_mi300() and (not envs.VLLM_USE_V1 or sliding_window == 0
|
||||
or sliding_window == (-1, -1))
|
||||
return (ON_GFX9 and (not envs.VLLM_USE_V1 or sliding_window == 0
|
||||
or sliding_window == (-1, -1))
|
||||
and (qtype == torch.half or qtype == torch.bfloat16)
|
||||
and (head_size == 64 or head_size == 128)
|
||||
and (block_size == 16 or block_size == 32)
|
||||
@ -337,7 +340,7 @@ class RocmPlatform(Platform):
|
||||
def use_custom_allreduce(cls) -> bool:
|
||||
# We only enable custom allreduce for MI300 series
|
||||
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
|
||||
supported_archs = ['gfx94']
|
||||
supported_archs = ['gfx94', 'gfx95']
|
||||
return any(gfx in gcn_arch for gfx in supported_archs)
|
||||
|
||||
@classmethod
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user