From 969b4da3a6ab737f72cb33db502b4c0bb70d4139 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Wed, 24 Sep 2025 00:12:14 +0200 Subject: [PATCH] [V0 Deprecation] Remove placeholder attn (#25510) Signed-off-by: Thomas Parnell --- .../attention/test_attention_selector.py | 37 +-- vllm/attention/backends/placeholder_attn.py | 314 ------------------ vllm/attention/layer.py | 3 - vllm/attention/selector.py | 9 - .../kv_connector/v1/nixl_connector.py | 1 - 5 files changed, 10 insertions(+), 354 deletions(-) delete mode 100644 vllm/attention/backends/placeholder_attn.py diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py index a4e200775c09..730514eb5a56 100644 --- a/tests/kernels/attention/test_attention_selector.py +++ b/tests/kernels/attention/test_attention_selector.py @@ -85,8 +85,7 @@ def test_env( if device == "cpu": with patch("vllm.attention.selector.current_platform", CpuPlatform()): - backend = get_attn_backend(16, torch.float16, None, block_size, - False) + backend = get_attn_backend(16, torch.float16, None, block_size) assert backend.get_name() == "TORCH_SDPA_VLLM_V1" elif device == "hip": @@ -106,7 +105,6 @@ def test_env( torch.float16, None, block_size, - False, use_mla=use_mla) assert f"The selected backend, {name}" in str( exc_info.value) @@ -117,7 +115,6 @@ def test_env( torch.float16, None, block_size, - False, use_mla=use_mla) assert f"The selected backend, {name}" in str( exc_info.value) @@ -127,7 +124,6 @@ def test_env( torch.float16, None, block_size, - False, use_mla=use_mla) expected = f"{name}_VLLM_V1" assert backend.get_name() == expected @@ -136,7 +132,6 @@ def test_env( torch.float16, None, block_size, - False, use_mla=use_mla) expected = "TRITON_ATTN_VLLM_V1" assert backend.get_name() == expected @@ -164,7 +159,6 @@ def test_env( torch.float16, None, block_size, - False, use_mla=use_mla) expected = "CUTLASS_MLA_VLLM_V1" assert backend.get_name() == expected @@ -179,7 +173,6 @@ def test_env( torch.float16, None, block_size, - False, use_mla=use_mla) expected = "FLASHINFER_MLA" assert backend.get_name() == expected @@ -199,7 +192,6 @@ def test_env( torch.float16, None, block_size, - False, use_mla=use_mla) expected = f"{name}_VLLM_V1" assert backend.get_name() == expected @@ -208,7 +200,6 @@ def test_env( torch.float16, None, block_size, - False, use_mla=use_mla) expected = "FLASH_ATTN_MLA" assert backend.get_name() == expected @@ -218,7 +209,6 @@ def test_env( torch.float16, None, block_size, - False, use_mla=use_mla) expected = "TRITON_MLA_VLLM_V1" assert backend.get_name() == expected @@ -227,7 +217,6 @@ def test_env( torch.float16, None, block_size, - False, use_mla=use_mla) expected = "FLASHINFER_VLLM_V1" assert backend.get_name() == expected @@ -236,7 +225,6 @@ def test_env( torch.float16, None, block_size, - False, use_mla=use_mla) expected = "FLASH_ATTN_VLLM_V1" assert backend.get_name() == expected @@ -245,7 +233,6 @@ def test_env( torch.float16, None, block_size, - False, use_mla=use_mla) assert backend.get_name() == "FLEX_ATTENTION", ( "Should fallback to FlexAttention if head size is " @@ -264,13 +251,13 @@ def test_fp32_fallback( if device == "cpu": with patch("vllm.attention.selector.current_platform", CpuPlatform()): - backend = get_attn_backend(16, torch.float32, None, 16, False) + backend = get_attn_backend(16, torch.float32, None, 16) assert backend.get_name() == "TORCH_SDPA_VLLM_V1" elif device == "cuda": with patch("vllm.attention.selector.current_platform", CudaPlatform()): - backend = get_attn_backend(16, torch.float32, None, 16, False) + backend = get_attn_backend(16, torch.float32, None, 16) assert backend.get_name() == "FLEX_ATTENTION" @@ -286,29 +273,29 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch): monkeypatch.setattr(torch.cuda, "get_device_capability", lambda _=None: (7, 5)) - backend = get_attn_backend(16, torch.float16, None, 16, False) + backend = get_attn_backend(16, torch.float16, None, 16) assert backend.get_name() != STR_FLASH_ATTN_VAL # Reset the monkeypatch for subsequent tests monkeypatch.undo() # Unsupported data type - backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16, False) + backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16) assert backend.get_name() != STR_FLASH_ATTN_VAL # Unsupported kv cache data type - backend = get_attn_backend(16, torch.float16, "fp8", 16, False) + backend = get_attn_backend(16, torch.float16, "fp8", 16) assert backend.get_name() != STR_FLASH_ATTN_VAL # Unsupported block size - backend = get_attn_backend(16, torch.float16, None, 8, False) + backend = get_attn_backend(16, torch.float16, None, 8) assert backend.get_name() != STR_FLASH_ATTN_VAL # flash-attn is not installed import sys original_module = sys.modules.get('vllm_flash_attn') monkeypatch.setitem(sys.modules, 'vllm_flash_attn', None) - backend = get_attn_backend(16, torch.float16, None, 16, False) + backend = get_attn_backend(16, torch.float16, None, 16) assert backend.get_name() != STR_FLASH_ATTN_VAL # Restore the original module if it existed @@ -319,11 +306,7 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch): monkeypatch.delitem(sys.modules, 'vllm_flash_attn', raising=False) # Unsupported head size - backend = get_attn_backend(17, torch.float16, None, 16, False) - assert backend.get_name() != STR_FLASH_ATTN_VAL - - # Attention-free models should bypass env and use PlaceholderAttention - backend = get_attn_backend(16, torch.float16, None, 16, True) + backend = get_attn_backend(17, torch.float16, None, 16) assert backend.get_name() != STR_FLASH_ATTN_VAL @@ -336,5 +319,5 @@ def test_invalid_env(monkeypatch: pytest.MonkeyPatch): # Should raise ValueError for invalid backend with pytest.raises(ValueError) as exc_info: - get_attn_backend(32, torch.float16, None, 16, False) + get_attn_backend(32, torch.float16, None, 16) assert "Invalid value 'INVALID'" in str(exc_info.value) diff --git a/vllm/attention/backends/placeholder_attn.py b/vllm/attention/backends/placeholder_attn.py deleted file mode 100644 index cddeb2cf39bf..000000000000 --- a/vllm/attention/backends/placeholder_attn.py +++ /dev/null @@ -1,314 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from dataclasses import dataclass -from itertools import accumulate -from typing import List, Optional, Tuple, Type - -import torch - -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, - AttentionMetadataBuilder) -from vllm.attention.backends.utils import CommonAttentionState -from vllm.utils import async_tensor_h2d - -# Placeholder attention backend for models like Mamba and pooling models that -# lack attention. - - -class PlaceholderAttentionBackend(AttentionBackend): - """Placeholder backend for when no attention is needed.""" - - @staticmethod - def get_name() -> str: - return "NO_ATTENTION" - - @staticmethod - def get_impl_cls() -> Type["PlaceholderAttentionImpl"]: - return PlaceholderAttentionImpl - - @staticmethod - def get_builder_cls() -> Type["PlaceholderAttentionMetadataBuilder"]: - return PlaceholderAttentionMetadataBuilder - - @staticmethod - def get_metadata_cls() -> Type["PlaceholderAttentionMetadata"]: - return PlaceholderAttentionMetadata - - @staticmethod - def get_state_cls() -> Type["CommonAttentionState"]: - return CommonAttentionState - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - ) -> Tuple[int, ...]: - return (1, 1, 1, 1, 1) - - @staticmethod - def swap_blocks( - src_kv_cache: torch.Tensor, - dst_kv_cache: torch.Tensor, - src_to_dst: torch.Tensor, - ) -> None: - return - - @staticmethod - def copy_blocks( - kv_caches: List[torch.Tensor], - src_to_dists: torch.Tensor, - ) -> None: - return - - -@dataclass -class PlaceholderAttentionMetadata(AttentionMetadata): - """Attention metadata for prefill and decode batched together.""" - # (batch_size,). The sequence length per sequence. Sequence length means - # the computed tokens + new tokens None if it is a decoding. - seq_lens: Optional[List[int]] - # seq_lens stored as a tensor. - seq_lens_tensor: Optional[torch.Tensor] - - # Maximum sequence length among prefill batch. 0 if there are decoding - # requests only. - max_prefill_seq_len: int - # Maximum sequence length among decode batch. 0 if there are prefill - # requests only. - max_decode_seq_len: int - # (batch_size,) A tensor of context lengths (tokens that are computed - # so far). - context_lens_tensor: Optional[torch.Tensor] - - # Whether or not if cuda graph is enabled. - # Cuda-graph is currently enabled for decoding only. - # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. - use_cuda_graph: bool - - # Maximum query length in the batch. - max_query_len: Optional[int] - - # Max number of query tokens among request in the batch. - max_decode_query_len: Optional[int] - - # (batch_size + 1,). The cumulative subquery lengths of the sequences in - # the batch, used to index into subquery. E.g., if the subquery length - # is [4, 6], it is [0, 4, 10]. - query_start_loc: Optional[torch.Tensor] = None - # (batch_size + 1,). The cumulative sequence lengths of the sequences in - # the batch, used to index into sequence. E.g., if the sequence length is - # [4, 6], it is [0, 4, 10]. - seq_start_loc: Optional[torch.Tensor] = None - - # Placeholder. - block_tables: Optional[torch.Tensor] = None - - _cached_prefill_metadata: Optional["PlaceholderAttentionMetadata"] = None - _cached_decode_metadata: Optional["PlaceholderAttentionMetadata"] = None - - @property - def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: - if self.num_prefills == 0: - return None - - if self._cached_prefill_metadata is not None: - return self._cached_prefill_metadata - - # Compute some attn_metadata fields which default to None - query_start_loc = (None if self.query_start_loc is None else - self.query_start_loc[:self.num_prefills + 1]) - seq_lens = (None if self.seq_lens is None else - self.seq_lens[:self.num_prefills]) - seq_lens_tensor = (None if self.seq_lens_tensor is None else - self.seq_lens_tensor[:self.num_prefills]) - seq_start_loc = (None if self.seq_start_loc is None else - self.seq_start_loc[:self.num_prefills + 1]) - context_lens_tensor = (None if self.context_lens_tensor is None else - self.context_lens_tensor[:self.num_prefills]) - - # Placeholders - slot_mapping = torch.empty(0) - block_tables = torch.empty(0) - - self._cached_prefill_metadata = PlaceholderAttentionMetadata( - num_prefills=self.num_prefills, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=0, - slot_mapping=slot_mapping, - enable_kv_scales_calculation=self.enable_kv_scales_calculation, - seq_lens=seq_lens, - seq_lens_tensor=seq_lens_tensor, - max_decode_query_len=0, - max_query_len=self.max_query_len, - max_prefill_seq_len=self.max_prefill_seq_len, - max_decode_seq_len=0, - query_start_loc=query_start_loc, - seq_start_loc=seq_start_loc, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=False, - ) - return self._cached_prefill_metadata - - @property - def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: - if self.num_decode_tokens == 0: - return None - - if self._cached_decode_metadata is not None: - return self._cached_decode_metadata - assert self.seq_lens_tensor is not None - - # Placeholders - slot_mapping = torch.empty(0) - block_tables = torch.empty(0) - seq_lens_tensor = (None if self.seq_lens_tensor is None else - self.seq_lens_tensor[self.num_prefills:]) - - self._cached_decode_metadata = PlaceholderAttentionMetadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=self.num_decode_tokens, - slot_mapping=slot_mapping, - enable_kv_scales_calculation=True, - seq_lens=None, - seq_lens_tensor=seq_lens_tensor, - max_decode_query_len=self.max_decode_query_len, - max_query_len=None, - max_prefill_seq_len=0, - max_decode_seq_len=self.max_decode_seq_len, - query_start_loc=(self.query_start_loc[self.num_prefills:] - - self.query_start_loc[self.num_prefills]) - if self.query_start_loc is not None else None, - seq_start_loc=self.seq_start_loc[self.num_prefills:] - if self.seq_start_loc is not None else None, - context_lens_tensor=None, - block_tables=block_tables, - use_cuda_graph=self.use_cuda_graph, - ) - return self._cached_decode_metadata - - -class PlaceholderAttentionMetadataBuilder( - AttentionMetadataBuilder[PlaceholderAttentionMetadata]): - - def __init__(self, input_builder): - - self.input_builder = input_builder - self.runner = input_builder.runner - - def prepare(self): - self.prefill_seq_lens: List[int] = [] - self.context_lens: List[int] = [] - self.curr_seq_lens: List[int] = [] - self.num_prefills = 0 - self.num_prefill_tokens = 0 - self.num_decode_tokens = 0 - - def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool): - """Add a sequence group to the metadata. Specifically update/append - 1. context length. - """ - is_prompt = inter_data.is_prompt - - for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, - curr_sliding_window_block) in zip( - inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], - inter_data.orig_seq_lens, inter_data.seq_lens, - inter_data.query_lens, inter_data.context_lens, - inter_data.curr_sliding_window_blocks): - self.context_lens.append(context_len) - - if is_prompt: - self.num_prefills += 1 - self.num_prefill_tokens += token_len - self.prefill_seq_lens.append(seq_len) - else: - self.num_decode_tokens += query_len - self.curr_seq_lens.append(curr_seq_len) - - def build(self, seq_lens: List[int], query_lens: List[int], - cuda_graph_pad_size: int, batch_size: int): - """Build attention metadata with on-device tensors. - - Args: - seq_lens: The maybe padded sequence lengths of the input sequences. - query_lens: The query lengths of the input sequences. - cuda_graph_pad_size: The padding size for cuda graph. - -1 if cuda graph is not used. - batch_size: The maybe padded batch size. - """ - - # Some input builders such as ModelInputForCPUBuilder do not have the - # "inter_data_list" attribute. - # Let's check inter_data_list exists before we reference it. - if hasattr(self.input_builder, "inter_data_list"): - for inter_data in self.input_builder.inter_data_list: - self._add_seq_group(inter_data, - self.input_builder.chunked_prefill_enabled) - - device = self.runner.device - use_captured_graph = cuda_graph_pad_size != -1 - - max_query_len = max(query_lens) - decode_query_lens = query_lens[self.num_prefills:] - if len(decode_query_lens) > 0: - max_decode_query_len = max(decode_query_lens) - else: - max_decode_query_len = 1 - max_prefill_seq_len = max(self.prefill_seq_lens, default=0) - max_decode_seq_len = max(self.curr_seq_lens, default=0) - num_decode_tokens = self.num_decode_tokens - query_start_loc = list(accumulate(query_lens, initial=0)) - seq_start_loc = list(accumulate(seq_lens, initial=0)) - - if use_captured_graph: - num_decode_tokens = batch_size - self.num_prefill_tokens - assert max_query_len > 0, ("query_lens: {}".format(query_lens)) - - assert device is not None - context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, - device, self.runner.pin_memory) - seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, - self.runner.pin_memory) - query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, - device, - self.runner.pin_memory) - seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, - device, self.runner.pin_memory) - - # Placeholders - slot_mapping_tensor = torch.empty(0) - block_tables = torch.empty(0) - - return PlaceholderAttentionMetadata( - num_prefills=self.num_prefills, - slot_mapping=slot_mapping_tensor, - enable_kv_scales_calculation=True, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - seq_lens=seq_lens, - seq_lens_tensor=seq_lens_tensor, - max_query_len=max_query_len, - max_decode_query_len=max_decode_query_len, - max_prefill_seq_len=max_prefill_seq_len, - max_decode_seq_len=max_decode_seq_len, - query_start_loc=query_start_loc_tensor, - seq_start_loc=seq_start_loc_tensor, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=use_captured_graph, - ) - - -class PlaceholderAttentionImpl(AttentionImpl): - - def __init__(self, *args, **kwargs) -> None: - return - - def forward(self, *args, **kwargs) -> torch.Tensor: - raise NotImplementedError diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 544a72052442..0ed20b3b7151 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -115,12 +115,10 @@ class Attention(nn.Module, AttentionLayerBase): if cache_config is not None: kv_cache_dtype = cache_config.cache_dtype block_size = cache_config.block_size - is_attention_free = cache_config.is_attention_free calculate_kv_scales = cache_config.calculate_kv_scales else: kv_cache_dtype = "auto" block_size = 16 - is_attention_free = False calculate_kv_scales = False if num_kv_heads is None: num_kv_heads = num_heads @@ -185,7 +183,6 @@ class Attention(nn.Module, AttentionLayerBase): dtype, kv_cache_dtype, block_size, - is_attention_free, use_mla=use_mla, has_sink=self.has_sink) else: diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 3a235ba6e0b4..b651fc3eaee3 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -142,7 +142,6 @@ def get_attn_backend( dtype: torch.dtype, kv_cache_dtype: Optional[str], block_size: int, - is_attention_free: bool = False, use_mla: bool = False, has_sink: bool = False, ) -> type[AttentionBackend]: @@ -156,7 +155,6 @@ def get_attn_backend( dtype=dtype, kv_cache_dtype=kv_cache_dtype, block_size=block_size, - is_attention_free=is_attention_free, use_v1=envs.VLLM_USE_V1, use_mla=use_mla, has_sink=has_sink, @@ -169,17 +167,10 @@ def _cached_get_attn_backend( dtype: torch.dtype, kv_cache_dtype: Optional[str], block_size: int, - is_attention_free: bool, use_v1: bool = False, use_mla: bool = False, has_sink: bool = False, ) -> type[AttentionBackend]: - # If there are no attention layers (e.g. we are running Mamba), - # use the placeholder NO_ATTENTION - if is_attention_free: - from vllm.attention.backends.placeholder_attn import ( - PlaceholderAttentionBackend) - return PlaceholderAttentionBackend # Check whether a particular choice of backend was # previously forced. diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 64feddb591c2..528d4022bd17 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -574,7 +574,6 @@ class NixlConnectorWorker: self.model_config.dtype, self.cache_config.cache_dtype, self.block_size, - self.model_config.is_attention_free, use_mla=self.use_mla) self.backend_name = backend.get_name() attn_backend = backend_name_to_enum(self.backend_name)