mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-18 21:47:34 +08:00
[ROCm][Kernel][V1] Enable AMD Radeon GPU Custom Paged Attention on v1 (#17004)
Signed-off-by: Hosang Yoon <hosang.yoon@amd.com>
This commit is contained in:
parent
2b16104557
commit
dd5fa7e04f
@ -84,7 +84,10 @@ def main(
|
||||
if version == "v2":
|
||||
if current_platform.is_rocm():
|
||||
global PARTITION_SIZE
|
||||
PARTITION_SIZE = 1024 if not args.custom_paged_attn else PARTITION_SIZE_ROCM
|
||||
if not args.custom_paged_attn and not current_platform.is_navi():
|
||||
PARTITION_SIZE = 1024
|
||||
else:
|
||||
PARTITION_SIZE = PARTITION_SIZE_ROCM
|
||||
num_partitions = (max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE
|
||||
tmp_output = torch.empty(
|
||||
size=(num_seqs, num_query_heads, num_partitions, head_size),
|
||||
@ -159,6 +162,7 @@ def main(
|
||||
scale,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
None,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -148,6 +148,11 @@ def test_paged_attention(
|
||||
or (version == "rocm" and head_size not in (64, 128))):
|
||||
pytest.skip()
|
||||
|
||||
if (version == "rocm" and current_platform.is_navi()
|
||||
and (kv_cache_dtype == "fp8" or head_size != 128
|
||||
or block_size != 16 or use_alibi)):
|
||||
pytest.skip()
|
||||
|
||||
global PARTITION_SIZE
|
||||
|
||||
current_platform.seed_everything(seed)
|
||||
@ -275,6 +280,7 @@ def test_paged_attention(
|
||||
scale,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
None,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
@ -286,7 +292,7 @@ def test_paged_attention(
|
||||
opcheck(torch.ops._rocm_C.paged_attention,
|
||||
(output, exp_sums, max_logits, tmp_output, query,
|
||||
key_cache, value_cache, num_kv_heads, scale, block_tables,
|
||||
seq_lens, block_size, max_seq_len, alibi_slopes,
|
||||
seq_lens, None, block_size, max_seq_len, alibi_slopes,
|
||||
kv_cache_dtype, k_scale, v_scale),
|
||||
cond=(head_size == HEAD_SIZES[0]
|
||||
and block_size == BLOCK_SIZES[0]))
|
||||
|
||||
@ -861,7 +861,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
gqa_ratio = num_heads // self.num_kv_heads
|
||||
use_custom = use_rocm_custom_paged_attention(
|
||||
decode_query.dtype, head_size, block_size, gqa_ratio,
|
||||
decode_meta.max_decode_seq_len, self.sliding_window)
|
||||
decode_meta.max_decode_seq_len, self.sliding_window,
|
||||
self.kv_cache_dtype, self.alibi_slopes)
|
||||
if use_custom:
|
||||
max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type
|
||||
!= AttentionType.ENCODER_DECODER else
|
||||
|
||||
@ -283,7 +283,8 @@ def chunked_prefill_paged_decode(
|
||||
use_custom = use_rocm_custom_paged_attention(query.dtype, head_size,
|
||||
block_size,
|
||||
num_queries_per_kv,
|
||||
max_seq_len, sliding_window)
|
||||
max_seq_len, sliding_window,
|
||||
kv_cache_dtype, alibi_slopes)
|
||||
if use_custom:
|
||||
_PARTITION_SIZE_ROCM = 256
|
||||
max_num_partitions = ((max_seq_len + _PARTITION_SIZE_ROCM - 1) //
|
||||
|
||||
@ -102,26 +102,42 @@ def on_mi250_mi300() -> bool:
|
||||
|
||||
|
||||
@cache
|
||||
def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
|
||||
block_size: int, gqa_ratio: int,
|
||||
max_seq_len: int,
|
||||
sliding_window: int) -> bool:
|
||||
def use_rocm_custom_paged_attention(
|
||||
qtype: torch.dtype,
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
gqa_ratio: int,
|
||||
max_seq_len: int,
|
||||
sliding_window: int,
|
||||
kv_cache_dtype: str,
|
||||
alibi_slopes: Optional[torch.Tensor] = None) -> bool:
|
||||
|
||||
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
|
||||
ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
|
||||
ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
|
||||
|
||||
# rocm custom page attention not support on gfx1*
|
||||
# custom paged attn always supported on V0. On V1, requires sliding window
|
||||
# disabled due to observed numerical discrepancy.
|
||||
return (ON_GFX9 and (not envs.VLLM_USE_V1 or sliding_window == 0
|
||||
or sliding_window == (-1, -1))
|
||||
and (qtype == torch.half or qtype == torch.bfloat16)
|
||||
and (head_size == 64 or head_size == 128)
|
||||
and (block_size == 16 or block_size == 32)
|
||||
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768
|
||||
and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
|
||||
and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
|
||||
and envs.VLLM_ROCM_USE_AITER))
|
||||
if ON_GFX9:
|
||||
return ((not envs.VLLM_USE_V1 or sliding_window == 0
|
||||
or sliding_window == (-1, -1))
|
||||
and (qtype == torch.half or qtype == torch.bfloat16)
|
||||
and (head_size == 64 or head_size == 128)
|
||||
and (block_size == 16 or block_size == 32)
|
||||
and (gqa_ratio >= 1 and gqa_ratio <= 16)
|
||||
and max_seq_len <= 32768 and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
|
||||
and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
|
||||
and envs.VLLM_ROCM_USE_AITER))
|
||||
|
||||
else:
|
||||
return (ON_GFX11_GFX12 and (not envs.VLLM_USE_V1 or sliding_window == 0
|
||||
or sliding_window == (-1, -1))
|
||||
and (qtype == torch.half or qtype == torch.bfloat16)
|
||||
and head_size == 128 and block_size == 16
|
||||
and (gqa_ratio >= 3 and gqa_ratio <= 16)
|
||||
and max_seq_len <= 32768 and alibi_slopes is None
|
||||
and kv_cache_dtype == "auto"
|
||||
and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
|
||||
|
||||
|
||||
class RocmPlatform(Platform):
|
||||
@ -362,3 +378,7 @@ class RocmPlatform(Platform):
|
||||
def get_cu_count(cls, device_id: int = 0) -> int:
|
||||
return torch.cuda.get_device_properties(
|
||||
device_id).multi_processor_count
|
||||
|
||||
@classmethod
|
||||
def is_navi(cls) -> bool:
|
||||
return 'gfx1' in torch.cuda.get_device_properties(0).gcnArchName
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user