[V1][PP] Optimization: continue scheduling prefill chunks (#17080)

Signed-off-by: Rui Qiao <ruisearch42@gmail.com>
This commit is contained in:
Rui Qiao 2025-04-24 05:27:08 -07:00 committed by GitHub
parent a9138e85b1
commit c0dfd97519
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 128 additions and 74 deletions

View File

@ -437,7 +437,6 @@ def test_stop_via_update_from_output():
req.num_computed_tokens = req.num_tokens req.num_computed_tokens = req.num_tokens
scheduler.requests[req.request_id] = req scheduler.requests[req.request_id] = req
scheduler.running.append(req) scheduler.running.append(req)
scheduler.scheduled_req_ids.add(req.request_id)
scheduler_output = SchedulerOutput(scheduled_new_reqs=[], scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
scheduled_cached_reqs=[], scheduled_cached_reqs=[],
@ -489,7 +488,6 @@ def test_stop_via_update_from_output():
req.num_computed_tokens = req.num_tokens req.num_computed_tokens = req.num_tokens
scheduler.requests[req.request_id] = req scheduler.requests[req.request_id] = req
scheduler.running.append(req) scheduler.running.append(req)
scheduler.scheduled_req_ids.add(req.request_id)
scheduler_output = SchedulerOutput(scheduled_new_reqs=[], scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
scheduled_cached_reqs=[], scheduled_cached_reqs=[],
@ -539,7 +537,6 @@ def test_stop_via_update_from_output():
req.num_computed_tokens = req.num_tokens req.num_computed_tokens = req.num_tokens
scheduler.requests[req.request_id] = req scheduler.requests[req.request_id] = req
scheduler.running.append(req) scheduler.running.append(req)
scheduler.scheduled_req_ids.add(req.request_id)
scheduler_output = SchedulerOutput(scheduled_new_reqs=[], scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
scheduled_cached_reqs=[], scheduled_cached_reqs=[],
@ -589,7 +586,6 @@ def test_stop_via_update_from_output():
requests[0].num_computed_tokens = requests[0].num_tokens requests[0].num_computed_tokens = requests[0].num_tokens
scheduler.requests[requests[0].request_id] = requests[0] scheduler.requests[requests[0].request_id] = requests[0]
scheduler.running.append(requests[0]) scheduler.running.append(requests[0])
scheduler.scheduled_req_ids.add(requests[0].request_id)
scheduler_output = SchedulerOutput( scheduler_output = SchedulerOutput(
scheduled_new_reqs=[], scheduled_new_reqs=[],

View File

@ -1,10 +1,9 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import copy import copy
import threading
import time import time
import uuid import uuid
from concurrent.futures import Future from concurrent.futures import Future, ThreadPoolExecutor
import pytest import pytest
from transformers import AutoTokenizer from transformers import AutoTokenizer
@ -244,33 +243,33 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
self, kv_cache_configs: list[KVCacheConfig]) -> None: self, kv_cache_configs: list[KVCacheConfig]) -> None:
super().initialize_from_config(kv_cache_configs) super().initialize_from_config(kv_cache_configs)
# This executor actually can only run 1 batch at a time # Create a thread pool with a single worker
self.semaphore = threading.Semaphore(1) self.thread_pool = ThreadPoolExecutor(max_workers=1)
def execute_model( def execute_model(
self, self,
scheduler_output, scheduler_output,
) -> Future[ModelRunnerOutput]: ) -> Future[ModelRunnerOutput]:
"""Make execute_model non-blocking.""" """Make execute_model non-blocking."""
future: Future[ModelRunnerOutput] = Future()
def _thread_wrapper(scheduler_output, future): def _execute():
with self.semaphore: output = self.collective_rpc("execute_model",
output = self.collective_rpc("execute_model", args=(scheduler_output, ))
args=(scheduler_output, )) # Make a copy because output[0] may be reused
# Make a copy because output[0] may be reused # by the next batch.
# by the next batch. return copy.deepcopy(output[0])
output = copy.deepcopy(output[0])
future.set_result(output)
threading.Thread(target=_thread_wrapper, # Use the thread pool instead of creating a new thread
args=(scheduler_output, future)).start() return self.thread_pool.submit(_execute)
return future
@property @property
def max_concurrent_batches(self) -> int: def max_concurrent_batches(self) -> int:
return 2 return 2
def shutdown(self):
if hasattr(self, 'thread_pool'):
self.thread_pool.shutdown(wait=False)
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_USE_V1", "1")
@ -299,14 +298,77 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
# Schedule Batch 1: (10, req0) # Schedule Batch 1: (10, req0)
assert engine_core.step_with_batch_queue() is None assert engine_core.step_with_batch_queue() is None
assert engine_core.batch_queue.qsize() == 1 assert engine_core.batch_queue.qsize() == 1
scheduler_output = engine_core.batch_queue.queue[-1][1]
assert scheduler_output.num_scheduled_tokens[0] == 10
# num_computed_tokens should have been updated immediately.
assert engine_core.scheduler.requests[
req0.request_id].num_computed_tokens == 10
# Schedule Batch 2: (2, req0), (8, req1)
assert engine_core.step_with_batch_queue() is None assert engine_core.step_with_batch_queue() is None
assert engine_core.batch_queue.qsize() == 2 assert engine_core.batch_queue.qsize() == 2
scheduler_output = engine_core.batch_queue.queue[-1][1]
assert scheduler_output.num_scheduled_tokens[0] == 2
assert scheduler_output.num_scheduled_tokens[1] == 8
# num_computed_tokens should have been updated immediately.
assert engine_core.scheduler.requests[0].num_computed_tokens == 12
assert engine_core.scheduler.requests[1].num_computed_tokens == 8
assert engine_core.scheduler.get_num_unfinished_requests() == 2 assert engine_core.scheduler.get_num_unfinished_requests() == 2
# Loop through both requests. # Batch queue is full. Finish Batch 1.
while engine_core.scheduler.get_num_unfinished_requests() == 2: engine_core.step_with_batch_queue()
engine_core.step_with_batch_queue()
# Reaching here when got the result of the first request. # Schedule Batch 3: (4, req1). Note that req0 cannot be scheduled
while engine_core.scheduler.get_num_unfinished_requests() == 1: # because it is in the decoding stage now.
engine_core.step_with_batch_queue() engine_core.step_with_batch_queue()
assert engine_core.batch_queue.qsize() == 2
scheduler_output = engine_core.batch_queue.queue[-1][1]
assert scheduler_output.num_scheduled_tokens[1] == 4
# Batch queue is full. Finish Batch 2. Get first token of req0.
output = engine_core.step_with_batch_queue()
assert output is not None
assert len(output.outputs) == 1
assert engine_core.scheduler.requests[req0.request_id].num_tokens == 13
# Schedule Batch 4: (1, req0).
engine_core.step_with_batch_queue()
assert engine_core.batch_queue.qsize() == 2
scheduler_output = engine_core.batch_queue.queue[-1][1]
assert scheduler_output.num_scheduled_tokens[0] == 1
# Batch queue is full. Finish Batch 3. Get first token of req1.
output = engine_core.step_with_batch_queue()
assert output is not None
assert len(output.outputs) == 1
assert engine_core.scheduler.requests[req1.request_id].num_tokens == 13
# Schedule Batch 5: (1, req1).
engine_core.step_with_batch_queue()
assert engine_core.batch_queue.qsize() == 2
scheduler_output = engine_core.batch_queue.queue[-1][1]
assert scheduler_output.num_scheduled_tokens[1] == 1
# Loop until req0 is finished.
step = 0
req_id = 0
expected_num_tokens = [
engine_core.scheduler.requests[0].num_tokens + 1,
engine_core.scheduler.requests[1].num_tokens + 1,
]
while engine_core.scheduler.get_num_unfinished_requests() == 2:
output = engine_core.step_with_batch_queue()
if step % 2 == 0:
# Even steps consumes an output.
assert output is not None
assert len(output.outputs) == 1
if req_id in engine_core.scheduler.requests:
assert engine_core.scheduler.requests[
req_id].num_tokens == expected_num_tokens[req_id]
expected_num_tokens[req_id] += 1
req_id = (req_id + 1) % 2
else:
# Odd steps schedules a new batch.
assert output is None
step += 1

View File

@ -117,11 +117,6 @@ class SchedulerInterface(ABC):
not yet returned in SchedulerOutputs.""" not yet returned in SchedulerOutputs."""
return self.has_unfinished_requests() or self.has_finished_requests() return self.has_unfinished_requests() or self.has_finished_requests()
@abstractmethod
def get_num_unscheduled_requests(self) -> int:
"""Number of requests that are not being processed by the executor."""
raise NotImplementedError
@abstractmethod @abstractmethod
def reset_prefix_cache(self) -> bool: def reset_prefix_cache(self) -> bool:
"""Reset the prefix cache for KV cache. """Reset the prefix cache for KV cache.

View File

@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
import time import time
from collections import deque from collections import defaultdict, deque
from collections.abc import Iterable from collections.abc import Iterable
from typing import Optional, Union from typing import Optional, Union
@ -88,9 +88,6 @@ class Scheduler(SchedulerInterface):
# Priority queues for requests. # Priority queues for requests.
self.waiting: deque[Request] = deque() self.waiting: deque[Request] = deque()
self.running: list[Request] = [] self.running: list[Request] = []
# The requests that have been scheduled and are being executed
# by the executor.
self.scheduled_req_ids: set[str] = set()
# The request IDs that are finished in between the previous and the # The request IDs that are finished in between the previous and the
# current steps. This is used to notify the workers about the finished # current steps. This is used to notify the workers about the finished
@ -100,8 +97,9 @@ class Scheduler(SchedulerInterface):
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
# them at each scheduling step. # them at each scheduling step.
# Request id -> CachedRequestData # Request id -> deque of CachedRequestData
self._cached_reqs_data: dict[str, CachedRequestData] = {} self._cached_reqs_data: dict[
str, deque[CachedRequestData]] = defaultdict(deque)
# Encoder-related. # Encoder-related.
# Calculate encoder cache size if applicable # Calculate encoder cache size if applicable
@ -171,10 +169,6 @@ class Scheduler(SchedulerInterface):
req_index = 0 req_index = 0
while req_index < len(self.running) and token_budget > 0: while req_index < len(self.running) and token_budget > 0:
request = self.running[req_index] request = self.running[req_index]
if request.request_id in self.scheduled_req_ids:
# This request has already been scheduled.
req_index += 1
continue
num_new_tokens = (request.num_tokens_with_spec - num_new_tokens = (request.num_tokens_with_spec -
request.num_computed_tokens) request.num_computed_tokens)
@ -183,33 +177,35 @@ class Scheduler(SchedulerInterface):
num_new_tokens = ( num_new_tokens = (
self.scheduler_config.long_prefill_token_threshold) self.scheduler_config.long_prefill_token_threshold)
num_new_tokens = min(num_new_tokens, token_budget) num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens > 0
# Make sure the input position does not exceed the max model len. # Make sure the input position does not exceed the max model len.
# This is necessary when using spec decoding. # This is necessary when using spec decoding.
num_new_tokens = min( num_new_tokens = min(
num_new_tokens, num_new_tokens,
self.max_model_len - request.num_computed_tokens) self.max_model_len - request.num_computed_tokens)
assert num_new_tokens > 0
# Schedule encoder inputs. # Schedule encoder inputs.
encoder_inputs_to_schedule = None
new_encoder_budget = encoder_budget
if request.has_encoder_inputs: if request.has_encoder_inputs:
(encoder_inputs_to_schedule, num_new_tokens, (encoder_inputs_to_schedule, num_new_tokens,
new_encoder_budget) = self._try_schedule_encoder_inputs( new_encoder_budget) = self._try_schedule_encoder_inputs(
request, request.num_computed_tokens, num_new_tokens, request, request.num_computed_tokens, num_new_tokens,
encoder_budget) encoder_budget)
if num_new_tokens == 0:
# The request cannot be scheduled because the encoder budget if num_new_tokens == 0:
# or the encoder cache is exhausted. # The request cannot be scheduled because one of the following
# NOTE(woosuk): By using `continue` instead of `break` here, # reasons:
# we intentionally relax the strict FCFS scheduling policy # 1. No new tokens to schedule. This may happen when PP>1 and
# to allow lower-priority requests to be scheduled when a # we have already scheduled all prompt tokens but they are
# higher-priority request is blocked by encoder constraints. # not finished yet.
req_index += 1 # 2. The encoder budget is exhausted.
continue # 3. The encoder cache is exhausted.
else: # NOTE(woosuk): Here, by doing `continue` instead of `break`,
encoder_inputs_to_schedule = None # we do not strictly follow the FCFS scheduling policy and
new_encoder_budget = encoder_budget # allow the lower-priority requests to be scheduled.
req_index += 1
continue
while True: while True:
new_blocks = self.kv_cache_manager.allocate_slots( new_blocks = self.kv_cache_manager.allocate_slots(
@ -243,7 +239,6 @@ class Scheduler(SchedulerInterface):
# Schedule the request. # Schedule the request.
scheduled_running_reqs.append(request) scheduled_running_reqs.append(request)
self.scheduled_req_ids.add(request.request_id)
if request.use_structured_output: if request.use_structured_output:
# PERF: in case of chunked prefill, # PERF: in case of chunked prefill,
# request might not include any new tokens. # request might not include any new tokens.
@ -382,7 +377,6 @@ class Scheduler(SchedulerInterface):
request.request_id] = req_index request.request_id] = req_index
req_index += 1 req_index += 1
self.running.append(request) self.running.append(request)
self.scheduled_req_ids.add(request.request_id)
if self.log_stats: if self.log_stats:
request.record_event(EngineCoreEventType.SCHEDULED, request.record_event(EngineCoreEventType.SCHEDULED,
scheduled_timestamp) scheduled_timestamp)
@ -521,18 +515,21 @@ class Scheduler(SchedulerInterface):
num_regular_tokens = num_scheduled_tokens - num_scheduled_spec_tokens num_regular_tokens = num_scheduled_tokens - num_scheduled_spec_tokens
new_token_ids = request.all_token_ids[ new_token_ids = request.all_token_ids[
num_computed_tokens:num_computed_tokens + num_regular_tokens] num_computed_tokens:num_computed_tokens + num_regular_tokens]
req_data = self._cached_reqs_data.get(request.request_id)
if req_data is not None: req_data_queue = self._cached_reqs_data.get(request.request_id)
if req_data_queue:
req_data = req_data_queue.popleft()
req_data.resumed_from_preemption = resumed_from_preemption req_data.resumed_from_preemption = resumed_from_preemption
req_data.new_token_ids = new_token_ids req_data.new_token_ids = new_token_ids
req_data.new_block_ids = new_block_ids req_data.new_block_ids = new_block_ids
req_data.num_computed_tokens = num_computed_tokens req_data.num_computed_tokens = num_computed_tokens
else: else:
# No cached request data, or all cached request data has been
# used by the scheduled requests.
req_data = CachedRequestData.from_request(request, req_data = CachedRequestData.from_request(request,
resumed_from_preemption, resumed_from_preemption,
new_token_ids, new_token_ids,
new_block_ids) new_block_ids)
self._cached_reqs_data[request.request_id] = req_data
return req_data return req_data
def _try_schedule_encoder_inputs( def _try_schedule_encoder_inputs(
@ -561,6 +558,8 @@ class Scheduler(SchedulerInterface):
Note that num_computed_tokens includes both locally cached Note that num_computed_tokens includes both locally cached
blocks and externally cached blocks (via KVConnector). blocks and externally cached blocks (via KVConnector).
""" """
if num_new_tokens == 0 or not request.has_encoder_inputs:
return [], num_new_tokens, encoder_budget
encoder_inputs_to_schedule: list[int] = [] encoder_inputs_to_schedule: list[int] = []
mm_positions = request.mm_positions mm_positions = request.mm_positions
assert mm_positions is not None assert mm_positions is not None
@ -728,10 +727,13 @@ class Scheduler(SchedulerInterface):
# Invariant: EngineCore returns no partial prefill outputs. # Invariant: EngineCore returns no partial prefill outputs.
assert not prompt_logprobs_tensors assert not prompt_logprobs_tensors
self.scheduled_req_ids.remove(req_id)
if not stopped: if not stopped:
new_running.append(request) new_running.append(request)
# Return the cached request data to the queue so they can be reused.
for req_data in scheduler_output.scheduled_cached_reqs:
self._cached_reqs_data[req_data.req_id].append(req_data)
self.running = new_running self.running = new_running
engine_core_outputs = EngineCoreOutputs( engine_core_outputs = EngineCoreOutputs(
outputs=outputs, outputs=outputs,
@ -774,7 +776,6 @@ class Scheduler(SchedulerInterface):
if request.status == RequestStatus.RUNNING: if request.status == RequestStatus.RUNNING:
self.running.remove(request) self.running.remove(request)
self.scheduled_req_ids.discard(request.request_id)
else: else:
self.waiting.remove(request) self.waiting.remove(request)
request.status = finished_status request.status = finished_status
@ -795,10 +796,6 @@ class Scheduler(SchedulerInterface):
def has_finished_requests(self) -> bool: def has_finished_requests(self) -> bool:
return len(self.finished_req_ids) > 0 return len(self.finished_req_ids) > 0
def get_num_unscheduled_requests(self) -> int:
"""Number of requests that are not being processed by the executor."""
return self.get_num_unfinished_requests() - len(self.scheduled_req_ids)
def reset_prefix_cache(self) -> bool: def reset_prefix_cache(self) -> bool:
return self.kv_cache_manager.reset_prefix_cache() return self.kv_cache_manager.reset_prefix_cache()

View File

@ -210,10 +210,10 @@ class EngineCore:
Note that if nothing to output in this step, None is returned. Note that if nothing to output in this step, None is returned.
The execution flow is as follows: The execution flow is as follows:
1. Try to schedule a new batch if there are unscheduled requests 1. Try to schedule a new batch if the batch queue is not full.
and the job queue is not full. If a new batch is scheduled, directly If a new batch is scheduled, directly return an empty engine core
return an empty engine core output. In other words, we won't check output. In other words, fulfilling the batch queue has a higher priority
and return model outputs before the batch queue is full. than getting model outputs.
2. If there is no new scheduled batch, meaning that the batch queue 2. If there is no new scheduled batch, meaning that the batch queue
is full or no other requests can be scheduled, we block until the first is full or no other requests can be scheduled, we block until the first
batch in the job queue is finished. batch in the job queue is finished.
@ -223,10 +223,10 @@ class EngineCore:
engine_core_outputs = None engine_core_outputs = None
scheduler_output = None scheduler_output = None
# If there are unscheduled requests and the job queue # Try to schedule a new batch if the batch queue is not full, but
# is not full, schedule a new batch. Note that this is not blocking. # the scheduler may return an empty batch if all requests are scheduled.
if (self.scheduler.get_num_unscheduled_requests() > 0 # Note that this is not blocking.
and not self.batch_queue.full()): if not self.batch_queue.full():
scheduler_output = self.scheduler.schedule() scheduler_output = self.scheduler.schedule()
if scheduler_output.total_num_scheduled_tokens > 0: if scheduler_output.total_num_scheduled_tokens > 0:
future = self.model_executor.execute_model(scheduler_output) future = self.model_executor.execute_model(scheduler_output)
@ -238,6 +238,10 @@ class EngineCore:
# If no more requests can be scheduled and the job queue is not empty, # If no more requests can be scheduled and the job queue is not empty,
# block until the first batch in the job queue is finished. # block until the first batch in the job queue is finished.
# TODO(comaniac): Ideally we should peek the first batch in the
# job queue to check if it's finished before scheduling a new batch,
# but peeking the first element in a queue is not thread-safe,
# so we need more work.
if not scheduled_batch and not self.batch_queue.empty(): if not scheduled_batch and not self.batch_queue.empty():
future, scheduler_output = self.batch_queue.get_nowait() future, scheduler_output = self.batch_queue.get_nowait()
# Blocking until the first result is available. # Blocking until the first result is available.