diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 07d7c12a4f5e..70e869178804 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -1293,7 +1293,8 @@ def create_requests_with_priority( mm_positions: Optional[list[list[PlaceholderRange]]] = None, max_tokens: int = 16, stop_token_ids: Optional[list[int]] = None, - prompt_logprobs: Optional[int] = None): + prompt_logprobs: Optional[int] = None, + starting_idx: int = 0): """Create requests with specified priorities and arrival times.""" assert len(priorities) == num_requests if arrival_times is not None: @@ -1315,8 +1316,8 @@ def create_requests_with_priority( mm_position = None mm_kwargs = None request = Request( - request_id=f"{i}", - prompt_token_ids=[i] * num_tokens, + request_id=f"{i + starting_idx}", + prompt_token_ids=[i + starting_idx] * num_tokens, sampling_params=sampling_params, pooling_params=None, multi_modal_kwargs=mm_kwargs, @@ -1813,3 +1814,87 @@ def test_schedule_skip_tokenizer_init_structured_output_request(): assert len(output.scheduled_new_reqs) == 0 assert len(scheduler.running) == 0 assert len(scheduler.waiting) == 1 + + +def test_priority_scheduling_preemption_when_out_of_kv(): + """Test that priority scheduling preempts lower priority requests + when out of KV cache space.""" + # Create scheduler with very limited memory to force preemption + scheduler = create_scheduler_with_priority( + max_num_seqs=2, # Allow multiple requests + max_num_batched_tokens=200, + num_blocks=5, # Can hold 64 tokens (first block is null) + block_size=16, # Standard block size + ) + + # Create a request and schedule it + request_low = create_requests_with_priority( + num_requests=1, + priorities=[1], + arrival_times=[0.0], + num_tokens=30, + starting_idx=0, + )[0] + scheduler.add_request(request_low) + output = scheduler.schedule() + assert len(output.scheduled_new_reqs) == 1 + assert len(scheduler.waiting) == 0 + assert len(scheduler.running) == 1 + + # Simulate model execution + model_output = ModelRunnerOutput( + req_ids=[request_low.request_id], + req_id_to_index={request_low.request_id: 0}, + sampled_token_ids=[[100]], + # spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + ) + scheduler.update_from_output(output, model_output) + + # Create a high priority request and schedule it + request_high = create_requests_with_priority( + num_requests=1, + priorities=[0], + arrival_times=[1.0], + num_tokens=32, + starting_idx=1, + )[0] + scheduler.add_request(request_high) + output = scheduler.schedule() + # KV cache should be full at this point + assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == 0 + assert len(output.scheduled_new_reqs) == 1 + assert output.scheduled_cached_reqs.num_reqs == 1 + assert len(scheduler.waiting) == 0 + assert len(scheduler.running) == 2 + + # Simulate model execution + requests = [request_low, request_high] + model_output = ModelRunnerOutput( + req_ids=[req.request_id for req in requests], + req_id_to_index={ + req.request_id: i + for i, req in enumerate(requests) + }, + sampled_token_ids=[[100] for _ in requests], + # spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + ) + scheduler.update_from_output(output, model_output) + + # Schedule again - this should trigger preemption + # req_low needs 32 tokens = 2 blocks + # req_high needs 33 tokens = 3 blocks + # so doesn't fit in 4 blocks. + output = scheduler.schedule() + + # Should have preempted req_low + assert len(output.scheduled_new_reqs) == 0 + assert output.scheduled_cached_reqs.num_reqs == 1 + assert output.scheduled_cached_reqs.req_ids[0] == request_high.request_id + assert len(scheduler.waiting) == 1 + assert len(scheduler.running) == 1 \ No newline at end of file diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 14a914d8f2f0..3bd2fe2f0515 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -253,6 +253,8 @@ class Scheduler(SchedulerInterface): key=lambda r: (r.priority, r.arrival_time), ) self.running.remove(preempted_req) + if preempted_req in scheduled_running_reqs: + scheduled_running_reqs.remove(preempted_req) else: preempted_req = self.running.pop()