diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index c4f558b7acdb..a0dfd54fb825 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -1,13 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os +import tempfile +import textwrap import time -import uuid -from collections import defaultdict -from typing import Optional from unittest.mock import patch import pytest +import ray from vllm import LLM from vllm.config import KVTransferConfig @@ -15,11 +16,32 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata, NixlConnectorWorker) from vllm.forward_context import ForwardContext +from vllm.mocks.mock_nixl_connector import FakeNixlWrapper from vllm.sampling_params import SamplingParams from .utils import create_request, create_scheduler, create_vllm_config +def _make_stub_pkg() -> str: + """Return a directory that makes + `from nixl._api import nixl_agent` resolve to our FakeNixlWrapper.""" + td = tempfile.mkdtemp() + pkg_root = os.path.join(td, "nixl", "_api") + os.makedirs(pkg_root, exist_ok=True) + + stub = textwrap.dedent("""\ + # Forward the real FakeNixlWrapper that the driver already defined. + print("In fake package") + from vllm.mocks.mock_nixl_connector import FakeNixlWrapper as nixl_agent + """) + with open(os.path.join(pkg_root, "__init__.py"), "w") as f: + f.write(stub) + + # touch parent package + open(os.path.join(td, "nixl", "__init__.py"), "w").close() + return td + + def test_basic_interface(): """Unit test for basic NixlConnector interface functionality.""" @@ -87,77 +109,6 @@ def test_prompt_less_than_block_size(): assert len(scheduler_output.scheduled_new_reqs) == 1 -class FakeNixlWrapper: - """Mock implementation of NixlWrapper for testing. - - We don't inherit from nixl._api.nixl_agent because nixl may not be - installed. - """ - - AGENT_METADATA = b"fake_agent_metadata" - REMOTE_AGENT_NAME = "remote_agent" - - def __init__(self, agent_name: str, *args, **kwargs): - self._cycles_before_xfer_done = 0 - self._check_xfer_state_cycles: defaultdict[int, int] = defaultdict( - lambda: 0) - - def get_reg_descs(self, caches_data, memory_type: str) -> list: - return [str(uuid.uuid4()) for _ in caches_data] - - def register_memory(self, descs) -> None: - pass - - def get_xfer_descs(self, blocks_data, memory_type: str) -> list: - return [str(uuid.uuid4()) for _ in blocks_data] - - def prep_xfer_dlist(self, agent_name: str, descs: list) -> int: - return uuid.uuid4().int - - def get_agent_metadata(self) -> bytes: - return self.AGENT_METADATA - - def add_remote_agent(self, agent_metadata: bytes) -> str: - return self.REMOTE_AGENT_NAME - - def get_new_notifs(self) -> dict[str, list[bytes]]: - # Used to collect done_sending, which we don't test yet. - return {} - - def check_xfer_state(self, handle: int) -> str: - if self._check_xfer_state_cycles[ - handle] >= self._cycles_before_xfer_done: - return "DONE" - self._check_xfer_state_cycles[handle] += 1 - return "PROC" - - def release_xfer_handle(self, handle: int) -> None: - pass - - def send_notif(self, agent_name: str, notif_msg: bytes) -> None: - pass - - def make_prepped_xfer(self, - xfer_type: str, - local_xfer_side_handle: int, - local_block_descs_ids: list[int], - remote_xfer_side_handle: int, - remote_block_descs_ids: list[int], - notif_msg: Optional[bytes] = None) -> int: - return uuid.uuid4().int - - def transfer(self, handle: int) -> str: - return "PROC" - - ############################################################ - # Follow are for changing the behavior during testing. - ############################################################ - - def set_cycles_before_xfer_done(self, cycles: int): - """Set the number of cycles before a transfer is considered done.""" - self._cycles_before_xfer_done = cycles - - class FakeNixlConnectorWorker(NixlConnectorWorker): REMOTE_ENGINE_ID = "remote_engine" @@ -378,10 +329,14 @@ class TestNixlHandshake: raise TimeoutError("Took too long to complete async handshake.") +# NOTE: resource cleanup in mp backend is a bit finicky, so the order in which +# we put here is important. First run ray, it will clean up the resources, then +# the rest of the tests. +@pytest.mark.parametrize("distributed_executor_backend", ["ray", None]) @patch( "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", FakeNixlWrapper) -def test_abort_timeout_on_prefiller(monkeypatch): +def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend): """ Test lifecycle of an aborted Remote Prefill request hitting the timeout. -----> P @@ -399,11 +354,23 @@ def test_abort_timeout_on_prefiller(monkeypatch): timeout = 6 monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") monkeypatch.setenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", str(timeout)) + + # Build runtime_env only if we’re using Ray + if distributed_executor_backend == "ray": + runtime_env = { + "working_dir": _make_stub_pkg(), # ship stub package + "env_vars": { + "VLLM_NIXL_ABORT_REQUEST_TIMEOUT": str(timeout), + }, + } + ray.init(runtime_env=runtime_env) + llm = LLM( model=model_name, enforce_eager=True, gpu_memory_utilization=0.5, kv_transfer_config=kv_transfer_config, + distributed_executor_backend=distributed_executor_backend, ) remote_prefill_opts = { "do_remote_decode": True, diff --git a/tests/v1/executor/test_multiproc_executor.py b/tests/v1/kv_connector/unit/test_output_aggreagator.py similarity index 72% rename from tests/v1/executor/test_multiproc_executor.py rename to tests/v1/kv_connector/unit/test_output_aggreagator.py index c1425d82becf..cad73f68e9f1 100644 --- a/tests/v1/executor/test_multiproc_executor.py +++ b/tests/v1/kv_connector/unit/test_output_aggreagator.py @@ -1,28 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import threading -from collections import defaultdict from concurrent.futures import Future from typing import Optional -from vllm.v1.executor.multiproc_executor import MultiprocExecutor +from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator from vllm.v1.outputs import ModelRunnerOutput -class DummyMultiprocExecutor(MultiprocExecutor): - - def __init__(self, output_rank, world_size): - # Manually initialize minimal required fields - self.output_rank = output_rank - self.world_size = world_size - self._send_remaining_count = defaultdict[str, - int](lambda: self.world_size) - self._recv_remaining_count = defaultdict[str, - int](lambda: self.world_size) - self.io_thread_pool = None - self.shutdown_event = threading.Event() - - class DummyModelRunnerOutput(ModelRunnerOutput): def __init__(self, @@ -33,14 +17,14 @@ class DummyModelRunnerOutput(ModelRunnerOutput): def test_aggregate_workers_output(): - executor = DummyMultiprocExecutor(output_rank=0, world_size=2) + aggregator = KVOutputAggregator(world_size=2) output1 = DummyModelRunnerOutput(finished_sending={'req1'}, finished_recving={'req2'}) output2 = DummyModelRunnerOutput(finished_sending=None, finished_recving=None) - aggregated = executor._aggregate_workers_output([output1, output2]) + aggregated = aggregator.aggregate([output1, output2]) assert aggregated is output1 assert aggregated.finished_sending is None @@ -51,7 +35,7 @@ def test_aggregate_workers_output(): output2 = DummyModelRunnerOutput(finished_sending={'req1'}, finished_recving=None) - aggregated = executor._aggregate_workers_output([output1, output2]) + aggregated = aggregator.aggregate([output1, output2]) assert aggregated is output1 assert aggregated.finished_sending == {'req1'} @@ -62,7 +46,7 @@ def test_aggregate_workers_output(): output2 = DummyModelRunnerOutput(finished_sending={'req1'}, finished_recving={'req2'}) - aggregated = executor._aggregate_workers_output([output1, output2]) + aggregated = aggregator.aggregate([output1, output2]) assert aggregated is output1 assert aggregated.finished_sending is None @@ -70,12 +54,11 @@ def test_aggregate_workers_output(): def test_async_aggregate_workers_output(): - executor = DummyMultiprocExecutor(output_rank=0, world_size=2) + aggregator = KVOutputAggregator(world_size=2) future1: Future[DummyModelRunnerOutput] = Future() future2: Future[DummyModelRunnerOutput] = Future() - result_future = executor._async_aggregate_workers_output( - [future1, future2]) + result_future = aggregator.async_aggregate([future1, future2]) output1 = DummyModelRunnerOutput(finished_sending={'req1'}, finished_recving={'req2'}) @@ -92,8 +75,7 @@ def test_async_aggregate_workers_output(): future1 = Future() future2 = Future() - result_future = executor._async_aggregate_workers_output( - [future1, future2]) + result_future = aggregator.async_aggregate([future1, future2]) output1 = DummyModelRunnerOutput(finished_sending=None, finished_recving=None) @@ -110,8 +92,7 @@ def test_async_aggregate_workers_output(): future1 = Future() future2 = Future() - result_future = executor._async_aggregate_workers_output( - [future1, future2]) + result_future = aggregator.async_aggregate([future1, future2]) output1 = DummyModelRunnerOutput(finished_sending=None, finished_recving=None) diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 5cbc8ca31752..c179d6cc29b7 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -3,12 +3,18 @@ """ KV cache helper for store. """ +from collections import defaultdict +from collections.abc import Sequence +from concurrent.futures import CancelledError, Future +from typing import Optional, cast + import torch import vllm.envs as envs from vllm import _custom_ops as ops from vllm.config import VllmConfig, get_current_vllm_config from vllm.logger import init_logger +from vllm.v1.outputs import ModelRunnerOutput logger = init_logger(__name__) @@ -107,3 +113,87 @@ def get_kv_connector_cache_layout(): "layout to HND for better xfer performance.") return "HND" return "NHD" + + +class KVOutputAggregator: + """Utility class to aggregate the output of all workers into a single + output corresponding to Rank 0 for scheduler.""" + + def __init__(self, world_size: int): + # Complete transfer tracker. Used by to track finished requests + # [req_id -> n_finished_workers] + self._recv_remaining_count = defaultdict[str, int](lambda: world_size) + self._send_remaining_count = defaultdict[str, int](lambda: world_size) + + def aggregate(self, + outputs: list[ModelRunnerOutput], + output_rank: int = 0) -> ModelRunnerOutput: + # aggregate finished_sending, finished_recving from all workers + + def update_finished_set(req_ids: Optional[set[str]], + remaining_count_dict: dict[str, int], + finished_set: set[str]) -> None: + for req_id in req_ids or (): + new_count = remaining_count_dict[req_id] - 1 + if new_count == 0: + finished_set.add(req_id) + del remaining_count_dict[req_id] + else: + remaining_count_dict[req_id] = new_count + + finished_sending = set[str]() + finished_recving = set[str]() + for output in outputs: + update_finished_set(output.finished_sending, + self._send_remaining_count, finished_sending) + update_finished_set(output.finished_recving, + self._recv_remaining_count, finished_recving) + + # select output of the worker specified by output_rank + output = outputs[output_rank] + + # set the aggregated finished_sending / finished_recving + # if output.finished_sending/recving is not empty, but the other ranks + # still have unfinished send/recv, we want to set the aggregated + # finished_sending/recving to None until all ranks have finished + # send/recv + output.finished_sending = finished_sending if finished_sending else None + output.finished_recving = finished_recving if finished_recving else None + + return output + + def async_aggregate(self, + output_futures: Sequence[Future[ModelRunnerOutput]], + output_rank: int = 0) -> Future[ModelRunnerOutput]: + """Takes a list of futures and returns a single future which resolves + to the respective list of outputs.""" + result_future: Future[ModelRunnerOutput] = Future() + + outputs: list[Optional[ModelRunnerOutput]] = [None + ] * len(output_futures) + + def make_callback(idx): + + def callback(fut): + if result_future.done(): + return + + try: + outputs[idx] = fut.result() + except CancelledError: + result_future.cancel() + except Exception as e: + result_future.set_exception(e) + + # this check assumes io_thread_pool uses a single thread + if all(outputs): + result_future.set_result( + self.aggregate(cast(list[ModelRunnerOutput], outputs), + output_rank)) + + return callback + + for i, output_future in enumerate(output_futures): + output_future.add_done_callback(make_callback(i)) + + return result_future diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 9459ab27aba3..e1245775bea3 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -194,7 +194,7 @@ class KVConnectorBase_V1(ABC): """ Notifies worker-side connector ids of requests that have finished generating tokens on the worker. - The scheduler process (via the MultiprocExecutor) will use this output + The scheduler process (via the Executors) will use this output to track which workers are done. Returns: diff --git a/vllm/mocks/__init__.py b/vllm/mocks/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/mocks/mock_nixl_connector.py b/vllm/mocks/mock_nixl_connector.py new file mode 100644 index 000000000000..54e2c5ee3b0a --- /dev/null +++ b/vllm/mocks/mock_nixl_connector.py @@ -0,0 +1,76 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import uuid +from collections import defaultdict +from typing import Optional + + +class FakeNixlWrapper: + """Mock implementation of NixlWrapper for testing. + + We don't inherit from nixl._api.nixl_agent because nixl may not be + installed. + """ + + AGENT_METADATA = b"fake_agent_metadata" + REMOTE_AGENT_NAME = "remote_agent" + + def __init__(self, agent_name: str, *args, **kwargs): + self._cycles_before_xfer_done = 0 + self._check_xfer_state_cycles: defaultdict[int, int] = defaultdict( + lambda: 0) + + def get_reg_descs(self, caches_data, memory_type: str) -> list: + return [str(uuid.uuid4()) for _ in caches_data] + + def register_memory(self, descs) -> None: + pass + + def get_xfer_descs(self, blocks_data, memory_type: str) -> list: + return [str(uuid.uuid4()) for _ in blocks_data] + + def prep_xfer_dlist(self, agent_name: str, descs: list) -> int: + return uuid.uuid4().int + + def get_agent_metadata(self) -> bytes: + return self.AGENT_METADATA + + def add_remote_agent(self, agent_metadata: bytes) -> str: + return self.REMOTE_AGENT_NAME + + def get_new_notifs(self) -> dict[str, list[bytes]]: + # Used to collect done_sending, which we don't test yet. + return {} + + def check_xfer_state(self, handle: int) -> str: + if self._check_xfer_state_cycles[ + handle] >= self._cycles_before_xfer_done: + return "DONE" + self._check_xfer_state_cycles[handle] += 1 + return "PROC" + + def release_xfer_handle(self, handle: int) -> None: + pass + + def send_notif(self, agent_name: str, notif_msg: bytes) -> None: + pass + + def make_prepped_xfer(self, + xfer_type: str, + local_xfer_side_handle: int, + local_block_descs_ids: list[int], + remote_xfer_side_handle: int, + remote_block_descs_ids: list[int], + notif_msg: Optional[bytes] = None) -> int: + return uuid.uuid4().int + + def transfer(self, handle: int) -> str: + return "PROC" + + ############################################################ + # Follow are for changing the behavior during testing. + ############################################################ + + def set_cycles_before_xfer_done(self, cycles: int): + """Set the number of cycles before a transfer is considered done.""" + self._cycles_before_xfer_done = cycles diff --git a/vllm/sequence.py b/vllm/sequence.py index 87ba74c68536..99208fbad65f 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1188,9 +1188,15 @@ class IntermediateTensors: """For all pipeline stages except the last, we need to return the hidden states and residuals to be sent to the next stage. This data structure contains the hidden states and residuals for a request. + + Each stage also needs to handle its own finished_sending and + finished_recving in case of kv transfer. """ tensors: dict[str, torch.Tensor] + # [req_ids] + finished_sending: Optional[set[str]] = None + finished_recving: Optional[set[str]] = None def __init__(self, tensors): # manually define this function, so that diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 4a4144c4860a..11ddade3eb70 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -9,8 +9,7 @@ import threading import time import traceback import weakref -from collections import defaultdict -from concurrent.futures import CancelledError, Future, ThreadPoolExecutor +from concurrent.futures import Future, ThreadPoolExecutor from dataclasses import dataclass from enum import Enum, auto from functools import partial @@ -27,6 +26,7 @@ from vllm.distributed import (destroy_distributed_environment, destroy_model_parallel) from vllm.distributed.device_communicators.shm_broadcast import (Handle, MessageQueue) +from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator from vllm.executor.multiproc_worker_utils import ( _add_prefix, set_multiprocessing_worker_envs) from vllm.logger import init_logger @@ -118,13 +118,8 @@ class MultiprocExecutor(Executor): self.output_rank = self._get_output_rank() self.has_connector = self.vllm_config.kv_transfer_config is not None - - # Complete transfer tracker. Used by to track finished requests - # [req_id -> n_finished_workers] - self._recv_remaining_count = defaultdict[str, - int](lambda: self.world_size) - self._send_remaining_count = defaultdict[str, - int](lambda: self.world_size) + self.kv_output_aggregator = KVOutputAggregator( + self.parallel_config.world_size) def start_worker_monitor(self): workers = self.workers @@ -186,8 +181,9 @@ class MultiprocExecutor(Executor): # aggregate all workers output to a single output if non_block: - return self._async_aggregate_workers_output(outputs) - return self._aggregate_workers_output(outputs) + return self.kv_output_aggregator.async_aggregate( + outputs, self.output_rank) + return self.kv_output_aggregator.aggregate(outputs, self.output_rank) def collective_rpc(self, method: Union[str, Callable], @@ -246,74 +242,6 @@ class MultiprocExecutor(Executor): except TimeoutError as e: raise TimeoutError(f"RPC call to {method} timed out.") from e - def _aggregate_workers_output( - self, outputs: list[ModelRunnerOutput]) -> ModelRunnerOutput: - # aggregate finished_sending, finished_recving from all workers - - def update_finished_set(req_ids: Optional[set[str]], - remaining_count_dict: dict[str, int], - finished_set: set[str]) -> None: - for req_id in req_ids or (): - new_count = remaining_count_dict[req_id] - 1 - if new_count == 0: - finished_set.add(req_id) - del remaining_count_dict[req_id] - else: - remaining_count_dict[req_id] = new_count - - finished_sending = set[str]() - finished_recving = set[str]() - for output in outputs: - update_finished_set(output.finished_sending, - self._send_remaining_count, finished_sending) - update_finished_set(output.finished_recving, - self._recv_remaining_count, finished_recving) - - # select output of the worker specified by output_rank - output = outputs[self.output_rank] - - # set the aggregated finished_sending / finished_recving - output.finished_sending = finished_sending if finished_sending else None - output.finished_recving = finished_recving if finished_recving else None - - return output - - def _async_aggregate_workers_output( - self, output_futures: list[Future[ModelRunnerOutput]] - ) -> (Future[ModelRunnerOutput]): - """Takes a list of futures and returns a single future which resolves - to the respective list of outputs.""" - result_future: Future[ModelRunnerOutput] = Future() - - outputs: list[Optional[ModelRunnerOutput]] = [None - ] * len(output_futures) - - def make_callback(idx): - - def callback(fut): - if result_future.done(): - return - - try: - outputs[idx] = fut.result() - except CancelledError: - result_future.cancel() - except Exception as e: - result_future.set_exception(e) - - # this check assumes io_thread_pool uses a single thread - if all(outputs): - result_future.set_result( - self._aggregate_workers_output( - cast(list[ModelRunnerOutput], outputs))) - - return callback - - for i, output_future in enumerate(output_futures): - output_future.add_done_callback(make_callback(i)) - - return result_future - @staticmethod def _ensure_worker_termination(worker_procs: list[BaseProcess]): """Ensure that all worker processes are terminated. Assumes workers have diff --git a/vllm/v1/executor/ray_distributed_executor.py b/vllm/v1/executor/ray_distributed_executor.py index eb659e4f9e47..b86ac048f520 100644 --- a/vllm/v1/executor/ray_distributed_executor.py +++ b/vllm/v1/executor/ray_distributed_executor.py @@ -2,33 +2,55 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from concurrent.futures import Future -from typing import Union +from typing import Optional, Union +from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator from vllm.executor.ray_distributed_executor import ( # noqa RayDistributedExecutor as RayDistributedExecutorV0) +from vllm.logger import init_logger from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType from vllm.v1.executor.abstract import Executor from vllm.v1.outputs import ModelRunnerOutput +logger = init_logger(__name__) + class FutureWrapper(Future): - """A wrapper around a Ray output reference to meet the interface - of .execute_model(). + """A wrapper around Ray output reference to meet the interface + of .execute_model(): The top level (core busy loop) expects .result() api + to block and return a single output. + + If aggregator is provided, the outputs from all workers are aggregated upon + the result() call. If not only the first worker's output is returned. """ - def __init__(self, ref): + def __init__(self, refs, aggregator: Optional[KVOutputAggregator] = None): super().__init__() - self.ref = ref + self.refs = refs + self.aggregator = aggregator def result(self, timeout=None): if timeout is not None: raise NotImplementedError("timeout is not supported") - return self.ref.get() + + if self.aggregator is None: + return self.refs[0].get() + + outputs = [ref.get() for ref in self.refs] + return self.aggregator.aggregate(outputs, output_rank=0) class RayDistributedExecutor(RayDistributedExecutorV0, Executor): """Ray distributed executor using Ray Compiled Graphs.""" + def _init_executor(self) -> None: + super()._init_executor() + + # KV connector setup + self.has_connector = self.vllm_config.kv_transfer_config is not None + self.kv_output_aggregator = KVOutputAggregator( + self.parallel_config.world_size) + @property def max_concurrent_batches(self) -> int: """Ray distributed executor supports pipeline parallelism, @@ -56,13 +78,24 @@ class RayDistributedExecutor(RayDistributedExecutorV0, Executor): refs = self.forward_dag.execute(scheduler_output) # type: ignore - # When PP is not used, we block here until the result is available. - if self.max_concurrent_batches == 1: - return refs[0].get() + if not self.has_connector: + # Get output only from a single worker (output_rank) + # When PP is not used, we block here until the result is available. + if self.max_concurrent_batches == 1: + return refs[0].get() - # When PP is used, we return a FutureWrapper immediately so that - # the scheduler can yield to the next batch. - return FutureWrapper(refs[0]) + # When PP is used, we return a FutureWrapper immediately so that + # the scheduler can yield to the next batch. + return FutureWrapper(refs) + + # Get output from all workers when connector is present + if self.max_concurrent_batches == 1: + # Block and get results from all workers + outputs = [ref.get() for ref in refs] + return self.kv_output_aggregator.aggregate(outputs) + + # Return a future that will aggregate outputs from all workers + return FutureWrapper(refs, self.kv_output_aggregator) def reinitialize_distributed( self, reconfig_request: ReconfigureDistributedRequest) -> None: @@ -70,4 +103,4 @@ class RayDistributedExecutor(RayDistributedExecutorV0, Executor): if reconfig_request.new_data_parallel_rank == \ ReconfigureRankType.SHUTDOWN_CURRENT_RANK: self.shutdown() - return + return \ No newline at end of file diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a5c446731144..d5449a68bc28 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import copy import gc import time from contextlib import contextmanager @@ -1270,6 +1271,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): hidden_states: torch.Tensor, num_scheduled_tokens: int, num_scheduled_tokens_np: np.ndarray, + finished_sending: Optional[set[str]], + finished_recving: Optional[set[str]], ) -> ModelRunnerOutput: assert self.input_batch.num_reqs ==\ len(self.input_batch.pooling_params), \ @@ -1304,6 +1307,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): logprobs=None, prompt_logprobs_dict={}, pooler_output=pooler_output, + finished_sending=finished_sending, + finished_recving=finished_recving, ) @torch.inference_mode() @@ -1314,12 +1319,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): ) -> Union[ModelRunnerOutput, IntermediateTensors]: self._update_states(scheduler_output) if not scheduler_output.total_num_scheduled_tokens: - if has_kv_transfer_group(): - with set_forward_context(None, self.vllm_config): - self.maybe_setup_kv_connector(scheduler_output) + if not has_kv_transfer_group(): + # Return empty ModelRunnerOutput if there's no work to do. + return EMPTY_MODEL_RUNNER_OUTPUT - # Return empty ModelRunnerOutput if there's no work to do. - return EMPTY_MODEL_RUNNER_OUTPUT + return self.kv_connector_no_forward(scheduler_output) # Prepare the decoder inputs. (attn_metadata, attention_cuda_graphs, logits_indices, @@ -1412,6 +1416,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): ) self.maybe_wait_for_kv_save() + finished_sending, finished_recving = ( + self.get_finished_kv_transfers(scheduler_output)) if self.use_aux_hidden_state_outputs: hidden_states, aux_hidden_states = model_output @@ -1429,6 +1435,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): if not get_pp_group().is_last_rank: # For mid-pipeline stages, return the hidden states. if not broadcast_pp_output: + if finished_sending or finished_recving: + hidden_states.finished_sending = finished_sending + hidden_states.finished_recving = finished_recving return hidden_states assert isinstance(hidden_states, IntermediateTensors) get_pp_group().send_tensor_dict(hidden_states.tensors, @@ -1437,7 +1446,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): else: if self.input_batch.pooling_params: return self._pool(hidden_states, num_scheduled_tokens, - num_scheduled_tokens_np) + num_scheduled_tokens_np, finished_sending, + finished_recving) sample_hidden_states = hidden_states[logits_indices] logits = self.model.compute_logits(sample_hidden_states, None) @@ -1587,6 +1597,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, pooler_output=[], + finished_sending=finished_sending, + finished_recving=finished_recving, num_nans_in_logits=num_nans_in_logits, ) @@ -1711,6 +1723,31 @@ class GPUModelRunner(LoRAModelRunnerMixin): if has_kv_transfer_group(): get_kv_transfer_group().wait_for_save() + @staticmethod + def get_finished_kv_transfers( + scheduler_output: "SchedulerOutput", + ) -> tuple[Optional[set[str]], Optional[set[str]]]: + if has_kv_transfer_group(): + return get_kv_transfer_group().get_finished( + scheduler_output.finished_req_ids) + return None, None + + def kv_connector_no_forward( + self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput: + # KV send/recv even if no work to do. + with set_forward_context(None, self.vllm_config): + self.maybe_setup_kv_connector(scheduler_output) + finished_sending, finished_recving = ( + self.get_finished_kv_transfers(scheduler_output)) + + if not finished_sending and not finished_recving: + return EMPTY_MODEL_RUNNER_OUTPUT + + output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) + output.finished_sending = finished_sending + output.finished_recving = finished_recving + return output + def propose_ngram_draft_token_ids( self, sampled_token_ids: list[list[int]], diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 2201481fa5bf..6411874883ef 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -15,9 +15,7 @@ from vllm.config import VllmConfig from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce) -from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized, - get_kv_transfer_group, - has_kv_transfer_group) +from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -335,25 +333,17 @@ class Worker(WorkerBase): assert isinstance(output, IntermediateTensors) get_pp_group().send_tensor_dict(output.tensors, all_gather_group=get_tp_group()) - output = EMPTY_MODEL_RUNNER_OUTPUT + + # In case of PP with kv transfer, we need to pass through the + # finished_sending and finished_recving buffers. + empty_output = EMPTY_MODEL_RUNNER_OUTPUT + if output.finished_sending or output.finished_recving: + empty_output = copy.copy(empty_output) + empty_output.finished_sending = output.finished_sending + empty_output.finished_recving = output.finished_recving + output = empty_output assert isinstance(output, ModelRunnerOutput) - if has_kv_transfer_group(): - finished_sending, finished_recving = ( - get_kv_transfer_group().get_finished( - scheduler_output.finished_req_ids)) - if finished_sending or finished_recving: - if output is EMPTY_MODEL_RUNNER_OUTPUT: - output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) - output.finished_sending = finished_sending - output.finished_recving = finished_recving - - # Clear KVConnector state for this step. - get_kv_transfer_group().clear_connector_metadata() - - # with a connector, the scheduler expects output from all workers - return output - # return output only from the driver worker return output if self.is_driver_worker else None