mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:15:42 +08:00
[PerfFix] Avoid separate thread for MP executor shm spin (take 2) (#28319)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
da786e339e
commit
67a2da890e
@ -4,6 +4,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
from concurrent.futures import Future
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -27,7 +28,7 @@ class CustomMultiprocExecutor(MultiprocExecutor):
|
|||||||
kwargs: dict | None = None,
|
kwargs: dict | None = None,
|
||||||
non_block: bool = False,
|
non_block: bool = False,
|
||||||
unique_reply_rank: int | None = None,
|
unique_reply_rank: int | None = None,
|
||||||
) -> list[Any]:
|
) -> Any | list[Any] | Future[Any | list[Any]]:
|
||||||
# Drop marker to show that this was run
|
# Drop marker to show that this was run
|
||||||
with open(".marker", "w"):
|
with open(".marker", "w"):
|
||||||
...
|
...
|
||||||
|
|||||||
@ -89,14 +89,12 @@ def test_aggregate_workers_output():
|
|||||||
def test_async_aggregate_workers_output():
|
def test_async_aggregate_workers_output():
|
||||||
aggregator = KVOutputAggregator(expected_finished_count=2)
|
aggregator = KVOutputAggregator(expected_finished_count=2)
|
||||||
|
|
||||||
future1: Future[DummyModelRunnerOutput] = Future()
|
future: Future[list[DummyModelRunnerOutput]] = Future()
|
||||||
future2: Future[DummyModelRunnerOutput] = Future()
|
result_future = aggregator.async_aggregate(future)
|
||||||
result_future = aggregator.async_aggregate([future1, future2])
|
|
||||||
|
|
||||||
output1 = DummyModelRunnerOutput()
|
output1 = DummyModelRunnerOutput()
|
||||||
output2 = DummyModelRunnerOutput()
|
output2 = DummyModelRunnerOutput()
|
||||||
future1.set_result(output1)
|
future.set_result([output1, output2])
|
||||||
future2.set_result(output2)
|
|
||||||
|
|
||||||
assert result_future.done()
|
assert result_future.done()
|
||||||
aggregated = result_future.result()
|
aggregated = result_future.result()
|
||||||
@ -106,16 +104,14 @@ def test_async_aggregate_workers_output():
|
|||||||
assert aggregated.finished_recving is None
|
assert aggregated.finished_recving is None
|
||||||
assert not aggregated.invalid_block_ids
|
assert not aggregated.invalid_block_ids
|
||||||
|
|
||||||
future1 = Future()
|
future = Future()
|
||||||
future2 = Future()
|
result_future = aggregator.async_aggregate(future)
|
||||||
result_future = aggregator.async_aggregate([future1, future2])
|
|
||||||
|
|
||||||
output1 = DummyModelRunnerOutput(
|
output1 = DummyModelRunnerOutput(
|
||||||
finished_sending={"req1"}, finished_recving={"req2"}
|
finished_sending={"req1"}, finished_recving={"req2"}
|
||||||
)
|
)
|
||||||
output2 = DummyModelRunnerOutput(invalid_block_ids={1})
|
output2 = DummyModelRunnerOutput(invalid_block_ids={1})
|
||||||
future1.set_result(output1)
|
future.set_result([output1, output2])
|
||||||
future2.set_result(output2)
|
|
||||||
|
|
||||||
assert result_future.done()
|
assert result_future.done()
|
||||||
aggregated = result_future.result()
|
aggregated = result_future.result()
|
||||||
@ -125,14 +121,12 @@ def test_async_aggregate_workers_output():
|
|||||||
assert aggregated.finished_recving is None
|
assert aggregated.finished_recving is None
|
||||||
assert aggregated.invalid_block_ids == {1}
|
assert aggregated.invalid_block_ids == {1}
|
||||||
|
|
||||||
future1 = Future()
|
future = Future()
|
||||||
future2 = Future()
|
result_future = aggregator.async_aggregate(future)
|
||||||
result_future = aggregator.async_aggregate([future1, future2])
|
|
||||||
|
|
||||||
output1 = DummyModelRunnerOutput(invalid_block_ids={2})
|
output1 = DummyModelRunnerOutput(invalid_block_ids={2})
|
||||||
output2 = DummyModelRunnerOutput(finished_sending={"req1"})
|
output2 = DummyModelRunnerOutput(finished_sending={"req1"})
|
||||||
future1.set_result(output1)
|
future.set_result([output1, output2])
|
||||||
future2.set_result(output2)
|
|
||||||
|
|
||||||
assert result_future.done()
|
assert result_future.done()
|
||||||
aggregated = result_future.result()
|
aggregated = result_future.result()
|
||||||
@ -142,16 +136,14 @@ def test_async_aggregate_workers_output():
|
|||||||
assert aggregated.finished_recving is None
|
assert aggregated.finished_recving is None
|
||||||
assert aggregated.invalid_block_ids == {2}
|
assert aggregated.invalid_block_ids == {2}
|
||||||
|
|
||||||
future1 = Future()
|
future = Future()
|
||||||
future2 = Future()
|
result_future = aggregator.async_aggregate(future)
|
||||||
result_future = aggregator.async_aggregate([future1, future2])
|
|
||||||
|
|
||||||
output1 = DummyModelRunnerOutput(invalid_block_ids={3, 4})
|
output1 = DummyModelRunnerOutput(invalid_block_ids={3, 4})
|
||||||
output2 = DummyModelRunnerOutput(
|
output2 = DummyModelRunnerOutput(
|
||||||
finished_recving={"req2"}, invalid_block_ids={4, 5}
|
finished_recving={"req2"}, invalid_block_ids={4, 5}
|
||||||
)
|
)
|
||||||
future1.set_result(output1)
|
future.set_result([output1, output2])
|
||||||
future2.set_result(output2)
|
|
||||||
|
|
||||||
assert result_future.done()
|
assert result_future.done()
|
||||||
aggregated = result_future.result()
|
aggregated = result_future.result()
|
||||||
|
|||||||
@ -4,6 +4,7 @@
|
|||||||
KV cache helper for store.
|
KV cache helper for store.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import contextlib
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from concurrent.futures import CancelledError, Future
|
from concurrent.futures import CancelledError, Future
|
||||||
from typing import TYPE_CHECKING, Literal
|
from typing import TYPE_CHECKING, Literal
|
||||||
@ -221,38 +222,38 @@ class KVOutputAggregator:
|
|||||||
|
|
||||||
def async_aggregate(
|
def async_aggregate(
|
||||||
self,
|
self,
|
||||||
output_futures: Sequence[Future[ModelRunnerOutput | None]],
|
output_future: Future[Sequence[ModelRunnerOutput | None]],
|
||||||
output_rank: int = 0,
|
output_rank: int = 0,
|
||||||
) -> Future[ModelRunnerOutput | None]:
|
) -> Future[ModelRunnerOutput | None]:
|
||||||
"""Takes a list of futures and returns a single future which resolves
|
"""Takes a future that resolves to a list of outputs and returns a future
|
||||||
to the respective list of outputs."""
|
which resolves to a single aggregated output."""
|
||||||
result_future: Future[ModelRunnerOutput | None] = Future()
|
result_future: Future[ModelRunnerOutput | None] = Future()
|
||||||
|
|
||||||
outputs: list[ModelRunnerOutput | None] = [None] * len(output_futures)
|
def callback(fut):
|
||||||
remaining = len(output_futures)
|
if result_future.done():
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
result_future.set_result(self.aggregate(fut.result(), output_rank))
|
||||||
|
except CancelledError:
|
||||||
|
result_future.cancel()
|
||||||
|
except Exception as e:
|
||||||
|
result_future.set_exception(e)
|
||||||
|
|
||||||
def make_callback(idx):
|
output_future.add_done_callback(callback)
|
||||||
def callback(fut):
|
|
||||||
if result_future.done():
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
from vllm.v1.executor.multiproc_executor import FutureWrapper
|
||||||
outputs[idx] = fut.result()
|
|
||||||
except CancelledError:
|
|
||||||
result_future.cancel()
|
|
||||||
except Exception as e:
|
|
||||||
result_future.set_exception(e)
|
|
||||||
|
|
||||||
# this check assumes io_thread_pool uses a single thread
|
if isinstance(output_future, FutureWrapper):
|
||||||
nonlocal remaining
|
# Due to the threadless implementation of multiproc FutureWrapper,
|
||||||
remaining -= 1
|
# we must block on the delegate future's result() method.
|
||||||
if not remaining:
|
delegate_result = result_future.result
|
||||||
result_future.set_result(self.aggregate(outputs, output_rank))
|
|
||||||
|
|
||||||
return callback
|
def result(timeout=None):
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
output_future.result(timeout=timeout)
|
||||||
|
return delegate_result()
|
||||||
|
|
||||||
for i, output_future in enumerate(output_futures):
|
result_future.result = result # type: ignore[method-assign]
|
||||||
output_future.add_done_callback(make_callback(i))
|
|
||||||
|
|
||||||
return result_future
|
return result_future
|
||||||
|
|
||||||
|
|||||||
@ -171,7 +171,7 @@ class Executor(ABC):
|
|||||||
args: tuple = (),
|
args: tuple = (),
|
||||||
kwargs: dict | None = None,
|
kwargs: dict | None = None,
|
||||||
non_block: Literal[True] = True,
|
non_block: Literal[True] = True,
|
||||||
) -> list[Future[_R]]:
|
) -> Future[list[_R]]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -219,7 +219,7 @@ class Executor(ABC):
|
|||||||
|
|
||||||
def sample_tokens(
|
def sample_tokens(
|
||||||
self, grammar_output: GrammarOutput | None, non_block: bool = False
|
self, grammar_output: GrammarOutput | None, non_block: bool = False
|
||||||
) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
|
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
|
||||||
output = self.collective_rpc( # type: ignore[call-overload]
|
output = self.collective_rpc( # type: ignore[call-overload]
|
||||||
"sample_tokens", args=(grammar_output,), non_block=non_block
|
"sample_tokens", args=(grammar_output,), non_block=non_block
|
||||||
)
|
)
|
||||||
|
|||||||
@ -9,8 +9,10 @@ import threading
|
|||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
import weakref
|
import weakref
|
||||||
|
from collections import deque
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from concurrent.futures import Future, ThreadPoolExecutor
|
from concurrent.futures import Future, InvalidStateError
|
||||||
|
from contextlib import suppress
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from functools import cached_property, partial
|
from functools import cached_property, partial
|
||||||
@ -54,6 +56,30 @@ from vllm.v1.worker.worker_base import WorkerWrapperBase
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class FutureWrapper(Future):
|
||||||
|
def __init__(self, futures_queue: deque[tuple["FutureWrapper", Callable]]):
|
||||||
|
self.futures_queue = futures_queue
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def result(self, timeout=None):
|
||||||
|
if timeout is not None:
|
||||||
|
raise RuntimeError("timeout not implemented")
|
||||||
|
# Drain any futures ahead of us in the queue.
|
||||||
|
while not self.done():
|
||||||
|
future, get_response = self.futures_queue.pop()
|
||||||
|
future.wait_for_response(get_response)
|
||||||
|
return super().result()
|
||||||
|
|
||||||
|
def wait_for_response(self, get_response: Callable):
|
||||||
|
try:
|
||||||
|
response = get_response()
|
||||||
|
with suppress(InvalidStateError):
|
||||||
|
self.set_result(response)
|
||||||
|
except Exception as e:
|
||||||
|
with suppress(InvalidStateError):
|
||||||
|
self.set_exception(e)
|
||||||
|
|
||||||
|
|
||||||
class MultiprocExecutor(Executor):
|
class MultiprocExecutor(Executor):
|
||||||
supports_pp: bool = True
|
supports_pp: bool = True
|
||||||
|
|
||||||
@ -64,7 +90,6 @@ class MultiprocExecutor(Executor):
|
|||||||
self.is_failed = False
|
self.is_failed = False
|
||||||
self.shutdown_event = threading.Event()
|
self.shutdown_event = threading.Event()
|
||||||
self.failure_callback: FailureCallback | None = None
|
self.failure_callback: FailureCallback | None = None
|
||||||
self.io_thread_pool: ThreadPoolExecutor | None = None
|
|
||||||
|
|
||||||
self.world_size = self.parallel_config.world_size
|
self.world_size = self.parallel_config.world_size
|
||||||
tensor_parallel_size = self.parallel_config.tensor_parallel_size
|
tensor_parallel_size = self.parallel_config.tensor_parallel_size
|
||||||
@ -132,12 +157,7 @@ class MultiprocExecutor(Executor):
|
|||||||
uw.death_writer.close()
|
uw.death_writer.close()
|
||||||
self._ensure_worker_termination([uw.proc for uw in unready_workers])
|
self._ensure_worker_termination([uw.proc for uw in unready_workers])
|
||||||
|
|
||||||
# Note: must use only 1 IO thread to keep dequeue sequence
|
self.futures_queue = deque[tuple[FutureWrapper, Callable]]()
|
||||||
# from the response queue.
|
|
||||||
# _async_aggregate_workers_output also assumes a single IO thread.
|
|
||||||
self.io_thread_pool = ThreadPoolExecutor(
|
|
||||||
max_workers=1, thread_name_prefix="mp_exec_io"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.output_rank = self._get_output_rank()
|
self.output_rank = self._get_output_rank()
|
||||||
self.has_connector = self.vllm_config.kv_transfer_config is not None
|
self.has_connector = self.vllm_config.kv_transfer_config is not None
|
||||||
@ -195,14 +215,13 @@ class MultiprocExecutor(Executor):
|
|||||||
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
|
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
|
||||||
if not self.has_connector:
|
if not self.has_connector:
|
||||||
# get output only from a single worker (output_rank)
|
# get output only from a single worker (output_rank)
|
||||||
(output,) = self.collective_rpc(
|
return self.collective_rpc(
|
||||||
method,
|
method,
|
||||||
args=args,
|
args=args,
|
||||||
unique_reply_rank=self.output_rank,
|
unique_reply_rank=self.output_rank,
|
||||||
non_block=non_block,
|
non_block=non_block,
|
||||||
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS,
|
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS,
|
||||||
)
|
)
|
||||||
return output
|
|
||||||
|
|
||||||
# get output from all workers
|
# get output from all workers
|
||||||
outputs = self.collective_rpc(
|
outputs = self.collective_rpc(
|
||||||
@ -223,12 +242,11 @@ class MultiprocExecutor(Executor):
|
|||||||
|
|
||||||
def take_draft_token_ids(self) -> DraftTokenIds | None:
|
def take_draft_token_ids(self) -> DraftTokenIds | None:
|
||||||
# OPTIMIZATION: Get output only from a single worker (output_rank)
|
# OPTIMIZATION: Get output only from a single worker (output_rank)
|
||||||
outputs = self.collective_rpc(
|
return self.collective_rpc(
|
||||||
"take_draft_token_ids", unique_reply_rank=self.output_rank
|
"take_draft_token_ids", unique_reply_rank=self.output_rank
|
||||||
)
|
)
|
||||||
return outputs[0]
|
|
||||||
|
|
||||||
def collective_rpc(
|
def collective_rpc( # type: ignore[override]
|
||||||
self,
|
self,
|
||||||
method: str | Callable,
|
method: str | Callable,
|
||||||
timeout: float | None = None,
|
timeout: float | None = None,
|
||||||
@ -236,7 +254,9 @@ class MultiprocExecutor(Executor):
|
|||||||
kwargs: dict | None = None,
|
kwargs: dict | None = None,
|
||||||
non_block: bool = False,
|
non_block: bool = False,
|
||||||
unique_reply_rank: int | None = None,
|
unique_reply_rank: int | None = None,
|
||||||
) -> list[Any]:
|
) -> Any | list[Any] | Future[Any | list[Any]]:
|
||||||
|
"""Returns single result if unique_reply_rank is provided, otherwise list."""
|
||||||
|
|
||||||
if self.is_failed:
|
if self.is_failed:
|
||||||
raise RuntimeError("Executor failed.")
|
raise RuntimeError("Executor failed.")
|
||||||
|
|
||||||
@ -246,63 +266,52 @@ class MultiprocExecutor(Executor):
|
|||||||
# NOTE: If the args are heterogeneous, then we pack them into a list,
|
# NOTE: If the args are heterogeneous, then we pack them into a list,
|
||||||
# and unpack them in the method of every worker, because every worker
|
# and unpack them in the method of every worker, because every worker
|
||||||
# knows their own rank.
|
# knows their own rank.
|
||||||
try:
|
|
||||||
if isinstance(method, str):
|
|
||||||
send_method = method
|
|
||||||
else:
|
|
||||||
send_method = cloudpickle.dumps(
|
|
||||||
method, protocol=pickle.HIGHEST_PROTOCOL
|
|
||||||
)
|
|
||||||
self.rpc_broadcast_mq.enqueue(
|
|
||||||
(send_method, args, kwargs, unique_reply_rank)
|
|
||||||
)
|
|
||||||
|
|
||||||
workers = (
|
if isinstance(method, str):
|
||||||
(self.workers[unique_reply_rank],)
|
send_method = method
|
||||||
if unique_reply_rank is not None
|
else:
|
||||||
else self.workers
|
send_method = cloudpickle.dumps(method, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
)
|
self.rpc_broadcast_mq.enqueue((send_method, args, kwargs, unique_reply_rank))
|
||||||
|
|
||||||
|
workers = (
|
||||||
|
(self.workers[unique_reply_rank],)
|
||||||
|
if unique_reply_rank is not None
|
||||||
|
else self.workers
|
||||||
|
)
|
||||||
|
|
||||||
|
shutdown_event = self.shutdown_event
|
||||||
|
|
||||||
|
def get_response():
|
||||||
responses = []
|
responses = []
|
||||||
|
for w in workers:
|
||||||
def get_response(
|
dequeue_timeout = (
|
||||||
w: WorkerProcHandle,
|
None if deadline is None else (deadline - time.monotonic())
|
||||||
dequeue_timeout: float | None = None,
|
|
||||||
cancel_event: threading.Event | None = None,
|
|
||||||
):
|
|
||||||
status, result = w.worker_response_mq.dequeue(
|
|
||||||
timeout=dequeue_timeout, cancel=cancel_event
|
|
||||||
)
|
)
|
||||||
|
try:
|
||||||
|
status, result = w.worker_response_mq.dequeue(
|
||||||
|
timeout=dequeue_timeout, cancel=shutdown_event
|
||||||
|
)
|
||||||
|
except TimeoutError as e:
|
||||||
|
raise TimeoutError(f"RPC call to {method} timed out.") from e
|
||||||
if status != WorkerProc.ResponseStatus.SUCCESS:
|
if status != WorkerProc.ResponseStatus.SUCCESS:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Worker failed with error '{result}', please check the"
|
f"Worker failed with error '{result}', please check the"
|
||||||
" stack trace above for the root cause"
|
" stack trace above for the root cause"
|
||||||
)
|
)
|
||||||
return result
|
|
||||||
|
|
||||||
for w in workers:
|
|
||||||
dequeue_timeout = (
|
|
||||||
None if deadline is None else (deadline - time.monotonic())
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.io_thread_pool is not None:
|
|
||||||
# We must consume worker_response_mq from a single thread.
|
|
||||||
result = self.io_thread_pool.submit( # type: ignore
|
|
||||||
get_response, w, dequeue_timeout, self.shutdown_event
|
|
||||||
)
|
|
||||||
if not non_block:
|
|
||||||
result = result.result()
|
|
||||||
elif not non_block:
|
|
||||||
result = get_response(w, dequeue_timeout, self.shutdown_event)
|
|
||||||
else:
|
|
||||||
raise RuntimeError(
|
|
||||||
"non_block can only be used when max_concurrent_batches > 1"
|
|
||||||
)
|
|
||||||
responses.append(result)
|
responses.append(result)
|
||||||
|
return responses[0] if unique_reply_rank is not None else responses
|
||||||
|
|
||||||
return responses
|
if non_block:
|
||||||
except TimeoutError as e:
|
future = FutureWrapper(self.futures_queue)
|
||||||
raise TimeoutError(f"RPC call to {method} timed out.") from e
|
self.futures_queue.appendleft((future, get_response))
|
||||||
|
return future
|
||||||
|
|
||||||
|
# First drain any pending futures in the queue.
|
||||||
|
while self.futures_queue:
|
||||||
|
future, get_fut_response = self.futures_queue.pop()
|
||||||
|
future.wait_for_response(get_fut_response)
|
||||||
|
|
||||||
|
return get_response()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _ensure_worker_termination(worker_procs: list[BaseProcess]):
|
def _ensure_worker_termination(worker_procs: list[BaseProcess]):
|
||||||
@ -348,9 +357,6 @@ class MultiprocExecutor(Executor):
|
|||||||
self._ensure_worker_termination([w.proc for w in workers])
|
self._ensure_worker_termination([w.proc for w in workers])
|
||||||
|
|
||||||
self.shutdown_event.set()
|
self.shutdown_event.set()
|
||||||
if self.io_thread_pool is not None:
|
|
||||||
self.io_thread_pool.shutdown(wait=False, cancel_futures=True)
|
|
||||||
del self.io_thread_pool
|
|
||||||
|
|
||||||
self.rpc_broadcast_mq = None
|
self.rpc_broadcast_mq = None
|
||||||
|
|
||||||
|
|||||||
@ -435,26 +435,25 @@ class RayDistributedExecutor(Executor):
|
|||||||
|
|
||||||
# When PP is used, we return a FutureWrapper immediately so that
|
# When PP is used, we return a FutureWrapper immediately so that
|
||||||
# the scheduler can yield to the next batch.
|
# the scheduler can yield to the next batch.
|
||||||
return FutureWrapper(refs)
|
return FutureWrapper(refs[0])
|
||||||
|
|
||||||
# Get output from all workers when connector is present
|
# Get output from all workers when connector is present
|
||||||
assert self.kv_output_aggregator is not None
|
assert self.kv_output_aggregator is not None
|
||||||
if not non_block:
|
if not non_block:
|
||||||
# Block and get results from all workers
|
# Block and get results from all workers
|
||||||
outputs = [ref.get() for ref in refs]
|
return self.kv_output_aggregator.aggregate(ray.get(refs))
|
||||||
return self.kv_output_aggregator.aggregate(outputs)
|
|
||||||
|
|
||||||
# Return a future that will aggregate outputs from all workers
|
# Return a future that will aggregate outputs from all workers
|
||||||
return FutureWrapper(refs, self.kv_output_aggregator)
|
return FutureWrapper(refs, self.kv_output_aggregator)
|
||||||
|
|
||||||
def collective_rpc(
|
def collective_rpc( # type: ignore[override]
|
||||||
self,
|
self,
|
||||||
method: str | Callable,
|
method: str | Callable,
|
||||||
timeout: float | None = None,
|
timeout: float | None = None,
|
||||||
args: tuple = (),
|
args: tuple = (),
|
||||||
kwargs: dict[str, Any] | None = None,
|
kwargs: dict[str, Any] | None = None,
|
||||||
non_block: bool = False,
|
non_block: bool = False,
|
||||||
) -> list[Any]:
|
) -> list[Any] | Future[list[Any]]:
|
||||||
"""Runs the given method on all workers."""
|
"""Runs the given method on all workers."""
|
||||||
sent_method = method if isinstance(method, str) else cloudpickle.dumps(method)
|
sent_method = method if isinstance(method, str) else cloudpickle.dumps(method)
|
||||||
del method
|
del method
|
||||||
@ -470,7 +469,7 @@ class RayDistributedExecutor(Executor):
|
|||||||
|
|
||||||
# Get the results of the ray workers.
|
# Get the results of the ray workers.
|
||||||
if non_block:
|
if non_block:
|
||||||
return [FutureWrapper((output,)) for output in ray_worker_outputs]
|
return FutureWrapper(ray_worker_outputs)
|
||||||
|
|
||||||
return ray.get(ray_worker_outputs, timeout=timeout)
|
return ray.get(ray_worker_outputs, timeout=timeout)
|
||||||
|
|
||||||
|
|||||||
@ -141,19 +141,16 @@ class FutureWrapper(Future):
|
|||||||
the result() call. If not only the first worker's output is returned.
|
the result() call. If not only the first worker's output is returned.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, refs, aggregator: KVOutputAggregator | None = None):
|
def __init__(self, ref_or_refs, aggregator: KVOutputAggregator | None = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.refs = refs
|
self.ref_or_refs = ref_or_refs
|
||||||
self.aggregator = aggregator
|
self.aggregator = aggregator
|
||||||
|
|
||||||
def result(self, timeout=None):
|
def result(self, timeout=None):
|
||||||
if timeout is not None:
|
outputs = ray.get(self.ref_or_refs, timeout=timeout)
|
||||||
raise NotImplementedError("timeout is not supported")
|
|
||||||
|
|
||||||
if self.aggregator is None:
|
if self.aggregator is None:
|
||||||
return self.refs[0].get()
|
return outputs
|
||||||
|
|
||||||
outputs = [ref.get() for ref in self.refs]
|
|
||||||
return self.aggregator.aggregate(outputs, output_rank=0)
|
return self.aggregator.aggregate(outputs, output_rank=0)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -13,9 +13,10 @@ import torch.distributed as dist
|
|||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils.network_utils import get_distributed_init_method, get_ip, get_open_port
|
from vllm.utils.network_utils import get_distributed_init_method, get_ip, get_open_port
|
||||||
|
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||||
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
|
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
|
||||||
from vllm.v1.executor.abstract import Executor
|
from vllm.v1.executor.abstract import Executor
|
||||||
from vllm.v1.outputs import AsyncModelRunnerOutput
|
from vllm.v1.outputs import AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput
|
||||||
from vllm.v1.serial_utils import run_method
|
from vllm.v1.serial_utils import run_method
|
||||||
from vllm.v1.worker.worker_base import WorkerWrapperBase
|
from vllm.v1.worker.worker_base import WorkerWrapperBase
|
||||||
|
|
||||||
@ -58,32 +59,60 @@ class UniProcExecutor(Executor):
|
|||||||
def max_concurrent_batches(self) -> int:
|
def max_concurrent_batches(self) -> int:
|
||||||
return 2 if self.scheduler_config.async_scheduling else 1
|
return 2 if self.scheduler_config.async_scheduling else 1
|
||||||
|
|
||||||
def collective_rpc(
|
def collective_rpc( # type: ignore[override]
|
||||||
self,
|
self,
|
||||||
method: str | Callable,
|
method: str | Callable,
|
||||||
timeout: float | None = None,
|
timeout: float | None = None,
|
||||||
args: tuple = (),
|
args: tuple = (),
|
||||||
kwargs: dict | None = None,
|
kwargs: dict | None = None,
|
||||||
non_block: bool = False,
|
non_block: bool = False,
|
||||||
) -> list[Any]:
|
single_value: bool = False,
|
||||||
|
) -> Any | list[Any] | Future[Any | list[Any]]:
|
||||||
if kwargs is None:
|
if kwargs is None:
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
|
|
||||||
if not non_block:
|
if not non_block:
|
||||||
return [run_method(self.driver_worker, method, args, kwargs)]
|
result = run_method(self.driver_worker, method, args, kwargs)
|
||||||
|
return result if single_value else [result]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = run_method(self.driver_worker, method, args, kwargs)
|
result = run_method(self.driver_worker, method, args, kwargs)
|
||||||
if isinstance(result, AsyncModelRunnerOutput):
|
if isinstance(result, AsyncModelRunnerOutput):
|
||||||
if (async_thread := self.async_output_thread) is not None:
|
if (async_thread := self.async_output_thread) is not None:
|
||||||
return [async_thread.submit(result.get_output)]
|
get_output = result.get_output
|
||||||
|
if not single_value:
|
||||||
|
get_output = lambda go=result.get_output: [go()]
|
||||||
|
return async_thread.submit(get_output)
|
||||||
result = result.get_output()
|
result = result.get_output()
|
||||||
future = Future[Any]()
|
future = Future[Any]()
|
||||||
future.set_result(result)
|
future.set_result(result if single_value else [result])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
future = Future[Any]()
|
future = Future[Any]()
|
||||||
future.set_exception(e)
|
future.set_exception(e)
|
||||||
return [future]
|
return future
|
||||||
|
|
||||||
|
def execute_model( # type: ignore[override]
|
||||||
|
self, scheduler_output: SchedulerOutput, non_block: bool = False
|
||||||
|
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
|
||||||
|
return self.collective_rpc(
|
||||||
|
"execute_model",
|
||||||
|
args=(scheduler_output,),
|
||||||
|
non_block=non_block,
|
||||||
|
single_value=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def sample_tokens( # type: ignore[override]
|
||||||
|
self, grammar_output: GrammarOutput | None, non_block: bool = False
|
||||||
|
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
|
||||||
|
return self.collective_rpc(
|
||||||
|
"sample_tokens",
|
||||||
|
args=(grammar_output,),
|
||||||
|
non_block=non_block,
|
||||||
|
single_value=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def take_draft_token_ids(self) -> DraftTokenIds | None:
|
||||||
|
return self.collective_rpc("take_draft_token_ids", single_value=True)
|
||||||
|
|
||||||
def check_health(self) -> None:
|
def check_health(self) -> None:
|
||||||
# UniProcExecutor will always be healthy as long as
|
# UniProcExecutor will always be healthy as long as
|
||||||
|
|||||||
@ -524,7 +524,7 @@ class Worker(WorkerBase):
|
|||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def sample_tokens(
|
def sample_tokens(
|
||||||
self, grammar_output: "GrammarOutput"
|
self, grammar_output: "GrammarOutput | None"
|
||||||
) -> ModelRunnerOutput | AsyncModelRunnerOutput:
|
) -> ModelRunnerOutput | AsyncModelRunnerOutput:
|
||||||
return self.model_runner.sample_tokens(grammar_output)
|
return self.model_runner.sample_tokens(grammar_output)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user