mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-11 14:08:47 +08:00
[Core] Simplify async KV output aggregation (#28327)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
19d91ece4b
commit
289eb6c537
@ -9,6 +9,7 @@ from typing import Any
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.v1.engine.async_llm import AsyncLLM
|
from vllm.v1.engine.async_llm import AsyncLLM
|
||||||
@ -28,12 +29,19 @@ class CustomMultiprocExecutor(MultiprocExecutor):
|
|||||||
kwargs: dict | None = None,
|
kwargs: dict | None = None,
|
||||||
non_block: bool = False,
|
non_block: bool = False,
|
||||||
unique_reply_rank: int | None = None,
|
unique_reply_rank: int | None = None,
|
||||||
|
kv_output_aggregator: KVOutputAggregator = None,
|
||||||
) -> Any | list[Any] | Future[Any | list[Any]]:
|
) -> Any | list[Any] | Future[Any | list[Any]]:
|
||||||
# Drop marker to show that this was run
|
# Drop marker to show that this was run
|
||||||
with open(".marker", "w"):
|
with open(".marker", "w"):
|
||||||
...
|
...
|
||||||
return super().collective_rpc(
|
return super().collective_rpc(
|
||||||
method, timeout, args, kwargs, non_block, unique_reply_rank
|
method,
|
||||||
|
timeout,
|
||||||
|
args,
|
||||||
|
kwargs,
|
||||||
|
non_block,
|
||||||
|
unique_reply_rank,
|
||||||
|
kv_output_aggregator,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from concurrent.futures import Future
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -86,74 +85,6 @@ def test_aggregate_workers_output():
|
|||||||
assert aggregated.invalid_block_ids == {3, 4, 5}
|
assert aggregated.invalid_block_ids == {3, 4, 5}
|
||||||
|
|
||||||
|
|
||||||
def test_async_aggregate_workers_output():
|
|
||||||
aggregator = KVOutputAggregator(expected_finished_count=2)
|
|
||||||
|
|
||||||
future: Future[list[DummyModelRunnerOutput]] = Future()
|
|
||||||
result_future = aggregator.async_aggregate(future)
|
|
||||||
|
|
||||||
output1 = DummyModelRunnerOutput()
|
|
||||||
output2 = DummyModelRunnerOutput()
|
|
||||||
future.set_result([output1, output2])
|
|
||||||
|
|
||||||
assert result_future.done()
|
|
||||||
aggregated = result_future.result()
|
|
||||||
assert aggregated is output1
|
|
||||||
aggregated = aggregated.kv_connector_output
|
|
||||||
assert aggregated.finished_sending is None
|
|
||||||
assert aggregated.finished_recving is None
|
|
||||||
assert not aggregated.invalid_block_ids
|
|
||||||
|
|
||||||
future = Future()
|
|
||||||
result_future = aggregator.async_aggregate(future)
|
|
||||||
|
|
||||||
output1 = DummyModelRunnerOutput(
|
|
||||||
finished_sending={"req1"}, finished_recving={"req2"}
|
|
||||||
)
|
|
||||||
output2 = DummyModelRunnerOutput(invalid_block_ids={1})
|
|
||||||
future.set_result([output1, output2])
|
|
||||||
|
|
||||||
assert result_future.done()
|
|
||||||
aggregated = result_future.result()
|
|
||||||
assert aggregated is output1
|
|
||||||
aggregated = aggregated.kv_connector_output
|
|
||||||
assert aggregated.finished_sending is None
|
|
||||||
assert aggregated.finished_recving is None
|
|
||||||
assert aggregated.invalid_block_ids == {1}
|
|
||||||
|
|
||||||
future = Future()
|
|
||||||
result_future = aggregator.async_aggregate(future)
|
|
||||||
|
|
||||||
output1 = DummyModelRunnerOutput(invalid_block_ids={2})
|
|
||||||
output2 = DummyModelRunnerOutput(finished_sending={"req1"})
|
|
||||||
future.set_result([output1, output2])
|
|
||||||
|
|
||||||
assert result_future.done()
|
|
||||||
aggregated = result_future.result()
|
|
||||||
assert aggregated is output1
|
|
||||||
aggregated = aggregated.kv_connector_output
|
|
||||||
assert aggregated.finished_sending == {"req1"}
|
|
||||||
assert aggregated.finished_recving is None
|
|
||||||
assert aggregated.invalid_block_ids == {2}
|
|
||||||
|
|
||||||
future = Future()
|
|
||||||
result_future = aggregator.async_aggregate(future)
|
|
||||||
|
|
||||||
output1 = DummyModelRunnerOutput(invalid_block_ids={3, 4})
|
|
||||||
output2 = DummyModelRunnerOutput(
|
|
||||||
finished_recving={"req2"}, invalid_block_ids={4, 5}
|
|
||||||
)
|
|
||||||
future.set_result([output1, output2])
|
|
||||||
|
|
||||||
assert result_future.done()
|
|
||||||
aggregated = result_future.result()
|
|
||||||
assert aggregated is output1
|
|
||||||
aggregated = aggregated.kv_connector_output
|
|
||||||
assert aggregated.finished_sending is None
|
|
||||||
assert aggregated.finished_recving == {"req2"}
|
|
||||||
assert aggregated.invalid_block_ids == {3, 4, 5}
|
|
||||||
|
|
||||||
|
|
||||||
def test_aggregate_workers_output_with_expected_finished_count():
|
def test_aggregate_workers_output_with_expected_finished_count():
|
||||||
# We create the aggregator expecting to collect from 4 workers
|
# We create the aggregator expecting to collect from 4 workers
|
||||||
aggregator = KVOutputAggregator(expected_finished_count=4)
|
aggregator = KVOutputAggregator(expected_finished_count=4)
|
||||||
|
|||||||
@ -4,9 +4,6 @@
|
|||||||
KV cache helper for store.
|
KV cache helper for store.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import contextlib
|
|
||||||
from collections.abc import Sequence
|
|
||||||
from concurrent.futures import CancelledError, Future
|
|
||||||
from typing import TYPE_CHECKING, Literal
|
from typing import TYPE_CHECKING, Literal
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -220,43 +217,6 @@ class KVOutputAggregator:
|
|||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def async_aggregate(
|
|
||||||
self,
|
|
||||||
output_future: Future[Sequence[ModelRunnerOutput | None]],
|
|
||||||
output_rank: int = 0,
|
|
||||||
) -> Future[ModelRunnerOutput | None]:
|
|
||||||
"""Takes a future that resolves to a list of outputs and returns a future
|
|
||||||
which resolves to a single aggregated output."""
|
|
||||||
result_future: Future[ModelRunnerOutput | None] = Future()
|
|
||||||
|
|
||||||
def callback(fut):
|
|
||||||
if result_future.done():
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
result_future.set_result(self.aggregate(fut.result(), output_rank))
|
|
||||||
except CancelledError:
|
|
||||||
result_future.cancel()
|
|
||||||
except Exception as e:
|
|
||||||
result_future.set_exception(e)
|
|
||||||
|
|
||||||
output_future.add_done_callback(callback)
|
|
||||||
|
|
||||||
from vllm.v1.executor.multiproc_executor import FutureWrapper
|
|
||||||
|
|
||||||
if isinstance(output_future, FutureWrapper):
|
|
||||||
# Due to the threadless implementation of multiproc FutureWrapper,
|
|
||||||
# we must block on the delegate future's result() method.
|
|
||||||
delegate_result = result_future.result
|
|
||||||
|
|
||||||
def result(timeout=None):
|
|
||||||
with contextlib.suppress(Exception):
|
|
||||||
output_future.result(timeout=timeout)
|
|
||||||
return delegate_result()
|
|
||||||
|
|
||||||
result_future.result = result # type: ignore[method-assign]
|
|
||||||
|
|
||||||
return result_future
|
|
||||||
|
|
||||||
|
|
||||||
def _make_src_and_dst_indices(
|
def _make_src_and_dst_indices(
|
||||||
src_block_ids: list[int],
|
src_block_ids: list[int],
|
||||||
|
|||||||
@ -29,6 +29,7 @@ import vllm.envs as envs
|
|||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import destroy_distributed_environment, destroy_model_parallel
|
from vllm.distributed import destroy_distributed_environment, destroy_model_parallel
|
||||||
from vllm.distributed.device_communicators.shm_broadcast import Handle, MessageQueue
|
from vllm.distributed.device_communicators.shm_broadcast import Handle, MessageQueue
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
|
||||||
from vllm.distributed.parallel_state import (
|
from vllm.distributed.parallel_state import (
|
||||||
get_dp_group,
|
get_dp_group,
|
||||||
get_ep_group,
|
get_ep_group,
|
||||||
@ -57,8 +58,13 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class FutureWrapper(Future):
|
class FutureWrapper(Future):
|
||||||
def __init__(self, futures_queue: deque[tuple["FutureWrapper", Callable]]):
|
def __init__(
|
||||||
|
self,
|
||||||
|
futures_queue: deque[tuple["FutureWrapper", Callable]],
|
||||||
|
aggregate: Callable = lambda x: x,
|
||||||
|
):
|
||||||
self.futures_queue = futures_queue
|
self.futures_queue = futures_queue
|
||||||
|
self.aggregate = aggregate
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def result(self, timeout=None):
|
def result(self, timeout=None):
|
||||||
@ -72,7 +78,7 @@ class FutureWrapper(Future):
|
|||||||
|
|
||||||
def wait_for_response(self, get_response: Callable):
|
def wait_for_response(self, get_response: Callable):
|
||||||
try:
|
try:
|
||||||
response = get_response()
|
response = self.aggregate(get_response())
|
||||||
with suppress(InvalidStateError):
|
with suppress(InvalidStateError):
|
||||||
self.set_result(response)
|
self.set_result(response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -160,7 +166,6 @@ class MultiprocExecutor(Executor):
|
|||||||
self.futures_queue = deque[tuple[FutureWrapper, Callable]]()
|
self.futures_queue = deque[tuple[FutureWrapper, Callable]]()
|
||||||
|
|
||||||
self.output_rank = self._get_output_rank()
|
self.output_rank = self._get_output_rank()
|
||||||
self.has_connector = self.vllm_config.kv_transfer_config is not None
|
|
||||||
|
|
||||||
def start_worker_monitor(self):
|
def start_worker_monitor(self):
|
||||||
workers = self.workers
|
workers = self.workers
|
||||||
@ -199,44 +204,27 @@ class MultiprocExecutor(Executor):
|
|||||||
def execute_model( # type: ignore[override]
|
def execute_model( # type: ignore[override]
|
||||||
self, scheduler_output: SchedulerOutput, non_block: bool = False
|
self, scheduler_output: SchedulerOutput, non_block: bool = False
|
||||||
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
|
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
|
||||||
return self._execute_with_aggregation(
|
return self.collective_rpc(
|
||||||
"execute_model", scheduler_output, non_block=non_block
|
"execute_model",
|
||||||
|
args=(scheduler_output,),
|
||||||
|
unique_reply_rank=self.output_rank,
|
||||||
|
non_block=non_block,
|
||||||
|
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS,
|
||||||
|
kv_output_aggregator=self.kv_output_aggregator,
|
||||||
)
|
)
|
||||||
|
|
||||||
def sample_tokens( # type: ignore[override]
|
def sample_tokens( # type: ignore[override]
|
||||||
self, grammar_output: GrammarOutput | None, non_block: bool = False
|
self, grammar_output: GrammarOutput | None, non_block: bool = False
|
||||||
) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
|
) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
|
||||||
return self._execute_with_aggregation( # type: ignore[return-value]
|
return self.collective_rpc(
|
||||||
"sample_tokens", grammar_output, non_block=non_block
|
"sample_tokens",
|
||||||
)
|
args=(grammar_output,),
|
||||||
|
unique_reply_rank=self.output_rank,
|
||||||
def _execute_with_aggregation(
|
|
||||||
self, method: str, *args, non_block: bool = False
|
|
||||||
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
|
|
||||||
if not self.has_connector:
|
|
||||||
# get output only from a single worker (output_rank)
|
|
||||||
return self.collective_rpc(
|
|
||||||
method,
|
|
||||||
args=args,
|
|
||||||
unique_reply_rank=self.output_rank,
|
|
||||||
non_block=non_block,
|
|
||||||
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS,
|
|
||||||
)
|
|
||||||
|
|
||||||
# get output from all workers
|
|
||||||
outputs = self.collective_rpc(
|
|
||||||
method,
|
|
||||||
args=args,
|
|
||||||
non_block=non_block,
|
non_block=non_block,
|
||||||
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS,
|
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS,
|
||||||
|
kv_output_aggregator=self.kv_output_aggregator,
|
||||||
)
|
)
|
||||||
|
|
||||||
# aggregate all workers output to a single output
|
|
||||||
assert self.kv_output_aggregator is not None
|
|
||||||
if non_block:
|
|
||||||
return self.kv_output_aggregator.async_aggregate(outputs, self.output_rank)
|
|
||||||
return self.kv_output_aggregator.aggregate(outputs, self.output_rank)
|
|
||||||
|
|
||||||
def execute_dummy_batch(self) -> None:
|
def execute_dummy_batch(self) -> None:
|
||||||
self.collective_rpc("execute_dummy_batch", unique_reply_rank=self.output_rank)
|
self.collective_rpc("execute_dummy_batch", unique_reply_rank=self.output_rank)
|
||||||
|
|
||||||
@ -254,8 +242,10 @@ class MultiprocExecutor(Executor):
|
|||||||
kwargs: dict | None = None,
|
kwargs: dict | None = None,
|
||||||
non_block: bool = False,
|
non_block: bool = False,
|
||||||
unique_reply_rank: int | None = None,
|
unique_reply_rank: int | None = None,
|
||||||
|
kv_output_aggregator: KVOutputAggregator = None,
|
||||||
) -> Any | list[Any] | Future[Any | list[Any]]:
|
) -> Any | list[Any] | Future[Any | list[Any]]:
|
||||||
"""Returns single result if unique_reply_rank is provided, otherwise list."""
|
"""Returns single result if unique_reply_rank and/or kv_output_aggregator
|
||||||
|
is provided, otherwise list."""
|
||||||
|
|
||||||
if self.is_failed:
|
if self.is_failed:
|
||||||
raise RuntimeError("Executor failed.")
|
raise RuntimeError("Executor failed.")
|
||||||
@ -263,20 +253,23 @@ class MultiprocExecutor(Executor):
|
|||||||
deadline = None if timeout is None else time.monotonic() + timeout
|
deadline = None if timeout is None else time.monotonic() + timeout
|
||||||
kwargs = kwargs or {}
|
kwargs = kwargs or {}
|
||||||
|
|
||||||
# NOTE: If the args are heterogeneous, then we pack them into a list,
|
if kv_output_aggregator is not None:
|
||||||
# and unpack them in the method of every worker, because every worker
|
output_rank = None
|
||||||
# knows their own rank.
|
aggregate: Callable[[Any], Any] = partial(
|
||||||
|
kv_output_aggregator.aggregate, output_rank=unique_reply_rank or 0
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output_rank = unique_reply_rank
|
||||||
|
aggregate = lambda x: x
|
||||||
|
|
||||||
if isinstance(method, str):
|
if isinstance(method, str):
|
||||||
send_method = method
|
send_method = method
|
||||||
else:
|
else:
|
||||||
send_method = cloudpickle.dumps(method, protocol=pickle.HIGHEST_PROTOCOL)
|
send_method = cloudpickle.dumps(method, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
self.rpc_broadcast_mq.enqueue((send_method, args, kwargs, unique_reply_rank))
|
self.rpc_broadcast_mq.enqueue((send_method, args, kwargs, output_rank))
|
||||||
|
|
||||||
workers = (
|
workers = (
|
||||||
(self.workers[unique_reply_rank],)
|
(self.workers[output_rank],) if output_rank is not None else self.workers
|
||||||
if unique_reply_rank is not None
|
|
||||||
else self.workers
|
|
||||||
)
|
)
|
||||||
|
|
||||||
shutdown_event = self.shutdown_event
|
shutdown_event = self.shutdown_event
|
||||||
@ -299,10 +292,10 @@ class MultiprocExecutor(Executor):
|
|||||||
" stack trace above for the root cause"
|
" stack trace above for the root cause"
|
||||||
)
|
)
|
||||||
responses.append(result)
|
responses.append(result)
|
||||||
return responses[0] if unique_reply_rank is not None else responses
|
return responses[0] if output_rank is not None else responses
|
||||||
|
|
||||||
if non_block:
|
if non_block:
|
||||||
future = FutureWrapper(self.futures_queue)
|
future = FutureWrapper(self.futures_queue, aggregate=aggregate)
|
||||||
self.futures_queue.appendleft((future, get_response))
|
self.futures_queue.appendleft((future, get_response))
|
||||||
return future
|
return future
|
||||||
|
|
||||||
@ -311,7 +304,7 @@ class MultiprocExecutor(Executor):
|
|||||||
future, get_fut_response = self.futures_queue.pop()
|
future, get_fut_response = self.futures_queue.pop()
|
||||||
future.wait_for_response(get_fut_response)
|
future.wait_for_response(get_fut_response)
|
||||||
|
|
||||||
return get_response()
|
return aggregate(get_response())
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _ensure_worker_termination(worker_procs: list[BaseProcess]):
|
def _ensure_worker_termination(worker_procs: list[BaseProcess]):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user