[Core][Hybrid allocator + connector 2/n] Unify remove_skipped_blocks by get_last_useful_token (#25431)

Signed-off-by: KuntaiDu <kuntai@uchicago.edu>
This commit is contained in:
Kuntai Du 2025-11-05 16:12:00 -08:00 committed by GitHub
parent 0b8e871e5e
commit efe73e9b57
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -243,18 +243,53 @@ class SingleTypeKVCacheManager(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod
def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
""" """
Remove the blocks that are no longer needed from `blocks` and free the Remove and free the blocks that are no longer needed for attention computation.
blocks. The removed blocks should be replaced by null_block. The removed blocks should be replaced by null_block.
Need to be customized for each attention type.
This function depends on `get_num_skipped_tokens`, which need to be implemented
differently for each attention type.
Args: Args:
request_id: The request ID. request_id: The request ID.
num_computed_tokens: The number of tokens that have been computed. num_computed_tokens: The number of tokens that have been computed.
""" """
raise NotImplementedError # Remove the blocks that will be skipped during attention computation.
num_skipped_tokens = self.get_num_skipped_tokens(num_computed_tokens)
if num_skipped_tokens <= 0:
# This indicates that ALL tokens are inside attention window.
# Thus we do not need to free any blocks outside attention window.
# A typical case is full attention that we never free any token
# before the request is finished.
return
num_skipped_blocks = num_skipped_tokens // self.block_size
blocks = self.req_to_blocks[request_id]
removed_blocks: list[KVCacheBlock] = []
# Because the block starts from index 0, the num_skipped_block-th block
# corresponds to index num_skipped_blocks - 1.
for i in range(num_skipped_blocks - 1, -1, -1):
if blocks[i] == self._null_block:
# If the block is already a null block, the blocks before it
# should also have been set to null blocks by the previous calls
# to this function.
break
removed_blocks.append(blocks[i])
blocks[i] = self._null_block
self.block_pool.free_blocks(removed_blocks)
def get_num_skipped_tokens(self, num_computed_tokens: int) -> int:
"""
Get the number of tokens that will be skipped for attention computation.
Args:
num_computed_tokens: The number of tokens that have been computed.
Returns:
The number of tokens that will be skipped for attention computation.
"""
# The default behavior is to not skip any tokens.
return 0
class FullAttentionManager(SingleTypeKVCacheManager): class FullAttentionManager(SingleTypeKVCacheManager):
@ -298,10 +333,6 @@ class FullAttentionManager(SingleTypeKVCacheManager):
computed.pop() computed.pop()
return computed_blocks return computed_blocks
def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
# No need to remove blocks for full attention.
pass
def get_num_common_prefix_blocks(self, running_request_id: str) -> int: def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
blocks = self.req_to_blocks[running_request_id] blocks = self.req_to_blocks[running_request_id]
num_common_blocks = 0 num_common_blocks = 0
@ -389,28 +420,33 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
computed.pop() computed.pop()
return computed_blocks return computed_blocks
def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: def get_num_skipped_tokens(self, num_computed_tokens: int) -> int:
# Remove the blocks that are no longer be in the sliding window and """
# skipped during the attention computation. Get the number of tokens that will be skipped for attention computation.
last_useful_token = num_computed_tokens - self.sliding_window + 1
last_useful_block = last_useful_token // self.block_size For sliding window, this corresponds to the tokens that are prior to
if last_useful_block <= 0: the current sliding window.
# Early return if tokens are not enough to fill the sliding window
return Example:
blocks = self.req_to_blocks[request_id] sliding_window=4, num_computed_tokens=7
if blocks[last_useful_block - 1] == self._null_block:
# Early return if there are no blocks to remove Tokens: [ 0 1 2 3 4 5 6 7 ]
return | ---- computed -----|
removed_blocks: list[KVCacheBlock] = [] ^ next token to be computed
for i in range(last_useful_block - 1, -1, -1): |-----------| sliding window for next token
if blocks[i] == self._null_block: |--skipped---|
# If the block is already a null block, the blocks before it
# should also have been set to null blocks by the previous calls The current window contains tokens 4~7. Tokens 0~3 will be skipped for
# to this function. attention computation since they are outside the sliding window.
break Thus, get_num_skipped_tokens(7) == 4.
removed_blocks.append(blocks[i])
blocks[i] = self._null_block Args:
self.block_pool.free_blocks(removed_blocks) num_computed_tokens: The number of tokens that have been computed.
Returns:
The number of tokens that will be skipped for attention computation.
"""
return num_computed_tokens - self.sliding_window + 1
def get_num_common_prefix_blocks(self, running_request_id: str) -> int: def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
""" """
@ -511,40 +547,51 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
break break
return computed_blocks return computed_blocks
def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: def get_num_skipped_tokens(self, num_computed_tokens: int) -> int:
# Remove the blocks that are no longer be in the chunked attention """
# window and skipped during the attention computation. Get the number of tokens that will be skipped for attention computation.
# [chunk 0][chunk 1]local_attention_start_idx ... current For chunked local attention, this corresponds to the tokens that are on
# we computed previous number of chunks to get the idx of the left side of the current chunk.
# current chunk window starting offset,
# e.g. for computed 1024 tokens, the 1024th token (0 indexed) Example 1:
# is in the second chunk, there are 1 prev chunk, the start idx chunk size = 8, num_computed_tokens = 13
# is 1024. for 1023, it will be 0. Tokens: [ 0 1 2 3 4 5 6 7 | 8 9 10 11 12 13 14 15 ] ...
num_cached_block = self.num_cached_block.get(request_id, 0) | ----- computed ---------------|
local_attention_start_idx = ( ^^ next token to be computed
(num_computed_tokens) |----------------| <-- attention window for
// self.attention_chunk_size next token
* self.attention_chunk_size |--- skipped -----|
) Output: get_num_skipped_tokens(13) == 8
first_useful_block_idx = local_attention_start_idx // self.block_size
if num_cached_block > 0: Example 2:
# Make sure we don't delete the last cached block chunk size = 8, num_computed_tokens = 8
first_useful_block_idx = min(first_useful_block_idx, num_cached_block - 1) Tokens: [ 0 1 2 3 4 5 6 7 | 8 9 10 11 12 13 14 15 ] ...
# if block size = 128, 0 -> block 0, 1024 (= 128 * 8) -> | --- computed ---|
# block 8, 372 (= 128 * 2 + 116) -> block 2 ^ next token to be computed
blocks = self.req_to_blocks[request_id] |--| <-- attention window for next token
removed_blocks: list[KVCacheBlock] = [] | --- skipped ----|
# we need to keep the last block to get the previous hash key Output: get_num_skipped_tokens(8) == 8
for i in range(first_useful_block_idx - 1, -1, -1):
if blocks[i] == self._null_block: Example 3:
# If the block is already a null block, the blocks before it chunk size = 8, num_computed_tokens = 7
# should also have been set to null blocks by the previous calls Tokens: [ 0 1 2 3 4 5 6 7 | 8 9 10 11 12 13 14 15 ] ...
# to this function. |---computed---|
break ^ next token to be computed
removed_blocks.append(blocks[i]) |-----------------| <-- attention window for next token
blocks[i] = self._null_block no token should be skipped.
self.block_pool.free_blocks(removed_blocks) Output: get_num_skipped_tokens(7) == 0
Args:
num_computed_tokens: The number of tokens that have been computed.
Returns:
The number of tokens that will be skipped for attention computation.
"""
num_skipped_tokens = (
num_computed_tokens // self.attention_chunk_size
) * self.attention_chunk_size
return num_skipped_tokens
def get_num_common_prefix_blocks(self, running_request_id: str) -> int: def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
""" """
@ -590,12 +637,6 @@ class MambaManager(SingleTypeKVCacheManager):
return computed_blocks return computed_blocks
def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
# Here unused blocks may be freed up for running requests.
# TODO(@s3woz) Free up all blocks that aren't needed by Mamba2
# (for which find_longest_cache_hit returns block_pool.null_block)
pass
def get_num_common_prefix_blocks(self, running_request_id: str) -> int: def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
""" """
cascade attention is not supported by mamba cascade attention is not supported by mamba
@ -676,11 +717,6 @@ class CrossAttentionManager(SingleTypeKVCacheManager):
# Return empty blocks to indicate no cache hits # Return empty blocks to indicate no cache hits
raise NotImplementedError("CrossAttentionManager does not support caching") raise NotImplementedError("CrossAttentionManager does not support caching")
def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
# Cross-attention blocks represent encoder states which are needed
# for the entire decoding process, so no blocks should be skipped
pass
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = { spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
FullAttentionSpec: FullAttentionManager, FullAttentionSpec: FullAttentionManager,