[PerfFix] Avoid separate thread for MP executor shm spin (take 2) (#28319)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-11-07 14:11:03 -08:00 committed by GitHub
parent da786e339e
commit 67a2da890e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 156 additions and 131 deletions

View File

@ -4,6 +4,7 @@
import asyncio
import os
from collections.abc import Callable
from concurrent.futures import Future
from typing import Any
import pytest
@ -27,7 +28,7 @@ class CustomMultiprocExecutor(MultiprocExecutor):
kwargs: dict | None = None,
non_block: bool = False,
unique_reply_rank: int | None = None,
) -> list[Any]:
) -> Any | list[Any] | Future[Any | list[Any]]:
# Drop marker to show that this was run
with open(".marker", "w"):
...

View File

@ -89,14 +89,12 @@ def test_aggregate_workers_output():
def test_async_aggregate_workers_output():
aggregator = KVOutputAggregator(expected_finished_count=2)
future1: Future[DummyModelRunnerOutput] = Future()
future2: Future[DummyModelRunnerOutput] = Future()
result_future = aggregator.async_aggregate([future1, future2])
future: Future[list[DummyModelRunnerOutput]] = Future()
result_future = aggregator.async_aggregate(future)
output1 = DummyModelRunnerOutput()
output2 = DummyModelRunnerOutput()
future1.set_result(output1)
future2.set_result(output2)
future.set_result([output1, output2])
assert result_future.done()
aggregated = result_future.result()
@ -106,16 +104,14 @@ def test_async_aggregate_workers_output():
assert aggregated.finished_recving is None
assert not aggregated.invalid_block_ids
future1 = Future()
future2 = Future()
result_future = aggregator.async_aggregate([future1, future2])
future = Future()
result_future = aggregator.async_aggregate(future)
output1 = DummyModelRunnerOutput(
finished_sending={"req1"}, finished_recving={"req2"}
)
output2 = DummyModelRunnerOutput(invalid_block_ids={1})
future1.set_result(output1)
future2.set_result(output2)
future.set_result([output1, output2])
assert result_future.done()
aggregated = result_future.result()
@ -125,14 +121,12 @@ def test_async_aggregate_workers_output():
assert aggregated.finished_recving is None
assert aggregated.invalid_block_ids == {1}
future1 = Future()
future2 = Future()
result_future = aggregator.async_aggregate([future1, future2])
future = Future()
result_future = aggregator.async_aggregate(future)
output1 = DummyModelRunnerOutput(invalid_block_ids={2})
output2 = DummyModelRunnerOutput(finished_sending={"req1"})
future1.set_result(output1)
future2.set_result(output2)
future.set_result([output1, output2])
assert result_future.done()
aggregated = result_future.result()
@ -142,16 +136,14 @@ def test_async_aggregate_workers_output():
assert aggregated.finished_recving is None
assert aggregated.invalid_block_ids == {2}
future1 = Future()
future2 = Future()
result_future = aggregator.async_aggregate([future1, future2])
future = Future()
result_future = aggregator.async_aggregate(future)
output1 = DummyModelRunnerOutput(invalid_block_ids={3, 4})
output2 = DummyModelRunnerOutput(
finished_recving={"req2"}, invalid_block_ids={4, 5}
)
future1.set_result(output1)
future2.set_result(output2)
future.set_result([output1, output2])
assert result_future.done()
aggregated = result_future.result()

View File

