[V0 Deprecation] Remove placeholder attn (#25510)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
Thomas Parnell 2025-09-24 00:12:14 +02:00 committed by GitHub
parent 4f8c4b890a
commit 969b4da3a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 10 additions and 354 deletions

View File

@ -85,8 +85,7 @@ def test_env(
if device == "cpu":
with patch("vllm.attention.selector.current_platform",
CpuPlatform()):
backend = get_attn_backend(16, torch.float16, None, block_size,
False)
backend = get_attn_backend(16, torch.float16, None, block_size)
assert backend.get_name() == "TORCH_SDPA_VLLM_V1"
elif device == "hip":
@ -106,7 +105,6 @@ def test_env(
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
assert f"The selected backend, {name}" in str(
exc_info.value)
@ -117,7 +115,6 @@ def test_env(
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
assert f"The selected backend, {name}" in str(
exc_info.value)
@ -127,7 +124,6 @@ def test_env(
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
expected = f"{name}_VLLM_V1"
assert backend.get_name() == expected
@ -136,7 +132,6 @@ def test_env(
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
expected = "TRITON_ATTN_VLLM_V1"
assert backend.get_name() == expected
@ -164,7 +159,6 @@ def test_env(
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
expected = "CUTLASS_MLA_VLLM_V1"
assert backend.get_name() == expected
@ -179,7 +173,6 @@ def test_env(
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
expected = "FLASHINFER_MLA"
assert backend.get_name() == expected
@ -199,7 +192,6 @@ def test_env(
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
expected = f"{name}_VLLM_V1"
assert backend.get_name() == expected
@ -208,7 +200,6 @@ def test_env(
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
expected = "FLASH_ATTN_MLA"
assert backend.get_name() == expected
@ -218,7 +209,6 @@ def test_env(
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
expected = "TRITON_MLA_VLLM_V1"
assert backend.get_name() == expected
@ -227,7 +217,6 @@ def test_env(
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
expected = "FLASHINFER_VLLM_V1"
assert backend.get_name() == expected
@ -236,7 +225,6 @@ def test_env(
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
expected = "FLASH_ATTN_VLLM_V1"
assert backend.get_name() == expected
@ -245,7 +233,6 @@ def test_env(
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
assert backend.get_name() == "FLEX_ATTENTION", (
"Should fallback to FlexAttention if head size is "
@ -264,13 +251,13 @@ def test_fp32_fallback(
if device == "cpu":
with patch("vllm.attention.selector.current_platform",
CpuPlatform()):
backend = get_attn_backend(16, torch.float32, None, 16, False)
backend = get_attn_backend(16, torch.float32, None, 16)
assert backend.get_name() == "TORCH_SDPA_VLLM_V1"
elif device == "cuda":
with patch("vllm.attention.selector.current_platform",
CudaPlatform()):
backend = get_attn_backend(16, torch.float32, None, 16, False)
backend = get_attn_backend(16, torch.float32, None, 16)
assert backend.get_name() == "FLEX_ATTENTION"
@ -286,29 +273,29 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(torch.cuda,
"get_device_capability",
lambda _=None: (7, 5))
backend = get_attn_backend(16, torch.float16, None, 16, False)
backend = get_attn_backend(16, torch.float16, None, 16)
assert backend.get_name() != STR_FLASH_ATTN_VAL
# Reset the monkeypatch for subsequent tests
monkeypatch.undo()
# Unsupported data type
backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16, False)
backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16)
assert backend.get_name() != STR_FLASH_ATTN_VAL
# Unsupported kv cache data type
backend = get_attn_backend(16, torch.float16, "fp8", 16, False)
backend = get_attn_backend(16, torch.float16, "fp8", 16)
assert backend.get_name() != STR_FLASH_ATTN_VAL
# Unsupported block size
backend = get_attn_backend(16, torch.float16, None, 8, False)
backend = get_attn_backend(16, torch.float16, None, 8)
assert backend.get_name() != STR_FLASH_ATTN_VAL
# flash-attn is not installed
import sys
original_module = sys.modules.get('vllm_flash_attn')
monkeypatch.setitem(sys.modules, 'vllm_flash_attn', None)
backend = get_attn_backend(16, torch.float16, None, 16, False)
backend = get_attn_backend(16, torch.float16, None, 16)
assert backend.get_name() != STR_FLASH_ATTN_VAL
# Restore the original module if it existed
@ -319,11 +306,7 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
monkeypatch.delitem(sys.modules, 'vllm_flash_attn', raising=False)
# Unsupported head size
backend = get_attn_backend(17, torch.float16, None, 16, False)
assert backend.get_name() != STR_FLASH_ATTN_VAL
# Attention-free models should bypass env and use PlaceholderAttention
backend = get_attn_backend(16, torch.float16, None, 16, True)
backend = get_attn_backend(17, torch.float16, None, 16)
assert backend.get_name() != STR_FLASH_ATTN_VAL
@ -336,5 +319,5 @@ def test_invalid_env(monkeypatch: pytest.MonkeyPatch):
# Should raise ValueError for invalid backend
with pytest.raises(ValueError) as exc_info:
get_attn_backend(32, torch.float16, None, 16, False)
get_attn_backend(32, torch.float16, None, 16)
assert "Invalid value 'INVALID'" in str(exc_info.value)

View File

@ -1,314 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from itertools import accumulate
from typing import List, Optional, Tuple, Type
import torch
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata,
AttentionMetadataBuilder)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.utils import async_tensor_h2d
# Placeholder attention backend for models like Mamba and pooling models that
# lack attention.
class PlaceholderAttentionBackend(AttentionBackend):
"""Placeholder backend for when no attention is needed."""
@staticmethod
def get_name() -> str:
return "NO_ATTENTION"
@staticmethod
def get_impl_cls() -> Type["PlaceholderAttentionImpl"]:
return PlaceholderAttentionImpl
@staticmethod
def get_builder_cls() -> Type["PlaceholderAttentionMetadataBuilder"]:
return PlaceholderAttentionMetadataBuilder
@staticmethod
def get_metadata_cls() -> Type["PlaceholderAttentionMetadata"]:
return PlaceholderAttentionMetadata
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (1, 1, 1, 1, 1)
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
return
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
return
@dataclass
class PlaceholderAttentionMetadata(AttentionMetadata):
"""Attention metadata for prefill and decode batched together."""
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]]
# seq_lens stored as a tensor.
seq_lens_tensor: Optional[torch.Tensor]
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
max_prefill_seq_len: int
# Maximum sequence length among decode batch. 0 if there are prefill
# requests only.
max_decode_seq_len: int
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor: Optional[torch.Tensor]
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool
# Maximum query length in the batch.
max_query_len: Optional[int]
# Max number of query tokens among request in the batch.
max_decode_query_len: Optional[int]
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
query_start_loc: Optional[torch.Tensor] = None
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor] = None
# Placeholder.
block_tables: Optional[torch.Tensor] = None
_cached_prefill_metadata: Optional["PlaceholderAttentionMetadata"] = None
_cached_decode_metadata: Optional["PlaceholderAttentionMetadata"] = None
@property
def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
if self.num_prefills == 0:
return None
if self._cached_prefill_metadata is not None:
return self._cached_prefill_metadata
# Compute some attn_metadata fields which default to None
query_start_loc = (None if self.query_start_loc is None else
self.query_start_loc[:self.num_prefills + 1])
seq_lens = (None if self.seq_lens is None else
self.seq_lens[:self.num_prefills])
seq_lens_tensor = (None if self.seq_lens_tensor is None else
self.seq_lens_tensor[:self.num_prefills])
seq_start_loc = (None if self.seq_start_loc is None else
self.seq_start_loc[:self.num_prefills + 1])
context_lens_tensor = (None if self.context_lens_tensor is None else
self.context_lens_tensor[:self.num_prefills])
# Placeholders
slot_mapping = torch.empty(0)
block_tables = torch.empty(0)
self._cached_prefill_metadata = PlaceholderAttentionMetadata(
num_prefills=self.num_prefills,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=0,
slot_mapping=slot_mapping,
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_decode_query_len=0,
max_query_len=self.max_query_len,
max_prefill_seq_len=self.max_prefill_seq_len,
max_decode_seq_len=0,
query_start_loc=query_start_loc,
seq_start_loc=seq_start_loc,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=False,
)
return self._cached_prefill_metadata
@property
def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
if self.num_decode_tokens == 0:
return None
if self._cached_decode_metadata is not None:
return self._cached_decode_metadata
assert self.seq_lens_tensor is not None
# Placeholders
slot_mapping = torch.empty(0)
block_tables = torch.empty(0)
seq_lens_tensor = (None if self.seq_lens_tensor is None else
self.seq_lens_tensor[self.num_prefills:])
self._cached_decode_metadata = PlaceholderAttentionMetadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=self.num_decode_tokens,
slot_mapping=slot_mapping,
enable_kv_scales_calculation=True,
seq_lens=None,
seq_lens_tensor=seq_lens_tensor,
max_decode_query_len=self.max_decode_query_len,
max_query_len=None,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len,
query_start_loc=(self.query_start_loc[self.num_prefills:] -
self.query_start_loc[self.num_prefills])
if self.query_start_loc is not None else None,
seq_start_loc=self.seq_start_loc[self.num_prefills:]
if self.seq_start_loc is not None else None,
context_lens_tensor=None,
block_tables=block_tables,
use_cuda_graph=self.use_cuda_graph,
)
return self._cached_decode_metadata
class PlaceholderAttentionMetadataBuilder(
AttentionMetadataBuilder[PlaceholderAttentionMetadata]):
def __init__(self, input_builder):
self.input_builder = input_builder
self.runner = input_builder.runner
def prepare(self):
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
self.curr_seq_lens: List[int] = []
self.num_prefills = 0
self.num_prefill_tokens = 0
self.num_decode_tokens = 0
def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool):
"""Add a sequence group to the metadata. Specifically update/append
1. context length.
"""
is_prompt = inter_data.is_prompt
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
curr_sliding_window_block) in zip(
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
inter_data.orig_seq_lens, inter_data.seq_lens,
inter_data.query_lens, inter_data.context_lens,
inter_data.curr_sliding_window_blocks):
self.context_lens.append(context_len)
if is_prompt:
self.num_prefills += 1
self.num_prefill_tokens += token_len
self.prefill_seq_lens.append(seq_len)
else:
self.num_decode_tokens += query_len
self.curr_seq_lens.append(curr_seq_len)
def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int):
"""Build attention metadata with on-device tensors.
Args:
seq_lens: The maybe padded sequence lengths of the input sequences.
query_lens: The query lengths of the input sequences.
cuda_graph_pad_size: The padding size for cuda graph.
-1 if cuda graph is not used.
batch_size: The maybe padded batch size.
"""
# Some input builders such as ModelInputForCPUBuilder do not have the
# "inter_data_list" attribute.
# Let's check inter_data_list exists before we reference it.
if hasattr(self.input_builder, "inter_data_list"):
for inter_data in self.input_builder.inter_data_list:
self._add_seq_group(inter_data,
self.input_builder.chunked_prefill_enabled)
device = self.runner.device
use_captured_graph = cuda_graph_pad_size != -1
max_query_len = max(query_lens)
decode_query_lens = query_lens[self.num_prefills:]
if len(decode_query_lens) > 0:
max_decode_query_len = max(decode_query_lens)
else:
max_decode_query_len = 1
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
max_decode_seq_len = max(self.curr_seq_lens, default=0)
num_decode_tokens = self.num_decode_tokens
query_start_loc = list(accumulate(query_lens, initial=0))
seq_start_loc = list(accumulate(seq_lens, initial=0))
if use_captured_graph:
num_decode_tokens = batch_size - self.num_prefill_tokens
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
assert device is not None
context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
device, self.runner.pin_memory)
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
self.runner.pin_memory)
query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
device,
self.runner.pin_memory)
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
device, self.runner.pin_memory)
# Placeholders
slot_mapping_tensor = torch.empty(0)
block_tables = torch.empty(0)
return PlaceholderAttentionMetadata(
num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor,
enable_kv_scales_calculation=True,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
max_decode_query_len=max_decode_query_len,
max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len,
query_start_loc=query_start_loc_tensor,
seq_start_loc=seq_start_loc_tensor,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=use_captured_graph,
)
class PlaceholderAttentionImpl(AttentionImpl):
def __init__(self, *args, **kwargs) -> None:
return
def forward(self, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError

View File

@ -115,12 +115,10 @@ class Attention(nn.Module, AttentionLayerBase):
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
is_attention_free = cache_config.is_attention_free
calculate_kv_scales = cache_config.calculate_kv_scales
else:
kv_cache_dtype = "auto"
block_size = 16
is_attention_free = False
calculate_kv_scales = False
if num_kv_heads is None:
num_kv_heads = num_heads
@ -185,7 +183,6 @@ class Attention(nn.Module, AttentionLayerBase):
dtype,
kv_cache_dtype,
block_size,
is_attention_free,
use_mla=use_mla,
has_sink=self.has_sink)
else:

View File

@ -142,7 +142,6 @@ def get_attn_backend(
dtype: torch.dtype,
kv_cache_dtype: Optional[str],
block_size: int,
is_attention_free: bool = False,
use_mla: bool = False,
has_sink: bool = False,
) -> type[AttentionBackend]:
@ -156,7 +155,6 @@ def get_attn_backend(
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
block_size=block_size,
is_attention_free=is_attention_free,
use_v1=envs.VLLM_USE_V1,
use_mla=use_mla,
has_sink=has_sink,
@ -169,17 +167,10 @@ def _cached_get_attn_backend(
dtype: torch.dtype,
kv_cache_dtype: Optional[str],
block_size: int,
is_attention_free: bool,
use_v1: bool = False,
use_mla: bool = False,
has_sink: bool = False,
) -> type[AttentionBackend]:
# If there are no attention layers (e.g. we are running Mamba),
# use the placeholder NO_ATTENTION
if is_attention_free:
from vllm.attention.backends.placeholder_attn import (
PlaceholderAttentionBackend)
return PlaceholderAttentionBackend
# Check whether a particular choice of backend was
# previously forced.

View File

@ -574,7 +574,6 @@ class NixlConnectorWorker:
self.model_config.dtype,
self.cache_config.cache_dtype,
self.block_size,
self.model_config.is_attention_free,
use_mla=self.use_mla)
self.backend_name = backend.get_name()
attn_backend = backend_name_to_enum(self.backend_name)