mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-04 03:02:15 +08:00
[Core][Hybrid allocator + connector 2/n] Unify remove_skipped_blocks by get_last_useful_token (#25431)
Signed-off-by: KuntaiDu <kuntai@uchicago.edu>
This commit is contained in:
parent
0b8e871e5e
commit
efe73e9b57
@ -243,18 +243,53 @@ class SingleTypeKVCacheManager(ABC):
|
|||||||
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
|
def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
|
||||||
"""
|
"""
|
||||||
Remove the blocks that are no longer needed from `blocks` and free the
|
Remove and free the blocks that are no longer needed for attention computation.
|
||||||
blocks. The removed blocks should be replaced by null_block.
|
The removed blocks should be replaced by null_block.
|
||||||
Need to be customized for each attention type.
|
|
||||||
|
This function depends on `get_num_skipped_tokens`, which need to be implemented
|
||||||
|
differently for each attention type.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request_id: The request ID.
|
request_id: The request ID.
|
||||||
num_computed_tokens: The number of tokens that have been computed.
|
num_computed_tokens: The number of tokens that have been computed.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
# Remove the blocks that will be skipped during attention computation.
|
||||||
|
num_skipped_tokens = self.get_num_skipped_tokens(num_computed_tokens)
|
||||||
|
if num_skipped_tokens <= 0:
|
||||||
|
# This indicates that ALL tokens are inside attention window.
|
||||||
|
# Thus we do not need to free any blocks outside attention window.
|
||||||
|
# A typical case is full attention that we never free any token
|
||||||
|
# before the request is finished.
|
||||||
|
return
|
||||||
|
num_skipped_blocks = num_skipped_tokens // self.block_size
|
||||||
|
blocks = self.req_to_blocks[request_id]
|
||||||
|
removed_blocks: list[KVCacheBlock] = []
|
||||||
|
# Because the block starts from index 0, the num_skipped_block-th block
|
||||||
|
# corresponds to index num_skipped_blocks - 1.
|
||||||
|
for i in range(num_skipped_blocks - 1, -1, -1):
|
||||||
|
if blocks[i] == self._null_block:
|
||||||
|
# If the block is already a null block, the blocks before it
|
||||||
|
# should also have been set to null blocks by the previous calls
|
||||||
|
# to this function.
|
||||||
|
break
|
||||||
|
removed_blocks.append(blocks[i])
|
||||||
|
blocks[i] = self._null_block
|
||||||
|
self.block_pool.free_blocks(removed_blocks)
|
||||||
|
|
||||||
|
def get_num_skipped_tokens(self, num_computed_tokens: int) -> int:
|
||||||
|
"""
|
||||||
|
Get the number of tokens that will be skipped for attention computation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_computed_tokens: The number of tokens that have been computed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The number of tokens that will be skipped for attention computation.
|
||||||
|
"""
|
||||||
|
# The default behavior is to not skip any tokens.
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
class FullAttentionManager(SingleTypeKVCacheManager):
|
class FullAttentionManager(SingleTypeKVCacheManager):
|
||||||
@ -298,10 +333,6 @@ class FullAttentionManager(SingleTypeKVCacheManager):
|
|||||||
computed.pop()
|
computed.pop()
|
||||||
return computed_blocks
|
return computed_blocks
|
||||||
|
|
||||||
def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
|
|
||||||
# No need to remove blocks for full attention.
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
|
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
|
||||||
blocks = self.req_to_blocks[running_request_id]
|
blocks = self.req_to_blocks[running_request_id]
|
||||||
num_common_blocks = 0
|
num_common_blocks = 0
|
||||||
@ -389,28 +420,33 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
|
|||||||
computed.pop()
|
computed.pop()
|
||||||
return computed_blocks
|
return computed_blocks
|
||||||
|
|
||||||
def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
|
def get_num_skipped_tokens(self, num_computed_tokens: int) -> int:
|
||||||
# Remove the blocks that are no longer be in the sliding window and
|
"""
|
||||||
# skipped during the attention computation.
|
Get the number of tokens that will be skipped for attention computation.
|
||||||
last_useful_token = num_computed_tokens - self.sliding_window + 1
|
|
||||||
last_useful_block = last_useful_token // self.block_size
|
For sliding window, this corresponds to the tokens that are prior to
|
||||||
if last_useful_block <= 0:
|
the current sliding window.
|
||||||
# Early return if tokens are not enough to fill the sliding window
|
|
||||||
return
|
Example:
|
||||||
blocks = self.req_to_blocks[request_id]
|
sliding_window=4, num_computed_tokens=7
|
||||||
if blocks[last_useful_block - 1] == self._null_block:
|
|
||||||
# Early return if there are no blocks to remove
|
Tokens: [ 0 1 2 3 4 5 6 7 ]
|
||||||
return
|
| ---- computed -----|
|
||||||
removed_blocks: list[KVCacheBlock] = []
|
^ next token to be computed
|
||||||
for i in range(last_useful_block - 1, -1, -1):
|
|-----------| sliding window for next token
|
||||||
if blocks[i] == self._null_block:
|
|--skipped---|
|
||||||
# If the block is already a null block, the blocks before it
|
|
||||||
# should also have been set to null blocks by the previous calls
|
The current window contains tokens 4~7. Tokens 0~3 will be skipped for
|
||||||
# to this function.
|
attention computation since they are outside the sliding window.
|
||||||
break
|
Thus, get_num_skipped_tokens(7) == 4.
|
||||||
removed_blocks.append(blocks[i])
|
|
||||||
blocks[i] = self._null_block
|
Args:
|
||||||
self.block_pool.free_blocks(removed_blocks)
|
num_computed_tokens: The number of tokens that have been computed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The number of tokens that will be skipped for attention computation.
|
||||||
|
"""
|
||||||
|
return num_computed_tokens - self.sliding_window + 1
|
||||||
|
|
||||||
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
|
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
|
||||||
"""
|
"""
|
||||||
@ -511,40 +547,51 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
|
|||||||
break
|
break
|
||||||
return computed_blocks
|
return computed_blocks
|
||||||
|
|
||||||
def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
|
def get_num_skipped_tokens(self, num_computed_tokens: int) -> int:
|
||||||
# Remove the blocks that are no longer be in the chunked attention
|
"""
|
||||||
# window and skipped during the attention computation.
|
Get the number of tokens that will be skipped for attention computation.
|
||||||
|
|
||||||
# [chunk 0][chunk 1]local_attention_start_idx ... current
|
For chunked local attention, this corresponds to the tokens that are on
|
||||||
# we computed previous number of chunks to get the idx of
|
the left side of the current chunk.
|
||||||
# current chunk window starting offset,
|
|
||||||
# e.g. for computed 1024 tokens, the 1024th token (0 indexed)
|
Example 1:
|
||||||
# is in the second chunk, there are 1 prev chunk, the start idx
|
chunk size = 8, num_computed_tokens = 13
|
||||||
# is 1024. for 1023, it will be 0.
|
Tokens: [ 0 1 2 3 4 5 6 7 | 8 9 10 11 12 13 14 15 ] ...
|
||||||
num_cached_block = self.num_cached_block.get(request_id, 0)
|
| ----- computed ---------------|
|
||||||
local_attention_start_idx = (
|
^^ next token to be computed
|
||||||
(num_computed_tokens)
|
|----------------| <-- attention window for
|
||||||
// self.attention_chunk_size
|
next token
|
||||||
* self.attention_chunk_size
|
|--- skipped -----|
|
||||||
)
|
Output: get_num_skipped_tokens(13) == 8
|
||||||
first_useful_block_idx = local_attention_start_idx // self.block_size
|
|
||||||
if num_cached_block > 0:
|
Example 2:
|
||||||
# Make sure we don't delete the last cached block
|
chunk size = 8, num_computed_tokens = 8
|
||||||
first_useful_block_idx = min(first_useful_block_idx, num_cached_block - 1)
|
Tokens: [ 0 1 2 3 4 5 6 7 | 8 9 10 11 12 13 14 15 ] ...
|
||||||
# if block size = 128, 0 -> block 0, 1024 (= 128 * 8) ->
|
| --- computed ---|
|
||||||
# block 8, 372 (= 128 * 2 + 116) -> block 2
|
^ next token to be computed
|
||||||
blocks = self.req_to_blocks[request_id]
|
|--| <-- attention window for next token
|
||||||
removed_blocks: list[KVCacheBlock] = []
|
| --- skipped ----|
|
||||||
# we need to keep the last block to get the previous hash key
|
Output: get_num_skipped_tokens(8) == 8
|
||||||
for i in range(first_useful_block_idx - 1, -1, -1):
|
|
||||||
if blocks[i] == self._null_block:
|
Example 3:
|
||||||
# If the block is already a null block, the blocks before it
|
chunk size = 8, num_computed_tokens = 7
|
||||||
# should also have been set to null blocks by the previous calls
|
Tokens: [ 0 1 2 3 4 5 6 7 | 8 9 10 11 12 13 14 15 ] ...
|
||||||
# to this function.
|
|---computed---|
|
||||||
break
|
^ next token to be computed
|
||||||
removed_blocks.append(blocks[i])
|
|-----------------| <-- attention window for next token
|
||||||
blocks[i] = self._null_block
|
no token should be skipped.
|
||||||
self.block_pool.free_blocks(removed_blocks)
|
Output: get_num_skipped_tokens(7) == 0
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_computed_tokens: The number of tokens that have been computed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The number of tokens that will be skipped for attention computation.
|
||||||
|
"""
|
||||||
|
num_skipped_tokens = (
|
||||||
|
num_computed_tokens // self.attention_chunk_size
|
||||||
|
) * self.attention_chunk_size
|
||||||
|
return num_skipped_tokens
|
||||||
|
|
||||||
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
|
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
|
||||||
"""
|
"""
|
||||||
@ -590,12 +637,6 @@ class MambaManager(SingleTypeKVCacheManager):
|
|||||||
|
|
||||||
return computed_blocks
|
return computed_blocks
|
||||||
|
|
||||||
def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
|
|
||||||
# Here unused blocks may be freed up for running requests.
|
|
||||||
# TODO(@s3woz) Free up all blocks that aren't needed by Mamba2
|
|
||||||
# (for which find_longest_cache_hit returns block_pool.null_block)
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
|
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
|
||||||
"""
|
"""
|
||||||
cascade attention is not supported by mamba
|
cascade attention is not supported by mamba
|
||||||
@ -676,11 +717,6 @@ class CrossAttentionManager(SingleTypeKVCacheManager):
|
|||||||
# Return empty blocks to indicate no cache hits
|
# Return empty blocks to indicate no cache hits
|
||||||
raise NotImplementedError("CrossAttentionManager does not support caching")
|
raise NotImplementedError("CrossAttentionManager does not support caching")
|
||||||
|
|
||||||
def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
|
|
||||||
# Cross-attention blocks represent encoder states which are needed
|
|
||||||
# for the entire decoding process, so no blocks should be skipped
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
|
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
|
||||||
FullAttentionSpec: FullAttentionManager,
|
FullAttentionSpec: FullAttentionManager,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user