From 7c12a765aa2f2a97ebf6b3bc8361b464461832fc Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 9 Jul 2025 14:48:35 -0700 Subject: [PATCH] [Misc] Simplify the prefix caching logic on draft tokens (#20701) Signed-off-by: Woosuk Kwon --- vllm/v1/core/kv_cache_manager.py | 16 ++++++++++------ vllm/v1/core/sched/scheduler.py | 5 ----- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 6937455e7d85..3d5f85d2eacd 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -190,7 +190,6 @@ class KVCacheManager: num_new_tokens: int, num_new_computed_tokens: int = 0, new_computed_blocks: Optional[KVCacheBlocks] = None, - num_draft_tokens: int = 0, num_lookahead_tokens: int = 0, delay_cache_blocks: bool = False, ) -> Optional[KVCacheBlocks]: @@ -286,12 +285,17 @@ class KVCacheManager: if not self.enable_caching or delay_cache_blocks: return KVCacheBlocks(new_blocks) - # Speculated tokens might be rejected in the future, so we does - # not cache any speculated tokens. We only cache blocks with - # generated (accepted) tokens. + # NOTE(woosuk): We want to commit (cache) up to num_computed_tokens + + # num_new_tokens, but must exclude "non-committable" tokens (e.g., + # draft tokens that could be rejected). Therefore, we cap the number + # at `request.num_tokens`, ensuring only "finalized" tokens are cached. + num_tokens_to_cache = min(num_computed_tokens + num_new_tokens, + request.num_tokens) self.coordinator.cache_blocks( - request, self.req_to_block_hashes[request.request_id], - num_computed_tokens + num_new_tokens - num_draft_tokens) + request, + self.req_to_block_hashes[request.request_id], + num_tokens_to_cache, + ) return KVCacheBlocks(new_blocks) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 0c3acea3ae40..b2d90614c294 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -241,15 +241,10 @@ class Scheduler(SchedulerInterface): req_index += 1 continue - num_draft_tokens = max( - num_new_tokens + request.num_computed_tokens - - request.num_tokens, 0) - while True: new_blocks = self.kv_cache_manager.allocate_slots( request, num_new_tokens, - num_draft_tokens=num_draft_tokens, num_lookahead_tokens=self.num_lookahead_tokens) if new_blocks is None: # The request cannot be scheduled.