diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index df487ec2ccaa9..1cdc80dd3546c 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -539,7 +539,7 @@ def test_allocate_with_lookahead(): max_model_len=100) blocks = kv_cache_manager.allocate_slots( request, - num_tokens=3, + num_new_tokens=3, num_lookahead_tokens=2, # Total required: 3+2=5 tokens ) assert len(blocks.blocks) == 2 # ceil(5/4)=2 blocks @@ -550,7 +550,7 @@ def test_allocate_with_lookahead(): # required_blocks = ceil((3 + 2) /4) = 2 blocks = kv_cache_manager.allocate_slots( request, - num_tokens=3, + num_new_tokens=3, num_lookahead_tokens=2, ) assert len(blocks.blocks) == 2 @@ -561,7 +561,7 @@ def test_allocate_with_lookahead(): max_model_len=100) blocks = kv_cache_manager.allocate_slots( request, - num_tokens=3, + num_new_tokens=3, num_lookahead_tokens=4, ) assert len(blocks.blocks) == 2 diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index ea4ec8a629e9a..a038106254666 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -299,7 +299,8 @@ def test_decode(): req0.append_output_token_ids(8) new_blocks = manager.allocate_slots(req0, 4) assert new_blocks is not None and len(new_blocks.blocks) == 0 - assert manager.req_to_blocks[req0.request_id][-1].block_hash is None + assert manager.single_type_manager.req_to_blocks[ + req0.request_id][-1].block_hash is None # Append slots with allocating a new block. req0.num_computed_tokens = 59 @@ -309,8 +310,10 @@ def test_decode(): req0.append_output_token_ids(7) new_blocks = manager.allocate_slots(req0, 19) assert new_blocks is not None and len(new_blocks.blocks) == 1 - assert manager.req_to_blocks[req0.request_id][-2].block_hash is not None - assert manager.req_to_blocks[req0.request_id][-1].block_hash is None + assert manager.single_type_manager.req_to_blocks[ + req0.request_id][-2].block_hash is not None + assert manager.single_type_manager.req_to_blocks[ + req0.request_id][-1].block_hash is None def test_evict(): @@ -689,7 +692,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): assert not computed_blocks.blocks assert num_computed_tokens == 0 manager.allocate_slots(req0, 48, computed_blocks) - block_part0 = manager.req_to_blocks[req0.request_id] + block_part0 = manager.single_type_manager.req_to_blocks[req0.request_id] # | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... | req1 = make_request("1", common_token_ids * 2) @@ -697,7 +700,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): assert computed_blocks.blocks == block_part0 assert num_computed_tokens == 3 * 16 manager.allocate_slots(req1, 48, computed_blocks) - block_part1 = manager.req_to_blocks[req1.request_id] + block_part1 = manager.single_type_manager.req_to_blocks[req1.request_id] # | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) | # | Req1-5(F)| ... | manager.free(req1) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index bfe9df10d4d19..0ca2ced891485 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -812,10 +812,11 @@ def _assert_right_kv_cache_manager( # Make sure the request stats are right. EXPECTED_TOTAL_BLOCKS = num_tokens // block_size for req_id in req_ids: - blocks = scheduler.kv_cache_manager.req_to_blocks[req_id] + blocks = (scheduler.kv_cache_manager.single_type_manager. + req_to_blocks[req_id]) hashes = scheduler.kv_cache_manager.req_to_block_hashes[req_id] - assert (scheduler.kv_cache_manager.num_cached_block[req_id] == - EXPECTED_TOTAL_BLOCKS) + assert (scheduler.kv_cache_manager.single_type_manager. + num_cached_block[req_id] == EXPECTED_TOTAL_BLOCKS) assert len(blocks) == EXPECTED_TOTAL_BLOCKS assert len(hashes) == EXPECTED_TOTAL_BLOCKS @@ -1195,9 +1196,11 @@ def assert_scheduler_empty(scheduler: Scheduler): assert len(scheduler.encoder_cache_manager.cached) == 0 # KVCache Manager. - assert len(scheduler.kv_cache_manager.req_to_blocks) == 0 + assert len( + scheduler.kv_cache_manager.single_type_manager.req_to_blocks) == 0 assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0 - assert len(scheduler.kv_cache_manager.num_cached_block) == 0 + assert len( + scheduler.kv_cache_manager.single_type_manager.num_cached_block) == 0 num_free_blocks = ( scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks) assert num_free_blocks == ( diff --git a/tests/v1/core/test_specialized_manager.py b/tests/v1/core/test_specialized_manager.py index 595c8608fc64d..540720cb9b2ff 100644 --- a/tests/v1/core/test_specialized_manager.py +++ b/tests/v1/core/test_specialized_manager.py @@ -8,6 +8,14 @@ from vllm.v1.core.specialized_manager import SlidingWindowManager from vllm.v1.kv_cache_interface import SlidingWindowSpec +def get_sliding_window_manager(sliding_window_spec, block_pool): + return SlidingWindowManager(sliding_window_spec, + block_pool, + use_eagle=False, + num_kv_cache_groups=1, + caching_hash_fn=lambda x: x) + + def test_sliding_window_possible_cached_prefix(): sliding_window_spec = SlidingWindowSpec( block_size=2, @@ -19,9 +27,7 @@ def test_sliding_window_possible_cached_prefix(): ) block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) - manager = SlidingWindowManager(sliding_window_spec, - block_pool, - use_eagle=False) + manager = get_sliding_window_manager(sliding_window_spec, block_pool) def run_one_case(block_is_cached, expect_length): block_hash_list = [ @@ -81,9 +87,7 @@ def test_sliding_window_remove_skipped_blocks(): block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True) - manager = SlidingWindowManager(sliding_window_spec, - block_pool, - use_eagle=False) + manager = get_sliding_window_manager(sliding_window_spec, block_pool) null_block_id = block_pool.null_block.block_id @@ -104,39 +108,35 @@ def test_sliding_window_remove_skipped_blocks(): 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010 ] block_table = id_to_block_table(original_block_ids) - removed = manager.remove_skipped_blocks(block_table, 0) - assert_block_id(removed, []) + manager.req_to_blocks["test"] = block_table + + manager.remove_skipped_blocks("test", 0) assert_block_id(block_table, original_block_ids) # 4 tokens are computed. Only token 0 is out of the sliding window. As # block 1000 also contains token 1 that is in the sliding window, block 1000 # cannot be removed. - removed = manager.remove_skipped_blocks(block_table, 4) - assert_block_id(removed, []) + manager.remove_skipped_blocks("test", 4) assert_block_id(block_table, original_block_ids) # 5 tokens are computed. Token 0 & 1 are out of the sliding window. # Block 1000 can be removed. - removed = manager.remove_skipped_blocks(block_table, 5) - assert_block_id(removed, [original_block_ids[0]]) + manager.remove_skipped_blocks("test", 5) assert_block_id(block_table, [null_block_id] + original_block_ids[1:]) # 6 tokens are computed. Token 0-2 are out of the sliding window. # Cannot remove new block as the block 1001 is still used by token 3. - removed = manager.remove_skipped_blocks(block_table, 6) - assert_block_id(removed, []) + manager.remove_skipped_blocks("test", 6) assert_block_id(block_table, [null_block_id] + original_block_ids[1:]) # 7 tokens are computed. Token 0-3 are out of the sliding window. # Block 1001 can be removed and block 1000 is already removed. - removed = manager.remove_skipped_blocks(block_table, 7) - assert_block_id(removed, [original_block_ids[1]]) + manager.remove_skipped_blocks("test", 7) assert_block_id(block_table, [null_block_id] * 2 + original_block_ids[2:]) # 11 tokens are computed. Token 0-7 are out of the sliding window. # Block 1002 & 1003 can be removed now. Block 1003 represents a longer # sequence, and is expected to be evicted earlier than 1002, so the order # of removed blocks should be [1003, 1002]. - removed = manager.remove_skipped_blocks(block_table, 11) - assert_block_id(removed, [original_block_ids[3], original_block_ids[2]]) + manager.remove_skipped_blocks("test", 11) assert_block_id(block_table, [null_block_id] * 4 + original_block_ids[4:]) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 8ef8143d1ed7b..c4ed127ece60e 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -1,17 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 from collections import defaultdict -from collections.abc import Iterable from dataclasses import dataclass from typing import Optional from vllm.distributed.kv_events import KVCacheEvent from vllm.logger import init_logger -from vllm.utils import cdiv, sha256 +from vllm.utils import sha256 from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock, hash_request_tokens) -from vllm.v1.core.specialized_manager import get_specialized_manager +from vllm.v1.core.specialized_manager import get_manager_for_kv_cache_spec from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request, RequestStatus @@ -56,7 +55,6 @@ class KVCacheManager: self.block_size = kv_cache_spec.block_size self.num_gpu_blocks = kv_cache_config.num_blocks self.max_model_len = max_model_len - self.max_num_blocks_per_req = cdiv(max_model_len, self.block_size) self.enable_caching = enable_caching self.caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash @@ -68,30 +66,20 @@ class KVCacheManager: self.block_pool = BlockPool(self.num_gpu_blocks, enable_caching, enable_kv_cache_events) - self.specialized_manager = get_specialized_manager( + self.single_type_manager = get_manager_for_kv_cache_spec( kv_cache_spec=kv_cache_spec, block_pool=self.block_pool, use_eagle=self.use_eagle, + num_kv_cache_groups=1, + caching_hash_fn=self.caching_hash_fn, ) - # Mapping from request ID to blocks to track the blocks allocated - # for each request, so that we can free the blocks when the request - # is finished. - self.req_to_blocks: defaultdict[str, - list[KVCacheBlock]] = defaultdict(list) - # Mapping from request ID to kv block hashes. # This is to avoid recomputing the block hashes for each call of # `get_computed_blocks` or `allocate_slots`. self.req_to_block_hashes: defaultdict[ str, list[BlockHashType]] = defaultdict(list) - # {req_id: The number of cached blocks for this given request} - # This is used to track the number of cached blocks for each request. - # This is only used to track the RUNNING requests, we do not track the - # data for reempted ones. - self.num_cached_block: dict[str, int] = {} - @property def usage(self) -> float: """Get the KV cache usage. @@ -159,7 +147,7 @@ class KVCacheManager: last_block_hash = None computed_blocks = ( - self.specialized_manager.find_longest_cache_hit(block_hashes)) + self.single_type_manager.find_longest_cache_hit(block_hashes)) if self.log_stats: assert self.prefix_cache_stats is not None @@ -181,7 +169,7 @@ class KVCacheManager: def allocate_slots( self, request: Request, - num_tokens: int, + num_new_tokens: int, new_computed_blocks: Optional[KVCacheBlocks] = None, num_lookahead_tokens: int = 0, ) -> Optional[KVCacheBlocks]: @@ -189,7 +177,7 @@ class KVCacheManager: Args: request: The request to allocate slots. - num_tokens: The number of tokens to allocate, including external + num_new_tokens: The number of tokens to allocate, including external tokens. Note that this does not include tokens that have already been computed locally (i.e. new_computed_blocks). new_computed_blocks: The new computed blocks just hitting the @@ -215,44 +203,38 @@ class KVCacheManager: Returns: A list of new allocated blocks. """ - if num_tokens == 0: - raise ValueError("num_tokens must be greater than 0") + if num_new_tokens == 0: + raise ValueError("num_new_tokens must be greater than 0") if new_computed_blocks is not None: new_computed_block_list = new_computed_blocks.blocks else: new_computed_block_list = [] - req_blocks = self.req_to_blocks[request.request_id] - # Free the blocks that are skipped during the attention computation # (e.g., tokens outside the sliding window). # We can do this even if we cannot schedule this request due to # insufficient free blocks. # Should call this function before allocating new blocks to reduce # the number of evicted blocks. - removed_blocks = self.specialized_manager.remove_skipped_blocks( - req_blocks, request.num_computed_tokens) - self.block_pool.free_blocks(removed_blocks) + self.single_type_manager.remove_skipped_blocks( + request.request_id, request.num_computed_tokens) # The number of computed tokens is the number of computed tokens plus # the new prefix caching hits num_computed_tokens = (request.num_computed_tokens + len(new_computed_block_list) * self.block_size) - num_required_blocks = cdiv( - num_computed_tokens + num_tokens + num_lookahead_tokens, - self.block_size) - num_new_blocks = (num_required_blocks - len(req_blocks) - - len(new_computed_block_list)) + num_tokens_need_slot = min( + num_computed_tokens + num_new_tokens + num_lookahead_tokens, + self.max_model_len) + num_blocks_to_allocate = ( + self.single_type_manager.get_num_blocks_to_allocate( + request_id=request.request_id, + num_tokens=num_tokens_need_slot, + new_computed_blocks=new_computed_block_list, + )) - # If a computed block of a request is an eviction candidate (in the - # free queue and ref_cnt == 0), it cannot be counted as a free block - # when allocating this request. - num_evictable_computed_blocks = sum(1 - for blk in new_computed_block_list - if blk.ref_cnt == 0) - if (num_new_blocks > self.block_pool.get_num_free_blocks() - - num_evictable_computed_blocks): + if num_blocks_to_allocate > self.block_pool.get_num_free_blocks(): # Cannot allocate new blocks return None @@ -266,74 +248,33 @@ class KVCacheManager: # Append the new computed blocks to the request blocks until now to # avoid the case where the new blocks cannot be allocated. - req_blocks.extend(new_computed_block_list) + self.single_type_manager.save_new_computed_blocks( + request.request_id, new_computed_block_list) - # Start to handle new blocks - - if num_new_blocks <= 0: - # No new block is needed. - new_blocks = [] - else: - # Get new blocks from the free block pool. - num_new_blocks = min( - num_new_blocks, - self.block_pool.get_num_free_blocks(), - # Should not exceed the maximum number of blocks per request. - # This is especially because the block table has the shape - # [..., max_num_blocks_per_req]. - self.max_num_blocks_per_req - len(req_blocks), - ) - assert num_new_blocks > 0 - - # Concatenate the computed block IDs and the new block IDs. - new_blocks = self.block_pool.get_new_blocks(num_new_blocks) - req_blocks.extend(new_blocks) + new_blocks = self.single_type_manager.allocate_new_blocks( + request.request_id, num_tokens_need_slot) if not self.enable_caching: return KVCacheBlocks(new_blocks) - # Use `new_computed_block_list` for a new request, and - # `num_cached_block` for a running request. - num_cached_blocks = self.num_cached_block.get( - request.request_id, len(new_computed_block_list)) # Speculated tokens might be rejected in the future, so we does # not cache any speculated tokens. We only cache blocks with # generated (accepted) tokens. - num_full_blocks_after_append = (num_computed_tokens + num_tokens - len( - request.spec_token_ids)) // self.block_size + self.single_type_manager.cache_blocks( + request, self.req_to_block_hashes[request.request_id], + num_computed_tokens + num_new_tokens - len(request.spec_token_ids)) - self.block_pool.cache_full_blocks( - request=request, - blocks=req_blocks, - block_hashes=self.req_to_block_hashes[request.request_id], - num_cached_blocks=num_cached_blocks, - num_full_blocks=num_full_blocks_after_append, - block_size=self.block_size, - hash_fn=self.caching_hash_fn, - ) - - self.num_cached_block[ - request.request_id] = num_full_blocks_after_append return KVCacheBlocks(new_blocks) def free(self, request: Request) -> None: """Free the blocks allocated for the request. - When caching is enabled, we free the blocks in reverse order so that - the tail blocks are evicted first. + We free the blocks in reverse order so that he tail blocks are evicted + first when caching is enabled. Args: request: The request to free the blocks. """ - # Default to [] in case a request is freed (aborted) before alloc. - blocks = self.req_to_blocks.pop(request.request_id, []) - ordered_blocks: Iterable[KVCacheBlock] = blocks - if self.enable_caching: - # Free blocks in reverse order so that the tail blocks are - # freed first. - ordered_blocks = reversed(blocks) - - self.block_pool.free_blocks(ordered_blocks) - self.num_cached_block.pop(request.request_id, None) + self.single_type_manager.free(request.request_id) def reset_prefix_cache(self) -> bool: """Reset prefix cache. This function may be used in RLHF @@ -390,14 +331,8 @@ class KVCacheManager: int: The number of common prefix blocks. """ assert request.status == RequestStatus.RUNNING - blocks = self.req_to_blocks[request.request_id] - num_common_blocks = 0 - for block in blocks: - if block.ref_cnt == num_running_requests: - num_common_blocks += 1 - else: - break - return num_common_blocks + return self.single_type_manager.get_num_common_prefix_blocks( + request.request_id, num_running_requests) def free_block_hashes(self, request: Request) -> None: """Discard the block hashes for the request. diff --git a/vllm/v1/core/specialized_manager.py b/vllm/v1/core/specialized_manager.py index f04eedf42662e..3fd3cb2841e07 100644 --- a/vllm/v1/core/specialized_manager.py +++ b/vllm/v1/core/specialized_manager.py @@ -1,17 +1,20 @@ # SPDX-License-Identifier: Apache-2.0 from abc import ABC, abstractmethod +from collections import defaultdict +from typing import Callable from vllm.utils import cdiv from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_utils import BlockHashType, KVCacheBlock from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec, SlidingWindowSpec) +from vllm.v1.request import Request -class SpecializedManager(ABC): +class SingleTypeKVCacheManager(ABC): """ - An abstract base class for specialized managers that handle the kv - cache management logic of different attention layers. + An abstract base class for a manager that handle the kv cache management + logic of one specific type of attention layer. """ def __init__( @@ -19,12 +22,18 @@ class SpecializedManager(ABC): kv_cache_spec: KVCacheSpec, block_pool: BlockPool, use_eagle: bool, + num_kv_cache_groups: int, + caching_hash_fn: Callable, ) -> None: """ Initializes the SpecializedManager. Args: kv_cache_spec: The kv_cache_spec for this manager. block_pool: The block pool. + use_eagle: Whether to use eagle. + num_kv_cache_groups: The number of kv cache groups managed by this + manager. + caching_hash_fn: The caching hash function. """ self.block_size = kv_cache_spec.block_size @@ -34,6 +43,149 @@ class SpecializedManager(ABC): # Needs special handling for find_longest_cache_hit if eagle is enabled self.use_eagle = use_eagle + # Mapping from request ID to blocks to track the blocks allocated + # for each request, so that we can free the blocks when the request + # is finished. + self.req_to_blocks: defaultdict[str, + list[KVCacheBlock]] = defaultdict(list) + + # {req_id: The number of cached blocks for this given request} + # This is used to track the number of cached blocks for each request. + # This is only used to track the RUNNING requests, we do not track the + # data for reempted ones. + self.num_cached_block: dict[str, int] = {} + + self.num_kv_cache_groups = num_kv_cache_groups + self.caching_hash_fn = caching_hash_fn + + def get_num_blocks_to_allocate( + self, request_id: str, num_tokens: int, + new_computed_blocks: list[KVCacheBlock]) -> int: + """ + Get the number of blocks needed to be allocated for the request. + + Args: + request_id: The request ID. + num_tokens: The total number of tokens that need a slot (including + tokens that are already allocated). + new_computed_blocks: The new computed blocks just hitting the + prefix caching. + + Returns: + The number of blocks. + """ + + num_required_blocks = cdiv(num_tokens, self.block_size) + num_new_blocks = (num_required_blocks - len(new_computed_blocks) - + len(self.req_to_blocks[request_id])) + # If a computed block of a request is an eviction candidate (in the + # free queue and ref_cnt == 0), it will be changed from a free block + # to a computed block when the request is allocated, so we also count + # it as needed to be allocated. + num_evictable_computed_blocks = sum(blk.ref_cnt == 0 + for blk in new_computed_blocks) + return ((num_new_blocks + num_evictable_computed_blocks) * + self.num_kv_cache_groups) + + def save_new_computed_blocks( + self, request_id: str, + new_computed_blocks: list[KVCacheBlock]) -> None: + """ + Add the new computed blocks to the request. + + Args: + request_id: The request ID. + new_computed_blocks: The new computed blocks just hitting the + prefix cache. + """ + if request_id not in self.num_cached_block: + # A new request. + req_blocks = self.req_to_blocks[request_id] + assert len(req_blocks) == 0 + req_blocks.extend(new_computed_blocks) + self.num_cached_block[request_id] = len(new_computed_blocks) + else: + # A running request. Should not have new computed blocks. + assert len(new_computed_blocks) == 0 + + def allocate_new_blocks(self, request_id: str, + num_tokens: int) -> list[KVCacheBlock]: + """ + Allocate new blocks for the request to give it at least `num_tokens` + token slots. + + Args: + request_id: The request ID. + num_tokens: The total number of tokens that need a slot (including + tokens that are already allocated). + + Returns: + The new allocated blocks. + """ + req_blocks = self.req_to_blocks[request_id] + num_required_blocks = cdiv(num_tokens, self.block_size) + num_new_blocks = num_required_blocks - len(req_blocks) + if num_new_blocks <= 0: + return [] + else: + new_blocks = self.block_pool.get_new_blocks( + num_new_blocks * self.num_kv_cache_groups) + req_blocks.extend(new_blocks) + return new_blocks + + def cache_blocks(self, request: Request, block_hashes: list[BlockHashType], + num_tokens: int) -> None: + """ + Cache the blocks for the request. + + Args: + request: The request. + block_hashes: The block hashes of the request. + num_tokens: The total number of tokens that need to be cached + (including tokens that are already cached). + """ + num_cached_blocks = self.num_cached_block[request.request_id] + num_full_blocks = num_tokens // self.block_size + + self.block_pool.cache_full_blocks( + request=request, + blocks=self.req_to_blocks[request.request_id], + block_hashes=block_hashes, + num_cached_blocks=num_cached_blocks, + num_full_blocks=num_full_blocks, + block_size=self.block_size, + hash_fn=self.caching_hash_fn, + ) + + self.num_cached_block[request.request_id] = num_full_blocks + + def free(self, request_id: str) -> None: + # Default to [] in case a request is freed (aborted) before alloc. + req_blocks = self.req_to_blocks.pop(request_id, []) + + # Free blocks in reverse order so that the tail blocks are + # freed first. + ordered_blocks = reversed(req_blocks) + + self.block_pool.free_blocks(ordered_blocks) + self.num_cached_block.pop(request_id, None) + + @abstractmethod + def get_num_common_prefix_blocks(self, request_id: str, + num_running_requests: int) -> int: + """ + Get the number of common prefix blocks for a request. + + Args: + request_id: The request ID. + block_hashes: The block hashes of the request. + + Returns: + The number of common prefix blocks. + """ + + raise NotImplementedError + @abstractmethod def find_longest_cache_hit( self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]: @@ -41,7 +193,8 @@ class SpecializedManager(ABC): Get the longest cache hit prefix of the blocks. If no cache hit is found, return an empty list. if eagle is enabled, drop the last matched block to force recompute the last block to get the required hidden - states for eagle drafting head. + states for eagle drafting head. Need to be customized for each attention + type. Args: block_hashes: The block hashes of the request. @@ -55,24 +208,23 @@ class SpecializedManager(ABC): raise NotImplementedError @abstractmethod - def remove_skipped_blocks(self, blocks: list[KVCacheBlock], - num_computed_tokens: int) -> list[KVCacheBlock]: + def remove_skipped_blocks(self, request_id: str, + num_computed_tokens: int) -> None: """ Remove the blocks that are no longer needed from `blocks`. The removed blocks should be replaced by null_block. Return the removed blocks in eviction order, where the first returned block should be evicted first. - Don't free the removed blocks in this function. + Don't free the removed blocks in this function. Need to be customized + for each attention type. Args: - blocks: The list of blocks to be updated. + request_id: The request ID. num_computed_tokens: The number of tokens that have been computed. - Returns: - The removed blocks in eviction order. """ raise NotImplementedError -class FullAttentionManager(SpecializedManager): +class FullAttentionManager(SingleTypeKVCacheManager): def find_longest_cache_hit( self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]: @@ -89,17 +241,28 @@ class FullAttentionManager(SpecializedManager): computed_blocks.pop() return computed_blocks - def remove_skipped_blocks(self, blocks: list[KVCacheBlock], - num_computed_tokens: int) -> list[KVCacheBlock]: + def remove_skipped_blocks(self, request_id: str, + num_computed_tokens: int) -> None: # No need to remove blocks for full attention. - return [] + pass + + def get_num_common_prefix_blocks(self, request_id: str, + num_running_requests: int) -> int: + blocks = self.req_to_blocks[request_id] + num_common_blocks = 0 + for block in blocks: + if block.ref_cnt == num_running_requests: + num_common_blocks += 1 + else: + break + return num_common_blocks -class SlidingWindowManager(SpecializedManager): +class SlidingWindowManager(SingleTypeKVCacheManager): def __init__(self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool, - use_eagle: bool): - super().__init__(kv_cache_spec, block_pool, use_eagle) + use_eagle: bool, **kwargs) -> None: + super().__init__(kv_cache_spec, block_pool, use_eagle, **kwargs) self.sliding_window = kv_cache_spec.sliding_window # The number of contiguous blocks needed for prefix cache hit. # -1 since the input token itself is also included in the window @@ -148,13 +311,13 @@ class SlidingWindowManager(SpecializedManager): computed_blocks.pop() return computed_blocks - def remove_skipped_blocks(self, blocks: list[KVCacheBlock], - num_computed_tokens: int) -> list[KVCacheBlock]: + def remove_skipped_blocks(self, request_id: str, + num_computed_tokens: int) -> None: # Remove the blocks that are no longer be in the sliding window and # skipped during the attention computation. last_useful_token = num_computed_tokens - self.sliding_window + 1 last_useful_block = last_useful_token // self.block_size - + blocks = self.req_to_blocks[request_id] removed_blocks: list[KVCacheBlock] = [] for i in range(last_useful_block - 1, -1, -1): if blocks[i] == self._null_block: @@ -164,17 +327,27 @@ class SlidingWindowManager(SpecializedManager): break removed_blocks.append(blocks[i]) blocks[i] = self._null_block - return removed_blocks + self.block_pool.free_blocks(removed_blocks) + + def get_num_common_prefix_blocks(self, request_id: str, + num_running_requests: int) -> int: + """ + NOTE(Chen): The prefix blocks are null blocks for sliding window layers. + So it's not correct to count ref_cnt like FullAttentionManager. Return + 0 here for correctness. Need to support cascade attention + sliding + window in the future. + """ + return 0 -spec_manager_map: dict[type[KVCacheSpec], type[SpecializedManager]] = { +spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = { FullAttentionSpec: FullAttentionManager, SlidingWindowSpec: SlidingWindowManager, } -def get_specialized_manager(kv_cache_spec: KVCacheSpec, - **kwargs) -> SpecializedManager: +def get_manager_for_kv_cache_spec(kv_cache_spec: KVCacheSpec, + **kwargs) -> SingleTypeKVCacheManager: manager_class = spec_manager_map[type(kv_cache_spec)] manager = manager_class(kv_cache_spec, **kwargs) return manager