From 0e237f00357c968a4f7ae25accd533e924baceff Mon Sep 17 00:00:00 2001 From: vllmellm Date: Tue, 22 Apr 2025 17:46:28 +0800 Subject: [PATCH] [FEAT][ROCm] Integrate Paged Attention Kernel from AITER (#15001) Signed-off-by: vllmellm Signed-off-by: tjtanaa Co-authored-by: tjtanaa --- docker/Dockerfile.rocm_base | 2 +- vllm/attention/backends/rocm_flash_attn.py | 109 +++++++++++++++----- vllm/attention/ops/rocm_aiter_paged_attn.py | 101 ++++++++++++++++++ vllm/envs.py | 7 ++ vllm/platforms/rocm.py | 3 +- 5 files changed, 195 insertions(+), 27 deletions(-) create mode 100644 vllm/attention/ops/rocm_aiter_paged_attn.py diff --git a/docker/Dockerfile.rocm_base b/docker/Dockerfile.rocm_base index 05192eb69b54b..1776b26d445ce 100644 --- a/docker/Dockerfile.rocm_base +++ b/docker/Dockerfile.rocm_base @@ -12,7 +12,7 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git" ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git" ARG FA_BRANCH="1a7f4dfa" ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git" -ARG AITER_BRANCH="5a77249" +ARG AITER_BRANCH="7e1ed08" ARG AITER_REPO="https://github.com/ROCm/aiter.git" FROM ${BASE_IMAGE} AS base diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 90a21906b6e63..37b6cadcb98ab 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -2,6 +2,7 @@ """Attention layer ROCm GPUs.""" import itertools from dataclasses import dataclass +from functools import cache from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type import torch @@ -26,6 +27,32 @@ logger = init_logger(__name__) _PARTITION_SIZE_ROCM = 256 +@cache +def is_rocm_aiter_paged_attn_enabled() -> bool: + return envs.VLLM_ROCM_USE_AITER_PAGED_ATTN \ + and envs.VLLM_ROCM_USE_AITER \ + + +@cache +def _get_paged_attn_module() -> PagedAttention: + """ + Initializes the appropriate PagedAttention module from `attention/ops`, + which is used as helper function + by `ROCmFlashAttentionImpl` and `ROCmFlashAttentionBackend`. + + The choice of attention module depends on whether + AITER paged attention is enabled: + - If enabled, `ROCmFlashAttentionImpl` uses `AITERPagedAttention`. + - Otherwise, it defaults to using the original `PagedAttention`. + """ + if is_rocm_aiter_paged_attn_enabled(): + # Import AITERPagedAttention only when the flag is enabled + from vllm.attention.ops.rocm_aiter_paged_attn import ( + AITERPagedAttention) + return AITERPagedAttention() + return PagedAttention() + + class ROCmFlashAttentionBackend(AttentionBackend): accept_output_buffer: bool = True @@ -56,8 +83,9 @@ class ROCmFlashAttentionBackend(AttentionBackend): num_kv_heads: int, head_size: int, ) -> Tuple[int, ...]: - return PagedAttention.get_kv_cache_shape(num_blocks, block_size, - num_kv_heads, head_size) + paged_attn = _get_paged_attn_module() + return paged_attn.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) @staticmethod def swap_blocks( @@ -65,14 +93,16 @@ class ROCmFlashAttentionBackend(AttentionBackend): dst_kv_cache: torch.Tensor, src_to_dst: torch.Tensor, ) -> None: - PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + paged_attn = _get_paged_attn_module() + paged_attn.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], src_to_dists: torch.Tensor, ) -> None: - PagedAttention.copy_blocks(kv_caches, src_to_dists) + paged_attn = _get_paged_attn_module() + paged_attn.copy_blocks(kv_caches, src_to_dists) @dataclass @@ -496,7 +526,10 @@ class ROCmFlashAttentionImpl(AttentionImpl): assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads - supported_head_sizes = PagedAttention.get_supported_head_sizes() + self.paged_attn_module = _get_paged_attn_module() + supported_head_sizes = self.paged_attn_module.get_supported_head_sizes( + ) + if head_size not in supported_head_sizes: raise ValueError( f"Head size {head_size} is not supported by PagedAttention. " @@ -546,6 +579,8 @@ class ROCmFlashAttentionImpl(AttentionImpl): self.sdpa_attn_func = _sdpa_attention logger.debug("Using naive (SDPA) attention in ROCmBackend") + self.aiter_kv_scales_initialized = False + def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: """torch.repeat_interleave(x, dim=1, repeats=n_rep)""" tokens, n_kv_heads, head_dim = x.shape @@ -624,12 +659,37 @@ class ROCmFlashAttentionImpl(AttentionImpl): else: assert value is None + paged_attn = self.paged_attn_module + + # Reshaping kv tensors is required for AITER paged attention kernel + # because it works on a different tensor shape, + # when the size of one element is one byte (int8/fp8 dtypes). + # This reshaping is only required on the first forward call + # and the kv cache must not be empty. + if (is_rocm_aiter_paged_attn_enabled() and kv_cache.dtype.itemsize == 1 + and not self.aiter_kv_scales_initialized + and kv_cache.shape != torch.Size([0])): + num_blocks = kv_cache.shape[1] + block_size = kv_cache.shape[2] // (self.num_kv_heads * + self.head_size) + k_scale = torch.empty((self.num_kv_heads, num_blocks * block_size), + dtype=torch.float32, + device=kv_cache.device) + v_scale = torch.empty((self.num_kv_heads, num_blocks * block_size), + dtype=torch.float32, + device=kv_cache.device) + self.aiter_kv_scales_initialized = True + k_scale.fill_(layer._k_scale.item()) + v_scale.fill_(layer._v_scale.item()) + layer._k_scale = k_scale + layer._v_scale = v_scale + # Only update KV cache for decoder self-attention # and encoder-decoder cross-attention if self.attn_type not in [ AttentionType.ENCODER, AttentionType.ENCODER_ONLY ] and kv_cache.numel() > 0: - key_cache, value_cache = PagedAttention.split_kv_cache( + key_cache, value_cache = paged_attn.split_kv_cache( kv_cache, self.num_kv_heads, self.head_size) if key is not None and value is not None: @@ -637,7 +697,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): # cache. If kv_cache is not provided, the new key and value # tensors are not cached. This happens during the initial # memory profiling run. - PagedAttention.write_to_paged_cache( + paged_attn.write_to_paged_cache( key, value, key_cache, @@ -768,23 +828,22 @@ class ROCmFlashAttentionImpl(AttentionImpl): # prefix-enabled attention - # not applicable for encoder-only models if self.attn_type != AttentionType.ENCODER_ONLY: - output[: - num_prefill_tokens] = PagedAttention.forward_prefix( - query, - key, - value, - self.kv_cache_dtype, - key_cache, - value_cache, - prefill_meta.block_tables, - prefill_meta.query_start_loc, - prefill_meta.seq_lens_tensor, - prefill_meta.max_query_len, - self.alibi_slopes, - self.sliding_window[0], - layer._k_scale, - layer._v_scale, - ) + output[:num_prefill_tokens] = paged_attn.forward_prefix( + query, + key, + value, + self.kv_cache_dtype, + key_cache, + value_cache, + prefill_meta.block_tables, + prefill_meta.query_start_loc, + prefill_meta.seq_lens_tensor, + prefill_meta.max_query_len, + self.alibi_slopes, + self.sliding_window[0], + layer._k_scale, + layer._v_scale, + ) # Skip decode phase for encoder-only models if (decode_meta := attn_metadata.decode_metadata) and ( self.attn_type != AttentionType.ENCODER_ONLY): @@ -843,7 +902,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): layer._v_scale, ) else: - output[num_prefill_tokens:] = PagedAttention.forward_decode( + output[num_prefill_tokens:] = paged_attn.forward_decode( decode_query, key_cache, value_cache, diff --git a/vllm/attention/ops/rocm_aiter_paged_attn.py b/vllm/attention/ops/rocm_aiter_paged_attn.py new file mode 100644 index 0000000000000..0f3cf1842c805 --- /dev/null +++ b/vllm/attention/ops/rocm_aiter_paged_attn.py @@ -0,0 +1,101 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional + +import aiter as rocm_aiter +import torch + +from vllm.attention.ops.paged_attn import PagedAttention +from vllm.platforms import current_platform +from vllm.utils import cdiv + +FP8_DTYPE = current_platform.fp8_dtype() + + +class AITERPagedAttention(PagedAttention): + + @staticmethod + def write_to_paged_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + ) -> None: + if kv_cache_dtype not in ["int8", "fp8", "fp8_e4m3"]: + PagedAttention.write_to_paged_cache(key, value, key_cache, + value_cache, slot_mapping, + kv_cache_dtype, k_scale, + v_scale) + else: + kv_cache_torch_dtype = (FP8_DTYPE + if "fp8" in kv_cache_dtype else torch.int8) + key_cache = key_cache.view(kv_cache_torch_dtype) + value_cache = value_cache.view(kv_cache_torch_dtype) + + rocm_aiter.reshape_and_cache_with_pertoken_quant( + key, value, key_cache, value_cache, k_scale, v_scale, + slot_mapping.flatten(), True) + + @staticmethod + def forward_decode( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + max_seq_len: int, + kv_cache_dtype: str, + num_kv_heads: int, + scale: float, + alibi_slopes: Optional[torch.Tensor], + k_scale: torch.Tensor, + v_scale: torch.Tensor, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, + ) -> torch.Tensor: + if kv_cache_dtype not in ["int8", "fp8", "fp8_e4m3"]: + return PagedAttention.forward_decode( + query=query, + key_cache=key_cache, + value_cache=value_cache, + block_tables=block_tables, + seq_lens=seq_lens, + max_seq_len=max_seq_len, + kv_cache_dtype=kv_cache_dtype, + num_kv_heads=num_kv_heads, + scale=scale, + alibi_slopes=alibi_slopes, + k_scale=k_scale, + v_scale=v_scale, + tp_rank=tp_rank, + blocksparse_local_blocks=blocksparse_local_blocks, + blocksparse_vert_stride=blocksparse_vert_stride, + blocksparse_block_size=blocksparse_block_size, + blocksparse_head_sliding_step=blocksparse_head_sliding_step) + + if "fp8" in kv_cache_dtype: + key_cache = key_cache.view(torch.float8_e4m3fnuz) + value_cache = value_cache.view(torch.float8_e4m3fnuz) + + if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1: + # use blocksparse paged attention + block_size = value_cache.size(-1) + assert (blocksparse_block_size > 0 and + blocksparse_block_size % block_size == 0), \ + (f"{blocksparse_block_size=} needs to be a multiple of" + f"{block_size=} used in block_tables.") + + output = torch.empty_like(query) + block_size = value_cache.shape[3] + max_num_blocks_per_seq = cdiv(max_seq_len, block_size) + + rocm_aiter.pa_fwd_asm(query, key_cache, value_cache, block_tables, + seq_lens, max_num_blocks_per_seq, k_scale, + v_scale, output) + return output diff --git a/vllm/envs.py b/vllm/envs.py index d2e21a8dcfc4b..5922e90a91b1c 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -75,6 +75,7 @@ if TYPE_CHECKING: VLLM_DISABLED_KERNELS: list[str] = [] VLLM_USE_V1: bool = True VLLM_ROCM_USE_AITER: bool = False + VLLM_ROCM_USE_AITER_PAGED_ATTN: bool = False VLLM_ROCM_USE_AITER_LINEAR: bool = True VLLM_ROCM_USE_AITER_MOE: bool = True VLLM_ROCM_USE_AITER_RMSNORM: bool = True @@ -533,6 +534,12 @@ environment_variables: dict[str, Callable[[], Any]] = { lambda: (os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in ("true", "1")), + # Whether to use aiter paged attention. + # By default is disabled. + "VLLM_ROCM_USE_AITER_PAGED_ATTN": + lambda: (os.getenv("VLLM_ROCM_USE_AITER_PAGED_ATTN", "False").lower() in + ("true", "1")), + # use aiter linear op if aiter ops are enabled # The following list of related ops # - scaled_mm (per-tensor / rowwise) diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index f44507eb54ccb..091ec2a1f945a 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -118,7 +118,8 @@ def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int, and (head_size == 64 or head_size == 128) and (block_size == 16 or block_size == 32) and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768 - and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) + and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN + and envs.VLLM_ROCM_USE_AITER)) class RocmPlatform(Platform):