mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:34:27 +08:00
[V1][PP] Run engine busy loop with batch queue (#13064)
This commit is contained in:
parent
ed0de3e4b8
commit
9206b3d7ec
@ -213,3 +213,54 @@ def test_schedule_partial_requests():
|
||||
assert output.num_scheduled_tokens[requests[0].request_id] == 1
|
||||
assert output.num_scheduled_tokens[requests[1].request_id] == 700
|
||||
assert requests[2].request_id not in output.num_scheduled_tokens
|
||||
|
||||
|
||||
def test_schedule_concurrent_batches():
|
||||
scheduler = create_scheduler(
|
||||
max_num_batched_tokens=1024,
|
||||
max_num_seqs=2,
|
||||
)
|
||||
requests = create_requests(
|
||||
num_requests=2,
|
||||
num_tokens=512,
|
||||
)
|
||||
|
||||
# Schedule the first request.
|
||||
scheduler.add_request(requests[0])
|
||||
scheduler_output0 = scheduler.schedule()
|
||||
assert len(scheduler_output0.scheduled_new_reqs) == 1
|
||||
assert scheduler_output0.num_scheduled_tokens[
|
||||
requests[0].request_id] == 512
|
||||
|
||||
# The first request is still running, so only schedule the second request.
|
||||
scheduler.add_request(requests[1])
|
||||
scheduler_output1 = scheduler.schedule()
|
||||
assert len(scheduler_output1.scheduled_new_reqs) == 1
|
||||
assert scheduler_output1.num_scheduled_tokens[
|
||||
requests[1].request_id] == 512
|
||||
|
||||
# Model output of the first request.
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=[requests[0].request_id],
|
||||
req_id_to_index={requests[0].request_id: 0},
|
||||
sampled_token_ids=[0],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output0, model_runner_output)
|
||||
|
||||
# Schedule the next step.
|
||||
# The first request can be scheduled again while the second
|
||||
# request is still running.
|
||||
scheduler_output2 = scheduler.schedule()
|
||||
assert scheduler_output2.num_scheduled_tokens[requests[0].request_id] == 1
|
||||
|
||||
# Model output of the second request.
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=[requests[1].request_id],
|
||||
req_id_to_index={requests[1].request_id: 0},
|
||||
sampled_token_ids=[0],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output1, model_runner_output)
|
||||
|
||||
@ -1,7 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import copy
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from concurrent.futures import Future
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
@ -12,7 +15,9 @@ from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.core import EngineCore
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.executor.abstract import Executor, UniProcExecutor
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
pytest.skip(reason="V1 currently only supported on CUDA.",
|
||||
@ -191,3 +196,85 @@ def test_engine_core_advanced_sampling(monkeypatch):
|
||||
)
|
||||
engine_core.add_request(request2)
|
||||
_check_engine_state()
|
||||
|
||||
|
||||
@fork_new_process_for_each_test
|
||||
def test_engine_core_concurrent_batches(monkeypatch):
|
||||
"""
|
||||
Test that the engine can handle multiple concurrent batches.
|
||||
"""
|
||||
|
||||
def make_request_with_max_tokens(max_tokens: int) -> EngineCoreRequest:
|
||||
request = make_request()
|
||||
request.sampling_params.max_tokens = max_tokens
|
||||
return request
|
||||
|
||||
class DummyExecutor(UniProcExecutor):
|
||||
|
||||
def initialize(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
super().initialize(kv_cache_config)
|
||||
|
||||
# This executor actually can only run 1 batch at a time
|
||||
self.semaphore = threading.Semaphore(1)
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output,
|
||||
) -> Future[ModelRunnerOutput]:
|
||||
"""Make execute_model non-blocking."""
|
||||
future: Future[ModelRunnerOutput] = Future()
|
||||
|
||||
def _thread_wrapper(scheduler_output, future):
|
||||
with self.semaphore:
|
||||
output = self.collective_rpc("execute_model",
|
||||
args=(scheduler_output, ))
|
||||
# Make a copy because output[0] may be reused
|
||||
# by the next batch.
|
||||
output = copy.deepcopy(output[0])
|
||||
future.set_result(output)
|
||||
|
||||
threading.Thread(target=_thread_wrapper,
|
||||
args=(scheduler_output, future)).start()
|
||||
return future
|
||||
|
||||
@property
|
||||
def max_concurrent_batches(self) -> int:
|
||||
return 2
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=MODEL_NAME,
|
||||
# To test concurrent batches.
|
||||
max_num_seqs=2,
|
||||
# Avoid all requests being scheduled once.
|
||||
enable_prefix_caching=False,
|
||||
max_num_batched_tokens=10,
|
||||
)
|
||||
vllm_config = engine_args.create_engine_config()
|
||||
engine_core = EngineCore(vllm_config=vllm_config,
|
||||
log_stats=False,
|
||||
executor_class=DummyExecutor)
|
||||
assert engine_core.batch_queue is not None
|
||||
|
||||
# Add two requests in a row.
|
||||
req = make_request_with_max_tokens(5)
|
||||
engine_core.add_request(req)
|
||||
req = make_request_with_max_tokens(5)
|
||||
engine_core.add_request(req)
|
||||
|
||||
# First saturate the batch queue.
|
||||
assert engine_core.step_with_batch_queue() is None
|
||||
assert engine_core.batch_queue.qsize() == 1
|
||||
assert engine_core.step_with_batch_queue() is None
|
||||
assert engine_core.batch_queue.qsize() == 2
|
||||
assert engine_core.scheduler.get_num_unfinished_requests() == 2
|
||||
|
||||
# Loop through both requests.
|
||||
while engine_core.scheduler.get_num_unfinished_requests() == 2:
|
||||
engine_core.step_with_batch_queue()
|
||||
|
||||
# Reaching here when got the result of the first request.
|
||||
while engine_core.scheduler.get_num_unfinished_requests() == 1:
|
||||
engine_core.step_with_batch_queue()
|
||||
|
||||
@ -58,6 +58,9 @@ class Scheduler:
|
||||
# Priority queues for requests.
|
||||
self.waiting: Deque[Request] = deque()
|
||||
self.running: List[Request] = []
|
||||
# The requests that have been scheduled and are being executed
|
||||
# by the executor.
|
||||
self.scheduled_req_ids: Set[str] = set()
|
||||
|
||||
# The request IDs that are finished in between the previous and the
|
||||
# current steps. This is used to notify the workers about the finished
|
||||
@ -118,6 +121,11 @@ class Scheduler:
|
||||
req_index = 0
|
||||
while req_index < len(self.running) and token_budget > 0:
|
||||
request = self.running[req_index]
|
||||
if request.request_id in self.scheduled_req_ids:
|
||||
# This request has already been scheduled.
|
||||
req_index += 1
|
||||
continue
|
||||
|
||||
num_new_tokens = request.num_tokens - request.num_computed_tokens
|
||||
num_new_tokens = min(num_new_tokens, token_budget)
|
||||
assert num_new_tokens > 0
|
||||
@ -164,6 +172,7 @@ class Scheduler:
|
||||
|
||||
# Schedule the request.
|
||||
scheduled_running_reqs.append(request)
|
||||
self.scheduled_req_ids.add(request.request_id)
|
||||
req_to_new_block_ids[request.request_id] = [
|
||||
b.block_id for b in new_blocks
|
||||
]
|
||||
@ -251,6 +260,7 @@ class Scheduler:
|
||||
|
||||
self.waiting.popleft()
|
||||
self.running.append(request)
|
||||
self.scheduled_req_ids.add(request.request_id)
|
||||
if request.status == RequestStatus.WAITING:
|
||||
scheduled_new_reqs.append(request)
|
||||
self.request_scheduled(request, scheduled_timestamp)
|
||||
@ -519,6 +529,7 @@ class Scheduler:
|
||||
stop_reason=request.stop_reason,
|
||||
events=request.take_events()))
|
||||
|
||||
self.scheduled_req_ids.remove(request.request_id)
|
||||
if not stopped:
|
||||
new_running.append(request)
|
||||
|
||||
@ -575,6 +586,8 @@ class Scheduler:
|
||||
|
||||
if request.status == RequestStatus.RUNNING:
|
||||
self.running.remove(request)
|
||||
if request.request_id in self.scheduled_req_ids:
|
||||
self.scheduled_req_ids.remove(request.request_id)
|
||||
else:
|
||||
self.waiting.remove(request)
|
||||
request.status = finished_status
|
||||
@ -595,6 +608,10 @@ class Scheduler:
|
||||
def has_unfinished_requests(self) -> bool:
|
||||
return self.get_num_unfinished_requests() > 0
|
||||
|
||||
def get_num_unscheduled_requests(self) -> int:
|
||||
"""Number of requests that are not being processed by the executor."""
|
||||
return self.get_num_unfinished_requests() - len(self.scheduled_req_ids)
|
||||
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
return self.kv_cache_manager.reset_prefix_cache()
|
||||
|
||||
|
||||
@ -4,8 +4,9 @@ import queue
|
||||
import signal
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import Future
|
||||
from multiprocessing.connection import Connection
|
||||
from typing import Any, List, Tuple, Type
|
||||
from typing import Any, List, Optional, Tuple, Type
|
||||
|
||||
import psutil
|
||||
import zmq
|
||||
@ -18,11 +19,12 @@ from vllm.transformers_utils.config import (
|
||||
maybe_register_config_serialize_by_value)
|
||||
from vllm.utils import get_exception_traceback, zmq_socket_ctx
|
||||
from vllm.v1.core.kv_cache_utils import get_kv_cache_configs
|
||||
from vllm.v1.core.scheduler import Scheduler
|
||||
from vllm.v1.core.scheduler import Scheduler, SchedulerOutput
|
||||
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
|
||||
EngineCoreRequestType)
|
||||
from vllm.v1.engine.mm_input_cache import MMInputCacheServer
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
@ -66,9 +68,22 @@ class EngineCore:
|
||||
log_stats=self.log_stats,
|
||||
)
|
||||
|
||||
# Setup MM Input Mapper.
|
||||
self.mm_input_cache_server = MMInputCacheServer(
|
||||
vllm_config.model_config)
|
||||
|
||||
# Setup batch queue for pipeline parallelism.
|
||||
# Batch queue for scheduled batches. This enables us to asynchronously
|
||||
# schedule and execute batches, and is required by pipeline parallelism
|
||||
# to eliminate pipeline bubbles.
|
||||
self.batch_queue_size = self.model_executor.max_concurrent_batches
|
||||
self.batch_queue: Optional[queue.Queue[Tuple[Future[ModelRunnerOutput],
|
||||
SchedulerOutput]]] = None
|
||||
if self.batch_queue_size > 1:
|
||||
logger.info("Batch queue is enabled with size %d",
|
||||
self.batch_queue_size)
|
||||
self.batch_queue = queue.Queue(self.batch_queue_size)
|
||||
|
||||
def _initialize_kv_caches(self,
|
||||
vllm_config: VllmConfig) -> Tuple[int, int]:
|
||||
start = time.time()
|
||||
@ -135,7 +150,55 @@ class EngineCore:
|
||||
scheduler_output = self.scheduler.schedule()
|
||||
output = self.model_executor.execute_model(scheduler_output)
|
||||
engine_core_outputs = self.scheduler.update_from_output(
|
||||
scheduler_output, output)
|
||||
scheduler_output, output) # type: ignore
|
||||
return engine_core_outputs
|
||||
|
||||
def step_with_batch_queue(self) -> Optional[EngineCoreOutputs]:
|
||||
"""Schedule and execute batches with the batch queue.
|
||||
Note that if nothing to output in this step, None is returned.
|
||||
|
||||
The execution flow is as follows:
|
||||
1. Try to schedule a new batch if there are unscheduled requests
|
||||
and the job queue is not full. If a new batch is scheduled, directly
|
||||
return an empty engine core output. In other words, we won't check
|
||||
and return model outputs before the batch queue is full.
|
||||
2. If there is no new scheduled batch, meaning that the batch queue
|
||||
is full or no other requests can be scheduled, we block until the first
|
||||
batch in the job queue is finished.
|
||||
3. Update the scheduler from the output.
|
||||
"""
|
||||
assert self.batch_queue is not None
|
||||
|
||||
engine_core_outputs = None
|
||||
scheduler_output = None
|
||||
# If there are unscheduled requests and the job queue
|
||||
# is not full, schedule a new batch. Note that this is not blocking.
|
||||
if (self.scheduler.get_num_unscheduled_requests() > 0
|
||||
and not self.batch_queue.full()):
|
||||
scheduler_output = self.scheduler.schedule()
|
||||
if scheduler_output.total_num_scheduled_tokens > 0:
|
||||
future = self.model_executor.execute_model(scheduler_output)
|
||||
self.batch_queue.put_nowait(
|
||||
(future, scheduler_output)) # type: ignore
|
||||
|
||||
# If all requests are scheduled or the job queue is full,
|
||||
# block until the first batch in the job queue is finished.
|
||||
if (scheduler_output is None
|
||||
or scheduler_output.total_num_scheduled_tokens == 0):
|
||||
try:
|
||||
future, scheduler_output = self.batch_queue.get(
|
||||
timeout=POLLING_TIMEOUT_S)
|
||||
# Blocking until the first result is available.
|
||||
model_output = future.result()
|
||||
self.batch_queue.task_done()
|
||||
engine_core_outputs = self.scheduler.update_from_output(
|
||||
scheduler_output, model_output)
|
||||
except queue.Empty:
|
||||
# If the queue is empty (timeout at .get), return
|
||||
# an empty EngineCoreOutputs for logging.
|
||||
engine_core_outputs = EngineCoreOutputs(
|
||||
outputs=[], scheduler_stats=self.scheduler.make_stats())
|
||||
|
||||
return engine_core_outputs
|
||||
|
||||
def shutdown(self):
|
||||
@ -226,6 +289,9 @@ class EngineCoreProc(EngineCore):
|
||||
def run_busy_loop(self):
|
||||
"""Core busy loop of the EngineCore."""
|
||||
|
||||
step_fn = (self.step
|
||||
if self.batch_queue is None else self.step_with_batch_queue)
|
||||
|
||||
# Loop until process is sent a SIGINT or SIGTERM
|
||||
while True:
|
||||
# 1) Poll the input queue until there is work to do.
|
||||
@ -249,10 +315,11 @@ class EngineCoreProc(EngineCore):
|
||||
self._handle_client_request(*req)
|
||||
|
||||
# 3) Step the engine core.
|
||||
outputs = self.step()
|
||||
outputs = step_fn()
|
||||
|
||||
# 5) Put EngineCoreOutputs into the output queue.
|
||||
self.output_queue.put_nowait(outputs)
|
||||
# 4) Put EngineCoreOutputs into the output queue.
|
||||
if outputs is not None:
|
||||
self.output_queue.put_nowait(outputs)
|
||||
|
||||
def _handle_client_request(self, request_type: EngineCoreRequestType,
|
||||
request: Any) -> None:
|
||||
|
||||
@ -1,11 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import List, Type
|
||||
from concurrent.futures import Future
|
||||
from typing import List, Type, Union
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.executor.ray_distributed_executor import ( # noqa
|
||||
RayDistributedExecutor as RayDistributedExecutorV0)
|
||||
from vllm.executor.uniproc_executor import ( # noqa
|
||||
ExecutorWithExternalLauncher as ExecutorWithExternalLauncherV0)
|
||||
from vllm.executor.uniproc_executor import ( # noqa
|
||||
@ -33,6 +32,8 @@ class Executor(ExecutorBase):
|
||||
f"ExecutorBase. Got {distributed_executor_backend}.")
|
||||
executor_class = distributed_executor_backend
|
||||
elif distributed_executor_backend == "ray":
|
||||
from vllm.v1.executor.ray_distributed_executor import ( # noqa
|
||||
RayDistributedExecutor)
|
||||
executor_class = RayDistributedExecutor
|
||||
elif distributed_executor_backend == "mp":
|
||||
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
|
||||
@ -70,11 +71,15 @@ class Executor(ExecutorBase):
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output,
|
||||
) -> ModelRunnerOutput:
|
||||
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
|
||||
output = self.collective_rpc("execute_model",
|
||||
args=(scheduler_output, ))
|
||||
return output[0]
|
||||
|
||||
@property
|
||||
def max_concurrent_batches(self) -> int:
|
||||
return 1
|
||||
|
||||
def profile(self, is_start: bool = True):
|
||||
self.collective_rpc("profile", args=(is_start, ))
|
||||
|
||||
@ -85,7 +90,3 @@ class UniProcExecutor(UniProcExecutorV0, Executor):
|
||||
|
||||
class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor):
|
||||
pass
|
||||
|
||||
|
||||
class RayDistributedExecutor(RayDistributedExecutorV0, Executor):
|
||||
pass
|
||||
|
||||
61
vllm/v1/executor/ray_distributed_executor.py
Normal file
61
vllm/v1/executor/ray_distributed_executor.py
Normal file
@ -0,0 +1,61 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from concurrent.futures import Future
|
||||
from typing import Union
|
||||
|
||||
from vllm.executor.ray_distributed_executor import ( # noqa
|
||||
RayDistributedExecutor as RayDistributedExecutorV0)
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
|
||||
|
||||
class FutureWrapper(Future):
|
||||
"""A wrapper around a Ray output reference to meet the interface
|
||||
of .execute_model().
|
||||
"""
|
||||
|
||||
def __init__(self, ref):
|
||||
super().__init__()
|
||||
self.ref = ref
|
||||
|
||||
def result(self, timeout=None):
|
||||
if timeout is not None:
|
||||
raise NotImplementedError("timeout is not supported")
|
||||
return self.ref.get()
|
||||
|
||||
|
||||
class RayDistributedExecutor(RayDistributedExecutorV0, Executor):
|
||||
"""Ray distributed executor using Ray Compiled Graphs."""
|
||||
|
||||
@property
|
||||
def max_concurrent_batches(self) -> int:
|
||||
"""Ray distributed executor supports pipeline parallelism,
|
||||
meaning that it allows PP size batches to be executed concurrently.
|
||||
"""
|
||||
return 1 #self.vllm_config.parallel_config.pipeline_parallel_size
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output,
|
||||
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
|
||||
"""Execute the model on the Ray workers.
|
||||
|
||||
Args:
|
||||
scheduler_output: The scheduler output to execute.
|
||||
|
||||
Returns:
|
||||
The model runner output.
|
||||
"""
|
||||
# Build the compiled DAG for the first time.
|
||||
if self.forward_dag is None: # type: ignore
|
||||
self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
|
||||
|
||||
refs = self.forward_dag.execute(scheduler_output) # type: ignore
|
||||
|
||||
# When PP is not used, we block here until the result is available.
|
||||
if self.max_concurrent_batches == 1:
|
||||
return refs[0].get()
|
||||
|
||||
# When PP is used, we return a FutureWrapper immediately so that
|
||||
# the scheduler can yield to the next batch.
|
||||
return FutureWrapper(refs[0])
|
||||
Loading…
x
Reference in New Issue
Block a user