[Core] Support async scheduling with uniproc executor (#24219)

Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
Co-authored-by: Ronald1995 <ronaldautomobile@163.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
Nick Hill 2025-09-12 16:34:28 -07:00 committed by GitHub
parent 8226dd56bf
commit 4fdd6f5cbf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 103 additions and 55 deletions

View File

@ -62,6 +62,8 @@ def _fix_prompt_embed_outputs(
@pytest.mark.parametrize("backend", ["FLASH_ATTN"]) @pytest.mark.parametrize("backend", ["FLASH_ATTN"])
@pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("enforce_eager", [False]) @pytest.mark.parametrize("enforce_eager", [False])
@pytest.mark.parametrize("async_scheduling", [True, False])
@pytest.mark.parametrize("model_executor", ["uni", "mp"])
@pytest.mark.parametrize("enable_prompt_embeds", [True, False]) @pytest.mark.parametrize("enable_prompt_embeds", [True, False])
def test_models( def test_models(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
@ -70,6 +72,8 @@ def test_models(
backend: str, backend: str,
max_tokens: int, max_tokens: int,
enforce_eager: bool, enforce_eager: bool,
async_scheduling: bool,
model_executor: str,
enable_prompt_embeds: bool, enable_prompt_embeds: bool,
) -> None: ) -> None:
@ -77,6 +81,12 @@ def test_models(
"VLLM_USE_V1") and envs.VLLM_USE_V1: "VLLM_USE_V1") and envs.VLLM_USE_V1:
pytest.skip("enable_prompt_embeds is not supported in v1.") pytest.skip("enable_prompt_embeds is not supported in v1.")
if not envs.VLLM_USE_V1:
if async_scheduling:
pytest.skip("async_scheduling only supported in v1.")
if model_executor != "uni":
pytest.skip("only test uniproc executor for v0.")
if backend == "XFORMERS" and model == "google/gemma-2-2b-it": if backend == "XFORMERS" and model == "google/gemma-2-2b-it":
pytest.skip( pytest.skip(
f"{backend} does not support gemma2 with full context length.") f"{backend} does not support gemma2 with full context length.")
@ -98,11 +108,15 @@ def test_models(
prompt_embeds = hf_model.get_prompt_embeddings( prompt_embeds = hf_model.get_prompt_embeddings(
example_prompts) example_prompts)
with VllmRunner(model, with VllmRunner(
model,
max_model_len=8192, max_model_len=8192,
enforce_eager=enforce_eager, enforce_eager=enforce_eager,
enable_prompt_embeds=enable_prompt_embeds, enable_prompt_embeds=enable_prompt_embeds,
gpu_memory_utilization=0.7) as vllm_model: gpu_memory_utilization=0.7,
async_scheduling=async_scheduling,
distributed_executor_backend=model_executor,
) as vllm_model:
if enable_prompt_embeds: if enable_prompt_embeds:
vllm_outputs = vllm_model.generate_greedy( vllm_outputs = vllm_model.generate_greedy(
prompt_embeds, max_tokens) prompt_embeds, max_tokens)

View File

@ -257,9 +257,13 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
def execute_model( def execute_model(
self, self,
scheduler_output, scheduler_output,
non_block=False,
) -> Future[ModelRunnerOutput]: ) -> Future[ModelRunnerOutput]:
"""Make execute_model non-blocking.""" """Make execute_model non-blocking."""
# DummyExecutor used only for testing async case.
assert non_block
def _execute(): def _execute():
output = self.collective_rpc("execute_model", output = self.collective_rpc("execute_model",
args=(scheduler_output, )) args=(scheduler_output, ))

View File

@ -1296,11 +1296,8 @@ class EngineArgs:
# Async scheduling does not work with the uniprocess backend. # Async scheduling does not work with the uniprocess backend.
if self.distributed_executor_backend is None: if self.distributed_executor_backend is None:
self.distributed_executor_backend = "mp" self.distributed_executor_backend = "mp"
logger.info("Using mp-based distributed executor backend " logger.info("Defaulting to mp-based distributed executor "
"for async scheduling.") "backend for async scheduling.")
if self.distributed_executor_backend == "uni":
raise ValueError("Async scheduling is not supported with "
"uni-process backend.")
if self.pipeline_parallel_size > 1: if self.pipeline_parallel_size > 1:
raise ValueError("Async scheduling is not supported with " raise ValueError("Async scheduling is not supported with "
"pipeline-parallel-size > 1.") "pipeline-parallel-size > 1.")

View File

@ -1,7 +1,8 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os import os
from concurrent.futures import Future, ThreadPoolExecutor
from functools import cached_property
from multiprocessing import Lock from multiprocessing import Lock
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
@ -17,6 +18,7 @@ from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
run_method) run_method)
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.executor.utils import get_and_update_mm_cache from vllm.v1.executor.utils import get_and_update_mm_cache
from vllm.v1.outputs import AsyncModelRunnerOutput
from vllm.worker.worker_base import WorkerWrapperBase from vllm.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__) logger = init_logger(__name__)
@ -31,15 +33,7 @@ class UniProcExecutor(ExecutorBase):
""" """
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config, self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config,
rpc_rank=0) rpc_rank=0)
distributed_init_method = get_distributed_init_method( distributed_init_method, rank, local_rank = self._distributed_args()
get_ip(), get_open_port())
local_rank = 0
# set local rank as the device index if specified
device_info = self.vllm_config.device_config.device.__str__().split(
":")
if len(device_info) > 1:
local_rank = int(device_info[1])
rank = 0
is_driver_worker = True is_driver_worker = True
kwargs = dict( kwargs = dict(
vllm_config=self.vllm_config, vllm_config=self.vllm_config,
@ -50,21 +44,56 @@ class UniProcExecutor(ExecutorBase):
) )
self.mm_receiver_cache = worker_receiver_cache_from_config( self.mm_receiver_cache = worker_receiver_cache_from_config(
self.vllm_config, MULTIMODAL_REGISTRY, Lock()) self.vllm_config, MULTIMODAL_REGISTRY, Lock())
self.async_output_thread: Optional[ThreadPoolExecutor] = None
if self.max_concurrent_batches > 1:
self.async_output_thread = ThreadPoolExecutor(
max_workers=1, thread_name_prefix="WorkerAsyncOutput")
self.collective_rpc("init_worker", args=([kwargs], )) self.collective_rpc("init_worker", args=([kwargs], ))
self.collective_rpc("init_device") self.collective_rpc("init_device")
self.collective_rpc("load_model") self.collective_rpc("load_model")
def _distributed_args(self) -> tuple[str, int, int]:
"""Return (distributed_init_method, rank, local_rank)."""
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
# set local rank as the device index if specified
device_info = self.vllm_config.device_config.device.__str__().split(
":")
local_rank = int(device_info[1]) if len(device_info) > 1 else 0
return distributed_init_method, 0, local_rank
@cached_property
def max_concurrent_batches(self) -> int:
return 2 if self.scheduler_config.async_scheduling else 1
def collective_rpc(self, def collective_rpc(self,
method: Union[str, Callable], method: Union[str, Callable],
timeout: Optional[float] = None, timeout: Optional[float] = None,
args: Tuple = (), args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]: kwargs: Optional[Dict] = None,
non_block: bool = False) -> List[Any]:
if kwargs is None: if kwargs is None:
kwargs = {} kwargs = {}
if self.mm_receiver_cache is not None and method == "execute_model": if self.mm_receiver_cache is not None and method == "execute_model":
get_and_update_mm_cache(self.mm_receiver_cache, args) get_and_update_mm_cache(self.mm_receiver_cache, args)
answer = run_method(self.driver_worker, method, args, kwargs)
return [answer] if not non_block:
return [run_method(self.driver_worker, method, args, kwargs)]
try:
result = run_method(self.driver_worker, method, args, kwargs)
if isinstance(result, AsyncModelRunnerOutput):
if (async_thread := self.async_output_thread) is not None:
return [async_thread.submit(result.get_output)]
result = result.get_output()
future = Future[Any]()
future.set_result(result)
except Exception as e:
future = Future[Any]()
future.set_exception(e)
return [future]
def check_health(self) -> None: def check_health(self) -> None:
# UniProcExecutor will always be healthy as long as # UniProcExecutor will always be healthy as long as
@ -116,8 +145,9 @@ class ExecutorWithExternalLauncher(UniProcExecutor):
assert not envs.VLLM_ENABLE_V1_MULTIPROCESSING, \ assert not envs.VLLM_ENABLE_V1_MULTIPROCESSING, \
("To get deterministic execution in V1, " ("To get deterministic execution in V1, "
"please set VLLM_ENABLE_V1_MULTIPROCESSING=0") "please set VLLM_ENABLE_V1_MULTIPROCESSING=0")
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config, super()._init_executor()
rpc_rank=0)
def _distributed_args(self) -> tuple[str, int, int]:
# engines are launched in torchrun-compatible launchers # engines are launched in torchrun-compatible launchers
# so we can use the env:// method. # so we can use the env:// method.
# required env vars: # required env vars:
@ -128,19 +158,7 @@ class ExecutorWithExternalLauncher(UniProcExecutor):
distributed_init_method = "env://" distributed_init_method = "env://"
rank = int(os.environ["RANK"]) rank = int(os.environ["RANK"])
local_rank = int(os.environ["LOCAL_RANK"]) local_rank = int(os.environ["LOCAL_RANK"])
is_driver_worker = True return distributed_init_method, rank, local_rank
kwargs = dict(
vllm_config=self.vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
is_driver_worker=is_driver_worker,
)
self.mm_receiver_cache = worker_receiver_cache_from_config(
self.vllm_config, MULTIMODAL_REGISTRY, Lock())
self.collective_rpc("init_worker", args=([kwargs], ))
self.collective_rpc("init_device")
self.collective_rpc("load_model")
def determine_num_available_blocks(self) -> Tuple[int, int]: def determine_num_available_blocks(self) -> Tuple[int, int]:
""" """

