diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index e9c6f1f95cd71..bff3724d95e68 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -112,9 +112,9 @@ def test_kv_cache_block(): assert block.block_hash is None # Test reference count manipulation - block.incr_ref() + block.ref_cnt += 1 assert block.ref_cnt == 1 - block.decr_ref() + block.ref_cnt -= 1 assert block.ref_cnt == 0 # Test block hash setting and resetting diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index 5bf4d3a2acb45..ad9854dd29c38 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -276,7 +276,7 @@ class BlockPool: # candidate), so remove it. if block.ref_cnt == 0 and not block.is_null: self.free_block_queue.remove(block) - block.incr_ref() + block.ref_cnt += 1 def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None: """Free a list of blocks. The blocks should be ordered by their diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 258805843e227..f3a16d64e19fd 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -126,14 +126,17 @@ class KVCacheCoordinator(ABC): def get_num_common_prefix_blocks(self, request_id: str, num_running_requests: int) -> list[int]: """ - Get the number of common prefix blocks for a request. + Get the number of common prefix blocks for all requests in the RUNNING + state for each kv cache group. Args: request_id: The request ID. - num_running_requests: The number of requests in the RUNNING state. + num_running_requests: The total number of requests in the RUNNING + state. Returns: - list[int]: The number of common prefix blocks. + list[int]: The number of common prefix blocks for all requests in + the RUNNING state for each kv cache group. """ num_blocks_per_group = [ manager.get_num_common_prefix_blocks(request_id, diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index e820a0ad6d5d0..ce333dbe61a19 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -170,10 +170,6 @@ class KVCacheManager: self.block_size, request) self.req_to_block_hashes[request.request_id] = block_hashes - if self.log_stats: - assert self.prefix_cache_stats is not None - self.prefix_cache_stats.requests += 1 - # NOTE: When all tokens hit the cache, we must recompute the last token # to obtain logits. Thus, set max_cache_hit_length to prompt_length - 1. # This can trigger recomputation of an entire block, rather than just @@ -187,6 +183,7 @@ class KVCacheManager: if self.log_stats: assert self.prefix_cache_stats is not None + self.prefix_cache_stats.requests += 1 self.prefix_cache_stats.queries += request.num_tokens self.prefix_cache_stats.hits += num_new_computed_tokens diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 3a72ac271afa6..25520eb655111 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -154,14 +154,6 @@ class KVCacheBlock: # Whether the block is a null block that should never be cached. is_null: bool = False - # TODO(Jialin): For performance, let callers handle ref_cnt bumps to - # avoid function calls. - def incr_ref(self): - self.ref_cnt += 1 - - def decr_ref(self): - self.ref_cnt -= 1 - @property def block_hash(self) -> Optional[BlockHashWithGroupId]: return self._block_hash diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 714f49494c9a1..8f310023a8cd3 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import itertools from abc import ABC, abstractmethod from collections import defaultdict from typing import Callable @@ -177,14 +178,17 @@ class SingleTypeKVCacheManager(ABC): def get_num_common_prefix_blocks(self, request_id: str, num_running_requests: int) -> int: """ - Get the number of common prefix blocks for a request. + Get the number of common prefix blocks for all requests in the RUNNING + state. Args: request_id: The request ID. - num_running_requests: The number of requests in the RUNNING state. + num_running_requests: The total number of requests in the RUNNING + state. Returns: - The number of common prefix blocks. + The number of common prefix blocks for all requests in the RUNNING + state. """ raise NotImplementedError @@ -264,7 +268,7 @@ class FullAttentionManager(SingleTypeKVCacheManager): computed_blocks: tuple[list[KVCacheBlock], ...] = tuple( [] for _ in range(len(kv_cache_group_ids))) max_num_blocks = max_length // kv_cache_spec.block_size - for i, block_hash in zip(range(max_num_blocks), block_hashes): + for block_hash in itertools.islice(block_hashes, max_num_blocks): # block_hashes is a chain of block hashes. If a block hash is not # in the cached_block_hash_to_id, the following block hashes are # not computed yet for sure.