From 6f9d81d03b3b60f8a8b51d624c86f99bf26e96a4 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Sat, 29 Nov 2025 00:44:33 +0800 Subject: [PATCH] [V0 deprecation] Clean up legacy paged attention helper functions (#28043) Signed-off-by: Isotr0py --- vllm/attention/ops/paged_attn.py | 211 -------------------- vllm/attention/ops/rocm_aiter_paged_attn.py | 123 ------------ 2 files changed, 334 deletions(-) delete mode 100644 vllm/attention/ops/rocm_aiter_paged_attn.py diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 8e010ffba32ec..4aa4bcf5bbd36 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -1,58 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from dataclasses import dataclass import torch from vllm.platforms import current_platform -from vllm.triton_utils import HAS_TRITON if current_platform.is_cuda_alike(): from vllm import _custom_ops as ops elif current_platform.is_xpu(): from vllm._ipex_ops import ipex_ops as ops -if HAS_TRITON: - from vllm.attention.ops.prefix_prefill import context_attention_fwd - -# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. -_PARTITION_SIZE = 512 - - -@dataclass -class PagedAttentionMetadata: - """Metadata for PagedAttention.""" - - # (batch_size,). The length of sequences (entire tokens seen so far) per - # sequence. - seq_lens_tensor: torch.Tensor | None - # Maximum sequence length in the batch. 0 if it is prefill-only batch. - max_decode_seq_len: int - # (batch_size, max_blocks_per_seq). - # Block addresses per sequence. (Seq id -> list of physical block) - # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks - # in the kv cache. Each block can contain up to block_size tokens. - # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph - # captured. - block_tables: torch.Tensor | None - class PagedAttention: - @staticmethod - def get_supported_head_sizes() -> list[int]: - return [32, 64, 80, 96, 112, 120, 128, 192, 256] - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - cache_dtype_str: str = "auto", - ) -> tuple[int, ...]: - return (2, num_blocks, block_size * num_kv_heads * head_size) - @staticmethod def split_kv_cache( kv_cache: torch.Tensor, @@ -89,174 +49,3 @@ class PagedAttention: k_scale, v_scale, ) - - @staticmethod - def forward_decode( - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - block_tables: torch.Tensor, - seq_lens: torch.Tensor, - max_seq_len: int, - kv_cache_dtype: str, - num_kv_heads: int, - scale: float, - alibi_slopes: torch.Tensor | None, - k_scale: torch.Tensor, - v_scale: torch.Tensor, - tp_rank: int = 0, - blocksparse_local_blocks: int = 0, - blocksparse_vert_stride: int = 0, - blocksparse_block_size: int = 64, - blocksparse_head_sliding_step: int = 0, - ) -> torch.Tensor: - if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1: - # use blocksparse paged attention - block_size = value_cache.size(-1) - assert ( - blocksparse_block_size > 0 and blocksparse_block_size % block_size == 0 - ), ( - f"{blocksparse_block_size=} needs to be a multiple of" - f"{block_size=} used in block_tables." - ) - - output = torch.empty_like(query) - block_size = value_cache.shape[3] - num_seqs, num_heads, head_size = query.shape - max_num_partitions = (max_seq_len + _PARTITION_SIZE - 1) // _PARTITION_SIZE - # NOTE(woosuk): We use a simple heuristic to decide whether to use - # PagedAttention V1 or V2. If the number of partitions is 1, we use - # V1 to avoid the overhead of reduction. Also, if the number of - # sequences or heads is large, we use V1 since there is enough work - # to parallelize. - # TODO(woosuk): Tune this heuristic. - # For context len > 8192, use V2 kernel to avoid shared memory shortage. - use_v1 = max_seq_len <= 8192 and ( - max_num_partitions == 1 or num_seqs * num_heads > 512 - ) - - if use_v1: - # Run PagedAttention V1. - ops.paged_attention_v1( - output, - query, - key_cache, - value_cache, - num_kv_heads, - scale, - block_tables, - seq_lens, - block_size, - max_seq_len, - alibi_slopes, - kv_cache_dtype, - k_scale, - v_scale, - tp_rank, - blocksparse_local_blocks, - blocksparse_vert_stride, - blocksparse_block_size, - blocksparse_head_sliding_step, - ) - else: - # Run PagedAttention V2. - assert _PARTITION_SIZE % block_size == 0 - tmp_output = torch.empty( - size=(num_seqs, num_heads, max_num_partitions, head_size), - dtype=output.dtype, - device=output.device, - ) - exp_sums = torch.empty( - size=(num_seqs, num_heads, max_num_partitions), - dtype=torch.float32, - device=output.device, - ) - max_logits = torch.empty_like(exp_sums) - ops.paged_attention_v2( - output, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - num_kv_heads, - scale, - block_tables, - seq_lens, - block_size, - max_seq_len, - alibi_slopes, - kv_cache_dtype, - k_scale, - v_scale, - tp_rank, - blocksparse_local_blocks, - blocksparse_vert_stride, - blocksparse_block_size, - blocksparse_head_sliding_step, - ) - return output - - @staticmethod - def forward_prefix( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache_dtype: str, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - block_tables: torch.Tensor, - query_start_loc: torch.Tensor, - seq_lens_tensor: torch.Tensor, - max_query_len: int, - alibi_slopes: torch.Tensor | None, - sliding_window: int | None, - k_scale: torch.Tensor, - v_scale: torch.Tensor, - ) -> torch.Tensor: - output = torch.empty_like(query) - max_seq_len = None - context_attention_fwd( - query, - key, - value, - output, - kv_cache_dtype, - key_cache, - value_cache, - block_tables, - # query_start_loc is (batch_size + 1,) - query_start_loc, - seq_lens_tensor, - max_seq_len, - max_query_len, - k_scale, - v_scale, - alibi_slopes, - sliding_window, - ) - return output - - @staticmethod - def swap_blocks( - src_kv_cache: torch.Tensor, - dst_kv_cache: torch.Tensor, - src_to_dst: torch.Tensor, - ) -> None: - src_key_cache = src_kv_cache[0] - dst_key_cache = dst_kv_cache[0] - ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) - - src_value_cache = src_kv_cache[1] - dst_value_cache = dst_kv_cache[1] - ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) - - @staticmethod - def copy_blocks( - kv_caches: list[torch.Tensor], - src_to_dists: torch.Tensor, - ) -> None: - key_caches = [kv_cache[0] for kv_cache in kv_caches] - value_caches = [kv_cache[1] for kv_cache in kv_caches] - ops.copy_blocks(key_caches, value_caches, src_to_dists) diff --git a/vllm/attention/ops/rocm_aiter_paged_attn.py b/vllm/attention/ops/rocm_aiter_paged_attn.py deleted file mode 100644 index bcd1e2cd56441..0000000000000 --- a/vllm/attention/ops/rocm_aiter_paged_attn.py +++ /dev/null @@ -1,123 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import aiter as rocm_aiter -import torch - -from vllm.attention.ops.paged_attn import PagedAttention -from vllm.platforms import current_platform -from vllm.utils.math_utils import cdiv - -FP8_DTYPE = current_platform.fp8_dtype() - - -class AITERPagedAttention(PagedAttention): - @staticmethod - def write_to_paged_cache( - key: torch.Tensor, - value: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - slot_mapping: torch.Tensor, - kv_cache_dtype: str, - k_scale: torch.Tensor, - v_scale: torch.Tensor, - ) -> None: - if kv_cache_dtype not in ["int8", "fp8", "fp8_e4m3"]: - PagedAttention.write_to_paged_cache( - key, - value, - key_cache, - value_cache, - slot_mapping, - kv_cache_dtype, - k_scale, - v_scale, - ) - else: - kv_cache_torch_dtype = FP8_DTYPE if "fp8" in kv_cache_dtype else torch.int8 - key_cache = key_cache.view(kv_cache_torch_dtype) - value_cache = value_cache.view(kv_cache_torch_dtype) - - rocm_aiter.reshape_and_cache_with_pertoken_quant( - key, - value, - key_cache, - value_cache, - k_scale, - v_scale, - slot_mapping.flatten(), - True, - ) - - @staticmethod - def forward_decode( - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - block_tables: torch.Tensor, - seq_lens: torch.Tensor, - max_seq_len: int, - kv_cache_dtype: str, - num_kv_heads: int, - scale: float, - alibi_slopes: torch.Tensor | None, - k_scale: torch.Tensor, - v_scale: torch.Tensor, - tp_rank: int = 0, - blocksparse_local_blocks: int = 0, - blocksparse_vert_stride: int = 0, - blocksparse_block_size: int = 64, - blocksparse_head_sliding_step: int = 0, - ) -> torch.Tensor: - if kv_cache_dtype not in ["int8", "fp8", "fp8_e4m3"]: - return PagedAttention.forward_decode( - query=query, - key_cache=key_cache, - value_cache=value_cache, - block_tables=block_tables, - seq_lens=seq_lens, - max_seq_len=max_seq_len, - kv_cache_dtype=kv_cache_dtype, - num_kv_heads=num_kv_heads, - scale=scale, - alibi_slopes=alibi_slopes, - k_scale=k_scale, - v_scale=v_scale, - tp_rank=tp_rank, - blocksparse_local_blocks=blocksparse_local_blocks, - blocksparse_vert_stride=blocksparse_vert_stride, - blocksparse_block_size=blocksparse_block_size, - blocksparse_head_sliding_step=blocksparse_head_sliding_step, - ) - - if "fp8" in kv_cache_dtype: - key_cache = key_cache.view(current_platform.fp8_dtype()) - value_cache = value_cache.view(current_platform.fp8_dtype()) - - if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1: - # use blocksparse paged attention - block_size = value_cache.size(-1) - assert ( - blocksparse_block_size > 0 and blocksparse_block_size % block_size == 0 - ), ( - f"{blocksparse_block_size=} needs to be a multiple of" - f"{block_size=} used in block_tables." - ) - - output = torch.empty_like(query) - block_size = value_cache.shape[3] - max_num_blocks_per_seq = cdiv(max_seq_len, block_size) - - rocm_aiter.pa_fwd_asm( - query, - key_cache, - value_cache, - block_tables, - seq_lens, - max_num_blocks_per_seq, - k_scale, - v_scale, - output, - ) - return output