@ -4,6 +4,7 @@
KV cache helper for store.
"""
import contextlib
from collections.abc import Sequence
from concurrent.futures import CancelledError, Future
from typing import TYPE_CHECKING, Literal
@ -221,38 +222,38 @@ class KVOutputAggregator:
def async_aggregate(
self,
output_futures: Sequence[Future[ModelRunnerOutput | None]],
output_future: Future[Sequence[ModelRunnerOutput | None]],
output_rank: int = 0,
) -> Future[ModelRunnerOutput | None]:
"""Takes a list of futures and returns a single future which resolves
to the respective list of outputs."""
"""Takes a future that resolves to a list of outputs and returns a future
which resolves to a single aggregated output."""
result_future: Future[ModelRunnerOutput | None] = Future()
outputs: list[ModelRunnerOutput | None] = [None] * len(output_futures)
remaining = len(output_futures)
def callback(fut):
if result_future.done():
return
try:
result_future.set_result(self.aggregate(fut.result(), output_rank))
except CancelledError:
result_future.cancel()
except Exception as e:
result_future.set_exception(e)
def make_callback(idx):
def callback(fut):
if result_future.done():
return
output_future.add_done_callback(callback)
try:
outputs[idx] = fut.result()
except CancelledError:
result_future.cancel()
except Exception as e:
result_future.set_exception(e)
from vllm.v1.executor.multiproc_executor import FutureWrapper
# this check assumes io_thread_pool uses a single thread
nonlocal remaining
remaining -= 1
if not remaining:
result_future.set_result(self.aggregate(outputs, output_rank))
if isinstance(output_future, FutureWrapper):
# Due to the threadless implementation of multiproc FutureWrapper,
# we must block on the delegate future's result() method.
delegate_result = result_future.result
return callback
def result(timeout=None):
with contextlib.suppress(Exception):
output_future.result(timeout=timeout)
return delegate_result()
for i, output_future in enumerate(output_futures):
output_future.add_done_callback(make_callback(i))
result_future.result = result # type: ignore[method-assign]
return result_future

View File

@ -171,7 +171,7 @@ class Executor(ABC):
args: tuple = (),
kwargs: dict | None = None,
non_block: Literal[True] = True,
) -> list[Future[_R]]:
) -> Future[list[_R]]:
pass
@abstractmethod
@ -219,7 +219,7 @@ class Executor(ABC):
def sample_tokens(
self, grammar_output: GrammarOutput | None, non_block: bool = False
) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
output = self.collective_rpc( # type: ignore[call-overload]
"sample_tokens", args=(grammar_output,), non_block=non_block
)

View File

@ -9,8 +9,10 @@ import threading
import time
import traceback
import weakref
from collections import deque
from collections.abc import Callable
from concurrent.futures import Future, ThreadPoolExecutor
from concurrent.futures import Future, InvalidStateError
from contextlib import suppress
from dataclasses import dataclass
from enum import Enum, auto
from functools import cached_property, partial
@ -54,6 +56,30 @@ from vllm.v1.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__)
class FutureWrapper(Future):
def __init__(self, futures_queue: deque[tuple["FutureWrapper", Callable]]):
self.futures_queue = futures_queue
super().__init__()
def result(self, timeout=None):
if timeout is not None:
raise RuntimeError("timeout not implemented")
# Drain any futures ahead of us in the queue.
while not self.done():
future, get_response = self.futures_queue.pop()
future.wait_for_response(get_response)
return super().result()
def wait_for_response(self, get_response: Callable):
try:
response = get_response()
with suppress(InvalidStateError):
self.set_result(response)
except Exception as e:
with suppress(InvalidStateError):
self.set_exception(e)
class MultiprocExecutor(Executor):
supports_pp: bool = True
@ -64,7 +90,6 @@ class MultiprocExecutor(Executor):
self.is_failed = False
self.shutdown_event = threading.Event()
self.failure_callback: FailureCallback | None = None
self.io_thread_pool: ThreadPoolExecutor | None = None
self.world_size = self.parallel_config.world_size
tensor_parallel_size = self.parallel_config.tensor_parallel_size
@ -132,12 +157,7 @@ class MultiprocExecutor(Executor):
uw.death_writer.close()
self._ensure_worker_termination([uw.proc for uw in unready_workers])
# Note: must use only 1 IO thread to keep dequeue sequence
# from the response queue.
# _async_aggregate_workers_output also assumes a single IO thread.
self.io_thread_pool = ThreadPoolExecutor(
max_workers=1, thread_name_prefix="mp_exec_io"
)
self.futures_queue = deque[tuple[FutureWrapper, Callable]]()
self.output_rank = self._get_output_rank()
self.has_connector = self.vllm_config.kv_transfer_config is not None
@ -195,14 +215,13 @@ class MultiprocExecutor(Executor):
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
if not self.has_connector:
# get output only from a single worker (output_rank)
(output,) = self.collective_rpc(
return self.collective_rpc(
method,
args=args,
unique_reply_rank=self.output_rank,
non_block=non_block,
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS,
)
return output
# get output from all workers
outputs = self.collective_rpc(
@ -223,12 +242,11 @@ class MultiprocExecutor(Executor):
def take_draft_token_ids(self) -> DraftTokenIds | None:
# OPTIMIZATION: Get output only from a single worker (output_rank)
outputs = self.collective_rpc(
return self.collective_rpc(
"take_draft_token_ids", unique_reply_rank=self.output_rank
)
return outputs[0]
def collective_rpc(
def collective_rpc( # type: ignore[override]
self,
method: str | Callable,
timeout: float | None = None,
@ -236,7 +254,9 @@ class MultiprocExecutor(Executor):
kwargs: dict | None = None,
non_block: bool = False,
unique_reply_rank: int | None = None,
) -> list[Any]:
) -> Any | list[Any] | Future[Any | list[Any]]:
"""Returns single result if unique_reply_rank is provided, otherwise list."""
if self.is_failed:
raise RuntimeError("Executor failed.")
@ -246,63 +266,52 @@ class MultiprocExecutor(Executor):
# NOTE: If the args are heterogeneous, then we pack them into a list,
# and unpack them in the method of every worker, because every worker
# knows their own rank.
try:
if isinstance(method, str):
send_method = method
else:
send_method = cloudpickle.dumps(
method, protocol=pickle.HIGHEST_PROTOCOL
)
self.rpc_broadcast_mq.enqueue(
(send_method, args, kwargs, unique_reply_rank)
)
workers = (
(self.workers[unique_reply_rank],)
if unique_reply_rank is not None
else self.workers
)
if isinstance(method, str):
send_method = method
else:
send_method = cloudpickle.dumps(method, protocol=pickle.HIGHEST_PROTOCOL)
self.rpc_broadcast_mq.enqueue((send_method, args, kwargs, unique_reply_rank))
workers = (
(self.workers[unique_reply_rank],)
if unique_reply_rank is not None
else self.workers
)
shutdown_event = self.shutdown_event
def get_response():
responses = []
def get_response(
w: WorkerProcHandle,
dequeue_timeout: float | None = None,
cancel_event: threading.Event | None = None,
):
status, result = w.worker_response_mq.dequeue(
timeout=dequeue_timeout, cancel=cancel_event
for w in workers:
dequeue_timeout = (
None if deadline is None else (deadline - time.monotonic())
)
try:
status, result = w.worker_response_mq.dequeue(
timeout=dequeue_timeout, cancel=shutdown_event
)
except TimeoutError as e:
raise TimeoutError(f"RPC call to {method} timed out.") from e
if status != WorkerProc.ResponseStatus.SUCCESS:
raise RuntimeError(
f"Worker failed with error '{result}', please check the"
" stack trace above for the root cause"
)
return result
for w in workers:
dequeue_timeout = (
None if deadline is None else (deadline - time.monotonic())
)
if self.io_thread_pool is not None:
# We must consume worker_response_mq from a single thread.
result = self.io_thread_pool.submit( # type: ignore
get_response, w, dequeue_timeout, self.shutdown_event
)
if not non_block:
result = result.result()
elif not non_block:
result = get_response(w, dequeue_timeout, self.shutdown_event)
else:
raise RuntimeError(
"non_block can only be used when max_concurrent_batches > 1"
)
responses.append(result)
return responses[0] if unique_reply_rank is not None else responses
return responses
except TimeoutError as e:
raise TimeoutError(f"RPC call to {method} timed out.") from e
if non_block:
future = FutureWrapper(self.futures_queue)
self.futures_queue.appendleft((future, get_response))
return future
# First drain any pending futures in the queue.
while self.futures_queue:
future, get_fut_response = self.futures_queue.pop()
future.wait_for_response(get_fut_response)
return get_response()
@staticmethod
def _ensure_worker_termination(worker_procs: list[BaseProcess]):
@ -348,9 +357,6 @@ class MultiprocExecutor(Executor):
self._ensure_worker_termination([w.proc for w in workers])
self.shutdown_event.set()
if self.io_thread_pool is not None:
self.io_thread_pool.shutdown(wait=False, cancel_futures=True)
del self.io_thread_pool
self.rpc_broadcast_mq = None

View File

@ -435,26 +435,25 @@ class RayDistributedExecutor(Executor):
# When PP is used, we return a FutureWrapper immediately so that
# the scheduler can yield to the next batch.
return FutureWrapper(refs)
return FutureWrapper(refs[0])
# Get output from all workers when connector is present
assert self.kv_output_aggregator is not None
if not non_block:
# Block and get results from all workers
outputs = [ref.get() for ref in refs]
return self.kv_output_aggregator.aggregate(outputs)
return self.kv_output_aggregator.aggregate(ray.get(refs))
# Return a future that will aggregate outputs from all workers
return FutureWrapper(refs, self.kv_output_aggregator)
def collective_rpc(
def collective_rpc( # type: ignore[override]
self,
method: str | Callable,
timeout: float | None = None,
args: tuple = (),
kwargs: dict[str, Any] | None = None,
non_block: bool = False,
) -> list[Any]:
) -> list[Any] | Future[list[Any]]:
"""Runs the given method on all workers."""
sent_method = method if isinstance(method, str) else cloudpickle.dumps(method)
del method
@ -470,7 +469,7 @@ class RayDistributedExecutor(Executor):
# Get the results of the ray workers.
if non_block:
return [FutureWrapper((output,)) for output in ray_worker_outputs]
return FutureWrapper(ray_worker_outputs)
return ray.get(ray_worker_outputs, timeout=timeout)

View File

@ -141,19 +141,16 @@ class FutureWrapper(Future):
the result() call. If not only the first worker's output is returned.
"""
def __init__(self, refs, aggregator: KVOutputAggregator | None = None):
def __init__(self, ref_or_refs, aggregator: KVOutputAggregator | None = None):
super().__init__()
self.refs = refs
self.ref_or_refs = ref_or_refs
self.aggregator = aggregator
def result(self, timeout=None):
if timeout is not None:
raise NotImplementedError("timeout is not supported")
outputs = ray.get(self.ref_or_refs, timeout=timeout)
if self.aggregator is None:
return self.refs[0].get()
return outputs
outputs = [ref.get() for ref in self.refs]
return self.aggregator.aggregate(outputs, output_rank=0)

