From 35afe1b30b154114dc2ee8329e12f8cf3fe9f576 Mon Sep 17 00:00:00 2001 From: Pradyun92 <142861237+Pradyun92@users.noreply.github.com> Date: Fri, 8 Aug 2025 20:04:15 -0400 Subject: [PATCH] [BugFix] [P/D] Handle lookahead token count edge-case with Eagle Spec Decoding and P/D (#22317) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Pradyun Ramadorai Signed-off-by: Pradyun92 <142861237+Pradyun92@users.noreply.github.com> Co-authored-by: Pradyun Ramadorai Co-authored-by: Nicolò Lucchesi --- vllm/v1/core/sched/scheduler.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index d39aea1f2d11..430085d9c978 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -437,14 +437,24 @@ class Scheduler(SchedulerInterface): # The request cannot be scheduled. break + # Handles an edge case when P/D Disaggregation + # is used with Spec Decoding where an + # extra block gets allocated which + # creates a mismatch between the number + # of local and remote blocks. + effective_lookahead_tokens = (0 if request.num_computed_tokens + == 0 else + self.num_lookahead_tokens) + new_blocks = self.kv_cache_manager.allocate_slots( request, num_new_tokens + num_external_computed_tokens, num_new_local_computed_tokens, new_computed_blocks, - num_lookahead_tokens=self.num_lookahead_tokens, + num_lookahead_tokens=effective_lookahead_tokens, delay_cache_blocks=load_kv_async, ) + if new_blocks is None: # The request cannot be scheduled. break