diff --git a/tests/v1/core/test_async_scheduler.py b/tests/v1/core/test_async_scheduler.py index 6d870b5640dfb..e0645ed43015e 100644 --- a/tests/v1/core/test_async_scheduler.py +++ b/tests/v1/core/test_async_scheduler.py @@ -34,15 +34,20 @@ def test_stop_by_max_tokens(max_tokens: int): requests = create_requests(num_requests=2, max_tokens=max_tokens) req0, req1 = requests + expected_total_num_scheduled_tokens = 0 sched_outputs: deque[SchedulerOutput] = deque() scheduler.add_request(req0) sched_outputs.append(scheduler.schedule()) + expected_total_num_scheduled_tokens += req0.num_prompt_tokens + max_tokens - 1 scheduler.add_request(req1) sched_outputs.append(scheduler.schedule()) + expected_total_num_scheduled_tokens += req1.num_prompt_tokens + max_tokens - 1 + total_num_scheduled_tokens = 0 while sched_outputs: sched_output = sched_outputs.popleft() + total_num_scheduled_tokens += sched_output.total_num_scheduled_tokens model_runner_output = _make_model_runner_output(sched_output) scheduler.update_from_output(sched_output, model_runner_output) @@ -53,6 +58,8 @@ def test_stop_by_max_tokens(max_tokens: int): assert scheduler.get_num_unfinished_requests() == 0 assert req0.num_output_tokens == max_tokens assert req1.num_output_tokens == max_tokens + # Ensure we aren't scheduling more tokens than necessary. + assert total_num_scheduled_tokens == expected_total_num_scheduled_tokens def test_abort(): diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 9b55d2b14b991..ffd9f3e0370f7 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -155,7 +155,6 @@ def test_suffix_decoding_acceptance( ) # Run several times and check that the accepted tokens increase. - spec_llm.chat(test_prompts, sampling_config) num_draft = [] num_accept = [] for i in range(10): # Run multiple times to warm up the cache. diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index aeb9869c52813..97341c762b99d 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -217,10 +217,14 @@ class Scheduler(SchedulerInterface): num_new_tokens = self.scheduler_config.long_prefill_token_threshold num_new_tokens = min(num_new_tokens, token_budget) - # Make sure the input position does not exceed the max model len. - # This is necessary when using spec decoding. + # Make sure the input position does not exceed the max model len or + # request's max_tokens. + # This is necessary when using spec decoding and/or async scheduling. + max_total_tokens = min( + request.num_prompt_tokens + request.max_tokens, self.max_model_len + ) num_new_tokens = min( - num_new_tokens, self.max_model_len - 1 - request.num_computed_tokens + num_new_tokens, max_total_tokens - 1 - request.num_computed_tokens ) # Schedule encoder inputs.