diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 8916aa580000a..c12f2fd594385 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -20,9 +20,10 @@ def create_scheduler( max_num_seqs: int = 16, max_num_batched_tokens: int = 8192, enable_prefix_caching: Optional[bool] = None, + long_prefill_token_threshold: int = 0, ) -> Scheduler: '''Create scheduler under test. - + Args: model: model under test max_num_seqs: max sequences to schedule @@ -38,6 +39,7 @@ def create_scheduler( max_num_seqs=max_num_seqs, max_num_batched_tokens=max_num_batched_tokens, max_model_len=max_num_batched_tokens, + long_prefill_token_threshold=long_prefill_token_threshold, ) model_config = ModelConfig( model=model, @@ -263,6 +265,78 @@ def test_schedule_partial_requests(): assert requests[2].request_id not in output.num_scheduled_tokens +@pytest.mark.parametrize("enable_prefix_caching", [True, False]) +def test_schedule_concurrent_partial_requestse(enable_prefix_caching: bool): + """Test scheduling behavior with concurrent partial requests. + + This test verifies that: there are multiple long prefill requests in the + RUNNING state, and we can schedule them together. + + """ + scheduler = create_scheduler( + model="facebook/opt-125m", + max_num_batched_tokens=1024, + long_prefill_token_threshold=400, + enable_prefix_caching=enable_prefix_caching, + ) + requests = create_requests( + num_requests=3, + num_tokens=800, + ) + for request in requests: + scheduler.add_request(request) + + output = scheduler.schedule() + assert len(output.scheduled_new_reqs) == 3 + assert len(output.scheduled_cached_reqs) == 0 + assert len(output.finished_req_ids) == 0 + + # The first request is scheduled partially - 400. + assert output.num_scheduled_tokens[requests[0].request_id] == 400 + # The second request is scheduled partially - 400. + assert output.num_scheduled_tokens[requests[1].request_id] == 400 + # The third request is also scheduled partially - 1024 - 400 - 400 = 224. + assert output.num_scheduled_tokens[requests[2].request_id] == 224 + req_to_index = { + request.request_id: i + for i, request in enumerate(requests) + } + model_runner_output = ModelRunnerOutput( + req_ids=[request.request_id for request in requests], + req_id_to_index=req_to_index, + sampled_token_ids=[[0] for _ in range(len(requests))], + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + ) + scheduler.update_from_output(output, model_runner_output) + + # Schedule the next step. All three requests are running. + # Processed the remaining prefills of the first and second requests. + output1 = scheduler.schedule() + assert len(scheduler.running) == 3 + assert len(output1.scheduled_new_reqs) == 0 + assert len(output1.scheduled_cached_reqs) == 3 + assert len(output1.finished_req_ids) == 0 + assert output1.num_scheduled_tokens[requests[0].request_id] == 400 + assert output1.num_scheduled_tokens[requests[1].request_id] == 400 + assert output1.num_scheduled_tokens[requests[2].request_id] == 224 + + # Schedule the third step. All three requests are running. + # First and second requests are in the decode stage. + # All the remaining tokens in the third request are processed. + scheduler.update_from_output(output1, model_runner_output) + output2 = scheduler.schedule() + assert len(scheduler.running) == 3 + assert len(output2.scheduled_new_reqs) == 0 + assert len(output2.scheduled_cached_reqs) == 3 + assert len(output2.finished_req_ids) == 0 + assert output2.num_scheduled_tokens[requests[0].request_id] == 1 + assert output2.num_scheduled_tokens[requests[1].request_id] == 1 + assert output2.num_scheduled_tokens[ + requests[2].request_id] == 800 - 224 - 224 + + def test_stop_via_update_from_output(): """Test stopping behavior through update_from_output""" scheduler = create_scheduler() diff --git a/tests/v1/core/test_scheduler_e2e.py b/tests/v1/core/test_scheduler_e2e.py new file mode 100644 index 0000000000000..0a79424a30b74 --- /dev/null +++ b/tests/v1/core/test_scheduler_e2e.py @@ -0,0 +1,29 @@ +# SPDX-License-Identifier: Apache-2.0 +import os + +import pytest + +from vllm import LLM + +if os.getenv("VLLM_USE_V1", "0") != "1": + pytest.skip("Test package requires V1", allow_module_level=True) + +MODEL = "meta-llama/Llama-3.2-1B" +PROMPT = "Hello my name is Robert and I" + + +@pytest.fixture(scope="module") +def model() -> LLM: + return LLM(MODEL, + enforce_eager=True, + enable_prefix_caching=True, + long_prefill_token_threshold=2, + max_num_batched_tokens=6, + max_num_seqs=3) + + +def test_concurrent_partial_prefill(model): + outputs = model.generate([PROMPT] * 3) + assert len(outputs) == 3 + for output in outputs: + assert len(output.outputs) == 1 diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 867842cc31d10..65a1676c0637d 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1625,9 +1625,7 @@ class EngineArgs: if (self.max_num_partial_prefills != EngineArgs.max_num_partial_prefills or self.max_long_partial_prefills - != EngineArgs.max_long_partial_prefills - or self.long_prefill_token_threshold - != EngineArgs.long_prefill_token_threshold): + != EngineArgs.max_long_partial_prefills): _raise_or_fallback(feature_name="Concurrent Partial Prefill", recommend_to_remove=False) return False diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index c71eb9a0445c7..9b0cddb2818c2 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -152,6 +152,10 @@ class Scheduler(SchedulerInterface): num_new_tokens = (request.num_tokens_with_spec - request.num_computed_tokens) + if self.scheduler_config.long_prefill_token_threshold > 0: + num_new_tokens = min( + num_new_tokens, + self.scheduler_config.long_prefill_token_threshold) num_new_tokens = min(num_new_tokens, token_budget) assert num_new_tokens > 0 @@ -299,6 +303,10 @@ class Scheduler(SchedulerInterface): num_computed_tokens -= self.block_size num_new_tokens = self.block_size computed_blocks.pop() + if self.scheduler_config.long_prefill_token_threshold > 0: + num_new_tokens = min( + num_new_tokens, + self.scheduler_config.long_prefill_token_threshold) num_new_tokens = min(num_new_tokens, token_budget) assert num_new_tokens > 0