mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 03:24:56 +08:00
[Mamba][KVCacheManager] Simplify kv cache manage logic for mamba + MTP (#25119)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
parent
1cab2f9cad
commit
3d5f1c8640
@ -565,35 +565,14 @@ class MambaManager(SingleTypeKVCacheManager):
|
||||
def get_num_blocks_to_allocate(
|
||||
self, request_id: str, num_tokens: int,
|
||||
new_computed_blocks: list[KVCacheBlock]) -> int:
|
||||
"""
|
||||
Get the number of blocks needed to be allocated for the request.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
num_tokens: The total number of tokens that need a slot (including
|
||||
tokens that are already allocated).
|
||||
new_computed_blocks: The new computed blocks just hitting the
|
||||
prefix caching.
|
||||
|
||||
Returns:
|
||||
The number of blocks
|
||||
"""
|
||||
|
||||
# Allocate extra `num_speculative_blocks` blocks for
|
||||
# speculative decoding (MTP/EAGLE) with linear attention.
|
||||
assert isinstance(self.kv_cache_spec, MambaSpec)
|
||||
if self.kv_cache_spec.num_speculative_blocks > 0:
|
||||
num_tokens += (self.kv_cache_spec.block_size *
|
||||
self.kv_cache_spec.num_speculative_blocks)
|
||||
num_required_blocks = cdiv(num_tokens, self.block_size)
|
||||
num_new_blocks = (num_required_blocks - len(new_computed_blocks) -
|
||||
len(self.req_to_blocks[request_id]))
|
||||
# If a computed block of a request is an eviction candidate (in the
|
||||
# free queue and ref_cnt == 0), it will be changed from a free block
|
||||
# to a computed block when the request is allocated, so we also count
|
||||
# it as needed to be allocated.
|
||||
num_evictable_computed_blocks = sum(
|
||||
blk.ref_cnt == 0 and not blk.is_null
|
||||
for blk in new_computed_blocks)
|
||||
return num_new_blocks + num_evictable_computed_blocks
|
||||
return super().get_num_blocks_to_allocate(request_id, num_tokens,
|
||||
new_computed_blocks)
|
||||
|
||||
def allocate_new_blocks(self, request_id: str,
|
||||
num_tokens: int) -> list[KVCacheBlock]:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user