diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 2e16d7d2502e7..a6c0162d3f308 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -629,33 +629,3 @@ def test_reset_prefix_cache(): assert manager.reset_prefix_cache() assert not manager.cached_block_hash_to_block assert all([blk.block_hash is None for blk in manager.block_pool]) - - -def test_uncache_blocks(): - manager = KVCacheManager( - block_size=16, - num_gpu_blocks=10, - max_model_len=8192, - sliding_window=None, - enable_caching=True, - num_preallocate_tokens=0, - ) - - req0 = make_request("0", list(range(30))) - blocks = manager.allocate_slots(req0, 30) - assert [b.block_id for b in blocks] == [0, 1] - assert len(manager.cached_block_hash_to_block) == 1 - - req0.num_computed_tokens = 30 - - # Simulate speculative tokens. - for _ in range(5): - req0.append_output_token_ids(8) - manager.allocate_slots(req0, 5) - assert len(manager.cached_block_hash_to_block) == 2 - - # After sampling, assuming only 1 token is accepted. - req0.num_computed_tokens = 31 - num_uncached_blocks = manager.uncache_blocks(req0) - assert num_uncached_blocks == 1 - assert len(manager.cached_block_hash_to_block) == 1 diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 94086e4a1f75b..de349ec120999 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -252,29 +252,6 @@ class KVCacheManager: if block.ref_cnt == 0: self.free_block_queue.append(block) - def uncache_blocks(self, request: Request) -> int: - """Uncache the blocks that are no longer full based on the - num_computed_tokens in the given request. This happens when - the blocks were full and cached due to speculative tokens, but the - speculative tokens are not accepted. - - Args: - request: The request. - - Returns: - The number of uncached blocks. - """ - blocks = self.req_to_blocks[request.request_id] - num_computed_tokens = request.num_computed_tokens - num_full_blocks = num_computed_tokens // self.block_size - num_uncached_blocks = 0 - for block in blocks[num_full_blocks:]: - # If the block is not cached, the following blocks are not cached. - if not self._maybe_evict_cached_block(block): - break - num_uncached_blocks += 1 - return num_uncached_blocks - def reset_prefix_cache(self) -> bool: """Reset prefix cache. This function may be used in RLHF flows to invalid prefix caching after the weights are updated, @@ -470,8 +447,22 @@ class KVCacheManager: assert prev_block.block_hash is not None prev_block_hash_value = prev_block.block_hash.hash_value - for i, blk in enumerate(full_blocks): - blk_idx = blk_start_idx + i + # Find the first uncached block. This case should only happen when + # speculative decoding is used. + offset = 0 + for blk in full_blocks: + if blk.block_hash is None: + break + else: + prev_block_hash_value = blk.block_hash.hash_value + offset += 1 + else: + # All blocks are cached. + return + + for i, blk in enumerate(full_blocks[offset:]): + blk_idx = blk_start_idx + offset + i + assert blk.block_hash is None if blk_idx < num_cached_block_hashes: # The block hash may already be computed in