From 67a2da890eef2a6fd40384aa5ae80e03beb39490 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 7 Nov 2025 14:11:03 -0800 Subject: [PATCH] [PerfFix] Avoid separate thread for MP executor shm spin (take 2) (#28319) Signed-off-by: Nick Hill --- tests/v1/executor/test_executor.py | 3 +- .../unit/test_output_aggregator.py | 32 ++--- .../kv_transfer/kv_connector/utils.py | 47 +++--- vllm/v1/executor/abstract.py | 4 +- vllm/v1/executor/multiproc_executor.py | 134 +++++++++--------- vllm/v1/executor/ray_executor.py | 11 +- vllm/v1/executor/ray_utils.py | 11 +- vllm/v1/executor/uniproc_executor.py | 43 +++++- vllm/v1/worker/gpu_worker.py | 2 +- 9 files changed, 156 insertions(+), 131 deletions(-) diff --git a/tests/v1/executor/test_executor.py b/tests/v1/executor/test_executor.py index 56574124b272..91bfba6826e0 100644 --- a/tests/v1/executor/test_executor.py +++ b/tests/v1/executor/test_executor.py @@ -4,6 +4,7 @@ import asyncio import os from collections.abc import Callable +from concurrent.futures import Future from typing import Any import pytest @@ -27,7 +28,7 @@ class CustomMultiprocExecutor(MultiprocExecutor): kwargs: dict | None = None, non_block: bool = False, unique_reply_rank: int | None = None, - ) -> list[Any]: + ) -> Any | list[Any] | Future[Any | list[Any]]: # Drop marker to show that this was run with open(".marker", "w"): ... diff --git a/tests/v1/kv_connector/unit/test_output_aggregator.py b/tests/v1/kv_connector/unit/test_output_aggregator.py index 4dba203ebc7d..d186f677c02f 100644 --- a/tests/v1/kv_connector/unit/test_output_aggregator.py +++ b/tests/v1/kv_connector/unit/test_output_aggregator.py @@ -89,14 +89,12 @@ def test_aggregate_workers_output(): def test_async_aggregate_workers_output(): aggregator = KVOutputAggregator(expected_finished_count=2) - future1: Future[DummyModelRunnerOutput] = Future() - future2: Future[DummyModelRunnerOutput] = Future() - result_future = aggregator.async_aggregate([future1, future2]) + future: Future[list[DummyModelRunnerOutput]] = Future() + result_future = aggregator.async_aggregate(future) output1 = DummyModelRunnerOutput() output2 = DummyModelRunnerOutput() - future1.set_result(output1) - future2.set_result(output2) + future.set_result([output1, output2]) assert result_future.done() aggregated = result_future.result() @@ -106,16 +104,14 @@ def test_async_aggregate_workers_output(): assert aggregated.finished_recving is None assert not aggregated.invalid_block_ids - future1 = Future() - future2 = Future() - result_future = aggregator.async_aggregate([future1, future2]) + future = Future() + result_future = aggregator.async_aggregate(future) output1 = DummyModelRunnerOutput( finished_sending={"req1"}, finished_recving={"req2"} ) output2 = DummyModelRunnerOutput(invalid_block_ids={1}) - future1.set_result(output1) - future2.set_result(output2) + future.set_result([output1, output2]) assert result_future.done() aggregated = result_future.result() @@ -125,14 +121,12 @@ def test_async_aggregate_workers_output(): assert aggregated.finished_recving is None assert aggregated.invalid_block_ids == {1} - future1 = Future() - future2 = Future() - result_future = aggregator.async_aggregate([future1, future2]) + future = Future() + result_future = aggregator.async_aggregate(future) output1 = DummyModelRunnerOutput(invalid_block_ids={2}) output2 = DummyModelRunnerOutput(finished_sending={"req1"}) - future1.set_result(output1) - future2.set_result(output2) + future.set_result([output1, output2]) assert result_future.done() aggregated = result_future.result() @@ -142,16 +136,14 @@ def test_async_aggregate_workers_output(): assert aggregated.finished_recving is None assert aggregated.invalid_block_ids == {2} - future1 = Future() - future2 = Future() - result_future = aggregator.async_aggregate([future1, future2]) + future = Future() + result_future = aggregator.async_aggregate(future) output1 = DummyModelRunnerOutput(invalid_block_ids={3, 4}) output2 = DummyModelRunnerOutput( finished_recving={"req2"}, invalid_block_ids={4, 5} ) - future1.set_result(output1) - future2.set_result(output2) + future.set_result([output1, output2]) assert result_future.done() aggregated = result_future.result() diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 7464f8469c3b..33a801e135d4 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -4,6 +4,7 @@ KV cache helper for store. """ +import contextlib from collections.abc import Sequence from concurrent.futures import CancelledError, Future from typing import TYPE_CHECKING, Literal @@ -221,38 +222,38 @@ class KVOutputAggregator: def async_aggregate( self, - output_futures: Sequence[Future[ModelRunnerOutput | None]], + output_future: Future[Sequence[ModelRunnerOutput | None]], output_rank: int = 0, ) -> Future[ModelRunnerOutput | None]: - """Takes a list of futures and returns a single future which resolves - to the respective list of outputs.""" + """Takes a future that resolves to a list of outputs and returns a future + which resolves to a single aggregated output.""" result_future: Future[ModelRunnerOutput | None] = Future() - outputs: list[ModelRunnerOutput | None] = [None] * len(output_futures) - remaining = len(output_futures) + def callback(fut): + if result_future.done(): + return + try: + result_future.set_result(self.aggregate(fut.result(), output_rank)) + except CancelledError: + result_future.cancel() + except Exception as e: + result_future.set_exception(e) - def make_callback(idx): - def callback(fut): - if result_future.done(): - return + output_future.add_done_callback(callback) - try: - outputs[idx] = fut.result() - except CancelledError: - result_future.cancel() - except Exception as e: - result_future.set_exception(e) + from vllm.v1.executor.multiproc_executor import FutureWrapper - # this check assumes io_thread_pool uses a single thread - nonlocal remaining - remaining -= 1 - if not remaining: - result_future.set_result(self.aggregate(outputs, output_rank)) + if isinstance(output_future, FutureWrapper): + # Due to the threadless implementation of multiproc FutureWrapper, + # we must block on the delegate future's result() method. + delegate_result = result_future.result - return callback + def result(timeout=None): + with contextlib.suppress(Exception): + output_future.result(timeout=timeout) + return delegate_result() - for i, output_future in enumerate(output_futures): - output_future.add_done_callback(make_callback(i)) + result_future.result = result # type: ignore[method-assign] return result_future diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index d76c6107ad2b..1e913876b763 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -171,7 +171,7 @@ class Executor(ABC): args: tuple = (), kwargs: dict | None = None, non_block: Literal[True] = True, - ) -> list[Future[_R]]: + ) -> Future[list[_R]]: pass @abstractmethod @@ -219,7 +219,7 @@ class Executor(ABC): def sample_tokens( self, grammar_output: GrammarOutput | None, non_block: bool = False - ) -> ModelRunnerOutput | Future[ModelRunnerOutput]: + ) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]: output = self.collective_rpc( # type: ignore[call-overload] "sample_tokens", args=(grammar_output,), non_block=non_block ) diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 999a3ba870ea..c9a50ecaa1de 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -9,8 +9,10 @@ import threading import time import traceback import weakref +from collections import deque from collections.abc import Callable -from concurrent.futures import Future, ThreadPoolExecutor +from concurrent.futures import Future, InvalidStateError +from contextlib import suppress from dataclasses import dataclass from enum import Enum, auto from functools import cached_property, partial @@ -54,6 +56,30 @@ from vllm.v1.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) +class FutureWrapper(Future): + def __init__(self, futures_queue: deque[tuple["FutureWrapper", Callable]]): + self.futures_queue = futures_queue + super().__init__() + + def result(self, timeout=None): + if timeout is not None: + raise RuntimeError("timeout not implemented") + # Drain any futures ahead of us in the queue. + while not self.done(): + future, get_response = self.futures_queue.pop() + future.wait_for_response(get_response) + return super().result() + + def wait_for_response(self, get_response: Callable): + try: + response = get_response() + with suppress(InvalidStateError): + self.set_result(response) + except Exception as e: + with suppress(InvalidStateError): + self.set_exception(e) + + class MultiprocExecutor(Executor): supports_pp: bool = True @@ -64,7 +90,6 @@ class MultiprocExecutor(Executor): self.is_failed = False self.shutdown_event = threading.Event() self.failure_callback: FailureCallback | None = None - self.io_thread_pool: ThreadPoolExecutor | None = None self.world_size = self.parallel_config.world_size tensor_parallel_size = self.parallel_config.tensor_parallel_size @@ -132,12 +157,7 @@ class MultiprocExecutor(Executor): uw.death_writer.close() self._ensure_worker_termination([uw.proc for uw in unready_workers]) - # Note: must use only 1 IO thread to keep dequeue sequence - # from the response queue. - # _async_aggregate_workers_output also assumes a single IO thread. - self.io_thread_pool = ThreadPoolExecutor( - max_workers=1, thread_name_prefix="mp_exec_io" - ) + self.futures_queue = deque[tuple[FutureWrapper, Callable]]() self.output_rank = self._get_output_rank() self.has_connector = self.vllm_config.kv_transfer_config is not None @@ -195,14 +215,13 @@ class MultiprocExecutor(Executor): ) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]: if not self.has_connector: # get output only from a single worker (output_rank) - (output,) = self.collective_rpc( + return self.collective_rpc( method, args=args, unique_reply_rank=self.output_rank, non_block=non_block, timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS, ) - return output # get output from all workers outputs = self.collective_rpc( @@ -223,12 +242,11 @@ class MultiprocExecutor(Executor): def take_draft_token_ids(self) -> DraftTokenIds | None: # OPTIMIZATION: Get output only from a single worker (output_rank) - outputs = self.collective_rpc( + return self.collective_rpc( "take_draft_token_ids", unique_reply_rank=self.output_rank ) - return outputs[0] - def collective_rpc( + def collective_rpc( # type: ignore[override] self, method: str | Callable, timeout: float | None = None, @@ -236,7 +254,9 @@ class MultiprocExecutor(Executor): kwargs: dict | None = None, non_block: bool = False, unique_reply_rank: int | None = None, - ) -> list[Any]: + ) -> Any | list[Any] | Future[Any | list[Any]]: + """Returns single result if unique_reply_rank is provided, otherwise list.""" + if self.is_failed: raise RuntimeError("Executor failed.") @@ -246,63 +266,52 @@ class MultiprocExecutor(Executor): # NOTE: If the args are heterogeneous, then we pack them into a list, # and unpack them in the method of every worker, because every worker # knows their own rank. - try: - if isinstance(method, str): - send_method = method - else: - send_method = cloudpickle.dumps( - method, protocol=pickle.HIGHEST_PROTOCOL - ) - self.rpc_broadcast_mq.enqueue( - (send_method, args, kwargs, unique_reply_rank) - ) - workers = ( - (self.workers[unique_reply_rank],) - if unique_reply_rank is not None - else self.workers - ) + if isinstance(method, str): + send_method = method + else: + send_method = cloudpickle.dumps(method, protocol=pickle.HIGHEST_PROTOCOL) + self.rpc_broadcast_mq.enqueue((send_method, args, kwargs, unique_reply_rank)) + + workers = ( + (self.workers[unique_reply_rank],) + if unique_reply_rank is not None + else self.workers + ) + + shutdown_event = self.shutdown_event + + def get_response(): responses = [] - - def get_response( - w: WorkerProcHandle, - dequeue_timeout: float | None = None, - cancel_event: threading.Event | None = None, - ): - status, result = w.worker_response_mq.dequeue( - timeout=dequeue_timeout, cancel=cancel_event + for w in workers: + dequeue_timeout = ( + None if deadline is None else (deadline - time.monotonic()) ) - + try: + status, result = w.worker_response_mq.dequeue( + timeout=dequeue_timeout, cancel=shutdown_event + ) + except TimeoutError as e: + raise TimeoutError(f"RPC call to {method} timed out.") from e if status != WorkerProc.ResponseStatus.SUCCESS: raise RuntimeError( f"Worker failed with error '{result}', please check the" " stack trace above for the root cause" ) - return result - - for w in workers: - dequeue_timeout = ( - None if deadline is None else (deadline - time.monotonic()) - ) - - if self.io_thread_pool is not None: - # We must consume worker_response_mq from a single thread. - result = self.io_thread_pool.submit( # type: ignore - get_response, w, dequeue_timeout, self.shutdown_event - ) - if not non_block: - result = result.result() - elif not non_block: - result = get_response(w, dequeue_timeout, self.shutdown_event) - else: - raise RuntimeError( - "non_block can only be used when max_concurrent_batches > 1" - ) responses.append(result) + return responses[0] if unique_reply_rank is not None else responses - return responses - except TimeoutError as e: - raise TimeoutError(f"RPC call to {method} timed out.") from e + if non_block: + future = FutureWrapper(self.futures_queue) + self.futures_queue.appendleft((future, get_response)) + return future + + # First drain any pending futures in the queue. + while self.futures_queue: + future, get_fut_response = self.futures_queue.pop() + future.wait_for_response(get_fut_response) + + return get_response() @staticmethod def _ensure_worker_termination(worker_procs: list[BaseProcess]): @@ -348,9 +357,6 @@ class MultiprocExecutor(Executor): self._ensure_worker_termination([w.proc for w in workers]) self.shutdown_event.set() - if self.io_thread_pool is not None: - self.io_thread_pool.shutdown(wait=False, cancel_futures=True) - del self.io_thread_pool self.rpc_broadcast_mq = None diff --git a/vllm/v1/executor/ray_executor.py b/vllm/v1/executor/ray_executor.py index 4a69cca723ac..119e4c081831 100644 --- a/vllm/v1/executor/ray_executor.py +++ b/vllm/v1/executor/ray_executor.py @@ -435,26 +435,25 @@ class RayDistributedExecutor(Executor): # When PP is used, we return a FutureWrapper immediately so that # the scheduler can yield to the next batch. - return FutureWrapper(refs) + return FutureWrapper(refs[0]) # Get output from all workers when connector is present assert self.kv_output_aggregator is not None if not non_block: # Block and get results from all workers - outputs = [ref.get() for ref in refs] - return self.kv_output_aggregator.aggregate(outputs) + return self.kv_output_aggregator.aggregate(ray.get(refs)) # Return a future that will aggregate outputs from all workers return FutureWrapper(refs, self.kv_output_aggregator) - def collective_rpc( + def collective_rpc( # type: ignore[override] self, method: str | Callable, timeout: float | None = None, args: tuple = (), kwargs: dict[str, Any] | None = None, non_block: bool = False, - ) -> list[Any]: + ) -> list[Any] | Future[list[Any]]: """Runs the given method on all workers.""" sent_method = method if isinstance(method, str) else cloudpickle.dumps(method) del method @@ -470,7 +469,7 @@ class RayDistributedExecutor(Executor): # Get the results of the ray workers. if non_block: - return [FutureWrapper((output,)) for output in ray_worker_outputs] + return FutureWrapper(ray_worker_outputs) return ray.get(ray_worker_outputs, timeout=timeout) diff --git a/vllm/v1/executor/ray_utils.py b/vllm/v1/executor/ray_utils.py index a282cdc9909d..21910d1160bd 100644 --- a/vllm/v1/executor/ray_utils.py +++ b/vllm/v1/executor/ray_utils.py @@ -141,19 +141,16 @@ class FutureWrapper(Future): the result() call. If not only the first worker's output is returned. """ - def __init__(self, refs, aggregator: KVOutputAggregator | None = None): + def __init__(self, ref_or_refs, aggregator: KVOutputAggregator | None = None): super().__init__() - self.refs = refs + self.ref_or_refs = ref_or_refs self.aggregator = aggregator def result(self, timeout=None): - if timeout is not None: - raise NotImplementedError("timeout is not supported") - + outputs = ray.get(self.ref_or_refs, timeout=timeout) if self.aggregator is None: - return self.refs[0].get() + return outputs - outputs = [ref.get() for ref in self.refs] return self.aggregator.aggregate(outputs, output_rank=0) diff --git a/vllm/v1/executor/uniproc_executor.py b/vllm/v1/executor/uniproc_executor.py index 32f00949b7f7..095d3d1dac21 100644 --- a/vllm/v1/executor/uniproc_executor.py +++ b/vllm/v1/executor/uniproc_executor.py @@ -13,9 +13,10 @@ import torch.distributed as dist import vllm.envs as envs from vllm.logger import init_logger from vllm.utils.network_utils import get_distributed_init_method, get_ip, get_open_port +from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType from vllm.v1.executor.abstract import Executor -from vllm.v1.outputs import AsyncModelRunnerOutput +from vllm.v1.outputs import AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput from vllm.v1.serial_utils import run_method from vllm.v1.worker.worker_base import WorkerWrapperBase @@ -58,32 +59,60 @@ class UniProcExecutor(Executor): def max_concurrent_batches(self) -> int: return 2 if self.scheduler_config.async_scheduling else 1 - def collective_rpc( + def collective_rpc( # type: ignore[override] self, method: str | Callable, timeout: float | None = None, args: tuple = (), kwargs: dict | None = None, non_block: bool = False, - ) -> list[Any]: + single_value: bool = False, + ) -> Any | list[Any] | Future[Any | list[Any]]: if kwargs is None: kwargs = {} if not non_block: - return [run_method(self.driver_worker, method, args, kwargs)] + result = run_method(self.driver_worker, method, args, kwargs) + return result if single_value else [result] try: result = run_method(self.driver_worker, method, args, kwargs) if isinstance(result, AsyncModelRunnerOutput): if (async_thread := self.async_output_thread) is not None: - return [async_thread.submit(result.get_output)] + get_output = result.get_output + if not single_value: + get_output = lambda go=result.get_output: [go()] + return async_thread.submit(get_output) result = result.get_output() future = Future[Any]() - future.set_result(result) + future.set_result(result if single_value else [result]) except Exception as e: future = Future[Any]() future.set_exception(e) - return [future] + return future + + def execute_model( # type: ignore[override] + self, scheduler_output: SchedulerOutput, non_block: bool = False + ) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]: + return self.collective_rpc( + "execute_model", + args=(scheduler_output,), + non_block=non_block, + single_value=True, + ) + + def sample_tokens( # type: ignore[override] + self, grammar_output: GrammarOutput | None, non_block: bool = False + ) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]: + return self.collective_rpc( + "sample_tokens", + args=(grammar_output,), + non_block=non_block, + single_value=True, + ) + + def take_draft_token_ids(self) -> DraftTokenIds | None: + return self.collective_rpc("take_draft_token_ids", single_value=True) def check_health(self) -> None: # UniProcExecutor will always be healthy as long as diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 160beb1292e9..f13ff4e726bd 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -524,7 +524,7 @@ class Worker(WorkerBase): @torch.inference_mode() def sample_tokens( - self, grammar_output: "GrammarOutput" + self, grammar_output: "GrammarOutput | None" ) -> ModelRunnerOutput | AsyncModelRunnerOutput: return self.model_runner.sample_tokens(grammar_output)