[ROCm][Kernel][V1] Enable AMD Radeon GPU Custom Paged Attention on v1 (#17004)

Signed-off-by: Hosang Yoon <hosang.yoon@amd.com>
This commit is contained in:
Hosang 2025-05-21 11:35:00 -04:00 committed by GitHub
parent 2b16104557
commit dd5fa7e04f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 1769 additions and 28 deletions

View File

@ -84,7 +84,10 @@ def main(
if version == "v2":
if current_platform.is_rocm():
global PARTITION_SIZE
PARTITION_SIZE = 1024 if not args.custom_paged_attn else PARTITION_SIZE_ROCM
if not args.custom_paged_attn and not current_platform.is_navi():
PARTITION_SIZE = 1024
else:
PARTITION_SIZE = PARTITION_SIZE_ROCM
num_partitions = (max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE
tmp_output = torch.empty(
size=(num_seqs, num_query_heads, num_partitions, head_size),
@ -159,6 +162,7 @@ def main(
scale,
block_tables,
seq_lens,
None,
block_size,
max_seq_len,
alibi_slopes,

File diff suppressed because it is too large Load Diff

View File

@ -148,6 +148,11 @@ def test_paged_attention(
or (version == "rocm" and head_size not in (64, 128))):
pytest.skip()
if (version == "rocm" and current_platform.is_navi()
and (kv_cache_dtype == "fp8" or head_size != 128
or block_size != 16 or use_alibi)):
pytest.skip()
global PARTITION_SIZE
current_platform.seed_everything(seed)
@ -275,6 +280,7 @@ def test_paged_attention(
scale,
block_tables,
seq_lens,
None,
block_size,
max_seq_len,
alibi_slopes,
@ -286,7 +292,7 @@ def test_paged_attention(
opcheck(torch.ops._rocm_C.paged_attention,
(output, exp_sums, max_logits, tmp_output, query,
key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, block_size, max_seq_len, alibi_slopes,
seq_lens, None, block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale),
cond=(head_size == HEAD_SIZES[0]
and block_size == BLOCK_SIZES[0]))

View File

@ -861,7 +861,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
gqa_ratio = num_heads // self.num_kv_heads
use_custom = use_rocm_custom_paged_attention(
decode_query.dtype, head_size, block_size, gqa_ratio,
decode_meta.max_decode_seq_len, self.sliding_window)
decode_meta.max_decode_seq_len, self.sliding_window,
self.kv_cache_dtype, self.alibi_slopes)
if use_custom:
max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type
!= AttentionType.ENCODER_DECODER else

View File

@ -283,7 +283,8 @@ def chunked_prefill_paged_decode(
use_custom = use_rocm_custom_paged_attention(query.dtype, head_size,
block_size,
num_queries_per_kv,
max_seq_len, sliding_window)
max_seq_len, sliding_window,
kv_cache_dtype, alibi_slopes)
if use_custom:
_PARTITION_SIZE_ROCM = 256
max_num_partitions = ((max_seq_len + _PARTITION_SIZE_ROCM - 1) //

View File

@ -102,26 +102,42 @@ def on_mi250_mi300() -> bool:
@cache
def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
block_size: int, gqa_ratio: int,
max_seq_len: int,
sliding_window: int) -> bool:
def use_rocm_custom_paged_attention(
qtype: torch.dtype,
head_size: int,
block_size: int,
gqa_ratio: int,
max_seq_len: int,
sliding_window: int,
kv_cache_dtype: str,
alibi_slopes: Optional[torch.Tensor] = None) -> bool:
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
# rocm custom page attention not support on gfx1*
# custom paged attn always supported on V0. On V1, requires sliding window
# disabled due to observed numerical discrepancy.
return (ON_GFX9 and (not envs.VLLM_USE_V1 or sliding_window == 0
or sliding_window == (-1, -1))
and (qtype == torch.half or qtype == torch.bfloat16)
and (head_size == 64 or head_size == 128)
and (block_size == 16 or block_size == 32)
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768
and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
and envs.VLLM_ROCM_USE_AITER))
if ON_GFX9:
return ((not envs.VLLM_USE_V1 or sliding_window == 0
or sliding_window == (-1, -1))
and (qtype == torch.half or qtype == torch.bfloat16)
and (head_size == 64 or head_size == 128)
and (block_size == 16 or block_size == 32)
and (gqa_ratio >= 1 and gqa_ratio <= 16)
and max_seq_len <= 32768 and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
and envs.VLLM_ROCM_USE_AITER))
else:
return (ON_GFX11_GFX12 and (not envs.VLLM_USE_V1 or sliding_window == 0
or sliding_window == (-1, -1))
and (qtype == torch.half or qtype == torch.bfloat16)
and head_size == 128 and block_size == 16
and (gqa_ratio >= 3 and gqa_ratio <= 16)
and max_seq_len <= 32768 and alibi_slopes is None
and kv_cache_dtype == "auto"
and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
class RocmPlatform(Platform):
@ -362,3 +378,7 @@ class RocmPlatform(Platform):
def get_cu_count(cls, device_id: int = 0) -> int:
return torch.cuda.get_device_properties(
device_id).multi_processor_count
@classmethod
def is_navi(cls) -> bool:
return 'gfx1' in torch.cuda.get_device_properties(0).gcnArchName