[Core] Simplify async KV output aggregation (#28327)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-11-09 09:44:13 -08:00 committed by GitHub
parent 19d91ece4b
commit 289eb6c537
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 45 additions and 153 deletions

View File

@ -9,6 +9,7 @@ from typing import Any
import pytest import pytest
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.async_llm import AsyncLLM
@ -28,12 +29,19 @@ class CustomMultiprocExecutor(MultiprocExecutor):
kwargs: dict | None = None, kwargs: dict | None = None,
non_block: bool = False, non_block: bool = False,
unique_reply_rank: int | None = None, unique_reply_rank: int | None = None,
kv_output_aggregator: KVOutputAggregator = None,
) -> Any | list[Any] | Future[Any | list[Any]]: ) -> Any | list[Any] | Future[Any | list[Any]]:
# Drop marker to show that this was run # Drop marker to show that this was run
with open(".marker", "w"): with open(".marker", "w"):
... ...
return super().collective_rpc( return super().collective_rpc(
method, timeout, args, kwargs, non_block, unique_reply_rank method,
timeout,
args,
kwargs,
non_block,
unique_reply_rank,
kv_output_aggregator,
) )

View File

@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from concurrent.futures import Future
import pytest import pytest
@ -86,74 +85,6 @@ def test_aggregate_workers_output():
assert aggregated.invalid_block_ids == {3, 4, 5} assert aggregated.invalid_block_ids == {3, 4, 5}
def test_async_aggregate_workers_output():
aggregator = KVOutputAggregator(expected_finished_count=2)
future: Future[list[DummyModelRunnerOutput]] = Future()
result_future = aggregator.async_aggregate(future)
output1 = DummyModelRunnerOutput()
output2 = DummyModelRunnerOutput()
future.set_result([output1, output2])
assert result_future.done()
aggregated = result_future.result()
assert aggregated is output1
aggregated = aggregated.kv_connector_output
assert aggregated.finished_sending is None
assert aggregated.finished_recving is None
assert not aggregated.invalid_block_ids
future = Future()
result_future = aggregator.async_aggregate(future)
output1 = DummyModelRunnerOutput(
finished_sending={"req1"}, finished_recving={"req2"}
)
output2 = DummyModelRunnerOutput(invalid_block_ids={1})
future.set_result([output1, output2])
assert result_future.done()
aggregated = result_future.result()
assert aggregated is output1
aggregated = aggregated.kv_connector_output
assert aggregated.finished_sending is None
assert aggregated.finished_recving is None
assert aggregated.invalid_block_ids == {1}
future = Future()
result_future = aggregator.async_aggregate(future)
output1 = DummyModelRunnerOutput(invalid_block_ids={2})
output2 = DummyModelRunnerOutput(finished_sending={"req1"})
future.set_result([output1, output2])
assert result_future.done()
aggregated = result_future.result()
assert aggregated is output1
aggregated = aggregated.kv_connector_output
assert aggregated.finished_sending == {"req1"}
assert aggregated.finished_recving is None
assert aggregated.invalid_block_ids == {2}
future = Future()
result_future = aggregator.async_aggregate(future)
output1 = DummyModelRunnerOutput(invalid_block_ids={3, 4})
output2 = DummyModelRunnerOutput(
finished_recving={"req2"}, invalid_block_ids={4, 5}
)
future.set_result([output1, output2])
assert result_future.done()
aggregated = result_future.result()
assert aggregated is output1
aggregated = aggregated.kv_connector_output
assert aggregated.finished_sending is None
assert aggregated.finished_recving == {"req2"}
assert aggregated.invalid_block_ids == {3, 4, 5}
def test_aggregate_workers_output_with_expected_finished_count(): def test_aggregate_workers_output_with_expected_finished_count():
# We create the aggregator expecting to collect from 4 workers # We create the aggregator expecting to collect from 4 workers
aggregator = KVOutputAggregator(expected_finished_count=4) aggregator = KVOutputAggregator(expected_finished_count=4)

View File

