[V1][PP] Optimization: continue scheduling prefill chunks (#17080)

Signed-off-by: Rui Qiao <ruisearch42@gmail.com>
This commit is contained in:
Rui Qiao 2025-04-24 05:27:08 -07:00 committed by GitHub
parent a9138e85b1
commit c0dfd97519
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 128 additions and 74 deletions

View File

@ -437,7 +437,6 @@ def test_stop_via_update_from_output():
req.num_computed_tokens = req.num_tokens
scheduler.requests[req.request_id] = req
scheduler.running.append(req)
scheduler.scheduled_req_ids.add(req.request_id)
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
scheduled_cached_reqs=[],
@ -489,7 +488,6 @@ def test_stop_via_update_from_output():
req.num_computed_tokens = req.num_tokens
scheduler.requests[req.request_id] = req
scheduler.running.append(req)
scheduler.scheduled_req_ids.add(req.request_id)
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
scheduled_cached_reqs=[],
@ -539,7 +537,6 @@ def test_stop_via_update_from_output():
req.num_computed_tokens = req.num_tokens
scheduler.requests[req.request_id] = req
scheduler.running.append(req)
scheduler.scheduled_req_ids.add(req.request_id)
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
scheduled_cached_reqs=[],
@ -589,7 +586,6 @@ def test_stop_via_update_from_output():
requests[0].num_computed_tokens = requests[0].num_tokens
scheduler.requests[requests[0].request_id] = requests[0]
scheduler.running.append(requests[0])
scheduler.scheduled_req_ids.add(requests[0].request_id)
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],

View File

@ -1,10 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
import copy
import threading
import time
import uuid
from concurrent.futures import Future
from concurrent.futures import Future, ThreadPoolExecutor
import pytest
from transformers import AutoTokenizer
@ -244,33 +243,33 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
self, kv_cache_configs: list[KVCacheConfig]) -> None:
super().initialize_from_config(kv_cache_configs)
# This executor actually can only run 1 batch at a time
self.semaphore = threading.Semaphore(1)
# Create a thread pool with a single worker
self.thread_pool = ThreadPoolExecutor(max_workers=1)
def execute_model(
self,
scheduler_output,
) -> Future[ModelRunnerOutput]:
"""Make execute_model non-blocking."""
future: Future[ModelRunnerOutput] = Future()
def _thread_wrapper(scheduler_output, future):
with self.semaphore:
output = self.collective_rpc("execute_model",
args=(scheduler_output, ))
# Make a copy because output[0] may be reused
# by the next batch.
output = copy.deepcopy(output[0])
future.set_result(output)
def _execute():
output = self.collective_rpc("execute_model",
args=(scheduler_output, ))
# Make a copy because output[0] may be reused
# by the next batch.
return copy.deepcopy(output[0])
threading.Thread(target=_thread_wrapper,
args=(scheduler_output, future)).start()
return future
# Use the thread pool instead of creating a new thread
return self.thread_pool.submit(_execute)
@property
def max_concurrent_batches(self) -> int:
return 2
def shutdown(self):
if hasattr(self, 'thread_pool'):
self.thread_pool.shutdown(wait=False)
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
@ -299,14 +298,77 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
# Schedule Batch 1: (10, req0)
assert engine_core.step_with_batch_queue() is None
assert engine_core.batch_queue.qsize() == 1
scheduler_output = engine_core.batch_queue.queue[-1][1]
assert scheduler_output.num_scheduled_tokens[0] == 10
# num_computed_tokens should have been updated immediately.
assert engine_core.scheduler.requests[
req0.request_id].num_computed_tokens == 10
# Schedule Batch 2: (2, req0), (8, req1)
assert engine_core.step_with_batch_queue() is None
assert engine_core.batch_queue.qsize() == 2
scheduler_output = engine_core.batch_queue.queue[-1][1]
assert scheduler_output.num_scheduled_tokens[0] == 2
assert scheduler_output.num_scheduled_tokens[1] == 8
# num_computed_tokens should have been updated immediately.
assert engine_core.scheduler.requests[0].num_computed_tokens == 12
assert engine_core.scheduler.requests[1].num_computed_tokens == 8
assert engine_core.scheduler.get_num_unfinished_requests() == 2
# Loop through both requests.
while engine_core.scheduler.get_num_unfinished_requests() == 2:
engine_core.step_with_batch_queue()
# Batch queue is full. Finish Batch 1.
engine_core.step_with_batch_queue()
# Reaching here when got the result of the first request.
while engine_core.scheduler.get_num_unfinished_requests() == 1:
engine_core.step_with_batch_queue()
# Schedule Batch 3: (4, req1). Note that req0 cannot be scheduled
# because it is in the decoding stage now.
engine_core.step_with_batch_queue()
assert engine_core.batch_queue.qsize() == 2
scheduler_output = engine_core.batch_queue.queue[-1][1]
assert scheduler_output.num_scheduled_tokens[1] == 4
# Batch queue is full. Finish Batch 2. Get first token of req0.
output = engine_core.step_with_batch_queue()
assert output is not None
assert len(output.outputs) == 1
assert engine_core.scheduler.requests[req0.request_id].num_tokens == 13
# Schedule Batch 4: (1, req0).
engine_core.step_with_batch_queue()
assert engine_core.batch_queue.qsize() == 2
scheduler_output = engine_core.batch_queue.queue[-1][1]
assert scheduler_output.num_scheduled_tokens[0] == 1
# Batch queue is full. Finish Batch 3. Get first token of req1.
output = engine_core.step_with_batch_queue()
assert output is not None
assert len(output.outputs) == 1
assert engine_core.scheduler.requests[req1.request_id].num_tokens == 13
# Schedule Batch 5: (1, req1).
engine_core.step_with_batch_queue()
assert engine_core.batch_queue.qsize() == 2
scheduler_output = engine_core.batch_queue.queue[-1][1]
assert scheduler_output.num_scheduled_tokens[1] == 1
# Loop until req0 is finished.
step = 0
req_id = 0
expected_num_tokens = [
engine_core.scheduler.requests[0].num_tokens + 1,
engine_core.scheduler.requests[1].num_tokens + 1,
]
while engine_core.scheduler.get_num_unfinished_requests() == 2:
output = engine_core.step_with_batch_queue()
if step % 2 == 0:
# Even steps consumes an output.
assert output is not None
assert len(output.outputs) == 1
if req_id in engine_core.scheduler.requests:
assert engine_core.scheduler.requests[
req_id].num_tokens == expected_num_tokens[req_id]
expected_num_tokens[req_id] += 1
req_id = (req_id + 1) % 2
else:
# Odd steps schedules a new batch.
assert output is None
step += 1

