mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:34:57 +08:00
[BugFix] Make PD work with Ray (#21072)
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
This commit is contained in:
parent
6a971ed692
commit
9f414a12ad
@ -1,13 +1,14 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import textwrap
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import ray
|
||||
|
||||
from vllm import LLM
|
||||
from vllm.config import KVTransferConfig
|
||||
@ -15,11 +16,32 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
|
||||
KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata,
|
||||
NixlConnectorWorker)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.mocks.mock_nixl_connector import FakeNixlWrapper
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
from .utils import create_request, create_scheduler, create_vllm_config
|
||||
|
||||
|
||||
def _make_stub_pkg() -> str:
|
||||
"""Return a directory that makes
|
||||
`from nixl._api import nixl_agent` resolve to our FakeNixlWrapper."""
|
||||
td = tempfile.mkdtemp()
|
||||
pkg_root = os.path.join(td, "nixl", "_api")
|
||||
os.makedirs(pkg_root, exist_ok=True)
|
||||
|
||||
stub = textwrap.dedent("""\
|
||||
# Forward the real FakeNixlWrapper that the driver already defined.
|
||||
print("In fake package")
|
||||
from vllm.mocks.mock_nixl_connector import FakeNixlWrapper as nixl_agent
|
||||
""")
|
||||
with open(os.path.join(pkg_root, "__init__.py"), "w") as f:
|
||||
f.write(stub)
|
||||
|
||||
# touch parent package
|
||||
open(os.path.join(td, "nixl", "__init__.py"), "w").close()
|
||||
return td
|
||||
|
||||
|
||||
def test_basic_interface():
|
||||
"""Unit test for basic NixlConnector interface functionality."""
|
||||
|
||||
@ -87,77 +109,6 @@ def test_prompt_less_than_block_size():
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 1
|
||||
|
||||
|
||||
class FakeNixlWrapper:
|
||||
"""Mock implementation of NixlWrapper for testing.
|
||||
|
||||
We don't inherit from nixl._api.nixl_agent because nixl may not be
|
||||
installed.
|
||||
"""
|
||||
|
||||
AGENT_METADATA = b"fake_agent_metadata"
|
||||
REMOTE_AGENT_NAME = "remote_agent"
|
||||
|
||||
def __init__(self, agent_name: str, *args, **kwargs):
|
||||
self._cycles_before_xfer_done = 0
|
||||
self._check_xfer_state_cycles: defaultdict[int, int] = defaultdict(
|
||||
lambda: 0)
|
||||
|
||||
def get_reg_descs(self, caches_data, memory_type: str) -> list:
|
||||
return [str(uuid.uuid4()) for _ in caches_data]
|
||||
|
||||
def register_memory(self, descs) -> None:
|
||||
pass
|
||||
|
||||
def get_xfer_descs(self, blocks_data, memory_type: str) -> list:
|
||||
return [str(uuid.uuid4()) for _ in blocks_data]
|
||||
|
||||
def prep_xfer_dlist(self, agent_name: str, descs: list) -> int:
|
||||
return uuid.uuid4().int
|
||||
|
||||
def get_agent_metadata(self) -> bytes:
|
||||
return self.AGENT_METADATA
|
||||
|
||||
def add_remote_agent(self, agent_metadata: bytes) -> str:
|
||||
return self.REMOTE_AGENT_NAME
|
||||
|
||||
def get_new_notifs(self) -> dict[str, list[bytes]]:
|
||||
# Used to collect done_sending, which we don't test yet.
|
||||
return {}
|
||||
|
||||
def check_xfer_state(self, handle: int) -> str:
|
||||
if self._check_xfer_state_cycles[
|
||||
handle] >= self._cycles_before_xfer_done:
|
||||
return "DONE"
|
||||
self._check_xfer_state_cycles[handle] += 1
|
||||
return "PROC"
|
||||
|
||||
def release_xfer_handle(self, handle: int) -> None:
|
||||
pass
|
||||
|
||||
def send_notif(self, agent_name: str, notif_msg: bytes) -> None:
|
||||
pass
|
||||
|
||||
def make_prepped_xfer(self,
|
||||
xfer_type: str,
|
||||
local_xfer_side_handle: int,
|
||||
local_block_descs_ids: list[int],
|
||||
remote_xfer_side_handle: int,
|
||||
remote_block_descs_ids: list[int],
|
||||
notif_msg: Optional[bytes] = None) -> int:
|
||||
return uuid.uuid4().int
|
||||
|
||||
def transfer(self, handle: int) -> str:
|
||||
return "PROC"
|
||||
|
||||
############################################################
|
||||
# Follow are for changing the behavior during testing.
|
||||
############################################################
|
||||
|
||||
def set_cycles_before_xfer_done(self, cycles: int):
|
||||
"""Set the number of cycles before a transfer is considered done."""
|
||||
self._cycles_before_xfer_done = cycles
|
||||
|
||||
|
||||
class FakeNixlConnectorWorker(NixlConnectorWorker):
|
||||
|
||||
REMOTE_ENGINE_ID = "remote_engine"
|
||||
@ -378,10 +329,14 @@ class TestNixlHandshake:
|
||||
raise TimeoutError("Took too long to complete async handshake.")
|
||||
|
||||
|
||||
# NOTE: resource cleanup in mp backend is a bit finicky, so the order in which
|
||||
# we put here is important. First run ray, it will clean up the resources, then
|
||||
# the rest of the tests.
|
||||
@pytest.mark.parametrize("distributed_executor_backend", ["ray", None])
|
||||
@patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||
FakeNixlWrapper)
|
||||
def test_abort_timeout_on_prefiller(monkeypatch):
|
||||
def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend):
|
||||
"""
|
||||
Test lifecycle of an aborted Remote Prefill request hitting the timeout.
|
||||
-----> P
|
||||
@ -399,11 +354,23 @@ def test_abort_timeout_on_prefiller(monkeypatch):
|
||||
timeout = 6
|
||||
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||
monkeypatch.setenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", str(timeout))
|
||||
|
||||
# Build runtime_env only if we’re using Ray
|
||||
if distributed_executor_backend == "ray":
|
||||
runtime_env = {
|
||||
"working_dir": _make_stub_pkg(), # ship stub package
|
||||
"env_vars": {
|
||||
"VLLM_NIXL_ABORT_REQUEST_TIMEOUT": str(timeout),
|
||||
},
|
||||
}
|
||||
ray.init(runtime_env=runtime_env)
|
||||
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
enforce_eager=True,
|
||||
gpu_memory_utilization=0.5,
|
||||
kv_transfer_config=kv_transfer_config,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
)
|
||||
remote_prefill_opts = {
|
||||
"do_remote_decode": True,
|
||||
|
||||
@ -1,28 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import Future
|
||||
from typing import Optional
|
||||
|
||||
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
|
||||
|
||||
class DummyMultiprocExecutor(MultiprocExecutor):
|
||||
|
||||
def __init__(self, output_rank, world_size):
|
||||
# Manually initialize minimal required fields
|
||||
self.output_rank = output_rank
|
||||
self.world_size = world_size
|
||||
self._send_remaining_count = defaultdict[str,
|
||||
int](lambda: self.world_size)
|
||||
self._recv_remaining_count = defaultdict[str,
|
||||
int](lambda: self.world_size)
|
||||
self.io_thread_pool = None
|
||||
self.shutdown_event = threading.Event()
|
||||
|
||||
|
||||
class DummyModelRunnerOutput(ModelRunnerOutput):
|
||||
|
||||
def __init__(self,
|
||||
@ -33,14 +17,14 @@ class DummyModelRunnerOutput(ModelRunnerOutput):
|
||||
|
||||
|
||||
def test_aggregate_workers_output():
|
||||
executor = DummyMultiprocExecutor(output_rank=0, world_size=2)
|
||||
aggregator = KVOutputAggregator(world_size=2)
|
||||
|
||||
output1 = DummyModelRunnerOutput(finished_sending={'req1'},
|
||||
finished_recving={'req2'})
|
||||
output2 = DummyModelRunnerOutput(finished_sending=None,
|
||||
finished_recving=None)
|
||||
|
||||
aggregated = executor._aggregate_workers_output([output1, output2])
|
||||
aggregated = aggregator.aggregate([output1, output2])
|
||||
|
||||
assert aggregated is output1
|
||||
assert aggregated.finished_sending is None
|
||||
@ -51,7 +35,7 @@ def test_aggregate_workers_output():
|
||||
output2 = DummyModelRunnerOutput(finished_sending={'req1'},
|
||||
finished_recving=None)
|
||||
|
||||
aggregated = executor._aggregate_workers_output([output1, output2])
|
||||
aggregated = aggregator.aggregate([output1, output2])
|
||||
|
||||
assert aggregated is output1
|
||||
assert aggregated.finished_sending == {'req1'}
|
||||
@ -62,7 +46,7 @@ def test_aggregate_workers_output():
|
||||
output2 = DummyModelRunnerOutput(finished_sending={'req1'},
|
||||
finished_recving={'req2'})
|
||||
|
||||
aggregated = executor._aggregate_workers_output([output1, output2])
|
||||
aggregated = aggregator.aggregate([output1, output2])
|
||||
|
||||
assert aggregated is output1
|
||||
assert aggregated.finished_sending is None
|
||||
@ -70,12 +54,11 @@ def test_aggregate_workers_output():
|
||||
|
||||
|
||||
def test_async_aggregate_workers_output():
|
||||
executor = DummyMultiprocExecutor(output_rank=0, world_size=2)
|
||||
aggregator = KVOutputAggregator(world_size=2)
|
||||
|
||||
future1: Future[DummyModelRunnerOutput] = Future()
|
||||
future2: Future[DummyModelRunnerOutput] = Future()
|
||||
result_future = executor._async_aggregate_workers_output(
|
||||
[future1, future2])
|
||||
result_future = aggregator.async_aggregate([future1, future2])
|
||||
|
||||
output1 = DummyModelRunnerOutput(finished_sending={'req1'},
|
||||
finished_recving={'req2'})
|
||||
@ -92,8 +75,7 @@ def test_async_aggregate_workers_output():
|
||||
|
||||
future1 = Future()
|
||||
future2 = Future()
|
||||
result_future = executor._async_aggregate_workers_output(
|
||||
[future1, future2])
|
||||
result_future = aggregator.async_aggregate([future1, future2])
|
||||
|
||||
output1 = DummyModelRunnerOutput(finished_sending=None,
|
||||
finished_recving=None)
|
||||
@ -110,8 +92,7 @@ def test_async_aggregate_workers_output():
|
||||
|
||||
future1 = Future()
|
||||
future2 = Future()
|
||||
result_future = executor._async_aggregate_workers_output(
|
||||
[future1, future2])
|
||||
result_future = aggregator.async_aggregate([future1, future2])
|
||||
|
||||
output1 = DummyModelRunnerOutput(finished_sending=None,
|
||||
finished_recving=None)
|
||||
@ -3,12 +3,18 @@
|
||||
"""
|
||||
KV cache helper for store.
|
||||
"""
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from concurrent.futures import CancelledError, Future
|
||||
from typing import Optional, cast
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -107,3 +113,87 @@ def get_kv_connector_cache_layout():
|
||||
"layout to HND for better xfer performance.")
|
||||
return "HND"
|
||||
return "NHD"
|
||||
|
||||
|
||||
class KVOutputAggregator:
|
||||
"""Utility class to aggregate the output of all workers into a single
|
||||
output corresponding to Rank 0 for scheduler."""
|
||||
|
||||
def __init__(self, world_size: int):
|
||||
# Complete transfer tracker. Used by to track finished requests
|
||||
# [req_id -> n_finished_workers]
|
||||
self._recv_remaining_count = defaultdict[str, int](lambda: world_size)
|
||||
self._send_remaining_count = defaultdict[str, int](lambda: world_size)
|
||||
|
||||
def aggregate(self,
|
||||
outputs: list[ModelRunnerOutput],
|
||||
output_rank: int = 0) -> ModelRunnerOutput:
|
||||
# aggregate finished_sending, finished_recving from all workers
|
||||
|
||||
def update_finished_set(req_ids: Optional[set[str]],
|
||||
remaining_count_dict: dict[str, int],
|
||||
finished_set: set[str]) -> None:
|
||||
for req_id in req_ids or ():
|
||||
new_count = remaining_count_dict[req_id] - 1
|
||||
if new_count == 0:
|
||||
finished_set.add(req_id)
|
||||
del remaining_count_dict[req_id]
|
||||
else:
|
||||
remaining_count_dict[req_id] = new_count
|
||||
|
||||
finished_sending = set[str]()
|
||||
finished_recving = set[str]()
|
||||
for output in outputs:
|
||||
update_finished_set(output.finished_sending,
|
||||
self._send_remaining_count, finished_sending)
|
||||
update_finished_set(output.finished_recving,
|
||||
self._recv_remaining_count, finished_recving)
|
||||
|
||||
# select output of the worker specified by output_rank
|
||||
output = outputs[output_rank]
|
||||
|
||||
# set the aggregated finished_sending / finished_recving
|
||||
# if output.finished_sending/recving is not empty, but the other ranks
|
||||
# still have unfinished send/recv, we want to set the aggregated
|
||||
# finished_sending/recving to None until all ranks have finished
|
||||
# send/recv
|
||||
output.finished_sending = finished_sending if finished_sending else None
|
||||
output.finished_recving = finished_recving if finished_recving else None
|
||||
|
||||
return output
|
||||
|
||||
def async_aggregate(self,
|
||||
output_futures: Sequence[Future[ModelRunnerOutput]],
|
||||
output_rank: int = 0) -> Future[ModelRunnerOutput]:
|
||||
"""Takes a list of futures and returns a single future which resolves
|
||||
to the respective list of outputs."""
|
||||
result_future: Future[ModelRunnerOutput] = Future()
|
||||
|
||||
outputs: list[Optional[ModelRunnerOutput]] = [None
|
||||
] * len(output_futures)
|
||||
|
||||
def make_callback(idx):
|
||||
|
||||
def callback(fut):
|
||||
if result_future.done():
|
||||
return
|
||||
|
||||
try:
|
||||
outputs[idx] = fut.result()
|
||||
except CancelledError:
|
||||
result_future.cancel()
|
||||
except Exception as e:
|
||||
result_future.set_exception(e)
|
||||
|
||||
# this check assumes io_thread_pool uses a single thread
|
||||
if all(outputs):
|
||||
result_future.set_result(
|
||||
self.aggregate(cast(list[ModelRunnerOutput], outputs),
|
||||
output_rank))
|
||||
|
||||
return callback
|
||||
|
||||
for i, output_future in enumerate(output_futures):
|
||||
output_future.add_done_callback(make_callback(i))
|
||||
|
||||
return result_future
|
||||
|
||||
@ -194,7 +194,7 @@ class KVConnectorBase_V1(ABC):
|
||||
"""
|
||||
Notifies worker-side connector ids of requests that have
|
||||
finished generating tokens on the worker.
|
||||
The scheduler process (via the MultiprocExecutor) will use this output
|
||||
The scheduler process (via the Executors) will use this output
|
||||
to track which workers are done.
|
||||
|
||||
Returns:
|
||||
|
||||
0
vllm/mocks/__init__.py
Normal file
0
vllm/mocks/__init__.py
Normal file
76
vllm/mocks/mock_nixl_connector.py
Normal file
76
vllm/mocks/mock_nixl_connector.py
Normal file
@ -0,0 +1,76 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class FakeNixlWrapper:
|
||||
"""Mock implementation of NixlWrapper for testing.
|
||||
|
||||
We don't inherit from nixl._api.nixl_agent because nixl may not be
|
||||
installed.
|
||||
"""
|
||||
|
||||
AGENT_METADATA = b"fake_agent_metadata"
|
||||
REMOTE_AGENT_NAME = "remote_agent"
|
||||
|
||||
def __init__(self, agent_name: str, *args, **kwargs):
|
||||
self._cycles_before_xfer_done = 0
|
||||
self._check_xfer_state_cycles: defaultdict[int, int] = defaultdict(
|
||||
lambda: 0)
|
||||
|
||||
def get_reg_descs(self, caches_data, memory_type: str) -> list:
|
||||
return [str(uuid.uuid4()) for _ in caches_data]
|
||||
|
||||
def register_memory(self, descs) -> None:
|
||||
pass
|
||||
|
||||
def get_xfer_descs(self, blocks_data, memory_type: str) -> list:
|
||||
return [str(uuid.uuid4()) for _ in blocks_data]
|
||||
|
||||
def prep_xfer_dlist(self, agent_name: str, descs: list) -> int:
|
||||
return uuid.uuid4().int
|
||||
|
||||
def get_agent_metadata(self) -> bytes:
|
||||
return self.AGENT_METADATA
|
||||
|
||||
def add_remote_agent(self, agent_metadata: bytes) -> str:
|
||||
return self.REMOTE_AGENT_NAME
|
||||
|
||||
def get_new_notifs(self) -> dict[str, list[bytes]]:
|
||||
# Used to collect done_sending, which we don't test yet.
|
||||
return {}
|
||||
|
||||
def check_xfer_state(self, handle: int) -> str:
|
||||
if self._check_xfer_state_cycles[
|
||||
handle] >= self._cycles_before_xfer_done:
|
||||
return "DONE"
|
||||
self._check_xfer_state_cycles[handle] += 1
|
||||
return "PROC"
|
||||
|
||||
def release_xfer_handle(self, handle: int) -> None:
|
||||
pass
|
||||
|
||||
def send_notif(self, agent_name: str, notif_msg: bytes) -> None:
|
||||
pass
|
||||
|
||||
def make_prepped_xfer(self,
|
||||
xfer_type: str,
|
||||
local_xfer_side_handle: int,
|
||||
local_block_descs_ids: list[int],
|
||||
remote_xfer_side_handle: int,
|
||||
remote_block_descs_ids: list[int],
|
||||
notif_msg: Optional[bytes] = None) -> int:
|
||||
return uuid.uuid4().int
|
||||
|
||||
def transfer(self, handle: int) -> str:
|
||||
return "PROC"
|
||||
|
||||
############################################################
|
||||
# Follow are for changing the behavior during testing.
|
||||
############################################################
|
||||
|
||||
def set_cycles_before_xfer_done(self, cycles: int):
|
||||
"""Set the number of cycles before a transfer is considered done."""
|
||||
self._cycles_before_xfer_done = cycles
|
||||
@ -1188,9 +1188,15 @@ class IntermediateTensors:
|
||||
"""For all pipeline stages except the last, we need to return the hidden
|
||||
states and residuals to be sent to the next stage. This data structure
|
||||
contains the hidden states and residuals for a request.
|
||||
|
||||
Each stage also needs to handle its own finished_sending and
|
||||
finished_recving in case of kv transfer.
|
||||
"""
|
||||
|
||||
tensors: dict[str, torch.Tensor]
|
||||
# [req_ids]
|
||||
finished_sending: Optional[set[str]] = None
|
||||
finished_recving: Optional[set[str]] = None
|
||||
|
||||
def __init__(self, tensors):
|
||||
# manually define this function, so that
|
||||
|
||||
@ -9,8 +9,7 @@ import threading
|
||||
import time
|
||||
import traceback
|
||||
import weakref
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import CancelledError, Future, ThreadPoolExecutor
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from functools import partial
|
||||
@ -27,6 +26,7 @@ from vllm.distributed import (destroy_distributed_environment,
|
||||
destroy_model_parallel)
|
||||
from vllm.distributed.device_communicators.shm_broadcast import (Handle,
|
||||
MessageQueue)
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
|
||||
from vllm.executor.multiproc_worker_utils import (
|
||||
_add_prefix, set_multiprocessing_worker_envs)
|
||||
from vllm.logger import init_logger
|
||||
@ -118,13 +118,8 @@ class MultiprocExecutor(Executor):
|
||||
|
||||
self.output_rank = self._get_output_rank()
|
||||
self.has_connector = self.vllm_config.kv_transfer_config is not None
|
||||
|
||||
# Complete transfer tracker. Used by to track finished requests
|
||||
# [req_id -> n_finished_workers]
|
||||
self._recv_remaining_count = defaultdict[str,
|
||||
int](lambda: self.world_size)
|
||||
self._send_remaining_count = defaultdict[str,
|
||||
int](lambda: self.world_size)
|
||||
self.kv_output_aggregator = KVOutputAggregator(
|
||||
self.parallel_config.world_size)
|
||||
|
||||
def start_worker_monitor(self):
|
||||
workers = self.workers
|
||||
@ -186,8 +181,9 @@ class MultiprocExecutor(Executor):
|
||||
|
||||
# aggregate all workers output to a single output
|
||||
if non_block:
|
||||
return self._async_aggregate_workers_output(outputs)
|
||||
return self._aggregate_workers_output(outputs)
|
||||
return self.kv_output_aggregator.async_aggregate(
|
||||
outputs, self.output_rank)
|
||||
return self.kv_output_aggregator.aggregate(outputs, self.output_rank)
|
||||
|
||||
def collective_rpc(self,
|
||||
method: Union[str, Callable],
|
||||
@ -246,74 +242,6 @@ class MultiprocExecutor(Executor):
|
||||
except TimeoutError as e:
|
||||
raise TimeoutError(f"RPC call to {method} timed out.") from e
|
||||
|
||||
def _aggregate_workers_output(
|
||||
self, outputs: list[ModelRunnerOutput]) -> ModelRunnerOutput:
|
||||
# aggregate finished_sending, finished_recving from all workers
|
||||
|
||||
def update_finished_set(req_ids: Optional[set[str]],
|
||||
remaining_count_dict: dict[str, int],
|
||||
finished_set: set[str]) -> None:
|
||||
for req_id in req_ids or ():
|
||||
new_count = remaining_count_dict[req_id] - 1
|
||||
if new_count == 0:
|
||||
finished_set.add(req_id)
|
||||
del remaining_count_dict[req_id]
|
||||
else:
|
||||
remaining_count_dict[req_id] = new_count
|
||||
|
||||
finished_sending = set[str]()
|
||||
finished_recving = set[str]()
|
||||
for output in outputs:
|
||||
update_finished_set(output.finished_sending,
|
||||
self._send_remaining_count, finished_sending)
|
||||
update_finished_set(output.finished_recving,
|
||||
self._recv_remaining_count, finished_recving)
|
||||
|
||||
# select output of the worker specified by output_rank
|
||||
output = outputs[self.output_rank]
|
||||
|
||||
# set the aggregated finished_sending / finished_recving
|
||||
output.finished_sending = finished_sending if finished_sending else None
|
||||
output.finished_recving = finished_recving if finished_recving else None
|
||||
|
||||
return output
|
||||
|
||||
def _async_aggregate_workers_output(
|
||||
self, output_futures: list[Future[ModelRunnerOutput]]
|
||||
) -> (Future[ModelRunnerOutput]):
|
||||
"""Takes a list of futures and returns a single future which resolves
|
||||
to the respective list of outputs."""
|
||||
result_future: Future[ModelRunnerOutput] = Future()
|
||||
|
||||
outputs: list[Optional[ModelRunnerOutput]] = [None
|
||||
] * len(output_futures)
|
||||
|
||||
def make_callback(idx):
|
||||
|
||||
def callback(fut):
|
||||
if result_future.done():
|
||||
return
|
||||
|
||||
try:
|
||||
outputs[idx] = fut.result()
|
||||
except CancelledError:
|
||||
result_future.cancel()
|
||||
except Exception as e:
|
||||
result_future.set_exception(e)
|
||||
|
||||
# this check assumes io_thread_pool uses a single thread
|
||||
if all(outputs):
|
||||
result_future.set_result(
|
||||
self._aggregate_workers_output(
|
||||
cast(list[ModelRunnerOutput], outputs)))
|
||||
|
||||
return callback
|
||||
|
||||
for i, output_future in enumerate(output_futures):
|
||||
output_future.add_done_callback(make_callback(i))
|
||||
|
||||
return result_future
|
||||
|
||||
@staticmethod
|
||||
def _ensure_worker_termination(worker_procs: list[BaseProcess]):
|
||||
"""Ensure that all worker processes are terminated. Assumes workers have
|
||||
|
||||
@ -2,33 +2,55 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from concurrent.futures import Future
|
||||
from typing import Union
|
||||
from typing import Optional, Union
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
|
||||
from vllm.executor.ray_distributed_executor import ( # noqa
|
||||
RayDistributedExecutor as RayDistributedExecutorV0)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FutureWrapper(Future):
|
||||
"""A wrapper around a Ray output reference to meet the interface
|
||||
of .execute_model().
|
||||
"""A wrapper around Ray output reference to meet the interface
|
||||
of .execute_model(): The top level (core busy loop) expects .result() api
|
||||
to block and return a single output.
|
||||
|
||||
If aggregator is provided, the outputs from all workers are aggregated upon
|
||||
the result() call. If not only the first worker's output is returned.
|
||||
"""
|
||||
|
||||
def __init__(self, ref):
|
||||
def __init__(self, refs, aggregator: Optional[KVOutputAggregator] = None):
|
||||
super().__init__()
|
||||
self.ref = ref
|
||||
self.refs = refs
|
||||
self.aggregator = aggregator
|
||||
|
||||
def result(self, timeout=None):
|
||||
if timeout is not None:
|
||||
raise NotImplementedError("timeout is not supported")
|
||||
return self.ref.get()
|
||||
|
||||
if self.aggregator is None:
|
||||
return self.refs[0].get()
|
||||
|
||||
outputs = [ref.get() for ref in self.refs]
|
||||
return self.aggregator.aggregate(outputs, output_rank=0)
|
||||
|
||||
|
||||
class RayDistributedExecutor(RayDistributedExecutorV0, Executor):
|
||||
"""Ray distributed executor using Ray Compiled Graphs."""
|
||||
|
||||
def _init_executor(self) -> None:
|
||||
super()._init_executor()
|
||||
|
||||
# KV connector setup
|
||||
self.has_connector = self.vllm_config.kv_transfer_config is not None
|
||||
self.kv_output_aggregator = KVOutputAggregator(
|
||||
self.parallel_config.world_size)
|
||||
|
||||
@property
|
||||
def max_concurrent_batches(self) -> int:
|
||||
"""Ray distributed executor supports pipeline parallelism,
|
||||
@ -56,13 +78,24 @@ class RayDistributedExecutor(RayDistributedExecutorV0, Executor):
|
||||
|
||||
refs = self.forward_dag.execute(scheduler_output) # type: ignore
|
||||
|
||||
# When PP is not used, we block here until the result is available.
|
||||
if self.max_concurrent_batches == 1:
|
||||
return refs[0].get()
|
||||
if not self.has_connector:
|
||||
# Get output only from a single worker (output_rank)
|
||||
# When PP is not used, we block here until the result is available.
|
||||
if self.max_concurrent_batches == 1:
|
||||
return refs[0].get()
|
||||
|
||||
# When PP is used, we return a FutureWrapper immediately so that
|
||||
# the scheduler can yield to the next batch.
|
||||
return FutureWrapper(refs[0])
|
||||
# When PP is used, we return a FutureWrapper immediately so that
|
||||
# the scheduler can yield to the next batch.
|
||||
return FutureWrapper(refs)
|
||||
|
||||
# Get output from all workers when connector is present
|
||||
if self.max_concurrent_batches == 1:
|
||||
# Block and get results from all workers
|
||||
outputs = [ref.get() for ref in refs]
|
||||
return self.kv_output_aggregator.aggregate(outputs)
|
||||
|
||||
# Return a future that will aggregate outputs from all workers
|
||||
return FutureWrapper(refs, self.kv_output_aggregator)
|
||||
|
||||
def reinitialize_distributed(
|
||||
self, reconfig_request: ReconfigureDistributedRequest) -> None:
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import copy
|
||||
import gc
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
@ -1270,6 +1271,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
hidden_states: torch.Tensor,
|
||||
num_scheduled_tokens: int,
|
||||
num_scheduled_tokens_np: np.ndarray,
|
||||
finished_sending: Optional[set[str]],
|
||||
finished_recving: Optional[set[str]],
|
||||
) -> ModelRunnerOutput:
|
||||
assert self.input_batch.num_reqs ==\
|
||||
len(self.input_batch.pooling_params), \
|
||||
@ -1304,6 +1307,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=pooler_output,
|
||||
finished_sending=finished_sending,
|
||||
finished_recving=finished_recving,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
@ -1314,12 +1319,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
) -> Union[ModelRunnerOutput, IntermediateTensors]:
|
||||
self._update_states(scheduler_output)
|
||||
if not scheduler_output.total_num_scheduled_tokens:
|
||||
if has_kv_transfer_group():
|
||||
with set_forward_context(None, self.vllm_config):
|
||||
self.maybe_setup_kv_connector(scheduler_output)
|
||||
if not has_kv_transfer_group():
|
||||
# Return empty ModelRunnerOutput if there's no work to do.
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
# Return empty ModelRunnerOutput if there's no work to do.
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
return self.kv_connector_no_forward(scheduler_output)
|
||||
|
||||
# Prepare the decoder inputs.
|
||||
(attn_metadata, attention_cuda_graphs, logits_indices,
|
||||
@ -1412,6 +1416,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
)
|
||||
|
||||
self.maybe_wait_for_kv_save()
|
||||
finished_sending, finished_recving = (
|
||||
self.get_finished_kv_transfers(scheduler_output))
|
||||
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
hidden_states, aux_hidden_states = model_output
|
||||
@ -1429,6 +1435,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if not get_pp_group().is_last_rank:
|
||||
# For mid-pipeline stages, return the hidden states.
|
||||
if not broadcast_pp_output:
|
||||
if finished_sending or finished_recving:
|
||||
hidden_states.finished_sending = finished_sending
|
||||
hidden_states.finished_recving = finished_recving
|
||||
return hidden_states
|
||||
assert isinstance(hidden_states, IntermediateTensors)
|
||||
get_pp_group().send_tensor_dict(hidden_states.tensors,
|
||||
@ -1437,7 +1446,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
else:
|
||||
if self.input_batch.pooling_params:
|
||||
return self._pool(hidden_states, num_scheduled_tokens,
|
||||
num_scheduled_tokens_np)
|
||||
num_scheduled_tokens_np, finished_sending,
|
||||
finished_recving)
|
||||
|
||||
sample_hidden_states = hidden_states[logits_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||
@ -1587,6 +1597,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
logprobs=logprobs_lists,
|
||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||
pooler_output=[],
|
||||
finished_sending=finished_sending,
|
||||
finished_recving=finished_recving,
|
||||
num_nans_in_logits=num_nans_in_logits,
|
||||
)
|
||||
|
||||
@ -1711,6 +1723,31 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if has_kv_transfer_group():
|
||||
get_kv_transfer_group().wait_for_save()
|
||||
|
||||
@staticmethod
|
||||
def get_finished_kv_transfers(
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> tuple[Optional[set[str]], Optional[set[str]]]:
|
||||
if has_kv_transfer_group():
|
||||
return get_kv_transfer_group().get_finished(
|
||||
scheduler_output.finished_req_ids)
|
||||
return None, None
|
||||
|
||||
def kv_connector_no_forward(
|
||||
self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
|
||||
# KV send/recv even if no work to do.
|
||||
with set_forward_context(None, self.vllm_config):
|
||||
self.maybe_setup_kv_connector(scheduler_output)
|
||||
finished_sending, finished_recving = (
|
||||
self.get_finished_kv_transfers(scheduler_output))
|
||||
|
||||
if not finished_sending and not finished_recving:
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
output.finished_sending = finished_sending
|
||||
output.finished_recving = finished_recving
|
||||
return output
|
||||
|
||||
def propose_ngram_draft_token_ids(
|
||||
self,
|
||||
sampled_token_ids: list[list[int]],
|
||||
|
||||
@ -15,9 +15,7 @@ from vllm.config import VllmConfig
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment,
|
||||
set_custom_all_reduce)
|
||||
from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
|
||||
get_kv_transfer_group,
|
||||
has_kv_transfer_group)
|
||||
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
|
||||
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
@ -335,25 +333,17 @@ class Worker(WorkerBase):
|
||||
assert isinstance(output, IntermediateTensors)
|
||||
get_pp_group().send_tensor_dict(output.tensors,
|
||||
all_gather_group=get_tp_group())
|
||||
output = EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
# In case of PP with kv transfer, we need to pass through the
|
||||
# finished_sending and finished_recving buffers.
|
||||
empty_output = EMPTY_MODEL_RUNNER_OUTPUT
|
||||
if output.finished_sending or output.finished_recving:
|
||||
empty_output = copy.copy(empty_output)
|
||||
empty_output.finished_sending = output.finished_sending
|
||||
empty_output.finished_recving = output.finished_recving
|
||||
output = empty_output
|
||||
|
||||
assert isinstance(output, ModelRunnerOutput)
|
||||
if has_kv_transfer_group():
|
||||
finished_sending, finished_recving = (
|
||||
get_kv_transfer_group().get_finished(
|
||||
scheduler_output.finished_req_ids))
|
||||
if finished_sending or finished_recving:
|
||||
if output is EMPTY_MODEL_RUNNER_OUTPUT:
|
||||
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
output.finished_sending = finished_sending
|
||||
output.finished_recving = finished_recving
|
||||
|
||||
# Clear KVConnector state for this step.
|
||||
get_kv_transfer_group().clear_connector_metadata()
|
||||
|
||||
# with a connector, the scheduler expects output from all workers
|
||||
return output
|
||||
|
||||
# return output only from the driver worker
|
||||
return output if self.is_driver_worker else None
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user