@ -4,9 +4,6 @@
KV cache helper for store. KV cache helper for store.
""" """
import contextlib
from collections.abc import Sequence
from concurrent.futures import CancelledError, Future
from typing import TYPE_CHECKING, Literal from typing import TYPE_CHECKING, Literal
import torch import torch
@ -220,43 +217,6 @@ class KVOutputAggregator:
return output return output
def async_aggregate(
self,
output_future: Future[Sequence[ModelRunnerOutput | None]],
output_rank: int = 0,
) -> Future[ModelRunnerOutput | None]:
"""Takes a future that resolves to a list of outputs and returns a future
which resolves to a single aggregated output."""
result_future: Future[ModelRunnerOutput | None] = Future()
def callback(fut):
if result_future.done():
return
try:
result_future.set_result(self.aggregate(fut.result(), output_rank))
except CancelledError:
result_future.cancel()
except Exception as e:
result_future.set_exception(e)
output_future.add_done_callback(callback)
from vllm.v1.executor.multiproc_executor import FutureWrapper
if isinstance(output_future, FutureWrapper):
# Due to the threadless implementation of multiproc FutureWrapper,
# we must block on the delegate future's result() method.
delegate_result = result_future.result
def result(timeout=None):
with contextlib.suppress(Exception):
output_future.result(timeout=timeout)
return delegate_result()
result_future.result = result # type: ignore[method-assign]
return result_future
def _make_src_and_dst_indices( def _make_src_and_dst_indices(
src_block_ids: list[int], src_block_ids: list[int],

View File

@ -29,6 +29,7 @@ import vllm.envs as envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import destroy_distributed_environment, destroy_model_parallel from vllm.distributed import destroy_distributed_environment, destroy_model_parallel
from vllm.distributed.device_communicators.shm_broadcast import Handle, MessageQueue from vllm.distributed.device_communicators.shm_broadcast import Handle, MessageQueue
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
get_dp_group, get_dp_group,
get_ep_group, get_ep_group,
@ -57,8 +58,13 @@ logger = init_logger(__name__)
class FutureWrapper(Future): class FutureWrapper(Future):
def __init__(self, futures_queue: deque[tuple["FutureWrapper", Callable]]): def __init__(
self,
futures_queue: deque[tuple["FutureWrapper", Callable]],
aggregate: Callable = lambda x: x,
):
self.futures_queue = futures_queue self.futures_queue = futures_queue
self.aggregate = aggregate
super().__init__() super().__init__()
def result(self, timeout=None): def result(self, timeout=None):
@ -72,7 +78,7 @@ class FutureWrapper(Future):
def wait_for_response(self, get_response: Callable): def wait_for_response(self, get_response: Callable):
try: try:
response = get_response() response = self.aggregate(get_response())
with suppress(InvalidStateError): with suppress(InvalidStateError):
self.set_result(response) self.set_result(response)
except Exception as e: except Exception as e:
@ -160,7 +166,6 @@ class MultiprocExecutor(Executor):
self.futures_queue = deque[tuple[FutureWrapper, Callable]]() self.futures_queue = deque[tuple[FutureWrapper, Callable]]()
self.output_rank = self._get_output_rank() self.output_rank = self._get_output_rank()
self.has_connector = self.vllm_config.kv_transfer_config is not None
def start_worker_monitor(self): def start_worker_monitor(self):
workers = self.workers workers = self.workers
@ -199,44 +204,27 @@ class MultiprocExecutor(Executor):
def execute_model( # type: ignore[override] def execute_model( # type: ignore[override]
self, scheduler_output: SchedulerOutput, non_block: bool = False self, scheduler_output: SchedulerOutput, non_block: bool = False
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]: ) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
return self._execute_with_aggregation( return self.collective_rpc(
"execute_model", scheduler_output, non_block=non_block "execute_model",
args=(scheduler_output,),
unique_reply_rank=self.output_rank,
non_block=non_block,
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS,
kv_output_aggregator=self.kv_output_aggregator,
) )
def sample_tokens( # type: ignore[override] def sample_tokens( # type: ignore[override]
self, grammar_output: GrammarOutput | None, non_block: bool = False self, grammar_output: GrammarOutput | None, non_block: bool = False
) -> ModelRunnerOutput | Future[ModelRunnerOutput]: ) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
return self._execute_with_aggregation( # type: ignore[return-value] return self.collective_rpc(
"sample_tokens", grammar_output, non_block=non_block "sample_tokens",
) args=(grammar_output,),
unique_reply_rank=self.output_rank,
def _execute_with_aggregation(
self, method: str, *args, non_block: bool = False
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
if not self.has_connector:
# get output only from a single worker (output_rank)
return self.collective_rpc(
method,
args=args,
unique_reply_rank=self.output_rank,
non_block=non_block,
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS,
)
# get output from all workers
outputs = self.collective_rpc(
method,
args=args,
non_block=non_block, non_block=non_block,
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS, timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS,
kv_output_aggregator=self.kv_output_aggregator,
) )
# aggregate all workers output to a single output
assert self.kv_output_aggregator is not None
if non_block:
return self.kv_output_aggregator.async_aggregate(outputs, self.output_rank)
return self.kv_output_aggregator.aggregate(outputs, self.output_rank)
def execute_dummy_batch(self) -> None: def execute_dummy_batch(self) -> None:
self.collective_rpc("execute_dummy_batch", unique_reply_rank=self.output_rank) self.collective_rpc("execute_dummy_batch", unique_reply_rank=self.output_rank)
@ -254,8 +242,10 @@ class MultiprocExecutor(Executor):
kwargs: dict | None = None, kwargs: dict | None = None,
non_block: bool = False, non_block: bool = False,
unique_reply_rank: int | None = None, unique_reply_rank: int | None = None,
kv_output_aggregator: KVOutputAggregator = None,
) -> Any | list[Any] | Future[Any | list[Any]]: ) -> Any | list[Any] | Future[Any | list[Any]]:
"""Returns single result if unique_reply_rank is provided, otherwise list.""" """Returns single result if unique_reply_rank and/or kv_output_aggregator
is provided, otherwise list."""
if self.is_failed: if self.is_failed:
raise RuntimeError("Executor failed.") raise RuntimeError("Executor failed.")
@ -263,20 +253,23 @@ class MultiprocExecutor(Executor):
deadline = None if timeout is None else time.monotonic() + timeout deadline = None if timeout is None else time.monotonic() + timeout
kwargs = kwargs or {} kwargs = kwargs or {}
# NOTE: If the args are heterogeneous, then we pack them into a list, if kv_output_aggregator is not None:
# and unpack them in the method of every worker, because every worker output_rank = None
# knows their own rank. aggregate: Callable[[Any], Any] = partial(
kv_output_aggregator.aggregate, output_rank=unique_reply_rank or 0
)
else:
output_rank = unique_reply_rank
aggregate = lambda x: x
if isinstance(method, str): if isinstance(method, str):
send_method = method send_method = method
else: else:
send_method = cloudpickle.dumps(method, protocol=pickle.HIGHEST_PROTOCOL) send_method = cloudpickle.dumps(method, protocol=pickle.HIGHEST_PROTOCOL)
self.rpc_broadcast_mq.enqueue((send_method, args, kwargs, unique_reply_rank)) self.rpc_broadcast_mq.enqueue((send_method, args, kwargs, output_rank))
workers = ( workers = (
(self.workers[unique_reply_rank],) (self.workers[output_rank],) if output_rank is not None else self.workers
if unique_reply_rank is not None
else self.workers
) )
shutdown_event = self.shutdown_event shutdown_event = self.shutdown_event
@ -299,10 +292,10 @@ class MultiprocExecutor(Executor):
" stack trace above for the root cause" " stack trace above for the root cause"
) )
responses.append(result) responses.append(result)
return responses[0] if unique_reply_rank is not None else responses return responses[0] if output_rank is not None else responses
if non_block: if non_block:
future = FutureWrapper(self.futures_queue) future = FutureWrapper(self.futures_queue, aggregate=aggregate)
self.futures_queue.appendleft((future, get_response)) self.futures_queue.appendleft((future, get_response))
return future return future
@ -311,7 +304,7 @@ class MultiprocExecutor(Executor):
future, get_fut_response = self.futures_queue.pop() future, get_fut_response = self.futures_queue.pop()
future.wait_for_response(get_fut_response) future.wait_for_response(get_fut_response)
return get_response() return aggregate(get_response())
@staticmethod @staticmethod
def _ensure_worker_termination(worker_procs: list[BaseProcess]): def _ensure_worker_termination(worker_procs: list[BaseProcess]):