View File

@ -117,11 +117,6 @@ class SchedulerInterface(ABC):
not yet returned in SchedulerOutputs."""
return self.has_unfinished_requests() or self.has_finished_requests()
@abstractmethod
def get_num_unscheduled_requests(self) -> int:
"""Number of requests that are not being processed by the executor."""
raise NotImplementedError
@abstractmethod
def reset_prefix_cache(self) -> bool:
"""Reset the prefix cache for KV cache.

View File

@ -3,7 +3,7 @@
from __future__ import annotations
import time
from collections import deque
from collections import defaultdict, deque
from collections.abc import Iterable
from typing import Optional, Union
@ -88,9 +88,6 @@ class Scheduler(SchedulerInterface):
# Priority queues for requests.
self.waiting: deque[Request] = deque()
self.running: list[Request] = []
# The requests that have been scheduled and are being executed
# by the executor.
self.scheduled_req_ids: set[str] = set()
# The request IDs that are finished in between the previous and the
# current steps. This is used to notify the workers about the finished
@ -100,8 +97,9 @@ class Scheduler(SchedulerInterface):
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
# them at each scheduling step.
# Request id -> CachedRequestData
self._cached_reqs_data: dict[str, CachedRequestData] = {}
# Request id -> deque of CachedRequestData
self._cached_reqs_data: dict[
str, deque[CachedRequestData]] = defaultdict(deque)
# Encoder-related.
# Calculate encoder cache size if applicable
@ -171,10 +169,6 @@ class Scheduler(SchedulerInterface):
req_index = 0
while req_index < len(self.running) and token_budget > 0:
request = self.running[req_index]
if request.request_id in self.scheduled_req_ids:
# This request has already been scheduled.
req_index += 1
continue
num_new_tokens = (request.num_tokens_with_spec -
request.num_computed_tokens)
@ -183,33 +177,35 @@ class Scheduler(SchedulerInterface):
num_new_tokens = (
self.scheduler_config.long_prefill_token_threshold)
num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens > 0
# Make sure the input position does not exceed the max model len.
# This is necessary when using spec decoding.
num_new_tokens = min(
num_new_tokens,
self.max_model_len - request.num_computed_tokens)
assert num_new_tokens > 0
# Schedule encoder inputs.
encoder_inputs_to_schedule = None
new_encoder_budget = encoder_budget
if request.has_encoder_inputs:
(encoder_inputs_to_schedule, num_new_tokens,
new_encoder_budget) = self._try_schedule_encoder_inputs(
request, request.num_computed_tokens, num_new_tokens,
encoder_budget)
if num_new_tokens == 0:
# The request cannot be scheduled because the encoder budget
# or the encoder cache is exhausted.
# NOTE(woosuk): By using `continue` instead of `break` here,
# we intentionally relax the strict FCFS scheduling policy
# to allow lower-priority requests to be scheduled when a
# higher-priority request is blocked by encoder constraints.
req_index += 1
continue
else:
encoder_inputs_to_schedule = None
new_encoder_budget = encoder_budget
if num_new_tokens == 0:
# The request cannot be scheduled because one of the following
# reasons:
# 1. No new tokens to schedule. This may happen when PP>1 and
# we have already scheduled all prompt tokens but they are
# not finished yet.
# 2. The encoder budget is exhausted.
# 3. The encoder cache is exhausted.
# NOTE(woosuk): Here, by doing `continue` instead of `break`,
# we do not strictly follow the FCFS scheduling policy and
# allow the lower-priority requests to be scheduled.
req_index += 1
continue
while True:
new_blocks = self.kv_cache_manager.allocate_slots(
@ -243,7 +239,6 @@ class Scheduler(SchedulerInterface):
# Schedule the request.
scheduled_running_reqs.append(request)
self.scheduled_req_ids.add(request.request_id)
if request.use_structured_output:
# PERF: in case of chunked prefill,
# request might not include any new tokens.
@ -382,7 +377,6 @@ class Scheduler(SchedulerInterface):
request.request_id] = req_index
req_index += 1
self.running.append(request)
self.scheduled_req_ids.add(request.request_id)
if self.log_stats:
request.record_event(EngineCoreEventType.SCHEDULED,
scheduled_timestamp)
@ -521,18 +515,21 @@ class Scheduler(SchedulerInterface):
num_regular_tokens = num_scheduled_tokens - num_scheduled_spec_tokens
new_token_ids = request.all_token_ids[
num_computed_tokens:num_computed_tokens + num_regular_tokens]
req_data = self._cached_reqs_data.get(request.request_id)
if req_data is not None:
req_data_queue = self._cached_reqs_data.get(request.request_id)
if req_data_queue:
req_data = req_data_queue.popleft()
req_data.resumed_from_preemption = resumed_from_preemption
req_data.new_token_ids = new_token_ids
req_data.new_block_ids = new_block_ids
req_data.num_computed_tokens = num_computed_tokens
else:
# No cached request data, or all cached request data has been
# used by the scheduled requests.
req_data = CachedRequestData.from_request(request,
resumed_from_preemption,
new_token_ids,
new_block_ids)
self._cached_reqs_data[request.request_id] = req_data
return req_data
def _try_schedule_encoder_inputs(
@ -561,6 +558,8 @@ class Scheduler(SchedulerInterface):
Note that num_computed_tokens includes both locally cached
blocks and externally cached blocks (via KVConnector).
"""
if num_new_tokens == 0 or not request.has_encoder_inputs:
return [], num_new_tokens, encoder_budget
encoder_inputs_to_schedule: list[int] = []
mm_positions = request.mm_positions
assert mm_positions is not None
@ -728,10 +727,13 @@ class Scheduler(SchedulerInterface):
# Invariant: EngineCore returns no partial prefill outputs.
assert not prompt_logprobs_tensors
self.scheduled_req_ids.remove(req_id)
if not stopped:
new_running.append(request)
# Return the cached request data to the queue so they can be reused.
for req_data in scheduler_output.scheduled_cached_reqs:
self._cached_reqs_data[req_data.req_id].append(req_data)
self.running = new_running
engine_core_outputs = EngineCoreOutputs(
outputs=outputs,
@ -774,7 +776,6 @@ class Scheduler(SchedulerInterface):
if request.status == RequestStatus.RUNNING:
self.running.remove(request)
self.scheduled_req_ids.discard(request.request_id)
else:
self.waiting.remove(request)
request.status = finished_status
@ -795,10 +796,6 @@ class Scheduler(SchedulerInterface):
def has_finished_requests(self) -> bool:
return len(self.finished_req_ids) > 0
def get_num_unscheduled_requests(self) -> int:
"""Number of requests that are not being processed by the executor."""
return self.get_num_unfinished_requests() - len(self.scheduled_req_ids)
def reset_prefix_cache(self) -> bool:
return self.kv_cache_manager.reset_prefix_cache()

View File

@ -210,10 +210,10 @@ class EngineCore:
Note that if nothing to output in this step, None is returned.
The execution flow is as follows:
1. Try to schedule a new batch if there are unscheduled requests
and the job queue is not full. If a new batch is scheduled, directly
return an empty engine core output. In other words, we won't check
and return model outputs before the batch queue is full.
1. Try to schedule a new batch if the batch queue is not full.
If a new batch is scheduled, directly return an empty engine core
output. In other words, fulfilling the batch queue has a higher priority
than getting model outputs.
2. If there is no new scheduled batch, meaning that the batch queue
is full or no other requests can be scheduled, we block until the first
batch in the job queue is finished.
@ -223,10 +223,10 @@ class EngineCore:
engine_core_outputs = None
scheduler_output = None
# If there are unscheduled requests and the job queue
# is not full, schedule a new batch. Note that this is not blocking.
if (self.scheduler.get_num_unscheduled_requests() > 0
and not self.batch_queue.full()):
# Try to schedule a new batch if the batch queue is not full, but
# the scheduler may return an empty batch if all requests are scheduled.
# Note that this is not blocking.
if not self.batch_queue.full():
scheduler_output = self.scheduler.schedule()
if scheduler_output.total_num_scheduled_tokens > 0:
future = self.model_executor.execute_model(scheduler_output)
@ -238,6 +238,10 @@ class EngineCore:
# If no more requests can be scheduled and the job queue is not empty,
# block until the first batch in the job queue is finished.
# TODO(comaniac): Ideally we should peek the first batch in the
# job queue to check if it's finished before scheduling a new batch,
# but peeking the first element in a queue is not thread-safe,
# so we need more work.
if not scheduled_batch and not self.batch_queue.empty():
future, scheduler_output = self.batch_queue.get_nowait()
# Blocking until the first result is available.