diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index 2c3cae95e7f55..292352649163d 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -25,8 +25,9 @@ #include "../attention/dtype_fp8.cuh" #include "../quantization/fp8/amd/quant_utils.cuh" -#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx942__)) - #define __HIP__MI300_MI250__ +#if defined(__HIPCC__) && \ + (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) + #define __HIP__GFX9__ #endif #if defined(NDEBUG) @@ -42,7 +43,7 @@ #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) -#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support +#if defined(__HIP__GFX9__) // TODO: Add NAVI support #define GCN_MFMA_INSTR1 __builtin_amdgcn_mfma_f32_16x16x4f32 #define GCN_MFMA_INSTR __builtin_amdgcn_mfma_f32_4x4x4f16 @@ -1479,7 +1480,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( } } -#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support +#else // !defined(__HIP__GFX9__) TODO: Add NAVI support // clang-format off template bool: + GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName + ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"]) + # rocm custom page attention not support on gfx1* # custom paged attn always supported on V0. On V1, requires sliding window # disabled due to observed numerical discrepancy. - return (on_mi250_mi300() and (not envs.VLLM_USE_V1 or sliding_window == 0 - or sliding_window == (-1, -1)) + return (ON_GFX9 and (not envs.VLLM_USE_V1 or sliding_window == 0 + or sliding_window == (-1, -1)) and (qtype == torch.half or qtype == torch.bfloat16) and (head_size == 64 or head_size == 128) and (block_size == 16 or block_size == 32) @@ -337,7 +340,7 @@ class RocmPlatform(Platform): def use_custom_allreduce(cls) -> bool: # We only enable custom allreduce for MI300 series gcn_arch = torch.cuda.get_device_properties(0).gcnArchName - supported_archs = ['gfx94'] + supported_archs = ['gfx94', 'gfx95'] return any(gfx in gcn_arch for gfx in supported_archs) @classmethod