mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-25 06:34:28 +08:00
[FEAT][ROCm] Integrate Paged Attention Kernel from AITER (#15001)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
parent
8f7bace7c3
commit
0e237f0035
@ -12,7 +12,7 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
|
||||
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
|
||||
ARG FA_BRANCH="1a7f4dfa"
|
||||
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
|
||||
ARG AITER_BRANCH="5a77249"
|
||||
ARG AITER_BRANCH="7e1ed08"
|
||||
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
|
||||
|
||||
FROM ${BASE_IMAGE} AS base
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
"""Attention layer ROCm GPUs."""
|
||||
import itertools
|
||||
from dataclasses import dataclass
|
||||
from functools import cache
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
@ -26,6 +27,32 @@ logger = init_logger(__name__)
|
||||
_PARTITION_SIZE_ROCM = 256
|
||||
|
||||
|
||||
@cache
|
||||
def is_rocm_aiter_paged_attn_enabled() -> bool:
|
||||
return envs.VLLM_ROCM_USE_AITER_PAGED_ATTN \
|
||||
and envs.VLLM_ROCM_USE_AITER \
|
||||
|
||||
|
||||
@cache
|
||||
def _get_paged_attn_module() -> PagedAttention:
|
||||
"""
|
||||
Initializes the appropriate PagedAttention module from `attention/ops`,
|
||||
which is used as helper function
|
||||
by `ROCmFlashAttentionImpl` and `ROCmFlashAttentionBackend`.
|
||||
|
||||
The choice of attention module depends on whether
|
||||
AITER paged attention is enabled:
|
||||
- If enabled, `ROCmFlashAttentionImpl` uses `AITERPagedAttention`.
|
||||
- Otherwise, it defaults to using the original `PagedAttention`.
|
||||
"""
|
||||
if is_rocm_aiter_paged_attn_enabled():
|
||||
# Import AITERPagedAttention only when the flag is enabled
|
||||
from vllm.attention.ops.rocm_aiter_paged_attn import (
|
||||
AITERPagedAttention)
|
||||
return AITERPagedAttention()
|
||||
return PagedAttention()
|
||||
|
||||
|
||||
class ROCmFlashAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@ -56,8 +83,9 @@ class ROCmFlashAttentionBackend(AttentionBackend):
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
|
||||
num_kv_heads, head_size)
|
||||
paged_attn = _get_paged_attn_module()
|
||||
return paged_attn.get_kv_cache_shape(num_blocks, block_size,
|
||||
num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
@ -65,14 +93,16 @@ class ROCmFlashAttentionBackend(AttentionBackend):
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
||||
paged_attn = _get_paged_attn_module()
|
||||
paged_attn.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
PagedAttention.copy_blocks(kv_caches, src_to_dists)
|
||||
paged_attn = _get_paged_attn_module()
|
||||
paged_attn.copy_blocks(kv_caches, src_to_dists)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -496,7 +526,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
supported_head_sizes = PagedAttention.get_supported_head_sizes()
|
||||
self.paged_attn_module = _get_paged_attn_module()
|
||||
supported_head_sizes = self.paged_attn_module.get_supported_head_sizes(
|
||||
)
|
||||
|
||||
if head_size not in supported_head_sizes:
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by PagedAttention. "
|
||||
@ -546,6 +579,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
self.sdpa_attn_func = _sdpa_attention
|
||||
logger.debug("Using naive (SDPA) attention in ROCmBackend")
|
||||
|
||||
self.aiter_kv_scales_initialized = False
|
||||
|
||||
def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
|
||||
tokens, n_kv_heads, head_dim = x.shape
|
||||
@ -624,12 +659,37 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
else:
|
||||
assert value is None
|
||||
|
||||
paged_attn = self.paged_attn_module
|
||||
|
||||
# Reshaping kv tensors is required for AITER paged attention kernel
|
||||
# because it works on a different tensor shape,
|
||||
# when the size of one element is one byte (int8/fp8 dtypes).
|
||||
# This reshaping is only required on the first forward call
|
||||
# and the kv cache must not be empty.
|
||||
if (is_rocm_aiter_paged_attn_enabled() and kv_cache.dtype.itemsize == 1
|
||||
and not self.aiter_kv_scales_initialized
|
||||
and kv_cache.shape != torch.Size([0])):
|
||||
num_blocks = kv_cache.shape[1]
|
||||
block_size = kv_cache.shape[2] // (self.num_kv_heads *
|
||||
self.head_size)
|
||||
k_scale = torch.empty((self.num_kv_heads, num_blocks * block_size),
|
||||
dtype=torch.float32,
|
||||
device=kv_cache.device)
|
||||
v_scale = torch.empty((self.num_kv_heads, num_blocks * block_size),
|
||||
dtype=torch.float32,
|
||||
device=kv_cache.device)
|
||||
self.aiter_kv_scales_initialized = True
|
||||
k_scale.fill_(layer._k_scale.item())
|
||||
v_scale.fill_(layer._v_scale.item())
|
||||
layer._k_scale = k_scale
|
||||
layer._v_scale = v_scale
|
||||
|
||||
# Only update KV cache for decoder self-attention
|
||||
# and encoder-decoder cross-attention
|
||||
if self.attn_type not in [
|
||||
AttentionType.ENCODER, AttentionType.ENCODER_ONLY
|
||||
] and kv_cache.numel() > 0:
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
key_cache, value_cache = paged_attn.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size)
|
||||
|
||||
if key is not None and value is not None:
|
||||
@ -637,7 +697,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
# cache. If kv_cache is not provided, the new key and value
|
||||
# tensors are not cached. This happens during the initial
|
||||
# memory profiling run.
|
||||
PagedAttention.write_to_paged_cache(
|
||||
paged_attn.write_to_paged_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
@ -768,23 +828,22 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
# prefix-enabled attention -
|
||||
# not applicable for encoder-only models
|
||||
if self.attn_type != AttentionType.ENCODER_ONLY:
|
||||
output[:
|
||||
num_prefill_tokens] = PagedAttention.forward_prefix(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self.kv_cache_dtype,
|
||||
key_cache,
|
||||
value_cache,
|
||||
prefill_meta.block_tables,
|
||||
prefill_meta.query_start_loc,
|
||||
prefill_meta.seq_lens_tensor,
|
||||
prefill_meta.max_query_len,
|
||||
self.alibi_slopes,
|
||||
self.sliding_window[0],
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
output[:num_prefill_tokens] = paged_attn.forward_prefix(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self.kv_cache_dtype,
|
||||
key_cache,
|
||||
value_cache,
|
||||
prefill_meta.block_tables,
|
||||
prefill_meta.query_start_loc,
|
||||
prefill_meta.seq_lens_tensor,
|
||||
prefill_meta.max_query_len,
|
||||
self.alibi_slopes,
|
||||
self.sliding_window[0],
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
# Skip decode phase for encoder-only models
|
||||
if (decode_meta := attn_metadata.decode_metadata) and (
|
||||
self.attn_type != AttentionType.ENCODER_ONLY):
|
||||
@ -843,7 +902,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
layer._v_scale,
|
||||
)
|
||||
else:
|
||||
output[num_prefill_tokens:] = PagedAttention.forward_decode(
|
||||
output[num_prefill_tokens:] = paged_attn.forward_decode(
|
||||
decode_query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
|
||||
101
vllm/attention/ops/rocm_aiter_paged_attn.py
Normal file
101
vllm/attention/ops/rocm_aiter_paged_attn.py
Normal file
@ -0,0 +1,101 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Optional
|
||||
|
||||
import aiter as rocm_aiter
|
||||
import torch
|
||||
|
||||
from vllm.attention.ops.paged_attn import PagedAttention
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
|
||||
class AITERPagedAttention(PagedAttention):
|
||||
|
||||
@staticmethod
|
||||
def write_to_paged_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
) -> None:
|
||||
if kv_cache_dtype not in ["int8", "fp8", "fp8_e4m3"]:
|
||||
PagedAttention.write_to_paged_cache(key, value, key_cache,
|
||||
value_cache, slot_mapping,
|
||||
kv_cache_dtype, k_scale,
|
||||
v_scale)
|
||||
else:
|
||||
kv_cache_torch_dtype = (FP8_DTYPE
|
||||
if "fp8" in kv_cache_dtype else torch.int8)
|
||||
key_cache = key_cache.view(kv_cache_torch_dtype)
|
||||
value_cache = value_cache.view(kv_cache_torch_dtype)
|
||||
|
||||
rocm_aiter.reshape_and_cache_with_pertoken_quant(
|
||||
key, value, key_cache, value_cache, k_scale, v_scale,
|
||||
slot_mapping.flatten(), True)
|
||||
|
||||
@staticmethod
|
||||
def forward_decode(
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
max_seq_len: int,
|
||||
kv_cache_dtype: str,
|
||||
num_kv_heads: int,
|
||||
scale: float,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
tp_rank: int = 0,
|
||||
blocksparse_local_blocks: int = 0,
|
||||
blocksparse_vert_stride: int = 0,
|
||||
blocksparse_block_size: int = 64,
|
||||
blocksparse_head_sliding_step: int = 0,
|
||||
) -> torch.Tensor:
|
||||
if kv_cache_dtype not in ["int8", "fp8", "fp8_e4m3"]:
|
||||
return PagedAttention.forward_decode(
|
||||
query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
block_tables=block_tables,
|
||||
seq_lens=seq_lens,
|
||||
max_seq_len=max_seq_len,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
num_kv_heads=num_kv_heads,
|
||||
scale=scale,
|
||||
alibi_slopes=alibi_slopes,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
tp_rank=tp_rank,
|
||||
blocksparse_local_blocks=blocksparse_local_blocks,
|
||||
blocksparse_vert_stride=blocksparse_vert_stride,
|
||||
blocksparse_block_size=blocksparse_block_size,
|
||||
blocksparse_head_sliding_step=blocksparse_head_sliding_step)
|
||||
|
||||
if "fp8" in kv_cache_dtype:
|
||||
key_cache = key_cache.view(torch.float8_e4m3fnuz)
|
||||
value_cache = value_cache.view(torch.float8_e4m3fnuz)
|
||||
|
||||
if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
|
||||
# use blocksparse paged attention
|
||||
block_size = value_cache.size(-1)
|
||||
assert (blocksparse_block_size > 0 and
|
||||
blocksparse_block_size % block_size == 0), \
|
||||
(f"{blocksparse_block_size=} needs to be a multiple of"
|
||||
f"{block_size=} used in block_tables.")
|
||||
|
||||
output = torch.empty_like(query)
|
||||
block_size = value_cache.shape[3]
|
||||
max_num_blocks_per_seq = cdiv(max_seq_len, block_size)
|
||||
|
||||
rocm_aiter.pa_fwd_asm(query, key_cache, value_cache, block_tables,
|
||||
seq_lens, max_num_blocks_per_seq, k_scale,
|
||||
v_scale, output)
|
||||
return output
|
||||
@ -75,6 +75,7 @@ if TYPE_CHECKING:
|
||||
VLLM_DISABLED_KERNELS: list[str] = []
|
||||
VLLM_USE_V1: bool = True
|
||||
VLLM_ROCM_USE_AITER: bool = False
|
||||
VLLM_ROCM_USE_AITER_PAGED_ATTN: bool = False
|
||||
VLLM_ROCM_USE_AITER_LINEAR: bool = True
|
||||
VLLM_ROCM_USE_AITER_MOE: bool = True
|
||||
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
|
||||
@ -533,6 +534,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
lambda: (os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in
|
||||
("true", "1")),
|
||||
|
||||
# Whether to use aiter paged attention.
|
||||
# By default is disabled.
|
||||
"VLLM_ROCM_USE_AITER_PAGED_ATTN":
|
||||
lambda: (os.getenv("VLLM_ROCM_USE_AITER_PAGED_ATTN", "False").lower() in
|
||||
("true", "1")),
|
||||
|
||||
# use aiter linear op if aiter ops are enabled
|
||||
# The following list of related ops
|
||||
# - scaled_mm (per-tensor / rowwise)
|
||||
|
||||
@ -118,7 +118,8 @@ def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
|
||||
and (head_size == 64 or head_size == 128)
|
||||
and (block_size == 16 or block_size == 32)
|
||||
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768
|
||||
and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
|
||||
and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
|
||||
and envs.VLLM_ROCM_USE_AITER))
|
||||
|
||||
|
||||
class RocmPlatform(Platform):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user