[V0 deprecation] Clean up legacy paged attention helper functions (#28043)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py 2025-11-29 00:44:33 +08:00 committed by GitHub
parent fae6943068
commit 6f9d81d03b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 0 additions and 334 deletions

View File

@ -1,58 +1,18 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import torch
from vllm.platforms import current_platform
from vllm.triton_utils import HAS_TRITON
if current_platform.is_cuda_alike():
from vllm import _custom_ops as ops
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops as ops
if HAS_TRITON:
from vllm.attention.ops.prefix_prefill import context_attention_fwd
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE = 512
@dataclass
class PagedAttentionMetadata:
"""Metadata for PagedAttention."""
# (batch_size,). The length of sequences (entire tokens seen so far) per
# sequence.
seq_lens_tensor: torch.Tensor | None
# Maximum sequence length in the batch. 0 if it is prefill-only batch.
max_decode_seq_len: int
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables: torch.Tensor | None
class PagedAttention:
@staticmethod
def get_supported_head_sizes() -> list[int]:
return [32, 64, 80, 96, 112, 120, 128, 192, 256]
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
return (2, num_blocks, block_size * num_kv_heads * head_size)
@staticmethod
def split_kv_cache(
kv_cache: torch.Tensor,
@ -89,174 +49,3 @@ class PagedAttention:
k_scale,
v_scale,
)
@staticmethod
def forward_decode(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
max_seq_len: int,
kv_cache_dtype: str,
num_kv_heads: int,
scale: float,
alibi_slopes: torch.Tensor | None,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
) -> torch.Tensor:
if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
# use blocksparse paged attention
block_size = value_cache.size(-1)
assert (
blocksparse_block_size > 0 and blocksparse_block_size % block_size == 0
), (
f"{blocksparse_block_size=} needs to be a multiple of"
f"{block_size=} used in block_tables."
)
output = torch.empty_like(query)
block_size = value_cache.shape[3]
num_seqs, num_heads, head_size = query.shape
max_num_partitions = (max_seq_len + _PARTITION_SIZE - 1) // _PARTITION_SIZE
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
# TODO(woosuk): Tune this heuristic.
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
use_v1 = max_seq_len <= 8192 and (
max_num_partitions == 1 or num_seqs * num_heads > 512
)
if use_v1:
# Run PagedAttention V1.
ops.paged_attention_v1(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
)
else:
# Run PagedAttention V2.
assert _PARTITION_SIZE % block_size == 0
tmp_output = torch.empty(
size=(num_seqs, num_heads, max_num_partitions, head_size),
dtype=output.dtype,
device=output.device,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, max_num_partitions),
dtype=torch.float32,
device=output.device,
)
max_logits = torch.empty_like(exp_sums)
ops.paged_attention_v2(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
)
return output
@staticmethod
def forward_prefix(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache_dtype: str,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
query_start_loc: torch.Tensor,
seq_lens_tensor: torch.Tensor,
max_query_len: int,
alibi_slopes: torch.Tensor | None,
sliding_window: int | None,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
) -> torch.Tensor:
output = torch.empty_like(query)
max_seq_len = None
context_attention_fwd(
query,
key,
value,
output,
kv_cache_dtype,
key_cache,
value_cache,
block_tables,
# query_start_loc is (batch_size + 1,)
query_start_loc,
seq_lens_tensor,
max_seq_len,
max_query_len,
k_scale,
v_scale,
alibi_slopes,
sliding_window,
)
return output
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
src_key_cache = src_kv_cache[0]
dst_key_cache = dst_kv_cache[0]
ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
src_value_cache = src_kv_cache[1]
dst_value_cache = dst_kv_cache[1]
ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
@staticmethod
def copy_blocks(
kv_caches: list[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
key_caches = [kv_cache[0] for kv_cache in kv_caches]
value_caches = [kv_cache[1] for kv_cache in kv_caches]
ops.copy_blocks(key_caches, value_caches, src_to_dists)

View File

@ -1,123 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import aiter as rocm_aiter
import torch
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv
FP8_DTYPE = current_platform.fp8_dtype()
class AITERPagedAttention(PagedAttention):
@staticmethod
def write_to_paged_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
) -> None:
if kv_cache_dtype not in ["int8", "fp8", "fp8_e4m3"]:
PagedAttention.write_to_paged_cache(
key,
value,
key_cache,
value_cache,
slot_mapping,
kv_cache_dtype,
k_scale,
v_scale,
)
else:
kv_cache_torch_dtype = FP8_DTYPE if "fp8" in kv_cache_dtype else torch.int8
key_cache = key_cache.view(kv_cache_torch_dtype)
value_cache = value_cache.view(kv_cache_torch_dtype)
rocm_aiter.reshape_and_cache_with_pertoken_quant(
key,
value,
key_cache,
value_cache,
k_scale,
v_scale,
slot_mapping.flatten(),
True,
)
@staticmethod
def forward_decode(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
max_seq_len: int,
kv_cache_dtype: str,
num_kv_heads: int,
scale: float,
alibi_slopes: torch.Tensor | None,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
) -> torch.Tensor:
if kv_cache_dtype not in ["int8", "fp8", "fp8_e4m3"]:
return PagedAttention.forward_decode(
query=query,
key_cache=key_cache,
value_cache=value_cache,
block_tables=block_tables,
seq_lens=seq_lens,
max_seq_len=max_seq_len,
kv_cache_dtype=kv_cache_dtype,
num_kv_heads=num_kv_heads,
scale=scale,
alibi_slopes=alibi_slopes,
k_scale=k_scale,
v_scale=v_scale,
tp_rank=tp_rank,
blocksparse_local_blocks=blocksparse_local_blocks,
blocksparse_vert_stride=blocksparse_vert_stride,
blocksparse_block_size=blocksparse_block_size,
blocksparse_head_sliding_step=blocksparse_head_sliding_step,
)
if "fp8" in kv_cache_dtype:
key_cache = key_cache.view(current_platform.fp8_dtype())
value_cache = value_cache.view(current_platform.fp8_dtype())
if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
# use blocksparse paged attention
block_size = value_cache.size(-1)
assert (
blocksparse_block_size > 0 and blocksparse_block_size % block_size == 0
), (
f"{blocksparse_block_size=} needs to be a multiple of"
f"{block_size=} used in block_tables."
)
output = torch.empty_like(query)
block_size = value_cache.shape[3]
max_num_blocks_per_seq = cdiv(max_seq_len, block_size)
rocm_aiter.pa_fwd_asm(
query,
key_cache,
value_cache,
block_tables,
seq_lens,
max_num_blocks_per_seq,
k_scale,
v_scale,
output,
)
return output