mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-15 00:03:33 +08:00
[ROCm][Kernel][V1] Enable AMD Radeon GPU Custom Paged Attention on v1 (#17004)
Signed-off-by: Hosang Yoon <hosang.yoon@amd.com>
This commit is contained in:
parent
2b16104557
commit
dd5fa7e04f
@ -84,7 +84,10 @@ def main(
|
|||||||
if version == "v2":
|
if version == "v2":
|
||||||
if current_platform.is_rocm():
|
if current_platform.is_rocm():
|
||||||
global PARTITION_SIZE
|
global PARTITION_SIZE
|
||||||
PARTITION_SIZE = 1024 if not args.custom_paged_attn else PARTITION_SIZE_ROCM
|
if not args.custom_paged_attn and not current_platform.is_navi():
|
||||||
|
PARTITION_SIZE = 1024
|
||||||
|
else:
|
||||||
|
PARTITION_SIZE = PARTITION_SIZE_ROCM
|
||||||
num_partitions = (max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE
|
num_partitions = (max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE
|
||||||
tmp_output = torch.empty(
|
tmp_output = torch.empty(
|
||||||
size=(num_seqs, num_query_heads, num_partitions, head_size),
|
size=(num_seqs, num_query_heads, num_partitions, head_size),
|
||||||
@ -159,6 +162,7 @@ def main(
|
|||||||
scale,
|
scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
seq_lens,
|
seq_lens,
|
||||||
|
None,
|
||||||
block_size,
|
block_size,
|
||||||
max_seq_len,
|
max_seq_len,
|
||||||
alibi_slopes,
|
alibi_slopes,
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@ -148,6 +148,11 @@ def test_paged_attention(
|
|||||||
or (version == "rocm" and head_size not in (64, 128))):
|
or (version == "rocm" and head_size not in (64, 128))):
|
||||||
pytest.skip()
|
pytest.skip()
|
||||||
|
|
||||||
|
if (version == "rocm" and current_platform.is_navi()
|
||||||
|
and (kv_cache_dtype == "fp8" or head_size != 128
|
||||||
|
or block_size != 16 or use_alibi)):
|
||||||
|
pytest.skip()
|
||||||
|
|
||||||
global PARTITION_SIZE
|
global PARTITION_SIZE
|
||||||
|
|
||||||
current_platform.seed_everything(seed)
|
current_platform.seed_everything(seed)
|
||||||
@ -275,6 +280,7 @@ def test_paged_attention(
|
|||||||
scale,
|
scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
seq_lens,
|
seq_lens,
|
||||||
|
None,
|
||||||
block_size,
|
block_size,
|
||||||
max_seq_len,
|
max_seq_len,
|
||||||
alibi_slopes,
|
alibi_slopes,
|
||||||
@ -286,7 +292,7 @@ def test_paged_attention(
|
|||||||
opcheck(torch.ops._rocm_C.paged_attention,
|
opcheck(torch.ops._rocm_C.paged_attention,
|
||||||
(output, exp_sums, max_logits, tmp_output, query,
|
(output, exp_sums, max_logits, tmp_output, query,
|
||||||
key_cache, value_cache, num_kv_heads, scale, block_tables,
|
key_cache, value_cache, num_kv_heads, scale, block_tables,
|
||||||
seq_lens, block_size, max_seq_len, alibi_slopes,
|
seq_lens, None, block_size, max_seq_len, alibi_slopes,
|
||||||
kv_cache_dtype, k_scale, v_scale),
|
kv_cache_dtype, k_scale, v_scale),
|
||||||
cond=(head_size == HEAD_SIZES[0]
|
cond=(head_size == HEAD_SIZES[0]
|
||||||
and block_size == BLOCK_SIZES[0]))
|
and block_size == BLOCK_SIZES[0]))
|
||||||
|
|||||||
@ -861,7 +861,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
gqa_ratio = num_heads // self.num_kv_heads
|
gqa_ratio = num_heads // self.num_kv_heads
|
||||||
use_custom = use_rocm_custom_paged_attention(
|
use_custom = use_rocm_custom_paged_attention(
|
||||||
decode_query.dtype, head_size, block_size, gqa_ratio,
|
decode_query.dtype, head_size, block_size, gqa_ratio,
|
||||||
decode_meta.max_decode_seq_len, self.sliding_window)
|
decode_meta.max_decode_seq_len, self.sliding_window,
|
||||||
|
self.kv_cache_dtype, self.alibi_slopes)
|
||||||
if use_custom:
|
if use_custom:
|
||||||
max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type
|
max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type
|
||||||
!= AttentionType.ENCODER_DECODER else
|
!= AttentionType.ENCODER_DECODER else
|
||||||
|
|||||||
@ -283,7 +283,8 @@ def chunked_prefill_paged_decode(
|
|||||||
use_custom = use_rocm_custom_paged_attention(query.dtype, head_size,
|
use_custom = use_rocm_custom_paged_attention(query.dtype, head_size,
|
||||||
block_size,
|
block_size,
|
||||||
num_queries_per_kv,
|
num_queries_per_kv,
|
||||||
max_seq_len, sliding_window)
|
max_seq_len, sliding_window,
|
||||||
|
kv_cache_dtype, alibi_slopes)
|
||||||
if use_custom:
|
if use_custom:
|
||||||
_PARTITION_SIZE_ROCM = 256
|
_PARTITION_SIZE_ROCM = 256
|
||||||
max_num_partitions = ((max_seq_len + _PARTITION_SIZE_ROCM - 1) //
|
max_num_partitions = ((max_seq_len + _PARTITION_SIZE_ROCM - 1) //
|
||||||
|
|||||||
@ -102,26 +102,42 @@ def on_mi250_mi300() -> bool:
|
|||||||
|
|
||||||
|
|
||||||
@cache
|
@cache
|
||||||
def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
|
def use_rocm_custom_paged_attention(
|
||||||
block_size: int, gqa_ratio: int,
|
qtype: torch.dtype,
|
||||||
max_seq_len: int,
|
head_size: int,
|
||||||
sliding_window: int) -> bool:
|
block_size: int,
|
||||||
|
gqa_ratio: int,
|
||||||
|
max_seq_len: int,
|
||||||
|
sliding_window: int,
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
alibi_slopes: Optional[torch.Tensor] = None) -> bool:
|
||||||
|
|
||||||
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
|
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
|
||||||
ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
|
ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
|
||||||
|
ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
|
||||||
|
|
||||||
# rocm custom page attention not support on gfx1*
|
|
||||||
# custom paged attn always supported on V0. On V1, requires sliding window
|
# custom paged attn always supported on V0. On V1, requires sliding window
|
||||||
# disabled due to observed numerical discrepancy.
|
# disabled due to observed numerical discrepancy.
|
||||||
return (ON_GFX9 and (not envs.VLLM_USE_V1 or sliding_window == 0
|
if ON_GFX9:
|
||||||
or sliding_window == (-1, -1))
|
return ((not envs.VLLM_USE_V1 or sliding_window == 0
|
||||||
and (qtype == torch.half or qtype == torch.bfloat16)
|
or sliding_window == (-1, -1))
|
||||||
and (head_size == 64 or head_size == 128)
|
and (qtype == torch.half or qtype == torch.bfloat16)
|
||||||
and (block_size == 16 or block_size == 32)
|
and (head_size == 64 or head_size == 128)
|
||||||
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768
|
and (block_size == 16 or block_size == 32)
|
||||||
and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
|
and (gqa_ratio >= 1 and gqa_ratio <= 16)
|
||||||
and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
|
and max_seq_len <= 32768 and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
|
||||||
and envs.VLLM_ROCM_USE_AITER))
|
and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
|
||||||
|
and envs.VLLM_ROCM_USE_AITER))
|
||||||
|
|
||||||
|
else:
|
||||||
|
return (ON_GFX11_GFX12 and (not envs.VLLM_USE_V1 or sliding_window == 0
|
||||||
|
or sliding_window == (-1, -1))
|
||||||
|
and (qtype == torch.half or qtype == torch.bfloat16)
|
||||||
|
and head_size == 128 and block_size == 16
|
||||||
|
and (gqa_ratio >= 3 and gqa_ratio <= 16)
|
||||||
|
and max_seq_len <= 32768 and alibi_slopes is None
|
||||||
|
and kv_cache_dtype == "auto"
|
||||||
|
and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
|
||||||
|
|
||||||
|
|
||||||
class RocmPlatform(Platform):
|
class RocmPlatform(Platform):
|
||||||
@ -362,3 +378,7 @@ class RocmPlatform(Platform):
|
|||||||
def get_cu_count(cls, device_id: int = 0) -> int:
|
def get_cu_count(cls, device_id: int = 0) -> int:
|
||||||
return torch.cuda.get_device_properties(
|
return torch.cuda.get_device_properties(
|
||||||
device_id).multi_processor_count
|
device_id).multi_processor_count
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_navi(cls) -> bool:
|
||||||
|
return 'gfx1' in torch.cuda.get_device_properties(0).gcnArchName
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user