mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:35:17 +08:00
[V1][PP] Optimization: continue scheduling prefill chunks (#17080)
Signed-off-by: Rui Qiao <ruisearch42@gmail.com>
This commit is contained in:
parent
a9138e85b1
commit
c0dfd97519
@ -437,7 +437,6 @@ def test_stop_via_update_from_output():
|
||||
req.num_computed_tokens = req.num_tokens
|
||||
scheduler.requests[req.request_id] = req
|
||||
scheduler.running.append(req)
|
||||
scheduler.scheduled_req_ids.add(req.request_id)
|
||||
|
||||
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=[],
|
||||
@ -489,7 +488,6 @@ def test_stop_via_update_from_output():
|
||||
req.num_computed_tokens = req.num_tokens
|
||||
scheduler.requests[req.request_id] = req
|
||||
scheduler.running.append(req)
|
||||
scheduler.scheduled_req_ids.add(req.request_id)
|
||||
|
||||
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=[],
|
||||
@ -539,7 +537,6 @@ def test_stop_via_update_from_output():
|
||||
req.num_computed_tokens = req.num_tokens
|
||||
scheduler.requests[req.request_id] = req
|
||||
scheduler.running.append(req)
|
||||
scheduler.scheduled_req_ids.add(req.request_id)
|
||||
|
||||
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=[],
|
||||
@ -589,7 +586,6 @@ def test_stop_via_update_from_output():
|
||||
requests[0].num_computed_tokens = requests[0].num_tokens
|
||||
scheduler.requests[requests[0].request_id] = requests[0]
|
||||
scheduler.running.append(requests[0])
|
||||
scheduler.scheduled_req_ids.add(requests[0].request_id)
|
||||
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=[],
|
||||
|
||||
@ -1,10 +1,9 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import copy
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from concurrent.futures import Future
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
@ -244,33 +243,33 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
|
||||
self, kv_cache_configs: list[KVCacheConfig]) -> None:
|
||||
super().initialize_from_config(kv_cache_configs)
|
||||
|
||||
# This executor actually can only run 1 batch at a time
|
||||
self.semaphore = threading.Semaphore(1)
|
||||
# Create a thread pool with a single worker
|
||||
self.thread_pool = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output,
|
||||
) -> Future[ModelRunnerOutput]:
|
||||
"""Make execute_model non-blocking."""
|
||||
future: Future[ModelRunnerOutput] = Future()
|
||||
|
||||
def _thread_wrapper(scheduler_output, future):
|
||||
with self.semaphore:
|
||||
def _execute():
|
||||
output = self.collective_rpc("execute_model",
|
||||
args=(scheduler_output, ))
|
||||
# Make a copy because output[0] may be reused
|
||||
# by the next batch.
|
||||
output = copy.deepcopy(output[0])
|
||||
future.set_result(output)
|
||||
return copy.deepcopy(output[0])
|
||||
|
||||
threading.Thread(target=_thread_wrapper,
|
||||
args=(scheduler_output, future)).start()
|
||||
return future
|
||||
# Use the thread pool instead of creating a new thread
|
||||
return self.thread_pool.submit(_execute)
|
||||
|
||||
@property
|
||||
def max_concurrent_batches(self) -> int:
|
||||
return 2
|
||||
|
||||
def shutdown(self):
|
||||
if hasattr(self, 'thread_pool'):
|
||||
self.thread_pool.shutdown(wait=False)
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
@ -299,14 +298,77 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
|
||||
# Schedule Batch 1: (10, req0)
|
||||
assert engine_core.step_with_batch_queue() is None
|
||||
assert engine_core.batch_queue.qsize() == 1
|
||||
scheduler_output = engine_core.batch_queue.queue[-1][1]
|
||||
assert scheduler_output.num_scheduled_tokens[0] == 10
|
||||
# num_computed_tokens should have been updated immediately.
|
||||
assert engine_core.scheduler.requests[
|
||||
req0.request_id].num_computed_tokens == 10
|
||||
|
||||
# Schedule Batch 2: (2, req0), (8, req1)
|
||||
assert engine_core.step_with_batch_queue() is None
|
||||
assert engine_core.batch_queue.qsize() == 2
|
||||
scheduler_output = engine_core.batch_queue.queue[-1][1]
|
||||
assert scheduler_output.num_scheduled_tokens[0] == 2
|
||||
assert scheduler_output.num_scheduled_tokens[1] == 8
|
||||
# num_computed_tokens should have been updated immediately.
|
||||
assert engine_core.scheduler.requests[0].num_computed_tokens == 12
|
||||
assert engine_core.scheduler.requests[1].num_computed_tokens == 8
|
||||
|
||||
assert engine_core.scheduler.get_num_unfinished_requests() == 2
|
||||
|
||||
# Loop through both requests.
|
||||
while engine_core.scheduler.get_num_unfinished_requests() == 2:
|
||||
# Batch queue is full. Finish Batch 1.
|
||||
engine_core.step_with_batch_queue()
|
||||
|
||||
# Reaching here when got the result of the first request.
|
||||
while engine_core.scheduler.get_num_unfinished_requests() == 1:
|
||||
# Schedule Batch 3: (4, req1). Note that req0 cannot be scheduled
|
||||
# because it is in the decoding stage now.
|
||||
engine_core.step_with_batch_queue()
|
||||
assert engine_core.batch_queue.qsize() == 2
|
||||
scheduler_output = engine_core.batch_queue.queue[-1][1]
|
||||
assert scheduler_output.num_scheduled_tokens[1] == 4
|
||||
|
||||
# Batch queue is full. Finish Batch 2. Get first token of req0.
|
||||
output = engine_core.step_with_batch_queue()
|
||||
assert output is not None
|
||||
assert len(output.outputs) == 1
|
||||
assert engine_core.scheduler.requests[req0.request_id].num_tokens == 13
|
||||
|
||||
# Schedule Batch 4: (1, req0).
|
||||
engine_core.step_with_batch_queue()
|
||||
assert engine_core.batch_queue.qsize() == 2
|
||||
scheduler_output = engine_core.batch_queue.queue[-1][1]
|
||||
assert scheduler_output.num_scheduled_tokens[0] == 1
|
||||
|
||||
# Batch queue is full. Finish Batch 3. Get first token of req1.
|
||||
output = engine_core.step_with_batch_queue()
|
||||
assert output is not None
|
||||
assert len(output.outputs) == 1
|
||||
assert engine_core.scheduler.requests[req1.request_id].num_tokens == 13
|
||||
|
||||
# Schedule Batch 5: (1, req1).
|
||||
engine_core.step_with_batch_queue()
|
||||
assert engine_core.batch_queue.qsize() == 2
|
||||
scheduler_output = engine_core.batch_queue.queue[-1][1]
|
||||
assert scheduler_output.num_scheduled_tokens[1] == 1
|
||||
|
||||
# Loop until req0 is finished.
|
||||
step = 0
|
||||
req_id = 0
|
||||
expected_num_tokens = [
|
||||
engine_core.scheduler.requests[0].num_tokens + 1,
|
||||
engine_core.scheduler.requests[1].num_tokens + 1,
|
||||
]
|
||||
while engine_core.scheduler.get_num_unfinished_requests() == 2:
|
||||
output = engine_core.step_with_batch_queue()
|
||||
if step % 2 == 0:
|
||||
# Even steps consumes an output.
|
||||
assert output is not None
|
||||
assert len(output.outputs) == 1
|
||||
if req_id in engine_core.scheduler.requests:
|
||||
assert engine_core.scheduler.requests[
|
||||
req_id].num_tokens == expected_num_tokens[req_id]
|
||||
expected_num_tokens[req_id] += 1
|
||||
req_id = (req_id + 1) % 2
|
||||
else:
|
||||
# Odd steps schedules a new batch.
|
||||
assert output is None
|
||||
step += 1
|
||||
|
||||
@ -117,11 +117,6 @@ class SchedulerInterface(ABC):
|
||||
not yet returned in SchedulerOutputs."""
|
||||
return self.has_unfinished_requests() or self.has_finished_requests()
|
||||
|
||||
@abstractmethod
|
||||
def get_num_unscheduled_requests(self) -> int:
|
||||
"""Number of requests that are not being processed by the executor."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
"""Reset the prefix cache for KV cache.
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from collections import deque
|
||||
from collections import defaultdict, deque
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional, Union
|
||||
|
||||
@ -88,9 +88,6 @@ class Scheduler(SchedulerInterface):
|
||||
# Priority queues for requests.
|
||||
self.waiting: deque[Request] = deque()
|
||||
self.running: list[Request] = []
|
||||
# The requests that have been scheduled and are being executed
|
||||
# by the executor.
|
||||
self.scheduled_req_ids: set[str] = set()
|
||||
|
||||
# The request IDs that are finished in between the previous and the
|
||||
# current steps. This is used to notify the workers about the finished
|
||||
@ -100,8 +97,9 @@ class Scheduler(SchedulerInterface):
|
||||
|
||||
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
|
||||
# them at each scheduling step.
|
||||
# Request id -> CachedRequestData
|
||||
self._cached_reqs_data: dict[str, CachedRequestData] = {}
|
||||
# Request id -> deque of CachedRequestData
|
||||
self._cached_reqs_data: dict[
|
||||
str, deque[CachedRequestData]] = defaultdict(deque)
|
||||
|
||||
# Encoder-related.
|
||||
# Calculate encoder cache size if applicable
|
||||
@ -171,10 +169,6 @@ class Scheduler(SchedulerInterface):
|
||||
req_index = 0
|
||||
while req_index < len(self.running) and token_budget > 0:
|
||||
request = self.running[req_index]
|
||||
if request.request_id in self.scheduled_req_ids:
|
||||
# This request has already been scheduled.
|
||||
req_index += 1
|
||||
continue
|
||||
|
||||
num_new_tokens = (request.num_tokens_with_spec -
|
||||
request.num_computed_tokens)
|
||||
@ -183,33 +177,35 @@ class Scheduler(SchedulerInterface):
|
||||
num_new_tokens = (
|
||||
self.scheduler_config.long_prefill_token_threshold)
|
||||
num_new_tokens = min(num_new_tokens, token_budget)
|
||||
assert num_new_tokens > 0
|
||||
|
||||
# Make sure the input position does not exceed the max model len.
|
||||
# This is necessary when using spec decoding.
|
||||
num_new_tokens = min(
|
||||
num_new_tokens,
|
||||
self.max_model_len - request.num_computed_tokens)
|
||||
assert num_new_tokens > 0
|
||||
|
||||
# Schedule encoder inputs.
|
||||
encoder_inputs_to_schedule = None
|
||||
new_encoder_budget = encoder_budget
|
||||
if request.has_encoder_inputs:
|
||||
(encoder_inputs_to_schedule, num_new_tokens,
|
||||
new_encoder_budget) = self._try_schedule_encoder_inputs(
|
||||
request, request.num_computed_tokens, num_new_tokens,
|
||||
encoder_budget)
|
||||
|
||||
if num_new_tokens == 0:
|
||||
# The request cannot be scheduled because the encoder budget
|
||||
# or the encoder cache is exhausted.
|
||||
# NOTE(woosuk): By using `continue` instead of `break` here,
|
||||
# we intentionally relax the strict FCFS scheduling policy
|
||||
# to allow lower-priority requests to be scheduled when a
|
||||
# higher-priority request is blocked by encoder constraints.
|
||||
# The request cannot be scheduled because one of the following
|
||||
# reasons:
|
||||
# 1. No new tokens to schedule. This may happen when PP>1 and
|
||||
# we have already scheduled all prompt tokens but they are
|
||||
# not finished yet.
|
||||
# 2. The encoder budget is exhausted.
|
||||
# 3. The encoder cache is exhausted.
|
||||
# NOTE(woosuk): Here, by doing `continue` instead of `break`,
|
||||
# we do not strictly follow the FCFS scheduling policy and
|
||||
# allow the lower-priority requests to be scheduled.
|
||||
req_index += 1
|
||||
continue
|
||||
else:
|
||||
encoder_inputs_to_schedule = None
|
||||
new_encoder_budget = encoder_budget
|
||||
|
||||
while True:
|
||||
new_blocks = self.kv_cache_manager.allocate_slots(
|
||||
@ -243,7 +239,6 @@ class Scheduler(SchedulerInterface):
|
||||
|
||||
# Schedule the request.
|
||||
scheduled_running_reqs.append(request)
|
||||
self.scheduled_req_ids.add(request.request_id)
|
||||
if request.use_structured_output:
|
||||
# PERF: in case of chunked prefill,
|
||||
# request might not include any new tokens.
|
||||
@ -382,7 +377,6 @@ class Scheduler(SchedulerInterface):
|
||||
request.request_id] = req_index
|
||||
req_index += 1
|
||||
self.running.append(request)
|
||||
self.scheduled_req_ids.add(request.request_id)
|
||||
if self.log_stats:
|
||||
request.record_event(EngineCoreEventType.SCHEDULED,
|
||||
scheduled_timestamp)
|
||||
@ -521,18 +515,21 @@ class Scheduler(SchedulerInterface):
|
||||
num_regular_tokens = num_scheduled_tokens - num_scheduled_spec_tokens
|
||||
new_token_ids = request.all_token_ids[
|
||||
num_computed_tokens:num_computed_tokens + num_regular_tokens]
|
||||
req_data = self._cached_reqs_data.get(request.request_id)
|
||||
if req_data is not None:
|
||||
|
||||
req_data_queue = self._cached_reqs_data.get(request.request_id)
|
||||
if req_data_queue:
|
||||
req_data = req_data_queue.popleft()
|
||||
req_data.resumed_from_preemption = resumed_from_preemption
|
||||
req_data.new_token_ids = new_token_ids
|
||||
req_data.new_block_ids = new_block_ids
|
||||
req_data.num_computed_tokens = num_computed_tokens
|
||||
else:
|
||||
# No cached request data, or all cached request data has been
|
||||
# used by the scheduled requests.
|
||||
req_data = CachedRequestData.from_request(request,
|
||||
resumed_from_preemption,
|
||||
new_token_ids,
|
||||
new_block_ids)
|
||||
self._cached_reqs_data[request.request_id] = req_data
|
||||
return req_data
|
||||
|
||||
def _try_schedule_encoder_inputs(
|
||||
@ -561,6 +558,8 @@ class Scheduler(SchedulerInterface):
|
||||
Note that num_computed_tokens includes both locally cached
|
||||
blocks and externally cached blocks (via KVConnector).
|
||||
"""
|
||||
if num_new_tokens == 0 or not request.has_encoder_inputs:
|
||||
return [], num_new_tokens, encoder_budget
|
||||
encoder_inputs_to_schedule: list[int] = []
|
||||
mm_positions = request.mm_positions
|
||||
assert mm_positions is not None
|
||||
@ -728,10 +727,13 @@ class Scheduler(SchedulerInterface):
|
||||
# Invariant: EngineCore returns no partial prefill outputs.
|
||||
assert not prompt_logprobs_tensors
|
||||
|
||||
self.scheduled_req_ids.remove(req_id)
|
||||
if not stopped:
|
||||
new_running.append(request)
|
||||
|
||||
# Return the cached request data to the queue so they can be reused.
|
||||
for req_data in scheduler_output.scheduled_cached_reqs:
|
||||
self._cached_reqs_data[req_data.req_id].append(req_data)
|
||||
|
||||
self.running = new_running
|
||||
engine_core_outputs = EngineCoreOutputs(
|
||||
outputs=outputs,
|
||||
@ -774,7 +776,6 @@ class Scheduler(SchedulerInterface):
|
||||
|
||||
if request.status == RequestStatus.RUNNING:
|
||||
self.running.remove(request)
|
||||
self.scheduled_req_ids.discard(request.request_id)
|
||||
else:
|
||||
self.waiting.remove(request)
|
||||
request.status = finished_status
|
||||
@ -795,10 +796,6 @@ class Scheduler(SchedulerInterface):
|
||||
def has_finished_requests(self) -> bool:
|
||||
return len(self.finished_req_ids) > 0
|
||||
|
||||
def get_num_unscheduled_requests(self) -> int:
|
||||
"""Number of requests that are not being processed by the executor."""
|
||||
return self.get_num_unfinished_requests() - len(self.scheduled_req_ids)
|
||||
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
return self.kv_cache_manager.reset_prefix_cache()
|
||||
|
||||
|
||||
@ -210,10 +210,10 @@ class EngineCore:
|
||||
Note that if nothing to output in this step, None is returned.
|
||||
|
||||
The execution flow is as follows:
|
||||
1. Try to schedule a new batch if there are unscheduled requests
|
||||
and the job queue is not full. If a new batch is scheduled, directly
|
||||
return an empty engine core output. In other words, we won't check
|
||||
and return model outputs before the batch queue is full.
|
||||
1. Try to schedule a new batch if the batch queue is not full.
|
||||
If a new batch is scheduled, directly return an empty engine core
|
||||
output. In other words, fulfilling the batch queue has a higher priority
|
||||
than getting model outputs.
|
||||
2. If there is no new scheduled batch, meaning that the batch queue
|
||||
is full or no other requests can be scheduled, we block until the first
|
||||
batch in the job queue is finished.
|
||||
@ -223,10 +223,10 @@ class EngineCore:
|
||||
|
||||
engine_core_outputs = None
|
||||
scheduler_output = None
|
||||
# If there are unscheduled requests and the job queue
|
||||
# is not full, schedule a new batch. Note that this is not blocking.
|
||||
if (self.scheduler.get_num_unscheduled_requests() > 0
|
||||
and not self.batch_queue.full()):
|
||||
# Try to schedule a new batch if the batch queue is not full, but
|
||||
# the scheduler may return an empty batch if all requests are scheduled.
|
||||
# Note that this is not blocking.
|
||||
if not self.batch_queue.full():
|
||||
scheduler_output = self.scheduler.schedule()
|
||||
if scheduler_output.total_num_scheduled_tokens > 0:
|
||||
future = self.model_executor.execute_model(scheduler_output)
|
||||
@ -238,6 +238,10 @@ class EngineCore:
|
||||
|
||||
# If no more requests can be scheduled and the job queue is not empty,
|
||||
# block until the first batch in the job queue is finished.
|
||||
# TODO(comaniac): Ideally we should peek the first batch in the
|
||||
# job queue to check if it's finished before scheduling a new batch,
|
||||
# but peeking the first element in a queue is not thread-safe,
|
||||
# so we need more work.
|
||||
if not scheduled_batch and not self.batch_queue.empty():
|
||||
future, scheduler_output = self.batch_queue.get_nowait()
|
||||
# Blocking until the first result is available.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user