From 4b68c4a55b0fa5846d180532ae7e58db85101e07 Mon Sep 17 00:00:00 2001 From: Jialin Ouyang Date: Thu, 30 Oct 2025 12:47:30 -0700 Subject: [PATCH] [Core][Perf] Only invoke save_new_computed_blocks when computed blocks are not empty (#27799) Signed-off-by: Jialin Ouyang --- vllm/v1/core/kv_cache_manager.py | 11 ++++++----- vllm/v1/core/single_type_kv_cache_manager.py | 2 +- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index bb8cec91f36dd..63a1ff06e4049 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -306,11 +306,12 @@ class KVCacheManager: "Computed blocks should be empty when prefix caching is disabled" ) - # Append the new computed blocks to the request blocks until now to - # avoid the case where the new blocks cannot be allocated. - self.coordinator.save_new_computed_blocks( - request.request_id, new_computed_block_list - ) + if new_computed_block_list is not self.empty_kv_cache_blocks.blocks: + # Append the new computed blocks to the request blocks until now to + # avoid the case where the new blocks cannot be allocated. + self.coordinator.save_new_computed_blocks( + request.request_id, new_computed_block_list + ) new_blocks = self.coordinator.allocate_new_blocks( request.request_id, num_tokens_need_slot, num_encoder_tokens diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 575ae3d7d83b6..8f14fb1894707 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -151,7 +151,7 @@ class SingleTypeKVCacheManager(ABC): num_tokens: The total number of tokens that need to be cached (including tokens that are already cached). """ - num_cached_blocks = self.num_cached_block[request.request_id] + num_cached_blocks = self.num_cached_block.get(request.request_id, 0) num_full_blocks = num_tokens // self.block_size if num_cached_blocks >= num_full_blocks: