From b73fdb927a23d5aedc571202c34ec7287ec7efce Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Sat, 3 May 2025 10:50:34 -0700 Subject: [PATCH] draft Signed-off-by: LiuXiaoxuanPKU --- tests/v1/core/test_scheduler.py | 55 +++++++++++++++++++++++++++++++++ vllm/config.py | 3 ++ vllm/v1/core/sched/scheduler.py | 16 ++++++++++ 3 files changed, 74 insertions(+) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index ee4e95856f233..5e4881c1a76fc 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -24,6 +24,7 @@ def create_scheduler( model: str = "facebook/opt-125m", max_num_seqs: int = 16, max_num_batched_tokens: int = 8192, + max_num_spec_tokens: Optional[int] = None, enable_prefix_caching: Optional[bool] = None, long_prefill_token_threshold: int = 0, disable_chunked_mm_input: bool = False, @@ -51,6 +52,7 @@ def create_scheduler( scheduler_config = SchedulerConfig( max_num_seqs=max_num_seqs, max_num_batched_tokens=max_num_batched_tokens, + max_num_spec_tokens=max_num_spec_tokens, max_model_len=max_model_len, long_prefill_token_threshold=long_prefill_token_threshold, disable_chunked_mm_input=disable_chunked_mm_input, @@ -684,6 +686,59 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool], scheduler.update_from_output(scheduler_output1, model_runner_output) +def test_spec_token_budget(): + """Test scheduling behavior when spec token buget limits the total + number of scheduled tokens.""" + # Create scheduler with spec_token_budget=5 + scheduler = create_scheduler( + max_num_batched_tokens=100, + max_num_spec_tokens=14, # Total spec budget for this test + ) + + requests = create_requests( + num_requests=2, + num_tokens=10, + ) + + spec_tokens = [list(range(10)), list(range(5))] + req_ids = [] + req_to_index = {} + for i, request in enumerate(requests): + scheduler.add_request(request) + req_ids.append(request.request_id) + req_to_index[request.request_id] = i + output = scheduler.schedule() + model_runner_output = ModelRunnerOutput( + req_ids=req_ids, + req_id_to_index=req_to_index, + sampled_token_ids=[[0] for _ in range(len(requests))], + spec_token_ids=spec_tokens, + logprobs=None, + prompt_logprobs_dict={}, + ) + scheduler.update_from_output(output, model_runner_output) + + output = scheduler.schedule() + request1, request2 = requests + # --- Verify request1 --- + # num_new_tokens = 11 + # num_scheduled_spec_tokens = 10 + # Budget starts at 14: 10 <= 14 → no truncation + # num_new_tokens = min(10, 14) → 10 + assert len(request1.spec_token_ids) == 10 # Not truncated + assert output.num_scheduled_tokens[request1.request_id] == 11 + assert len(output.scheduled_spec_decode_tokens[request1.request_id]) == 10 + + # --- Verify request2 --- + # Remaining budget after request1: 14 - 10 = 4 + # num_new_tokens = 6 + # num_scheduled_spec_tokens = 6-1 = 5 > 4 → truncate to 4 + # num_new_tokens = min(5, 4) → 4 + assert len(request2.spec_token_ids) == 4 # Truncated from 5 + assert output.num_scheduled_tokens[request2.request_id] == 5 + assert len(output.scheduled_spec_decode_tokens[request2.request_id]) == 4 + + # Note - these test cases mirror some of those in test_rejection_sampler.py @pytest.mark.parametrize( "spec_tokens,output_tokens,expected", diff --git a/vllm/config.py b/vllm/config.py index e645103557c12..abcea872c24e5 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1841,6 +1841,9 @@ class SchedulerConfig: is primarily set in `ModelConfig` and that value should be manually duplicated here.""" + max_num_spec_tokens: int = None # type: ignore + """Maximum number of speculative tokens for all requests in the batch.""" + max_num_partial_prefills: int = 1 """For chunked prefill, the maximum number of sequences that can be partially prefilled concurrently.""" diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 21711c9292f9f..bd13ffe6487f9 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -62,6 +62,7 @@ class Scheduler(SchedulerInterface): self.max_num_scheduled_tokens = \ self.scheduler_config.max_num_batched_tokens self.max_model_len = self.scheduler_config.max_model_len + self.max_num_spec_tokens = self.scheduler_config.max_num_spec_tokens # Create KVConnector for the Scheduler. Note that each Worker # will have a corresponding KVConnector with Role=WORKER. @@ -162,6 +163,8 @@ class Scheduler(SchedulerInterface): req_to_new_block_ids: dict[str, list[int]] = {} num_scheduled_tokens: dict[str, int] = {} token_budget = self.max_num_scheduled_tokens + spec_token_budget = self.max_num_spec_tokens + # Encoder-related. scheduled_encoder_inputs: dict[str, list[int]] = {} encoder_budget = self.max_num_encoder_input_tokens @@ -184,6 +187,19 @@ class Scheduler(SchedulerInterface): self.scheduler_config.long_prefill_token_threshold) num_new_tokens = min(num_new_tokens, token_budget) + num_scheduled_spec_tokens = (num_new_tokens + + request.num_computed_tokens - + request.num_tokens) + if spec_token_budget: + if num_scheduled_spec_tokens > spec_token_budget: + # We don't truncate the spec_token_ids list here because + # it will be trimmed in the end of the while loop. + num_scheduled_spec_tokens = spec_token_budget + # +1 here to include the last generated token. + num_new_tokens = min(num_new_tokens, + num_scheduled_spec_tokens + 1) + spec_token_budget -= num_scheduled_spec_tokens + # Make sure the input position does not exceed the max model len. # This is necessary when using spec decoding. num_new_tokens = min(