mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-01 19:27:08 +08:00
Exteng SinkFullAttentionManager to handle sink blocks management, avoid modifying blk_table_tensor during the build of attn_metadata
Signed-off-by: yuantao <2422264527@qq.com>
This commit is contained in:
parent
a7430ab479
commit
93a7afcab3
@ -66,20 +66,13 @@ def create_static_sink_attention_backend(
|
|||||||
common_attn_metadata.seq_lens[:] = (
|
common_attn_metadata.seq_lens[:] = (
|
||||||
common_attn_metadata.seq_lens + self.sink_len
|
common_attn_metadata.seq_lens + self.sink_len
|
||||||
)
|
)
|
||||||
|
common_attn_metadata.seq_lens[
|
||||||
|
common_attn_metadata.seq_lens == self.sink_len
|
||||||
|
] = 0
|
||||||
common_attn_metadata.max_seq_len = (
|
common_attn_metadata.max_seq_len = (
|
||||||
common_attn_metadata.max_seq_len + self.sink_len
|
common_attn_metadata.max_seq_len + self.sink_len
|
||||||
)
|
)
|
||||||
|
|
||||||
blk_table_tensor = common_attn_metadata.block_table_tensor
|
|
||||||
sink_block_table = self.sink_block_table[None, :].expand(
|
|
||||||
blk_table_tensor.shape[0], -1
|
|
||||||
)
|
|
||||||
blk_table_tensor_clone = blk_table_tensor.clone()
|
|
||||||
blk_table_tensor[:, self.num_sink_blocks :] = blk_table_tensor_clone[
|
|
||||||
:, : -self.num_sink_blocks
|
|
||||||
]
|
|
||||||
blk_table_tensor[:, : self.num_sink_blocks] = sink_block_table
|
|
||||||
|
|
||||||
return super().build(common_prefix_len, common_attn_metadata, fast_build)
|
return super().build(common_prefix_len, common_attn_metadata, fast_build)
|
||||||
|
|
||||||
attn_backend = subclass_attention_backend(
|
attn_backend = subclass_attention_backend(
|
||||||
|
|||||||
@ -801,6 +801,146 @@ class SinkFullAttentionManager(FullAttentionManager):
|
|||||||
num_sink_block = sink_len // self.block_size
|
num_sink_block = sink_len // self.block_size
|
||||||
self.sink_blocks = self.block_pool.free_block_queue.popleft_n(num_sink_block)
|
self.sink_blocks = self.block_pool.free_block_queue.popleft_n(num_sink_block)
|
||||||
|
|
||||||
|
def get_num_blocks_to_allocate(
|
||||||
|
self,
|
||||||
|
request_id: str,
|
||||||
|
num_tokens: int,
|
||||||
|
new_computed_blocks: Sequence[KVCacheBlock],
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Get the number of blocks needed to be allocated for the request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request_id: The request ID.
|
||||||
|
num_tokens: The total number of tokens that need a slot (including
|
||||||
|
tokens that are already allocated).
|
||||||
|
new_computed_blocks: The new computed blocks just hitting the
|
||||||
|
prefix caching.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The number of blocks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
num_required_blocks = cdiv(num_tokens, self.block_size)
|
||||||
|
num_new_blocks = (
|
||||||
|
num_required_blocks
|
||||||
|
- len(new_computed_blocks)
|
||||||
|
- len(self.req_to_blocks[request_id])
|
||||||
|
)
|
||||||
|
# Number of sink blocks is calculated into num_new_blocks
|
||||||
|
if len(self.req_to_blocks[request_id]) > 0:
|
||||||
|
num_new_blocks = num_new_blocks + len(self.sink_blocks)
|
||||||
|
# If a computed block of a request is an eviction candidate (in the
|
||||||
|
# free queue and ref_cnt == 0), it will be changed from a free block
|
||||||
|
# to a computed block when the request is allocated, so we also count
|
||||||
|
# it as needed to be allocated.
|
||||||
|
num_evictable_computed_blocks = sum(
|
||||||
|
blk.ref_cnt == 0 and not blk.is_null for blk in new_computed_blocks
|
||||||
|
)
|
||||||
|
return num_new_blocks + num_evictable_computed_blocks
|
||||||
|
|
||||||
|
def save_new_computed_blocks(
|
||||||
|
self, request_id: str, new_computed_blocks: Sequence[KVCacheBlock]
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Add the new computed blocks to the request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request_id: The request ID.
|
||||||
|
new_computed_blocks: The new computed blocks just hitting the
|
||||||
|
prefix cache.
|
||||||
|
"""
|
||||||
|
if request_id not in self.num_cached_block:
|
||||||
|
# A new request.
|
||||||
|
req_blocks = self.req_to_blocks[request_id]
|
||||||
|
assert len(req_blocks) == 0
|
||||||
|
# Append both sink blocks and hitted prefix cache blocks
|
||||||
|
req_blocks.extend(self.sink_blocks + new_computed_blocks)
|
||||||
|
self.num_cached_block[request_id] = len(new_computed_blocks)
|
||||||
|
else:
|
||||||
|
# A running request. Should not have new computed blocks.
|
||||||
|
assert len(new_computed_blocks) == 0
|
||||||
|
|
||||||
|
def allocate_new_blocks(
|
||||||
|
self, request_id: str, num_tokens: int
|
||||||
|
) -> list[KVCacheBlock]:
|
||||||
|
"""
|
||||||
|
Allocate new blocks for the request to give it at least `num_tokens`
|
||||||
|
token slots.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request_id: The request ID.
|
||||||
|
num_tokens: The total number of tokens that need a slot (including
|
||||||
|
tokens that are already allocated).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The new allocated blocks.
|
||||||
|
"""
|
||||||
|
req_blocks = self.req_to_blocks[request_id]
|
||||||
|
num_required_blocks = cdiv(num_tokens, self.block_size)
|
||||||
|
num_new_blocks = num_required_blocks - len(req_blocks)
|
||||||
|
# For existing requests, number of sink blocks is calculated into
|
||||||
|
# num_new_blocks
|
||||||
|
if len(req_blocks) > 0:
|
||||||
|
num_new_blocks = num_new_blocks + len(self.sink_blocks)
|
||||||
|
if num_new_blocks <= 0:
|
||||||
|
return []
|
||||||
|
else:
|
||||||
|
new_blocks = self.block_pool.get_new_blocks(num_new_blocks)
|
||||||
|
# For new requests, allocate sink blocks
|
||||||
|
if len(req_blocks) == 0:
|
||||||
|
req_blocks.extend(self.sink_blocks + new_blocks)
|
||||||
|
else:
|
||||||
|
req_blocks.extend(new_blocks)
|
||||||
|
return new_blocks
|
||||||
|
|
||||||
|
def cache_blocks(self, request: Request, num_tokens: int) -> None:
|
||||||
|
"""
|
||||||
|
Cache the blocks for the request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: The request.
|
||||||
|
num_tokens: The total number of tokens that need to be cached
|
||||||
|
(including tokens that are already cached).
|
||||||
|
"""
|
||||||
|
num_cached_blocks = self.num_cached_block.get(request.request_id, 0)
|
||||||
|
num_full_blocks = num_tokens // self.block_size
|
||||||
|
|
||||||
|
if num_cached_blocks >= num_full_blocks:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.block_pool.cache_full_blocks(
|
||||||
|
request=request,
|
||||||
|
# Do not cache sink blocks
|
||||||
|
blocks=self.req_to_blocks[request.request_id][len(self.sink_blocks) :],
|
||||||
|
num_cached_blocks=num_cached_blocks,
|
||||||
|
num_full_blocks=num_full_blocks,
|
||||||
|
block_size=self.block_size,
|
||||||
|
kv_cache_group_id=self.kv_cache_group_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.num_cached_block[request.request_id] = num_full_blocks
|
||||||
|
|
||||||
|
def free(self, request_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Free the blocks for the request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request_id: The request ID.
|
||||||
|
"""
|
||||||
|
# Default to [] in case a request is freed (aborted) before alloc.
|
||||||
|
req_blocks = self.req_to_blocks.pop(request_id, [])
|
||||||
|
# Do not free sink blocks
|
||||||
|
if len(req_blocks) > 0:
|
||||||
|
req_blocks = req_blocks[len(self.sink_blocks) :]
|
||||||
|
|
||||||
|
# Free blocks in reverse order so that the tail blocks are
|
||||||
|
# freed first.
|
||||||
|
ordered_blocks = reversed(req_blocks)
|
||||||
|
|
||||||
|
self.block_pool.free_blocks(ordered_blocks)
|
||||||
|
self.num_cached_block.pop(request_id, None)
|
||||||
|
|
||||||
|
|
||||||
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
|
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
|
||||||
FullAttentionSpec: FullAttentionManager,
|
FullAttentionSpec: FullAttentionManager,
|
||||||
|
|||||||
@ -23,6 +23,7 @@ class BlockTable:
|
|||||||
device: torch.device,
|
device: torch.device,
|
||||||
kernel_block_size: int,
|
kernel_block_size: int,
|
||||||
cp_kv_cache_interleave_size: int,
|
cp_kv_cache_interleave_size: int,
|
||||||
|
sink_len: int = 0,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -63,6 +64,8 @@ class BlockTable:
|
|||||||
self.use_hybrid_blocks = True
|
self.use_hybrid_blocks = True
|
||||||
|
|
||||||
self.max_num_blocks_per_req = max_num_blocks_per_req * self.blocks_per_kv_block
|
self.max_num_blocks_per_req = max_num_blocks_per_req * self.blocks_per_kv_block
|
||||||
|
self.sink_block_len = sink_len // self.block_size
|
||||||
|
self.max_num_blocks_per_req = self.max_num_blocks_per_req + self.sink_block_len
|
||||||
|
|
||||||
self.block_table = self._make_buffer(
|
self.block_table = self._make_buffer(
|
||||||
self.max_num_reqs, self.max_num_blocks_per_req, dtype=torch.int32
|
self.max_num_reqs, self.max_num_blocks_per_req, dtype=torch.int32
|
||||||
@ -151,7 +154,7 @@ class BlockTable:
|
|||||||
block_table_indices = (
|
block_table_indices = (
|
||||||
req_indices * self.max_num_blocks_per_req
|
req_indices * self.max_num_blocks_per_req
|
||||||
+ positions // virtual_block_size
|
+ positions // virtual_block_size
|
||||||
)
|
) + self.sink_block_len
|
||||||
|
|
||||||
block_numbers = self.block_table.np.ravel()[block_table_indices]
|
block_numbers = self.block_table.np.ravel()[block_table_indices]
|
||||||
# Use virtual_block_size for mask calculation, which marks local
|
# Use virtual_block_size for mask calculation, which marks local
|
||||||
@ -177,9 +180,10 @@ class BlockTable:
|
|||||||
mask, slot_mapping, -1
|
mask, slot_mapping, -1
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
# When self.sink_block_len > 0, we need to shift block table indices
|
||||||
block_table_indices = (
|
block_table_indices = (
|
||||||
req_indices * self.max_num_blocks_per_req + positions // self.block_size
|
req_indices * self.max_num_blocks_per_req + positions // self.block_size
|
||||||
)
|
) + self.sink_block_len
|
||||||
|
|
||||||
block_numbers = self.block_table.np.ravel()[block_table_indices]
|
block_numbers = self.block_table.np.ravel()[block_table_indices]
|
||||||
block_offsets = positions % self.block_size
|
block_offsets = positions % self.block_size
|
||||||
@ -293,7 +297,7 @@ class MultiGroupBlockTable:
|
|||||||
block_size,
|
block_size,
|
||||||
max_num_reqs,
|
max_num_reqs,
|
||||||
max(
|
max(
|
||||||
cdiv(max_model_len + sink_len, block_size * total_cp_world_size),
|
cdiv(max_model_len, block_size * total_cp_world_size),
|
||||||
1 + num_speculative_tokens,
|
1 + num_speculative_tokens,
|
||||||
),
|
),
|
||||||
max_num_batched_tokens,
|
max_num_batched_tokens,
|
||||||
@ -301,6 +305,7 @@ class MultiGroupBlockTable:
|
|||||||
device,
|
device,
|
||||||
kernel_block_size,
|
kernel_block_size,
|
||||||
cp_kv_cache_interleave_size,
|
cp_kv_cache_interleave_size,
|
||||||
|
sink_len=sink_len,
|
||||||
)
|
)
|
||||||
for block_size, kernel_block_size in zip(block_sizes, kernel_block_sizes)
|
for block_size, kernel_block_size in zip(block_sizes, kernel_block_sizes)
|
||||||
]
|
]
|
||||||
|
|||||||
@ -101,7 +101,10 @@ def _reshape_kv_cache(
|
|||||||
num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes
|
num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes
|
||||||
|
|
||||||
attn_backend = attn_backends[layer_name]
|
attn_backend = attn_backends[layer_name]
|
||||||
if hasattr(kv_cache_spec, "head_size_v"):
|
if (
|
||||||
|
getattr(kv_cache_spec, "head_size_v", kv_cache_spec.head_size)
|
||||||
|
!= kv_cache_spec.head_size
|
||||||
|
):
|
||||||
kwargs = {"head_size_v": kv_cache_spec.head_size_v}
|
kwargs = {"head_size_v": kv_cache_spec.head_size_v}
|
||||||
stride_kwargs = {"diff_kv": True}
|
stride_kwargs = {"diff_kv": True}
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -5206,7 +5206,10 @@ class GPUModelRunner(
|
|||||||
)
|
)
|
||||||
kernel_num_blocks = num_blocks * num_blocks_per_kv_block
|
kernel_num_blocks = num_blocks * num_blocks_per_kv_block
|
||||||
|
|
||||||
if hasattr(kv_cache_spec, "head_size_v"):
|
if (
|
||||||
|
getattr(kv_cache_spec, "head_size_v", kv_cache_spec.head_size)
|
||||||
|
!= kv_cache_spec.head_size
|
||||||
|
):
|
||||||
kwargs = {"head_size_v": kv_cache_spec.head_size_v}
|
kwargs = {"head_size_v": kv_cache_spec.head_size_v}
|
||||||
stride_kwargs = {"diff_kv": True}
|
stride_kwargs = {"diff_kv": True}
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -190,7 +190,10 @@ class KVConnectorModelRunnerMixin:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
attn_backend = attn_group.backend
|
attn_backend = attn_group.backend
|
||||||
if hasattr(kv_cache_spec, "head_size_v"):
|
if (
|
||||||
|
getattr(kv_cache_spec, "head_size_v", kv_cache_spec.head_size)
|
||||||
|
!= kv_cache_spec.head_size
|
||||||
|
):
|
||||||
kwargs = {"head_size_v": kv_cache_spec.head_size_v}
|
kwargs = {"head_size_v": kv_cache_spec.head_size_v}
|
||||||
stride_kwargs = {"diff_kv": True}
|
stride_kwargs = {"diff_kv": True}
|
||||||
else:
|
else:
|
||||||
@ -265,7 +268,10 @@ class KVConnectorModelRunnerMixin:
|
|||||||
kernel_num_blocks = num_blocks * num_blocks_per_kv_block
|
kernel_num_blocks = num_blocks * num_blocks_per_kv_block
|
||||||
|
|
||||||
attn_backend = attn_group.backend
|
attn_backend = attn_group.backend
|
||||||
if hasattr(kv_cache_spec, "head_size_v"):
|
if (
|
||||||
|
getattr(kv_cache_spec, "head_size_v", kv_cache_spec.head_size)
|
||||||
|
!= kv_cache_spec.head_size
|
||||||
|
):
|
||||||
kwargs = {"head_size_v": kv_cache_spec.head_size_v}
|
kwargs = {"head_size_v": kv_cache_spec.head_size_v}
|
||||||
stride_kwargs = {"diff_kv": True}
|
stride_kwargs = {"diff_kv": True}
|
||||||
else:
|
else:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user