View File

@ -159,6 +159,9 @@ class EngineCore:
self.request_block_hasher = get_request_block_hasher( self.request_block_hasher = get_request_block_hasher(
block_size, caching_hash_fn) block_size, caching_hash_fn)
self.step_fn = (self.step if self.batch_queue is None else
self.step_with_batch_queue)
def _initialize_kv_caches( def _initialize_kv_caches(
self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]: self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]:
start = time.time() start = time.time()
@ -331,7 +334,8 @@ class EngineCore:
model_executed = False model_executed = False
if self.scheduler.has_requests(): if self.scheduler.has_requests():
scheduler_output = self.scheduler.schedule() scheduler_output = self.scheduler.schedule()
future = self.model_executor.execute_model(scheduler_output) future = self.model_executor.execute_model(scheduler_output,
non_block=True)
batch_queue.appendleft( batch_queue.appendleft(
(future, scheduler_output)) # type: ignore[arg-type] (future, scheduler_output)) # type: ignore[arg-type]
@ -534,9 +538,6 @@ class EngineCoreProc(EngineCore):
assert addresses.coordinator_input is not None assert addresses.coordinator_input is not None
logger.info("Waiting for READY message from DP Coordinator...") logger.info("Waiting for READY message from DP Coordinator...")
self.step_fn = (self.step if self.batch_queue is None else
self.step_with_batch_queue)
# Mark the startup heap as static so that it's ignored by GC. # Mark the startup heap as static so that it's ignored by GC.
# Reduces pause times of oldest generation collections. # Reduces pause times of oldest generation collections.
gc.collect() gc.collect()

View File

@ -245,8 +245,8 @@ class InprocClient(EngineCoreClient):
self.engine_core = EngineCore(*args, **kwargs) self.engine_core = EngineCore(*args, **kwargs)
def get_output(self) -> EngineCoreOutputs: def get_output(self) -> EngineCoreOutputs:
outputs, _ = self.engine_core.step() outputs, _ = self.engine_core.step_fn()
return outputs.get(0) or EngineCoreOutputs() return outputs and outputs.get(0) or EngineCoreOutputs()
def get_supported_tasks(self) -> tuple[SupportedTask, ...]: def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return self.engine_core.get_supported_tasks() return self.engine_core.get_supported_tasks()

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from concurrent.futures import Future from concurrent.futures import Future
from typing import Callable, Optional, Union from typing import Any, Callable, Optional, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -14,6 +14,7 @@ from vllm.executor.uniproc_executor import ( # noqa
from vllm.executor.uniproc_executor import ( # noqa from vllm.executor.uniproc_executor import ( # noqa
UniProcExecutor as UniProcExecutorV0) UniProcExecutor as UniProcExecutorV0)
from vllm.utils import resolve_obj_by_qualname from vllm.utils import resolve_obj_by_qualname
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
@ -86,12 +87,22 @@ class Executor(ExecutorBase):
def get_kv_cache_specs(self) -> list[dict[str, KVCacheSpec]]: def get_kv_cache_specs(self) -> list[dict[str, KVCacheSpec]]:
return self.collective_rpc("get_kv_cache_spec") return self.collective_rpc("get_kv_cache_spec")
def collective_rpc(self,
method: Union[str, Callable],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict] = None,
non_block: bool = False) -> list[Any]:
raise NotImplementedError
def execute_model( def execute_model(
self, self,
scheduler_output, scheduler_output: SchedulerOutput,
non_block: bool = False,
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]: ) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
output = self.collective_rpc("execute_model", output = self.collective_rpc("execute_model",
args=(scheduler_output, )) args=(scheduler_output, ),
non_block=non_block)
return output[0] return output[0]
def execute_dummy_batch(self) -> None: def execute_dummy_batch(self) -> None:

View File

@ -11,7 +11,7 @@ import weakref
from concurrent.futures import Future, ThreadPoolExecutor from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
from functools import partial from functools import cached_property, partial
from multiprocessing.connection import Connection from multiprocessing.connection import Connection
from multiprocessing.process import BaseProcess from multiprocessing.process import BaseProcess
from multiprocessing.synchronize import Lock as LockType from multiprocessing.synchronize import Lock as LockType
@ -37,6 +37,7 @@ from vllm.multimodal.cache import worker_receiver_cache_from_config
from vllm.utils import (decorate_logs, get_distributed_init_method, from vllm.utils import (decorate_logs, get_distributed_init_method,
get_loopback_ip, get_mp_context, get_open_port, get_loopback_ip, get_mp_context, get_open_port,
set_process_title) set_process_title)
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.executor.abstract import Executor, FailureCallback from vllm.v1.executor.abstract import Executor, FailureCallback
from vllm.v1.executor.utils import get_and_update_mm_cache from vllm.v1.executor.utils import get_and_update_mm_cache
from vllm.v1.outputs import (AsyncModelRunnerOutput, DraftTokenIds, from vllm.v1.outputs import (AsyncModelRunnerOutput, DraftTokenIds,
@ -174,9 +175,9 @@ class MultiprocExecutor(Executor):
def execute_model( def execute_model(
self, self,
scheduler_output, scheduler_output: SchedulerOutput,
non_block: bool = False,
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]: ) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
non_block = self.max_concurrent_batches > 1
if not self.has_connector: if not self.has_connector:
# get output only from a single worker (output_rank) # get output only from a single worker (output_rank)
@ -328,7 +329,7 @@ class MultiprocExecutor(Executor):
self.collective_rpc("check_health", timeout=10) self.collective_rpc("check_health", timeout=10)
return return
@property @cached_property
def max_concurrent_batches(self) -> int: def max_concurrent_batches(self) -> int:
if self.scheduler_config.async_scheduling: if self.scheduler_config.async_scheduling:
return 2 return 2
@ -632,7 +633,8 @@ class WorkerProc:
result = (WorkerProc.ResponseStatus.FAILURE, str(output)) result = (WorkerProc.ResponseStatus.FAILURE, str(output))
else: else:
result = (WorkerProc.ResponseStatus.SUCCESS, output) result = (WorkerProc.ResponseStatus.SUCCESS, output)
self.worker_response_mq.enqueue(result) if (response_mq := self.worker_response_mq) is not None:
response_mq.enqueue(result)
def handle_output(self, output: Any): def handle_output(self, output: Any):
"""Handles output from the worker. If async scheduling is enabled, """Handles output from the worker. If async scheduling is enabled,

View File

@ -66,11 +66,13 @@ class RayDistributedExecutor(RayDistributedExecutorV0, Executor):
def execute_model( def execute_model(
self, self,
scheduler_output: SchedulerOutput, scheduler_output: SchedulerOutput,
non_block: bool = False,
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]: ) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
"""Execute the model on the Ray workers. """Execute the model on the Ray workers.
Args: Args:
scheduler_output: The scheduler output to execute. scheduler_output: The scheduler output to execute.
non_block: If True, the method will return a Future.
Returns: Returns:
The model runner output. The model runner output.
@ -84,7 +86,7 @@ class RayDistributedExecutor(RayDistributedExecutorV0, Executor):
if not self.has_connector: if not self.has_connector:
# Get output only from a single worker (output_rank) # Get output only from a single worker (output_rank)
# When PP is not used, we block here until the result is available. # When PP is not used, we block here until the result is available.
if self.max_concurrent_batches == 1: if not non_block:
return refs[0].get() return refs[0].get()
# When PP is used, we return a FutureWrapper immediately so that # When PP is used, we return a FutureWrapper immediately so that
@ -92,7 +94,7 @@ class RayDistributedExecutor(RayDistributedExecutorV0, Executor):
return FutureWrapper(refs) return FutureWrapper(refs)
# Get output from all workers when connector is present # Get output from all workers when connector is present
if self.max_concurrent_batches == 1: if not non_block:
# Block and get results from all workers # Block and get results from all workers
outputs = [ref.get() for ref in refs] outputs = [ref.get() for ref in refs]
return self.kv_output_aggregator.aggregate(outputs) return self.kv_output_aggregator.aggregate(outputs)
@ -106,4 +108,3 @@ class RayDistributedExecutor(RayDistributedExecutorV0, Executor):
if reconfig_request.new_data_parallel_rank == \ if reconfig_request.new_data_parallel_rank == \
ReconfigureRankType.SHUTDOWN_CURRENT_RANK: ReconfigureRankType.SHUTDOWN_CURRENT_RANK:
self.shutdown() self.shutdown()
return