mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 19:45:36 +08:00
[Core][Hybrid allocator + connector 2/n] Unify remove_skipped_blocks by get_last_useful_token (#25431)
Signed-off-by: KuntaiDu <kuntai@uchicago.edu>
This commit is contained in:
parent
0b8e871e5e
commit
efe73e9b57
@ -243,18 +243,53 @@ class SingleTypeKVCacheManager(ABC):
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
|
||||
"""
|
||||
Remove the blocks that are no longer needed from `blocks` and free the
|
||||
blocks. The removed blocks should be replaced by null_block.
|
||||
Need to be customized for each attention type.
|
||||
Remove and free the blocks that are no longer needed for attention computation.
|
||||
The removed blocks should be replaced by null_block.
|
||||
|
||||
This function depends on `get_num_skipped_tokens`, which need to be implemented
|
||||
differently for each attention type.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
num_computed_tokens: The number of tokens that have been computed.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
# Remove the blocks that will be skipped during attention computation.
|
||||
num_skipped_tokens = self.get_num_skipped_tokens(num_computed_tokens)
|
||||
if num_skipped_tokens <= 0:
|
||||
# This indicates that ALL tokens are inside attention window.
|
||||
# Thus we do not need to free any blocks outside attention window.
|
||||
# A typical case is full attention that we never free any token
|
||||
# before the request is finished.
|
||||
return
|
||||
num_skipped_blocks = num_skipped_tokens // self.block_size
|
||||
blocks = self.req_to_blocks[request_id]
|
||||
removed_blocks: list[KVCacheBlock] = []
|
||||
# Because the block starts from index 0, the num_skipped_block-th block
|
||||
# corresponds to index num_skipped_blocks - 1.
|
||||
for i in range(num_skipped_blocks - 1, -1, -1):
|
||||
if blocks[i] == self._null_block:
|
||||
# If the block is already a null block, the blocks before it
|
||||
# should also have been set to null blocks by the previous calls
|
||||
# to this function.
|
||||
break
|
||||
removed_blocks.append(blocks[i])
|
||||
blocks[i] = self._null_block
|
||||
self.block_pool.free_blocks(removed_blocks)
|
||||
|
||||
def get_num_skipped_tokens(self, num_computed_tokens: int) -> int:
|
||||
"""
|
||||
Get the number of tokens that will be skipped for attention computation.
|
||||
|
||||
Args:
|
||||
num_computed_tokens: The number of tokens that have been computed.
|
||||
|
||||
Returns:
|
||||
The number of tokens that will be skipped for attention computation.
|
||||
"""
|
||||
# The default behavior is to not skip any tokens.
|
||||
return 0
|
||||
|
||||
|
||||
class FullAttentionManager(SingleTypeKVCacheManager):
|
||||
@ -298,10 +333,6 @@ class FullAttentionManager(SingleTypeKVCacheManager):
|
||||
computed.pop()
|
||||
return computed_blocks
|
||||
|
||||
def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
|
||||
# No need to remove blocks for full attention.
|
||||
pass
|
||||
|
||||
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
|
||||
blocks = self.req_to_blocks[running_request_id]
|
||||
num_common_blocks = 0
|
||||
@ -389,28 +420,33 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
|
||||
computed.pop()
|
||||
return computed_blocks
|
||||
|
||||
def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
|
||||
# Remove the blocks that are no longer be in the sliding window and
|
||||
# skipped during the attention computation.
|
||||
last_useful_token = num_computed_tokens - self.sliding_window + 1
|
||||
last_useful_block = last_useful_token // self.block_size
|
||||
if last_useful_block <= 0:
|
||||
# Early return if tokens are not enough to fill the sliding window
|
||||
return
|
||||
blocks = self.req_to_blocks[request_id]
|
||||
if blocks[last_useful_block - 1] == self._null_block:
|
||||
# Early return if there are no blocks to remove
|
||||
return
|
||||
removed_blocks: list[KVCacheBlock] = []
|
||||
for i in range(last_useful_block - 1, -1, -1):
|
||||
if blocks[i] == self._null_block:
|
||||
# If the block is already a null block, the blocks before it
|
||||
# should also have been set to null blocks by the previous calls
|
||||
# to this function.
|
||||
break
|
||||
removed_blocks.append(blocks[i])
|
||||
blocks[i] = self._null_block
|
||||
self.block_pool.free_blocks(removed_blocks)
|
||||
def get_num_skipped_tokens(self, num_computed_tokens: int) -> int:
|
||||
"""
|
||||
Get the number of tokens that will be skipped for attention computation.
|
||||
|
||||
For sliding window, this corresponds to the tokens that are prior to
|
||||
the current sliding window.
|
||||
|
||||
Example:
|
||||
sliding_window=4, num_computed_tokens=7
|
||||
|
||||
Tokens: [ 0 1 2 3 4 5 6 7 ]
|
||||
| ---- computed -----|
|
||||
^ next token to be computed
|
||||
|-----------| sliding window for next token
|
||||
|--skipped---|
|
||||
|
||||
The current window contains tokens 4~7. Tokens 0~3 will be skipped for
|
||||
attention computation since they are outside the sliding window.
|
||||
Thus, get_num_skipped_tokens(7) == 4.
|
||||
|
||||
Args:
|
||||
num_computed_tokens: The number of tokens that have been computed.
|
||||
|
||||
Returns:
|
||||
The number of tokens that will be skipped for attention computation.
|
||||
"""
|
||||
return num_computed_tokens - self.sliding_window + 1
|
||||
|
||||
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
|
||||
"""
|
||||
@ -511,40 +547,51 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
|
||||
break
|
||||
return computed_blocks
|
||||
|
||||
def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
|
||||
# Remove the blocks that are no longer be in the chunked attention
|
||||
# window and skipped during the attention computation.
|
||||
def get_num_skipped_tokens(self, num_computed_tokens: int) -> int:
|
||||
"""
|
||||
Get the number of tokens that will be skipped for attention computation.
|
||||
|
||||
# [chunk 0][chunk 1]local_attention_start_idx ... current
|
||||
# we computed previous number of chunks to get the idx of
|
||||
# current chunk window starting offset,
|
||||
# e.g. for computed 1024 tokens, the 1024th token (0 indexed)
|
||||
# is in the second chunk, there are 1 prev chunk, the start idx
|
||||
# is 1024. for 1023, it will be 0.
|
||||
num_cached_block = self.num_cached_block.get(request_id, 0)
|
||||
local_attention_start_idx = (
|
||||
(num_computed_tokens)
|
||||
// self.attention_chunk_size
|
||||
* self.attention_chunk_size
|
||||
)
|
||||
first_useful_block_idx = local_attention_start_idx // self.block_size
|
||||
if num_cached_block > 0:
|
||||
# Make sure we don't delete the last cached block
|
||||
first_useful_block_idx = min(first_useful_block_idx, num_cached_block - 1)
|
||||
# if block size = 128, 0 -> block 0, 1024 (= 128 * 8) ->
|
||||
# block 8, 372 (= 128 * 2 + 116) -> block 2
|
||||
blocks = self.req_to_blocks[request_id]
|
||||
removed_blocks: list[KVCacheBlock] = []
|
||||
# we need to keep the last block to get the previous hash key
|
||||
for i in range(first_useful_block_idx - 1, -1, -1):
|
||||
if blocks[i] == self._null_block:
|
||||
# If the block is already a null block, the blocks before it
|
||||
# should also have been set to null blocks by the previous calls
|
||||
# to this function.
|
||||
break
|
||||
removed_blocks.append(blocks[i])
|
||||
blocks[i] = self._null_block
|
||||
self.block_pool.free_blocks(removed_blocks)
|
||||
For chunked local attention, this corresponds to the tokens that are on
|
||||
the left side of the current chunk.
|
||||
|
||||
Example 1:
|
||||
chunk size = 8, num_computed_tokens = 13
|
||||
Tokens: [ 0 1 2 3 4 5 6 7 | 8 9 10 11 12 13 14 15 ] ...
|
||||
| ----- computed ---------------|
|
||||
^^ next token to be computed
|
||||
|----------------| <-- attention window for
|
||||
next token
|
||||
|--- skipped -----|
|
||||
Output: get_num_skipped_tokens(13) == 8
|
||||
|
||||
Example 2:
|
||||
chunk size = 8, num_computed_tokens = 8
|
||||
Tokens: [ 0 1 2 3 4 5 6 7 | 8 9 10 11 12 13 14 15 ] ...
|
||||
| --- computed ---|
|
||||
^ next token to be computed
|
||||
|--| <-- attention window for next token
|
||||
| --- skipped ----|
|
||||
Output: get_num_skipped_tokens(8) == 8
|
||||
|
||||
Example 3:
|
||||
chunk size = 8, num_computed_tokens = 7
|
||||
Tokens: [ 0 1 2 3 4 5 6 7 | 8 9 10 11 12 13 14 15 ] ...
|
||||
|---computed---|
|
||||
^ next token to be computed
|
||||
|-----------------| <-- attention window for next token
|
||||
no token should be skipped.
|
||||
Output: get_num_skipped_tokens(7) == 0
|
||||
|
||||
Args:
|
||||
num_computed_tokens: The number of tokens that have been computed.
|
||||
|
||||
Returns:
|
||||
The number of tokens that will be skipped for attention computation.
|
||||
"""
|
||||
num_skipped_tokens = (
|
||||
num_computed_tokens // self.attention_chunk_size
|
||||
) * self.attention_chunk_size
|
||||
return num_skipped_tokens
|
||||
|
||||
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
|
||||
"""
|
||||
@ -590,12 +637,6 @@ class MambaManager(SingleTypeKVCacheManager):
|
||||
|
||||
return computed_blocks
|
||||
|
||||
def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
|
||||
# Here unused blocks may be freed up for running requests.
|
||||
# TODO(@s3woz) Free up all blocks that aren't needed by Mamba2
|
||||
# (for which find_longest_cache_hit returns block_pool.null_block)
|
||||
pass
|
||||
|
||||
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
|
||||
"""
|
||||
cascade attention is not supported by mamba
|
||||
@ -676,11 +717,6 @@ class CrossAttentionManager(SingleTypeKVCacheManager):
|
||||
# Return empty blocks to indicate no cache hits
|
||||
raise NotImplementedError("CrossAttentionManager does not support caching")
|
||||
|
||||
def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
|
||||
# Cross-attention blocks represent encoder states which are needed
|
||||
# for the entire decoding process, so no blocks should be skipped
|
||||
pass
|
||||
|
||||
|
||||
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
|
||||
FullAttentionSpec: FullAttentionManager,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user