diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 99bde919c725..c5ca7df83685 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -1,10 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib +import inspect import os import tempfile import textwrap import time +import uuid +from collections import defaultdict +from typing import Optional from unittest.mock import patch import pytest @@ -16,30 +21,118 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata, NixlConnectorWorker) from vllm.forward_context import ForwardContext -from vllm.mocks.mock_nixl_connector import FakeNixlWrapper from vllm.sampling_params import SamplingParams from .utils import create_request, create_scheduler, create_vllm_config -def _make_stub_pkg() -> str: - """Return a directory that makes - `from nixl._api import nixl_agent` resolve to our FakeNixlWrapper.""" - td = tempfile.mkdtemp() - pkg_root = os.path.join(td, "nixl", "_api") - os.makedirs(pkg_root, exist_ok=True) +class FakeNixlWrapper: + """Mock implementation of NixlWrapper for testing. - stub = textwrap.dedent("""\ - # Forward the real FakeNixlWrapper that the driver already defined. - print("In fake package") - from vllm.mocks.mock_nixl_connector import FakeNixlWrapper as nixl_agent - """) - with open(os.path.join(pkg_root, "__init__.py"), "w") as f: - f.write(stub) + We don't inherit from nixl._api.nixl_agent because nixl may not be + installed. + + Note: The complete source of this class is also used in the + `_make_fake_nixl_pkg` function to create a fake nixl package + for Ray workers. + """ - # touch parent package - open(os.path.join(td, "nixl", "__init__.py"), "w").close() - return td + AGENT_METADATA = b"fake_agent_metadata" + REMOTE_AGENT_NAME = "remote_agent" + + def __init__(self, agent_name: str, *args, **kwargs): + self._cycles_before_xfer_done = 0 + self._check_xfer_state_cycles: defaultdict[int, int] = defaultdict( + lambda: 0) + + def get_reg_descs(self, caches_data, memory_type: str) -> list: + return [str(uuid.uuid4()) for _ in caches_data] + + def register_memory(self, descs) -> None: + pass + + def get_xfer_descs(self, blocks_data, memory_type: str) -> list: + return [str(uuid.uuid4()) for _ in blocks_data] + + def prep_xfer_dlist(self, agent_name: str, descs: list) -> int: + return uuid.uuid4().int + + def get_agent_metadata(self) -> bytes: + return self.AGENT_METADATA + + def add_remote_agent(self, agent_metadata: bytes) -> str: + return self.REMOTE_AGENT_NAME + + def get_new_notifs(self) -> dict[str, list[bytes]]: + # Used to collect done_sending, which we don't test yet. + return {} + + def check_xfer_state(self, handle: int) -> str: + if self._check_xfer_state_cycles[ + handle] >= self._cycles_before_xfer_done: + return "DONE" + self._check_xfer_state_cycles[handle] += 1 + return "PROC" + + def release_xfer_handle(self, handle: int) -> None: + pass + + def send_notif(self, agent_name: str, notif_msg: bytes) -> None: + pass + + def make_prepped_xfer(self, + xfer_type: str, + local_xfer_side_handle: int, + local_block_descs_ids: list[int], + remote_xfer_side_handle: int, + remote_block_descs_ids: list[int], + notif_msg: Optional[bytes] = None) -> int: + return uuid.uuid4().int + + def transfer(self, handle: int) -> str: + return "PROC" + + ############################################################ + # Follow are for changing the behavior during testing. + ############################################################ + + def set_cycles_before_xfer_done(self, cycles: int): + """Set the number of cycles before a transfer is considered done.""" + self._cycles_before_xfer_done = cycles + + +@contextlib.contextmanager +def _make_fake_nixl_pkg(): + """Context manager that creates a temporary package making + `from nixl._api import nixl_agent` resolve to our FakeNixlWrapper. + + Automatically cleans up the temporary directory when done. + """ + with tempfile.TemporaryDirectory() as td: + pkg_root = os.path.join(td, "nixl", "_api") + os.makedirs(pkg_root, exist_ok=True) + + # Get the source code of FakeNixlWrapper class and dedent it + fake_nixl_source = inspect.getsource(FakeNixlWrapper) + fake_nixl_source = textwrap.dedent(fake_nixl_source) + + stub = f"""\ +# Copy of FakeNixlWrapper implementation for Ray workers +import uuid +from collections import defaultdict +from typing import Optional + +{fake_nixl_source} + +# Export as nixl_agent +nixl_agent = FakeNixlWrapper +""" + with open(os.path.join(pkg_root, "__init__.py"), "w") as f: + f.write(stub) + + # touch parent package + open(os.path.join(td, "nixl", "__init__.py"), "w").close() + yield td def test_basic_interface(): @@ -351,27 +444,37 @@ def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend): kv_connector="NixlConnector", kv_role="kv_both", ) + llm_kwargs = { + "model": model_name, + "enforce_eager": True, + "gpu_memory_utilization": 0.5, + "kv_transfer_config": kv_transfer_config, + "distributed_executor_backend": distributed_executor_backend, + } + timeout = 6 monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") monkeypatch.setenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", str(timeout)) - # Build runtime_env only if we’re using Ray + # Build runtime_env only if we're using Ray if distributed_executor_backend == "ray": - runtime_env = { - "working_dir": _make_stub_pkg(), # ship stub package - "env_vars": { - "VLLM_NIXL_ABORT_REQUEST_TIMEOUT": str(timeout), - }, - } - ray.init(runtime_env=runtime_env) + with _make_fake_nixl_pkg() as working_dir: + runtime_env = { + "working_dir": working_dir, # ship fake nixl package + "env_vars": { + "VLLM_NIXL_ABORT_REQUEST_TIMEOUT": str(timeout), + }, + } + ray.init(runtime_env=runtime_env) - llm = LLM( - model=model_name, - enforce_eager=True, - gpu_memory_utilization=0.5, - kv_transfer_config=kv_transfer_config, - distributed_executor_backend=distributed_executor_backend, - ) + _run_abort_timeout_test(llm_kwargs, timeout) + else: + _run_abort_timeout_test(llm_kwargs, timeout) + + +def _run_abort_timeout_test(llm_kwargs: dict, timeout: int): + """Helper function to run the abort timeout test logic.""" + llm = LLM(**llm_kwargs) remote_prefill_opts = { "do_remote_decode": True, "do_remote_prefill": False, diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index c179d6cc29b7..459a53298914 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -120,8 +120,8 @@ class KVOutputAggregator: output corresponding to Rank 0 for scheduler.""" def __init__(self, world_size: int): - # Complete transfer tracker. Used by to track finished requests - # [req_id -> n_finished_workers] + # Complete transfer tracker. Used to track finished requests + # [req_id -> n_remaining_workers] self._recv_remaining_count = defaultdict[str, int](lambda: world_size) self._send_remaining_count = defaultdict[str, int](lambda: world_size) @@ -134,12 +134,10 @@ class KVOutputAggregator: remaining_count_dict: dict[str, int], finished_set: set[str]) -> None: for req_id in req_ids or (): - new_count = remaining_count_dict[req_id] - 1 - if new_count == 0: + remaining_count_dict[req_id] -= 1 + if remaining_count_dict[req_id] == 0: finished_set.add(req_id) del remaining_count_dict[req_id] - else: - remaining_count_dict[req_id] = new_count finished_sending = set[str]() finished_recving = set[str]() diff --git a/vllm/mocks/__init__.py b/vllm/mocks/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/vllm/mocks/mock_nixl_connector.py b/vllm/mocks/mock_nixl_connector.py deleted file mode 100644 index 54e2c5ee3b0a..000000000000 --- a/vllm/mocks/mock_nixl_connector.py +++ /dev/null @@ -1,76 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import uuid -from collections import defaultdict -from typing import Optional - - -class FakeNixlWrapper: - """Mock implementation of NixlWrapper for testing. - - We don't inherit from nixl._api.nixl_agent because nixl may not be - installed. - """ - - AGENT_METADATA = b"fake_agent_metadata" - REMOTE_AGENT_NAME = "remote_agent" - - def __init__(self, agent_name: str, *args, **kwargs): - self._cycles_before_xfer_done = 0 - self._check_xfer_state_cycles: defaultdict[int, int] = defaultdict( - lambda: 0) - - def get_reg_descs(self, caches_data, memory_type: str) -> list: - return [str(uuid.uuid4()) for _ in caches_data] - - def register_memory(self, descs) -> None: - pass - - def get_xfer_descs(self, blocks_data, memory_type: str) -> list: - return [str(uuid.uuid4()) for _ in blocks_data] - - def prep_xfer_dlist(self, agent_name: str, descs: list) -> int: - return uuid.uuid4().int - - def get_agent_metadata(self) -> bytes: - return self.AGENT_METADATA - - def add_remote_agent(self, agent_metadata: bytes) -> str: - return self.REMOTE_AGENT_NAME - - def get_new_notifs(self) -> dict[str, list[bytes]]: - # Used to collect done_sending, which we don't test yet. - return {} - - def check_xfer_state(self, handle: int) -> str: - if self._check_xfer_state_cycles[ - handle] >= self._cycles_before_xfer_done: - return "DONE" - self._check_xfer_state_cycles[handle] += 1 - return "PROC" - - def release_xfer_handle(self, handle: int) -> None: - pass - - def send_notif(self, agent_name: str, notif_msg: bytes) -> None: - pass - - def make_prepped_xfer(self, - xfer_type: str, - local_xfer_side_handle: int, - local_block_descs_ids: list[int], - remote_xfer_side_handle: int, - remote_block_descs_ids: list[int], - notif_msg: Optional[bytes] = None) -> int: - return uuid.uuid4().int - - def transfer(self, handle: int) -> str: - return "PROC" - - ############################################################ - # Follow are for changing the behavior during testing. - ############################################################ - - def set_cycles_before_xfer_done(self, cycles: int): - """Set the number of cycles before a transfer is considered done.""" - self._cycles_before_xfer_done = cycles