[FEAT][ROCm] Integrate Paged Attention Kernel from AITER (#15001)

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
vllmellm 2025-04-22 17:46:28 +08:00 committed by GitHub
parent 8f7bace7c3
commit 0e237f0035
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 195 additions and 27 deletions

View File

@ -12,7 +12,7 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
ARG FA_BRANCH="1a7f4dfa"
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
ARG AITER_BRANCH="5a77249"
ARG AITER_BRANCH="7e1ed08"
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
FROM ${BASE_IMAGE} AS base

View File

@ -2,6 +2,7 @@
"""Attention layer ROCm GPUs."""
import itertools
from dataclasses import dataclass
from functools import cache
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
import torch
@ -26,6 +27,32 @@ logger = init_logger(__name__)
_PARTITION_SIZE_ROCM = 256
@cache
def is_rocm_aiter_paged_attn_enabled() -> bool:
return envs.VLLM_ROCM_USE_AITER_PAGED_ATTN \
and envs.VLLM_ROCM_USE_AITER \
@cache
def _get_paged_attn_module() -> PagedAttention:
"""
Initializes the appropriate PagedAttention module from `attention/ops`,
which is used as helper function
by `ROCmFlashAttentionImpl` and `ROCmFlashAttentionBackend`.
The choice of attention module depends on whether
AITER paged attention is enabled:
- If enabled, `ROCmFlashAttentionImpl` uses `AITERPagedAttention`.
- Otherwise, it defaults to using the original `PagedAttention`.
"""
if is_rocm_aiter_paged_attn_enabled():
# Import AITERPagedAttention only when the flag is enabled
from vllm.attention.ops.rocm_aiter_paged_attn import (
AITERPagedAttention)
return AITERPagedAttention()
return PagedAttention()
class ROCmFlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@ -56,8 +83,9 @@ class ROCmFlashAttentionBackend(AttentionBackend):
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
num_kv_heads, head_size)
paged_attn = _get_paged_attn_module()
return paged_attn.get_kv_cache_shape(num_blocks, block_size,
num_kv_heads, head_size)
@staticmethod
def swap_blocks(
@ -65,14 +93,16 @@ class ROCmFlashAttentionBackend(AttentionBackend):
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
paged_attn = _get_paged_attn_module()
paged_attn.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists)
paged_attn = _get_paged_attn_module()
paged_attn.copy_blocks(kv_caches, src_to_dists)
@dataclass
@ -496,7 +526,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
supported_head_sizes = PagedAttention.get_supported_head_sizes()
self.paged_attn_module = _get_paged_attn_module()
supported_head_sizes = self.paged_attn_module.get_supported_head_sizes(
)
if head_size not in supported_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. "
@ -546,6 +579,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self.sdpa_attn_func = _sdpa_attention
logger.debug("Using naive (SDPA) attention in ROCmBackend")
self.aiter_kv_scales_initialized = False
def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
tokens, n_kv_heads, head_dim = x.shape
@ -624,12 +659,37 @@ class ROCmFlashAttentionImpl(AttentionImpl):
else:
assert value is None
paged_attn = self.paged_attn_module
# Reshaping kv tensors is required for AITER paged attention kernel
# because it works on a different tensor shape,
# when the size of one element is one byte (int8/fp8 dtypes).
# This reshaping is only required on the first forward call
# and the kv cache must not be empty.
if (is_rocm_aiter_paged_attn_enabled() and kv_cache.dtype.itemsize == 1
and not self.aiter_kv_scales_initialized
and kv_cache.shape != torch.Size([0])):
num_blocks = kv_cache.shape[1]
block_size = kv_cache.shape[2] // (self.num_kv_heads *
self.head_size)
k_scale = torch.empty((self.num_kv_heads, num_blocks * block_size),
dtype=torch.float32,
device=kv_cache.device)
v_scale = torch.empty((self.num_kv_heads, num_blocks * block_size),
dtype=torch.float32,
device=kv_cache.device)
self.aiter_kv_scales_initialized = True
k_scale.fill_(layer._k_scale.item())
v_scale.fill_(layer._v_scale.item())
layer._k_scale = k_scale
layer._v_scale = v_scale
# Only update KV cache for decoder self-attention
# and encoder-decoder cross-attention
if self.attn_type not in [
AttentionType.ENCODER, AttentionType.ENCODER_ONLY
] and kv_cache.numel() > 0:
key_cache, value_cache = PagedAttention.split_kv_cache(
key_cache, value_cache = paged_attn.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)
if key is not None and value is not None:
@ -637,7 +697,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
# cache. If kv_cache is not provided, the new key and value
# tensors are not cached. This happens during the initial
# memory profiling run.
PagedAttention.write_to_paged_cache(
paged_attn.write_to_paged_cache(
key,
value,
key_cache,
@ -768,23 +828,22 @@ class ROCmFlashAttentionImpl(AttentionImpl):
# prefix-enabled attention -
# not applicable for encoder-only models
if self.attn_type != AttentionType.ENCODER_ONLY:
output[:
num_prefill_tokens] = PagedAttention.forward_prefix(
query,
key,
value,
self.kv_cache_dtype,
key_cache,
value_cache,
prefill_meta.block_tables,
prefill_meta.query_start_loc,
prefill_meta.seq_lens_tensor,
prefill_meta.max_query_len,
self.alibi_slopes,
self.sliding_window[0],
layer._k_scale,
layer._v_scale,
)
output[:num_prefill_tokens] = paged_attn.forward_prefix(
query,
key,
value,
self.kv_cache_dtype,
key_cache,
value_cache,
prefill_meta.block_tables,
prefill_meta.query_start_loc,
prefill_meta.seq_lens_tensor,
prefill_meta.max_query_len,
self.alibi_slopes,
self.sliding_window[0],
layer._k_scale,
layer._v_scale,
)
# Skip decode phase for encoder-only models
if (decode_meta := attn_metadata.decode_metadata) and (
self.attn_type != AttentionType.ENCODER_ONLY):
@ -843,7 +902,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
layer._v_scale,
)
else:
output[num_prefill_tokens:] = PagedAttention.forward_decode(
output[num_prefill_tokens:] = paged_attn.forward_decode(
decode_query,
key_cache,
value_cache,

View File

@ -0,0 +1,101 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import aiter as rocm_aiter
import torch
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.platforms import current_platform
from vllm.utils import cdiv
FP8_DTYPE = current_platform.fp8_dtype()
class AITERPagedAttention(PagedAttention):
@staticmethod
def write_to_paged_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
) -> None:
if kv_cache_dtype not in ["int8", "fp8", "fp8_e4m3"]:
PagedAttention.write_to_paged_cache(key, value, key_cache,
value_cache, slot_mapping,
kv_cache_dtype, k_scale,
v_scale)
else:
kv_cache_torch_dtype = (FP8_DTYPE
if "fp8" in kv_cache_dtype else torch.int8)
key_cache = key_cache.view(kv_cache_torch_dtype)
value_cache = value_cache.view(kv_cache_torch_dtype)
rocm_aiter.reshape_and_cache_with_pertoken_quant(
key, value, key_cache, value_cache, k_scale, v_scale,
slot_mapping.flatten(), True)
@staticmethod
def forward_decode(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
max_seq_len: int,
kv_cache_dtype: str,
num_kv_heads: int,
scale: float,
alibi_slopes: Optional[torch.Tensor],
k_scale: torch.Tensor,
v_scale: torch.Tensor,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
) -> torch.Tensor:
if kv_cache_dtype not in ["int8", "fp8", "fp8_e4m3"]:
return PagedAttention.forward_decode(
query=query,
key_cache=key_cache,
value_cache=value_cache,
block_tables=block_tables,
seq_lens=seq_lens,
max_seq_len=max_seq_len,
kv_cache_dtype=kv_cache_dtype,
num_kv_heads=num_kv_heads,
scale=scale,
alibi_slopes=alibi_slopes,
k_scale=k_scale,
v_scale=v_scale,
tp_rank=tp_rank,
blocksparse_local_blocks=blocksparse_local_blocks,
blocksparse_vert_stride=blocksparse_vert_stride,
blocksparse_block_size=blocksparse_block_size,
blocksparse_head_sliding_step=blocksparse_head_sliding_step)
if "fp8" in kv_cache_dtype:
key_cache = key_cache.view(torch.float8_e4m3fnuz)
value_cache = value_cache.view(torch.float8_e4m3fnuz)
if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
# use blocksparse paged attention
block_size = value_cache.size(-1)
assert (blocksparse_block_size > 0 and
blocksparse_block_size % block_size == 0), \
(f"{blocksparse_block_size=} needs to be a multiple of"
f"{block_size=} used in block_tables.")
output = torch.empty_like(query)
block_size = value_cache.shape[3]
max_num_blocks_per_seq = cdiv(max_seq_len, block_size)
rocm_aiter.pa_fwd_asm(query, key_cache, value_cache, block_tables,
seq_lens, max_num_blocks_per_seq, k_scale,
v_scale, output)
return output

View File

@ -75,6 +75,7 @@ if TYPE_CHECKING:
VLLM_DISABLED_KERNELS: list[str] = []
VLLM_USE_V1: bool = True
VLLM_ROCM_USE_AITER: bool = False
VLLM_ROCM_USE_AITER_PAGED_ATTN: bool = False
VLLM_ROCM_USE_AITER_LINEAR: bool = True
VLLM_ROCM_USE_AITER_MOE: bool = True
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
@ -533,6 +534,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: (os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in
("true", "1")),
# Whether to use aiter paged attention.
# By default is disabled.
"VLLM_ROCM_USE_AITER_PAGED_ATTN":
lambda: (os.getenv("VLLM_ROCM_USE_AITER_PAGED_ATTN", "False").lower() in
("true", "1")),
# use aiter linear op if aiter ops are enabled
# The following list of related ops
# - scaled_mm (per-tensor / rowwise)

View File

@ -118,7 +118,8 @@ def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
and (head_size == 64 or head_size == 128)
and (block_size == 16 or block_size == 32)
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768
and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
and envs.VLLM_ROCM_USE_AITER))
class RocmPlatform(Platform):