Refactor code, move static sink logics to builder

Signed-off-by: yuantao <2422264527@qq.com>
This commit is contained in:
yuantao 2025-12-22 22:48:14 +08:00
parent 4ad3f75875
commit cf58a62099
7 changed files with 48 additions and 209 deletions

View File

@ -17,6 +17,8 @@ from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig, VllmConfig
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
@ -48,9 +50,23 @@ def create_static_sink_attention_backend(
device: torch.device,
):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
model_config = vllm_config.model_config
scheduler_config = vllm_config.scheduler_config
self.sink_len = sink_len
self.block_size = vllm_config.cache_config.block_size
self.num_sink_blocks = self.sink_len // vllm_config.cache_config.block_size
self.sink_block_table = torch.arange(
self.max_num_blocks = cdiv(
model_config.max_model_len, vllm_config.cache_config.block_size
)
self.block_table_with_sink = torch.zeros(
(
scheduler_config.max_num_seqs,
self.max_num_blocks + self.num_sink_blocks,
),
device=device,
dtype=torch.int32,
)
self.block_table_with_sink[:, : self.num_sink_blocks] = torch.arange(
1,
self.num_sink_blocks + 1,
device=device,
@ -72,6 +88,14 @@ def create_static_sink_attention_backend(
common_attn_metadata.max_seq_len = (
common_attn_metadata.max_seq_len + self.sink_len
)
max_num_blocks = cdiv(common_attn_metadata.max_seq_len, self.block_size)
num_reqs = common_attn_metadata.num_reqs
self.block_table_with_sink[
:num_reqs, self.num_sink_blocks : self.num_sink_blocks + max_num_blocks
] = common_attn_metadata.block_table_tensor[:, :max_num_blocks]
common_attn_metadata.block_table_tensor = self.block_table_with_sink[
:num_reqs
]
return super().build(common_prefix_len, common_attn_metadata, fast_build)
@ -84,7 +108,8 @@ def create_static_sink_attention_backend(
return attn_backend
class StaticSinkAttention(Attention):
@CustomOp.register("static_sink_attention")
class StaticSinkAttention(Attention, CustomOp):
"""
Attention with static sink tokens
"""
@ -118,7 +143,8 @@ class StaticSinkAttention(Attention):
underlying_attn_backend,
sink_len=sink_len,
)
super().__init__(
Attention.__init__(
self=self,
num_heads=num_heads,
head_size=head_size,
scale=scale,
@ -126,6 +152,7 @@ class StaticSinkAttention(Attention):
attn_backend=attn_backend,
**kwargs,
)
CustomOp.__init__(self)
self.sink_len = sink_len
self.block_size = block_size
@ -137,7 +164,7 @@ class StaticSinkAttention(Attention):
self.sink_key = sink_key
self.sink_value = sink_value
def forward(
def forward_native(
self,
query: torch.Tensor,
key: torch.Tensor,
@ -154,6 +181,18 @@ class StaticSinkAttention(Attention):
return super().forward(query, key, value, output_shape)
def forward_cuda(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output_shape: torch.Size | None = None,
) -> torch.Tensor:
return self.forward_native(query, key, value, output_shape)
def forward(self, *args, **kwargs):
return self._forward_method(*args, **kwargs)
def populate_sink_kv(self, self_kv_cache):
sink_kv_slot_mapping = torch.arange(
self.block_size,

View File

@ -801,147 +801,6 @@ class SinkFullAttentionManager(FullAttentionManager):
num_sink_block = sink_len // self.block_size
self.sink_blocks = self.block_pool.free_block_queue.popleft_n(num_sink_block)
def get_num_blocks_to_allocate(
self,
request_id: str,
num_tokens: int,
new_computed_blocks: Sequence[KVCacheBlock],
) -> int:
"""
Get the number of blocks needed to be allocated for the request.
Args:
request_id: The request ID.
num_tokens: The total number of tokens that need a slot (including
tokens that are already allocated).
new_computed_blocks: The new computed blocks just hitting the
prefix caching.
Returns:
The number of blocks.
"""
num_required_blocks = cdiv(num_tokens, self.block_size)
num_new_blocks = (
num_required_blocks
- len(new_computed_blocks)
- len(self.req_to_blocks[request_id])
)
# Number of sink blocks is calculated into num_new_blocks
if len(self.req_to_blocks[request_id]) > 0:
num_new_blocks = num_new_blocks + len(self.sink_blocks)
# If a computed block of a request is an eviction candidate (in the
# free queue and ref_cnt == 0), it will be changed from a free block
# to a computed block when the request is allocated, so we also count
# it as needed to be allocated.
num_evictable_computed_blocks = sum(
blk.ref_cnt == 0 and not blk.is_null for blk in new_computed_blocks
)
return num_new_blocks + num_evictable_computed_blocks
def save_new_computed_blocks(
self, request_id: str, new_computed_blocks: Sequence[KVCacheBlock]
) -> None:
"""
Add the new computed blocks to the request.
Args:
request_id: The request ID.
new_computed_blocks: The new computed blocks just hitting the
prefix cache.
"""
if request_id not in self.num_cached_block:
# A new request.
req_blocks = self.req_to_blocks[request_id]
assert len(req_blocks) == 0
# Append both sink blocks and hitted prefix cache blocks
req_blocks.extend(self.sink_blocks)
req_blocks.extend(new_computed_blocks)
self.num_cached_block[request_id] = len(new_computed_blocks)
else:
# A running request. Should not have new computed blocks.
assert len(new_computed_blocks) == 0
def allocate_new_blocks(
self, request_id: str, num_tokens: int
) -> list[KVCacheBlock]:
"""
Allocate new blocks for the request to give it at least `num_tokens`
token slots.
Args:
request_id: The request ID.
num_tokens: The total number of tokens that need a slot (including
tokens that are already allocated).
Returns:
The new allocated blocks.
"""
req_blocks = self.req_to_blocks[request_id]
num_required_blocks = cdiv(num_tokens, self.block_size)
num_new_blocks = num_required_blocks - len(req_blocks)
# For existing requests, number of sink blocks is calculated into
# num_new_blocks
if len(req_blocks) > 0:
num_new_blocks = num_new_blocks + len(self.sink_blocks)
if num_new_blocks <= 0:
return []
else:
new_blocks = self.block_pool.get_new_blocks(num_new_blocks)
# For new requests, allocate sink blocks
if len(req_blocks) == 0:
req_blocks.extend(self.sink_blocks + new_blocks)
else:
req_blocks.extend(new_blocks)
return new_blocks
def cache_blocks(self, request: Request, num_tokens: int) -> None:
"""
Cache the blocks for the request.
Args:
request: The request.
num_tokens: The total number of tokens that need to be cached
(including tokens that are already cached).
"""
num_cached_blocks = self.num_cached_block.get(request.request_id, 0)
num_full_blocks = num_tokens // self.block_size
if num_cached_blocks >= num_full_blocks:
return
self.block_pool.cache_full_blocks(
request=request,
# Do not cache sink blocks
blocks=self.req_to_blocks[request.request_id][len(self.sink_blocks) :],
num_cached_blocks=num_cached_blocks,
num_full_blocks=num_full_blocks,
block_size=self.block_size,
kv_cache_group_id=self.kv_cache_group_id,
)
self.num_cached_block[request.request_id] = num_full_blocks
def free(self, request_id: str) -> None:
"""
Free the blocks for the request.
Args:
request_id: The request ID.
"""
# Default to [] in case a request is freed (aborted) before alloc.
req_blocks = self.req_to_blocks.pop(request_id, [])
# Do not free sink blocks
if len(req_blocks) > 0:
req_blocks = req_blocks[len(self.sink_blocks) :]
# Free blocks in reverse order so that the tail blocks are
# freed first.
ordered_blocks = reversed(req_blocks)
self.block_pool.free_blocks(ordered_blocks)
self.num_cached_block.pop(request_id, None)
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
FullAttentionSpec: FullAttentionManager,

View File

@ -23,7 +23,6 @@ class BlockTable:
device: torch.device,
kernel_block_size: int,
cp_kv_cache_interleave_size: int,
sink_len: int = 0,
):
"""
Args:
@ -64,8 +63,6 @@ class BlockTable:
self.use_hybrid_blocks = True
self.max_num_blocks_per_req = max_num_blocks_per_req * self.blocks_per_kv_block
self.sink_block_len = sink_len // self.block_size
self.max_num_blocks_per_req = self.max_num_blocks_per_req + self.sink_block_len
self.block_table = self._make_buffer(
self.max_num_reqs, self.max_num_blocks_per_req, dtype=torch.int32
@ -154,7 +151,7 @@ class BlockTable:
block_table_indices = (
req_indices * self.max_num_blocks_per_req
+ positions // virtual_block_size
) + self.sink_block_len
)
block_numbers = self.block_table.np.ravel()[block_table_indices]
# Use virtual_block_size for mask calculation, which marks local
@ -180,10 +177,9 @@ class BlockTable:
mask, slot_mapping, -1
)
else:
# When self.sink_block_len > 0, we need to shift block table indices
block_table_indices = (
req_indices * self.max_num_blocks_per_req + positions // self.block_size
) + self.sink_block_len
)
block_numbers = self.block_table.np.ravel()[block_table_indices]
block_offsets = positions % self.block_size
@ -267,7 +263,6 @@ class MultiGroupBlockTable:
kernel_block_sizes: list[int],
num_speculative_tokens: int = 0,
cp_kv_cache_interleave_size: int = 1,
sink_len: int = 0,
) -> None:
# Note(hc): each dcp rank only store
# (max_model_len//dcp_world_size) tokens in kvcache,
@ -305,7 +300,6 @@ class MultiGroupBlockTable:
device,
kernel_block_size,
cp_kv_cache_interleave_size,
sink_len=sink_len,
)
for block_size, kernel_block_size in zip(block_sizes, kernel_block_sizes)
]

View File

@ -101,28 +101,16 @@ def _reshape_kv_cache(
num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes
attn_backend = attn_backends[layer_name]
if (
hasattr(kv_cache_spec, "head_size_v")
and kv_cache_spec.head_size_v != kv_cache_spec.head_size
):
kwargs = {"head_size_v": kv_cache_spec.head_size_v}
stride_kwargs = {"diff_kv": True}
else:
kwargs = {}
stride_kwargs = {}
kv_cache_shape = attn_backend.get_kv_cache_shape(
num_blocks,
kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size,
**kwargs,
)
# FIXME(woosuk): Add kv_cache_stride_order to all attention backends.
try:
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
**stride_kwargs
)
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
assert len(kv_cache_stride_order) == len(kv_cache_shape)
except (AttributeError, NotImplementedError):
kv_cache_stride_order = tuple(range(len(kv_cache_shape)))

View File

@ -96,7 +96,6 @@ class InputBatch:
is_pooling_model: bool = False,
num_speculative_tokens: int = 0,
cp_kv_cache_interleave_size: int = 1,
sink_len: int = 0,
):
self.is_pooling_model = is_pooling_model
self.is_spec_decode = is_spec_decode
@ -151,7 +150,6 @@ class InputBatch:
kernel_block_sizes=kernel_block_sizes,
num_speculative_tokens=num_speculative_tokens,
cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
sink_len=sink_len,
)
# Sampling-related.

View File

@ -332,10 +332,6 @@ class GPUModelRunner(
self.num_query_heads = model_config.get_num_attention_heads(parallel_config)
self.inputs_embeds_size = model_config.get_inputs_embeds_size()
self.attention_chunk_size = model_config.attention_chunk_size
self.sink_len = getattr(
self.vllm_config.model_config.hf_config, "param_sink_number", 0
)
assert self.sink_len % self.cache_config.block_size == 0
# Only relevant for models using ALiBi (e.g, MPT)
self.use_alibi = model_config.uses_alibi
@ -459,7 +455,6 @@ class GPUModelRunner(
logitsprocs_need_output_token_ids=bool(custom_logitsprocs),
is_pooling_model=self.is_pooling_model,
cp_kv_cache_interleave_size=self.parallel_config.cp_kv_cache_interleave_size,
sink_len=self.sink_len,
)
self.use_async_scheduling = self.scheduler_config.async_scheduling
@ -5079,7 +5074,7 @@ class GPUModelRunner(
logitsprocs_need_output_token_ids=self.input_batch.logitsprocs_need_output_token_ids,
is_pooling_model=self.is_pooling_model,
num_speculative_tokens=self.num_spec_tokens,
sink_len=self.sink_len,
# sink_len=self.sink_len,
)
def _allocate_kv_cache_tensors(
@ -5206,28 +5201,16 @@ class GPUModelRunner(
)
kernel_num_blocks = num_blocks * num_blocks_per_kv_block
if (
hasattr(kv_cache_spec, "head_size_v")
and kv_cache_spec.head_size_v != kv_cache_spec.head_size
):
kwargs = {"head_size_v": kv_cache_spec.head_size_v}
stride_kwargs = {"diff_kv": True}
else:
kwargs = {}
stride_kwargs = {}
kv_cache_shape = attn_backend.get_kv_cache_shape(
kernel_num_blocks,
kernel_block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size,
cache_dtype_str=self.cache_config.cache_dtype,
**kwargs,
)
dtype = kv_cache_spec.dtype
try:
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
**stride_kwargs
)
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
assert len(kv_cache_stride_order) == len(kv_cache_shape)
except (AttributeError, NotImplementedError):
kv_cache_stride_order = tuple(range(len(kv_cache_shape)))

View File

@ -190,28 +190,17 @@ class KVConnectorModelRunnerMixin:
return False
attn_backend = attn_group.backend
if (
hasattr(kv_cache_spec, "head_size_v")
and kv_cache_spec.head_size_v != kv_cache_spec.head_size
):
kwargs = {"head_size_v": kv_cache_spec.head_size_v}
stride_kwargs = {"diff_kv": True}
else:
kwargs = {}
stride_kwargs = {}
kv_cache_shape = attn_backend.get_kv_cache_shape(
1234,
kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size,
cache_dtype_str=cache_dtype,
**kwargs,
)
try:
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
include_num_layers_dimension=True,
**stride_kwargs,
)
except (AttributeError, NotImplementedError):
return False
@ -268,22 +257,12 @@ class KVConnectorModelRunnerMixin:
kernel_num_blocks = num_blocks * num_blocks_per_kv_block
attn_backend = attn_group.backend
if (
hasattr(kv_cache_spec, "head_size_v")
and kv_cache_spec.head_size_v != kv_cache_spec.head_size
):
kwargs = {"head_size_v": kv_cache_spec.head_size_v}
stride_kwargs = {"diff_kv": True}
else:
kwargs = {}
stride_kwargs = {}
kv_cache_shape = attn_backend.get_kv_cache_shape(
kernel_num_blocks,
kernel_block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size,
cache_dtype_str=cache_dtype,
**kwargs,
)
# prepend a num_layers dimension into the shape
@ -292,7 +271,6 @@ class KVConnectorModelRunnerMixin:
try:
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
include_num_layers_dimension=True,
**stride_kwargs,
)
assert len(kv_cache_stride_order) == len(kv_cache_shape)
except (AttributeError, NotImplementedError):