View File

@ -13,9 +13,10 @@ import torch.distributed as dist
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.utils.network_utils import get_distributed_init_method, get_ip, get_open_port
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.executor.abstract import Executor
from vllm.v1.outputs import AsyncModelRunnerOutput
from vllm.v1.outputs import AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput
from vllm.v1.serial_utils import run_method
from vllm.v1.worker.worker_base import WorkerWrapperBase
@ -58,32 +59,60 @@ class UniProcExecutor(Executor):
def max_concurrent_batches(self) -> int:
return 2 if self.scheduler_config.async_scheduling else 1
def collective_rpc(
def collective_rpc( # type: ignore[override]
self,
method: str | Callable,
timeout: float | None = None,
args: tuple = (),
kwargs: dict | None = None,
non_block: bool = False,
) -> list[Any]:
single_value: bool = False,
) -> Any | list[Any] | Future[Any | list[Any]]:
if kwargs is None:
kwargs = {}
if not non_block:
return [run_method(self.driver_worker, method, args, kwargs)]
result = run_method(self.driver_worker, method, args, kwargs)
return result if single_value else [result]
try:
result = run_method(self.driver_worker, method, args, kwargs)
if isinstance(result, AsyncModelRunnerOutput):
if (async_thread := self.async_output_thread) is not None:
return [async_thread.submit(result.get_output)]
get_output = result.get_output
if not single_value:
get_output = lambda go=result.get_output: [go()]
return async_thread.submit(get_output)
result = result.get_output()
future = Future[Any]()
future.set_result(result)
future.set_result(result if single_value else [result])
except Exception as e:
future = Future[Any]()
future.set_exception(e)
return [future]
return future
def execute_model( # type: ignore[override]
self, scheduler_output: SchedulerOutput, non_block: bool = False
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
return self.collective_rpc(
"execute_model",
args=(scheduler_output,),
non_block=non_block,
single_value=True,
)
def sample_tokens( # type: ignore[override]
self, grammar_output: GrammarOutput | None, non_block: bool = False
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
return self.collective_rpc(
"sample_tokens",
args=(grammar_output,),
non_block=non_block,
single_value=True,
)
def take_draft_token_ids(self) -> DraftTokenIds | None:
return self.collective_rpc("take_draft_token_ids", single_value=True)
def check_health(self) -> None:
# UniProcExecutor will always be healthy as long as

View File

@ -524,7 +524,7 @@ class Worker(WorkerBase):
@torch.inference_mode()
def sample_tokens(
self, grammar_output: "GrammarOutput"
self, grammar_output: "GrammarOutput | None"
) -> ModelRunnerOutput | AsyncModelRunnerOutput:
return self.model_runner.sample_tokens(grammar_output)