From 44822d7ff22cb62856b4a107f97f00fc8e1199a0 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Mon, 1 Dec 2025 17:15:52 -0800 Subject: [PATCH] [BugFix] Preserve spec decoding uniform decode when scheduling (#29759) Signed-off-by: Nick Hill --- tests/v1/e2e/test_spec_decode.py | 4 +-- vllm/v1/core/sched/async_scheduler.py | 2 +- vllm/v1/core/sched/scheduler.py | 36 ++++++++++++++++----------- 3 files changed, 25 insertions(+), 17 deletions(-) diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 3a25f7411eecd..f711715dec0e6 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -191,8 +191,8 @@ def test_suffix_decoding_acceptance( # Expect the acceptance rate to improve. assert first_accept_rate < last_accept_rate - # Heuristic: expect at least 85% acceptance rate at the end. - assert last_accept_rate > 0.85 + # Heuristic: expect at least 82.5% acceptance rate at the end. + assert last_accept_rate > 0.825 del spec_llm torch.cuda.empty_cache() diff --git a/vllm/v1/core/sched/async_scheduler.py b/vllm/v1/core/sched/async_scheduler.py index 3214f65a09728..7916fafdae1fb 100644 --- a/vllm/v1/core/sched/async_scheduler.py +++ b/vllm/v1/core/sched/async_scheduler.py @@ -33,7 +33,7 @@ class AsyncScheduler(Scheduler): # in this scheduling step. request.num_output_placeholders += 1 + cur_num_spec_tokens # Add placeholders for the new tokens in spec_token_ids. - # Wwe will update the actual spec token ids in the worker process. + # We will update the actual spec token ids in the worker process. request.spec_token_ids = [-1] * self.num_spec_tokens scheduler_output.pending_structured_output_tokens = ( diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 4314ba75eceef..c1ead200ba8d6 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -236,6 +236,22 @@ class Scheduler(SchedulerInterface): while req_index < len(self.running) and token_budget > 0: request = self.running[req_index] + if ( + request.num_output_placeholders > 0 + # This is (num_computed_tokens + 1) - (num_output_placeholders - 1). + # Since output placeholders are also included in the computed tokens + # count, we subtract (num_output_placeholders - 1) to remove any draft + # tokens, so that we can be sure no further steps are needed even if + # they are all rejected. + and request.num_computed_tokens + 2 - request.num_output_placeholders + >= request.num_prompt_tokens + request.max_tokens + ): + # Async scheduling: Avoid scheduling an extra step when we are sure that + # the previous step has reached request.max_tokens. We don't schedule + # partial draft tokens since this prevents uniform decode optimizations. + req_index += 1 + continue + num_new_tokens = ( request.num_tokens_with_spec + request.num_output_placeholders @@ -245,18 +261,10 @@ class Scheduler(SchedulerInterface): num_new_tokens = self.scheduler_config.long_prefill_token_threshold num_new_tokens = min(num_new_tokens, token_budget) - num_spec_placeholders = max(0, request.num_output_placeholders - 1) - max_total_tokens = min( - # Avoid scheduling tokens that we're sure won't will be needed based on - # request.max_tokens. For this calculation we assume placeholder - # speculated output tokens are rejected. - request.num_prompt_tokens + request.max_tokens + num_spec_placeholders, - # Make sure the input position does not exceed the max model len. - # This is necessary when using spec decoding. - self.max_model_len, - ) + # Make sure the input position does not exceed the max model len. + # This is necessary when using spec decoding. num_new_tokens = min( - num_new_tokens, max_total_tokens - 1 - request.num_computed_tokens + num_new_tokens, self.max_model_len - 1 - request.num_computed_tokens ) # Schedule encoder inputs. @@ -799,15 +807,15 @@ class Scheduler(SchedulerInterface): for idx, req in enumerate(itertools.chain(running_reqs, resumed_reqs)): req_id = req.request_id req_ids.append(req_id) - num_tokens = num_scheduled_tokens[req_id] - len( - spec_decode_tokens.get(req_id, ()) - ) if self.use_pp: # When using PP, the scheduler sends the sampled tokens back, # because there's no direct communication between the first- # stage worker and the last-stage worker. Otherwise, we don't # need to send the sampled tokens back because the model runner # will cache them. + num_tokens = num_scheduled_tokens[req_id] - len( + spec_decode_tokens.get(req_id, ()) + ) token_ids = req.all_token_ids[ req.num_computed_tokens : req.num_computed_tokens + num_tokens ]