[Core][Hybrid allocator + connector 2/n] Unify remove_skipped_blocks by get_last_useful_token (#25431)

Signed-off-by: KuntaiDu <kuntai@uchicago.edu>
This commit is contained in:
Kuntai Du 2025-11-05 16:12:00 -08:00 committed by GitHub
parent 0b8e871e5e
commit efe73e9b57
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -243,18 +243,53 @@ class SingleTypeKVCacheManager(ABC):
raise NotImplementedError
@abstractmethod
def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
"""
Remove the blocks that are no longer needed from `blocks` and free the
blocks. The removed blocks should be replaced by null_block.
Need to be customized for each attention type.
Remove and free the blocks that are no longer needed for attention computation.
The removed blocks should be replaced by null_block.
This function depends on `get_num_skipped_tokens`, which need to be implemented
differently for each attention type.
Args:
request_id: The request ID.
num_computed_tokens: The number of tokens that have been computed.
"""
raise NotImplementedError
# Remove the blocks that will be skipped during attention computation.
num_skipped_tokens = self.get_num_skipped_tokens(num_computed_tokens)
if num_skipped_tokens <= 0:
# This indicates that ALL tokens are inside attention window.
# Thus we do not need to free any blocks outside attention window.
# A typical case is full attention that we never free any token
# before the request is finished.
return
num_skipped_blocks = num_skipped_tokens // self.block_size
blocks = self.req_to_blocks[request_id]
removed_blocks: list[KVCacheBlock] = []
# Because the block starts from index 0, the num_skipped_block-th block
# corresponds to index num_skipped_blocks - 1.
for i in range(num_skipped_blocks - 1, -1, -1):
if blocks[i] == self._null_block:
# If the block is already a null block, the blocks before it
# should also have been set to null blocks by the previous calls
# to this function.
break
removed_blocks.append(blocks[i])
blocks[i] = self._null_block
self.block_pool.free_blocks(removed_blocks)
def get_num_skipped_tokens(self, num_computed_tokens: int) -> int:
"""
Get the number of tokens that will be skipped for attention computation.
Args:
num_computed_tokens: The number of tokens that have been computed.
Returns:
The number of tokens that will be skipped for attention computation.
"""
# The default behavior is to not skip any tokens.
return 0
class FullAttentionManager(SingleTypeKVCacheManager):
@ -298,10 +333,6 @@ class FullAttentionManager(SingleTypeKVCacheManager):
computed.pop()
return computed_blocks
def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
# No need to remove blocks for full attention.
pass
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
blocks = self.req_to_blocks[running_request_id]
num_common_blocks = 0
@ -389,28 +420,33 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
computed.pop()
return computed_blocks
def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
# Remove the blocks that are no longer be in the sliding window and
# skipped during the attention computation.
last_useful_token = num_computed_tokens - self.sliding_window + 1
last_useful_block = last_useful_token // self.block_size
if last_useful_block <= 0:
# Early return if tokens are not enough to fill the sliding window
return
blocks = self.req_to_blocks[request_id]
if blocks[last_useful_block - 1] == self._null_block:
# Early return if there are no blocks to remove
return
removed_blocks: list[KVCacheBlock] = []
for i in range(last_useful_block - 1, -1, -1):
if blocks[i] == self._null_block:
# If the block is already a null block, the blocks before it
# should also have been set to null blocks by the previous calls
# to this function.
break
removed_blocks.append(blocks[i])
blocks[i] = self._null_block
self.block_pool.free_blocks(removed_blocks)
def get_num_skipped_tokens(self, num_computed_tokens: int) -> int:
"""
Get the number of tokens that will be skipped for attention computation.
For sliding window, this corresponds to the tokens that are prior to
the current sliding window.
Example:
sliding_window=4, num_computed_tokens=7
Tokens: [ 0 1 2 3 4 5 6 7 ]
| ---- computed -----|
^ next token to be computed
|-----------| sliding window for next token
|--skipped---|
The current window contains tokens 4~7. Tokens 0~3 will be skipped for
attention computation since they are outside the sliding window.
Thus, get_num_skipped_tokens(7) == 4.
Args:
num_computed_tokens: The number of tokens that have been computed.
Returns:
The number of tokens that will be skipped for attention computation.
"""
return num_computed_tokens - self.sliding_window + 1
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
"""
@ -511,40 +547,51 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
break
return computed_blocks
def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
# Remove the blocks that are no longer be in the chunked attention
# window and skipped during the attention computation.
def get_num_skipped_tokens(self, num_computed_tokens: int) -> int:
"""
Get the number of tokens that will be skipped for attention computation.
# [chunk 0][chunk 1]local_attention_start_idx ... current
# we computed previous number of chunks to get the idx of
# current chunk window starting offset,
# e.g. for computed 1024 tokens, the 1024th token (0 indexed)
# is in the second chunk, there are 1 prev chunk, the start idx
# is 1024. for 1023, it will be 0.
num_cached_block = self.num_cached_block.get(request_id, 0)
local_attention_start_idx = (
(num_computed_tokens)
// self.attention_chunk_size
* self.attention_chunk_size
)
first_useful_block_idx = local_attention_start_idx // self.block_size
if num_cached_block > 0:
# Make sure we don't delete the last cached block
first_useful_block_idx = min(first_useful_block_idx, num_cached_block - 1)
# if block size = 128, 0 -> block 0, 1024 (= 128 * 8) ->
# block 8, 372 (= 128 * 2 + 116) -> block 2
blocks = self.req_to_blocks[request_id]
removed_blocks: list[KVCacheBlock] = []
# we need to keep the last block to get the previous hash key
for i in range(first_useful_block_idx - 1, -1, -1):
if blocks[i] == self._null_block:
# If the block is already a null block, the blocks before it
# should also have been set to null blocks by the previous calls
# to this function.
break
removed_blocks.append(blocks[i])
blocks[i] = self._null_block
self.block_pool.free_blocks(removed_blocks)
For chunked local attention, this corresponds to the tokens that are on
the left side of the current chunk.
Example 1:
chunk size = 8, num_computed_tokens = 13
Tokens: [ 0 1 2 3 4 5 6 7 | 8 9 10 11 12 13 14 15 ] ...
| ----- computed ---------------|
^^ next token to be computed
|----------------| <-- attention window for
next token
|--- skipped -----|
Output: get_num_skipped_tokens(13) == 8
Example 2:
chunk size = 8, num_computed_tokens = 8
Tokens: [ 0 1 2 3 4 5 6 7 | 8 9 10 11 12 13 14 15 ] ...
| --- computed ---|
^ next token to be computed
|--| <-- attention window for next token
| --- skipped ----|
Output: get_num_skipped_tokens(8) == 8
Example 3:
chunk size = 8, num_computed_tokens = 7
Tokens: [ 0 1 2 3 4 5 6 7 | 8 9 10 11 12 13 14 15 ] ...
|---computed---|
^ next token to be computed
|-----------------| <-- attention window for next token
no token should be skipped.
Output: get_num_skipped_tokens(7) == 0
Args:
num_computed_tokens: The number of tokens that have been computed.
Returns:
The number of tokens that will be skipped for attention computation.
"""
num_skipped_tokens = (
num_computed_tokens // self.attention_chunk_size
) * self.attention_chunk_size
return num_skipped_tokens
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
"""
@ -590,12 +637,6 @@ class MambaManager(SingleTypeKVCacheManager):
return computed_blocks
def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
# Here unused blocks may be freed up for running requests.
# TODO(@s3woz) Free up all blocks that aren't needed by Mamba2
# (for which find_longest_cache_hit returns block_pool.null_block)
pass
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
"""
cascade attention is not supported by mamba
@ -676,11 +717,6 @@ class CrossAttentionManager(SingleTypeKVCacheManager):
# Return empty blocks to indicate no cache hits
raise NotImplementedError("CrossAttentionManager does not support caching")
def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
# Cross-attention blocks represent encoder states which are needed
# for the entire decoding process, so no blocks should be skipped
pass
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
FullAttentionSpec: FullAttentionManager,