mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-25 02:44:26 +08:00
[FEAT][ROCm] Integrate Paged Attention Kernel from AITER (#15001)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
parent
8f7bace7c3
commit
0e237f0035
@ -12,7 +12,7 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
|
|||||||
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
|
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
|
||||||
ARG FA_BRANCH="1a7f4dfa"
|
ARG FA_BRANCH="1a7f4dfa"
|
||||||
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
|
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
|
||||||
ARG AITER_BRANCH="5a77249"
|
ARG AITER_BRANCH="7e1ed08"
|
||||||
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
|
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
|
||||||
|
|
||||||
FROM ${BASE_IMAGE} AS base
|
FROM ${BASE_IMAGE} AS base
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
"""Attention layer ROCm GPUs."""
|
"""Attention layer ROCm GPUs."""
|
||||||
import itertools
|
import itertools
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from functools import cache
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -26,6 +27,32 @@ logger = init_logger(__name__)
|
|||||||
_PARTITION_SIZE_ROCM = 256
|
_PARTITION_SIZE_ROCM = 256
|
||||||
|
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def is_rocm_aiter_paged_attn_enabled() -> bool:
|
||||||
|
return envs.VLLM_ROCM_USE_AITER_PAGED_ATTN \
|
||||||
|
and envs.VLLM_ROCM_USE_AITER \
|
||||||
|
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def _get_paged_attn_module() -> PagedAttention:
|
||||||
|
"""
|
||||||
|
Initializes the appropriate PagedAttention module from `attention/ops`,
|
||||||
|
which is used as helper function
|
||||||
|
by `ROCmFlashAttentionImpl` and `ROCmFlashAttentionBackend`.
|
||||||
|
|
||||||
|
The choice of attention module depends on whether
|
||||||
|
AITER paged attention is enabled:
|
||||||
|
- If enabled, `ROCmFlashAttentionImpl` uses `AITERPagedAttention`.
|
||||||
|
- Otherwise, it defaults to using the original `PagedAttention`.
|
||||||
|
"""
|
||||||
|
if is_rocm_aiter_paged_attn_enabled():
|
||||||
|
# Import AITERPagedAttention only when the flag is enabled
|
||||||
|
from vllm.attention.ops.rocm_aiter_paged_attn import (
|
||||||
|
AITERPagedAttention)
|
||||||
|
return AITERPagedAttention()
|
||||||
|
return PagedAttention()
|
||||||
|
|
||||||
|
|
||||||
class ROCmFlashAttentionBackend(AttentionBackend):
|
class ROCmFlashAttentionBackend(AttentionBackend):
|
||||||
accept_output_buffer: bool = True
|
accept_output_buffer: bool = True
|
||||||
|
|
||||||
@ -56,8 +83,9 @@ class ROCmFlashAttentionBackend(AttentionBackend):
|
|||||||
num_kv_heads: int,
|
num_kv_heads: int,
|
||||||
head_size: int,
|
head_size: int,
|
||||||
) -> Tuple[int, ...]:
|
) -> Tuple[int, ...]:
|
||||||
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
|
paged_attn = _get_paged_attn_module()
|
||||||
num_kv_heads, head_size)
|
return paged_attn.get_kv_cache_shape(num_blocks, block_size,
|
||||||
|
num_kv_heads, head_size)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def swap_blocks(
|
def swap_blocks(
|
||||||
@ -65,14 +93,16 @@ class ROCmFlashAttentionBackend(AttentionBackend):
|
|||||||
dst_kv_cache: torch.Tensor,
|
dst_kv_cache: torch.Tensor,
|
||||||
src_to_dst: torch.Tensor,
|
src_to_dst: torch.Tensor,
|
||||||
) -> None:
|
) -> None:
|
||||||
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
paged_attn = _get_paged_attn_module()
|
||||||
|
paged_attn.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def copy_blocks(
|
def copy_blocks(
|
||||||
kv_caches: List[torch.Tensor],
|
kv_caches: List[torch.Tensor],
|
||||||
src_to_dists: torch.Tensor,
|
src_to_dists: torch.Tensor,
|
||||||
) -> None:
|
) -> None:
|
||||||
PagedAttention.copy_blocks(kv_caches, src_to_dists)
|
paged_attn = _get_paged_attn_module()
|
||||||
|
paged_attn.copy_blocks(kv_caches, src_to_dists)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -496,7 +526,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
assert self.num_heads % self.num_kv_heads == 0
|
assert self.num_heads % self.num_kv_heads == 0
|
||||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
|
|
||||||
supported_head_sizes = PagedAttention.get_supported_head_sizes()
|
self.paged_attn_module = _get_paged_attn_module()
|
||||||
|
supported_head_sizes = self.paged_attn_module.get_supported_head_sizes(
|
||||||
|
)
|
||||||
|
|
||||||
if head_size not in supported_head_sizes:
|
if head_size not in supported_head_sizes:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Head size {head_size} is not supported by PagedAttention. "
|
f"Head size {head_size} is not supported by PagedAttention. "
|
||||||
@ -546,6 +579,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
self.sdpa_attn_func = _sdpa_attention
|
self.sdpa_attn_func = _sdpa_attention
|
||||||
logger.debug("Using naive (SDPA) attention in ROCmBackend")
|
logger.debug("Using naive (SDPA) attention in ROCmBackend")
|
||||||
|
|
||||||
|
self.aiter_kv_scales_initialized = False
|
||||||
|
|
||||||
def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
|
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
|
||||||
tokens, n_kv_heads, head_dim = x.shape
|
tokens, n_kv_heads, head_dim = x.shape
|
||||||
@ -624,12 +659,37 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
else:
|
else:
|
||||||
assert value is None
|
assert value is None
|
||||||
|
|
||||||
|
paged_attn = self.paged_attn_module
|
||||||
|
|
||||||
|
# Reshaping kv tensors is required for AITER paged attention kernel
|
||||||
|
# because it works on a different tensor shape,
|
||||||
|
# when the size of one element is one byte (int8/fp8 dtypes).
|
||||||
|
# This reshaping is only required on the first forward call
|
||||||
|
# and the kv cache must not be empty.
|
||||||
|
if (is_rocm_aiter_paged_attn_enabled() and kv_cache.dtype.itemsize == 1
|
||||||
|
and not self.aiter_kv_scales_initialized
|
||||||
|
and kv_cache.shape != torch.Size([0])):
|
||||||
|
num_blocks = kv_cache.shape[1]
|
||||||
|
block_size = kv_cache.shape[2] // (self.num_kv_heads *
|
||||||
|
self.head_size)
|
||||||
|
k_scale = torch.empty((self.num_kv_heads, num_blocks * block_size),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=kv_cache.device)
|
||||||
|
v_scale = torch.empty((self.num_kv_heads, num_blocks * block_size),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=kv_cache.device)
|
||||||
|
self.aiter_kv_scales_initialized = True
|
||||||
|
k_scale.fill_(layer._k_scale.item())
|
||||||
|
v_scale.fill_(layer._v_scale.item())
|
||||||
|
layer._k_scale = k_scale
|
||||||
|
layer._v_scale = v_scale
|
||||||
|
|
||||||
# Only update KV cache for decoder self-attention
|
# Only update KV cache for decoder self-attention
|
||||||
# and encoder-decoder cross-attention
|
# and encoder-decoder cross-attention
|
||||||
if self.attn_type not in [
|
if self.attn_type not in [
|
||||||
AttentionType.ENCODER, AttentionType.ENCODER_ONLY
|
AttentionType.ENCODER, AttentionType.ENCODER_ONLY
|
||||||
] and kv_cache.numel() > 0:
|
] and kv_cache.numel() > 0:
|
||||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
key_cache, value_cache = paged_attn.split_kv_cache(
|
||||||
kv_cache, self.num_kv_heads, self.head_size)
|
kv_cache, self.num_kv_heads, self.head_size)
|
||||||
|
|
||||||
if key is not None and value is not None:
|
if key is not None and value is not None:
|
||||||
@ -637,7 +697,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
# cache. If kv_cache is not provided, the new key and value
|
# cache. If kv_cache is not provided, the new key and value
|
||||||
# tensors are not cached. This happens during the initial
|
# tensors are not cached. This happens during the initial
|
||||||
# memory profiling run.
|
# memory profiling run.
|
||||||
PagedAttention.write_to_paged_cache(
|
paged_attn.write_to_paged_cache(
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
key_cache,
|
key_cache,
|
||||||
@ -768,23 +828,22 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
# prefix-enabled attention -
|
# prefix-enabled attention -
|
||||||
# not applicable for encoder-only models
|
# not applicable for encoder-only models
|
||||||
if self.attn_type != AttentionType.ENCODER_ONLY:
|
if self.attn_type != AttentionType.ENCODER_ONLY:
|
||||||
output[:
|
output[:num_prefill_tokens] = paged_attn.forward_prefix(
|
||||||
num_prefill_tokens] = PagedAttention.forward_prefix(
|
query,
|
||||||
query,
|
key,
|
||||||
key,
|
value,
|
||||||
value,
|
self.kv_cache_dtype,
|
||||||
self.kv_cache_dtype,
|
key_cache,
|
||||||
key_cache,
|
value_cache,
|
||||||
value_cache,
|
prefill_meta.block_tables,
|
||||||
prefill_meta.block_tables,
|
prefill_meta.query_start_loc,
|
||||||
prefill_meta.query_start_loc,
|
prefill_meta.seq_lens_tensor,
|
||||||
prefill_meta.seq_lens_tensor,
|
prefill_meta.max_query_len,
|
||||||
prefill_meta.max_query_len,
|
self.alibi_slopes,
|
||||||
self.alibi_slopes,
|
self.sliding_window[0],
|
||||||
self.sliding_window[0],
|
layer._k_scale,
|
||||||
layer._k_scale,
|
layer._v_scale,
|
||||||
layer._v_scale,
|
)
|
||||||
)
|
|
||||||
# Skip decode phase for encoder-only models
|
# Skip decode phase for encoder-only models
|
||||||
if (decode_meta := attn_metadata.decode_metadata) and (
|
if (decode_meta := attn_metadata.decode_metadata) and (
|
||||||
self.attn_type != AttentionType.ENCODER_ONLY):
|
self.attn_type != AttentionType.ENCODER_ONLY):
|
||||||
@ -843,7 +902,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
layer._v_scale,
|
layer._v_scale,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
output[num_prefill_tokens:] = PagedAttention.forward_decode(
|
output[num_prefill_tokens:] = paged_attn.forward_decode(
|
||||||
decode_query,
|
decode_query,
|
||||||
key_cache,
|
key_cache,
|
||||||
value_cache,
|
value_cache,
|
||||||
|
|||||||
101
vllm/attention/ops/rocm_aiter_paged_attn.py
Normal file
101
vllm/attention/ops/rocm_aiter_paged_attn.py
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import aiter as rocm_aiter
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.attention.ops.paged_attn import PagedAttention
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils import cdiv
|
||||||
|
|
||||||
|
FP8_DTYPE = current_platform.fp8_dtype()
|
||||||
|
|
||||||
|
|
||||||
|
class AITERPagedAttention(PagedAttention):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def write_to_paged_cache(
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
slot_mapping: torch.Tensor,
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
k_scale: torch.Tensor,
|
||||||
|
v_scale: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
if kv_cache_dtype not in ["int8", "fp8", "fp8_e4m3"]:
|
||||||
|
PagedAttention.write_to_paged_cache(key, value, key_cache,
|
||||||
|
value_cache, slot_mapping,
|
||||||
|
kv_cache_dtype, k_scale,
|
||||||
|
v_scale)
|
||||||
|
else:
|
||||||
|
kv_cache_torch_dtype = (FP8_DTYPE
|
||||||
|
if "fp8" in kv_cache_dtype else torch.int8)
|
||||||
|
key_cache = key_cache.view(kv_cache_torch_dtype)
|
||||||
|
value_cache = value_cache.view(kv_cache_torch_dtype)
|
||||||
|
|
||||||
|
rocm_aiter.reshape_and_cache_with_pertoken_quant(
|
||||||
|
key, value, key_cache, value_cache, k_scale, v_scale,
|
||||||
|
slot_mapping.flatten(), True)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward_decode(
|
||||||
|
query: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
seq_lens: torch.Tensor,
|
||||||
|
max_seq_len: int,
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
num_kv_heads: int,
|
||||||
|
scale: float,
|
||||||
|
alibi_slopes: Optional[torch.Tensor],
|
||||||
|
k_scale: torch.Tensor,
|
||||||
|
v_scale: torch.Tensor,
|
||||||
|
tp_rank: int = 0,
|
||||||
|
blocksparse_local_blocks: int = 0,
|
||||||
|
blocksparse_vert_stride: int = 0,
|
||||||
|
blocksparse_block_size: int = 64,
|
||||||
|
blocksparse_head_sliding_step: int = 0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if kv_cache_dtype not in ["int8", "fp8", "fp8_e4m3"]:
|
||||||
|
return PagedAttention.forward_decode(
|
||||||
|
query=query,
|
||||||
|
key_cache=key_cache,
|
||||||
|
value_cache=value_cache,
|
||||||
|
block_tables=block_tables,
|
||||||
|
seq_lens=seq_lens,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
num_kv_heads=num_kv_heads,
|
||||||
|
scale=scale,
|
||||||
|
alibi_slopes=alibi_slopes,
|
||||||
|
k_scale=k_scale,
|
||||||
|
v_scale=v_scale,
|
||||||
|
tp_rank=tp_rank,
|
||||||
|
blocksparse_local_blocks=blocksparse_local_blocks,
|
||||||
|
blocksparse_vert_stride=blocksparse_vert_stride,
|
||||||
|
blocksparse_block_size=blocksparse_block_size,
|
||||||
|
blocksparse_head_sliding_step=blocksparse_head_sliding_step)
|
||||||
|
|
||||||
|
if "fp8" in kv_cache_dtype:
|
||||||
|
key_cache = key_cache.view(torch.float8_e4m3fnuz)
|
||||||
|
value_cache = value_cache.view(torch.float8_e4m3fnuz)
|
||||||
|
|
||||||
|
if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
|
||||||
|
# use blocksparse paged attention
|
||||||
|
block_size = value_cache.size(-1)
|
||||||
|
assert (blocksparse_block_size > 0 and
|
||||||
|
blocksparse_block_size % block_size == 0), \
|
||||||
|
(f"{blocksparse_block_size=} needs to be a multiple of"
|
||||||
|
f"{block_size=} used in block_tables.")
|
||||||
|
|
||||||
|
output = torch.empty_like(query)
|
||||||
|
block_size = value_cache.shape[3]
|
||||||
|
max_num_blocks_per_seq = cdiv(max_seq_len, block_size)
|
||||||
|
|
||||||
|
rocm_aiter.pa_fwd_asm(query, key_cache, value_cache, block_tables,
|
||||||
|
seq_lens, max_num_blocks_per_seq, k_scale,
|
||||||
|
v_scale, output)
|
||||||
|
return output
|
||||||
@ -75,6 +75,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_DISABLED_KERNELS: list[str] = []
|
VLLM_DISABLED_KERNELS: list[str] = []
|
||||||
VLLM_USE_V1: bool = True
|
VLLM_USE_V1: bool = True
|
||||||
VLLM_ROCM_USE_AITER: bool = False
|
VLLM_ROCM_USE_AITER: bool = False
|
||||||
|
VLLM_ROCM_USE_AITER_PAGED_ATTN: bool = False
|
||||||
VLLM_ROCM_USE_AITER_LINEAR: bool = True
|
VLLM_ROCM_USE_AITER_LINEAR: bool = True
|
||||||
VLLM_ROCM_USE_AITER_MOE: bool = True
|
VLLM_ROCM_USE_AITER_MOE: bool = True
|
||||||
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
|
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
|
||||||
@ -533,6 +534,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
lambda: (os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in
|
lambda: (os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in
|
||||||
("true", "1")),
|
("true", "1")),
|
||||||
|
|
||||||
|
# Whether to use aiter paged attention.
|
||||||
|
# By default is disabled.
|
||||||
|
"VLLM_ROCM_USE_AITER_PAGED_ATTN":
|
||||||
|
lambda: (os.getenv("VLLM_ROCM_USE_AITER_PAGED_ATTN", "False").lower() in
|
||||||
|
("true", "1")),
|
||||||
|
|
||||||
# use aiter linear op if aiter ops are enabled
|
# use aiter linear op if aiter ops are enabled
|
||||||
# The following list of related ops
|
# The following list of related ops
|
||||||
# - scaled_mm (per-tensor / rowwise)
|
# - scaled_mm (per-tensor / rowwise)
|
||||||
|
|||||||
@ -118,7 +118,8 @@ def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
|
|||||||
and (head_size == 64 or head_size == 128)
|
and (head_size == 64 or head_size == 128)
|
||||||
and (block_size == 16 or block_size == 32)
|
and (block_size == 16 or block_size == 32)
|
||||||
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768
|
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768
|
||||||
and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
|
and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
|
||||||
|
and envs.VLLM_ROCM_USE_AITER))
|
||||||
|
|
||||||
|
|
||||||
class RocmPlatform(Platform):
|
class RocmPlatform(Platform):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user