mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 18:25:01 +08:00
[V1] Support long_prefill_token_threshold in v1 scheduler (#15419)
Signed-off-by: Lu Fang <lufang@fb.com>
This commit is contained in:
parent
6aa196c8dc
commit
082ab86f5f
@ -20,6 +20,7 @@ def create_scheduler(
|
|||||||
max_num_seqs: int = 16,
|
max_num_seqs: int = 16,
|
||||||
max_num_batched_tokens: int = 8192,
|
max_num_batched_tokens: int = 8192,
|
||||||
enable_prefix_caching: Optional[bool] = None,
|
enable_prefix_caching: Optional[bool] = None,
|
||||||
|
long_prefill_token_threshold: int = 0,
|
||||||
) -> Scheduler:
|
) -> Scheduler:
|
||||||
'''Create scheduler under test.
|
'''Create scheduler under test.
|
||||||
|
|
||||||
@ -38,6 +39,7 @@ def create_scheduler(
|
|||||||
max_num_seqs=max_num_seqs,
|
max_num_seqs=max_num_seqs,
|
||||||
max_num_batched_tokens=max_num_batched_tokens,
|
max_num_batched_tokens=max_num_batched_tokens,
|
||||||
max_model_len=max_num_batched_tokens,
|
max_model_len=max_num_batched_tokens,
|
||||||
|
long_prefill_token_threshold=long_prefill_token_threshold,
|
||||||
)
|
)
|
||||||
model_config = ModelConfig(
|
model_config = ModelConfig(
|
||||||
model=model,
|
model=model,
|
||||||
@ -263,6 +265,78 @@ def test_schedule_partial_requests():
|
|||||||
assert requests[2].request_id not in output.num_scheduled_tokens
|
assert requests[2].request_id not in output.num_scheduled_tokens
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("enable_prefix_caching", [True, False])
|
||||||
|
def test_schedule_concurrent_partial_requestse(enable_prefix_caching: bool):
|
||||||
|
"""Test scheduling behavior with concurrent partial requests.
|
||||||
|
|
||||||
|
This test verifies that: there are multiple long prefill requests in the
|
||||||
|
RUNNING state, and we can schedule them together.
|
||||||
|
|
||||||
|
"""
|
||||||
|
scheduler = create_scheduler(
|
||||||
|
model="facebook/opt-125m",
|
||||||
|
max_num_batched_tokens=1024,
|
||||||
|
long_prefill_token_threshold=400,
|
||||||
|
enable_prefix_caching=enable_prefix_caching,
|
||||||
|
)
|
||||||
|
requests = create_requests(
|
||||||
|
num_requests=3,
|
||||||
|
num_tokens=800,
|
||||||
|
)
|
||||||
|
for request in requests:
|
||||||
|
scheduler.add_request(request)
|
||||||
|
|
||||||
|
output = scheduler.schedule()
|
||||||
|
assert len(output.scheduled_new_reqs) == 3
|
||||||
|
assert len(output.scheduled_cached_reqs) == 0
|
||||||
|
assert len(output.finished_req_ids) == 0
|
||||||
|
|
||||||
|
# The first request is scheduled partially - 400.
|
||||||
|
assert output.num_scheduled_tokens[requests[0].request_id] == 400
|
||||||
|
# The second request is scheduled partially - 400.
|
||||||
|
assert output.num_scheduled_tokens[requests[1].request_id] == 400
|
||||||
|
# The third request is also scheduled partially - 1024 - 400 - 400 = 224.
|
||||||
|
assert output.num_scheduled_tokens[requests[2].request_id] == 224
|
||||||
|
req_to_index = {
|
||||||
|
request.request_id: i
|
||||||
|
for i, request in enumerate(requests)
|
||||||
|
}
|
||||||
|
model_runner_output = ModelRunnerOutput(
|
||||||
|
req_ids=[request.request_id for request in requests],
|
||||||
|
req_id_to_index=req_to_index,
|
||||||
|
sampled_token_ids=[[0] for _ in range(len(requests))],
|
||||||
|
spec_token_ids=None,
|
||||||
|
logprobs=None,
|
||||||
|
prompt_logprobs_dict={},
|
||||||
|
)
|
||||||
|
scheduler.update_from_output(output, model_runner_output)
|
||||||
|
|
||||||
|
# Schedule the next step. All three requests are running.
|
||||||
|
# Processed the remaining prefills of the first and second requests.
|
||||||
|
output1 = scheduler.schedule()
|
||||||
|
assert len(scheduler.running) == 3
|
||||||
|
assert len(output1.scheduled_new_reqs) == 0
|
||||||
|
assert len(output1.scheduled_cached_reqs) == 3
|
||||||
|
assert len(output1.finished_req_ids) == 0
|
||||||
|
assert output1.num_scheduled_tokens[requests[0].request_id] == 400
|
||||||
|
assert output1.num_scheduled_tokens[requests[1].request_id] == 400
|
||||||
|
assert output1.num_scheduled_tokens[requests[2].request_id] == 224
|
||||||
|
|
||||||
|
# Schedule the third step. All three requests are running.
|
||||||
|
# First and second requests are in the decode stage.
|
||||||
|
# All the remaining tokens in the third request are processed.
|
||||||
|
scheduler.update_from_output(output1, model_runner_output)
|
||||||
|
output2 = scheduler.schedule()
|
||||||
|
assert len(scheduler.running) == 3
|
||||||
|
assert len(output2.scheduled_new_reqs) == 0
|
||||||
|
assert len(output2.scheduled_cached_reqs) == 3
|
||||||
|
assert len(output2.finished_req_ids) == 0
|
||||||
|
assert output2.num_scheduled_tokens[requests[0].request_id] == 1
|
||||||
|
assert output2.num_scheduled_tokens[requests[1].request_id] == 1
|
||||||
|
assert output2.num_scheduled_tokens[
|
||||||
|
requests[2].request_id] == 800 - 224 - 224
|
||||||
|
|
||||||
|
|
||||||
def test_stop_via_update_from_output():
|
def test_stop_via_update_from_output():
|
||||||
"""Test stopping behavior through update_from_output"""
|
"""Test stopping behavior through update_from_output"""
|
||||||
scheduler = create_scheduler()
|
scheduler = create_scheduler()
|
||||||
|
|||||||
29
tests/v1/core/test_scheduler_e2e.py
Normal file
29
tests/v1/core/test_scheduler_e2e.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm import LLM
|
||||||
|
|
||||||
|
if os.getenv("VLLM_USE_V1", "0") != "1":
|
||||||
|
pytest.skip("Test package requires V1", allow_module_level=True)
|
||||||
|
|
||||||
|
MODEL = "meta-llama/Llama-3.2-1B"
|
||||||
|
PROMPT = "Hello my name is Robert and I"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def model() -> LLM:
|
||||||
|
return LLM(MODEL,
|
||||||
|
enforce_eager=True,
|
||||||
|
enable_prefix_caching=True,
|
||||||
|
long_prefill_token_threshold=2,
|
||||||
|
max_num_batched_tokens=6,
|
||||||
|
max_num_seqs=3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_concurrent_partial_prefill(model):
|
||||||
|
outputs = model.generate([PROMPT] * 3)
|
||||||
|
assert len(outputs) == 3
|
||||||
|
for output in outputs:
|
||||||
|
assert len(output.outputs) == 1
|
||||||
@ -1625,9 +1625,7 @@ class EngineArgs:
|
|||||||
if (self.max_num_partial_prefills
|
if (self.max_num_partial_prefills
|
||||||
!= EngineArgs.max_num_partial_prefills
|
!= EngineArgs.max_num_partial_prefills
|
||||||
or self.max_long_partial_prefills
|
or self.max_long_partial_prefills
|
||||||
!= EngineArgs.max_long_partial_prefills
|
!= EngineArgs.max_long_partial_prefills):
|
||||||
or self.long_prefill_token_threshold
|
|
||||||
!= EngineArgs.long_prefill_token_threshold):
|
|
||||||
_raise_or_fallback(feature_name="Concurrent Partial Prefill",
|
_raise_or_fallback(feature_name="Concurrent Partial Prefill",
|
||||||
recommend_to_remove=False)
|
recommend_to_remove=False)
|
||||||
return False
|
return False
|
||||||
|
|||||||
@ -152,6 +152,10 @@ class Scheduler(SchedulerInterface):
|
|||||||
|
|
||||||
num_new_tokens = (request.num_tokens_with_spec -
|
num_new_tokens = (request.num_tokens_with_spec -
|
||||||
request.num_computed_tokens)
|
request.num_computed_tokens)
|
||||||
|
if self.scheduler_config.long_prefill_token_threshold > 0:
|
||||||
|
num_new_tokens = min(
|
||||||
|
num_new_tokens,
|
||||||
|
self.scheduler_config.long_prefill_token_threshold)
|
||||||
num_new_tokens = min(num_new_tokens, token_budget)
|
num_new_tokens = min(num_new_tokens, token_budget)
|
||||||
assert num_new_tokens > 0
|
assert num_new_tokens > 0
|
||||||
|
|
||||||
@ -299,6 +303,10 @@ class Scheduler(SchedulerInterface):
|
|||||||
num_computed_tokens -= self.block_size
|
num_computed_tokens -= self.block_size
|
||||||
num_new_tokens = self.block_size
|
num_new_tokens = self.block_size
|
||||||
computed_blocks.pop()
|
computed_blocks.pop()
|
||||||
|
if self.scheduler_config.long_prefill_token_threshold > 0:
|
||||||
|
num_new_tokens = min(
|
||||||
|
num_new_tokens,
|
||||||
|
self.scheduler_config.long_prefill_token_threshold)
|
||||||
num_new_tokens = min(num_new_tokens, token_budget)
|
num_new_tokens = min(num_new_tokens, token_budget)
|
||||||
assert num_new_tokens > 0
|
assert num_new_tokens > 0
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user