From 1bcd15edc71422e4eb4525f5e07903d73187da17 Mon Sep 17 00:00:00 2001 From: lkchen Date: Sun, 22 Jun 2025 22:41:53 -0700 Subject: [PATCH] [BugFix][P/D] Fix for cases where _recving_transfers can be cleaned up when *all* transfer done (#19874) Signed-off-by: Linkun Chen --- .../kv_connector/unit/test_nixl_connector.py | 174 +++++++++++++++++- .../kv_connector/v1/nixl_connector.py | 9 +- 2 files changed, 179 insertions(+), 4 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index a0bcb8f602e11..b00be7b83e12b 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -1,8 +1,23 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import time +import uuid +from collections import defaultdict +from typing import Optional +from unittest.mock import patch + +import pytest + +try: + from nixl._api import nixl_agent as NixlWrapper +except ImportError: + NixlWrapper = None + from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( - NixlConnectorMetadata) + KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata, + NixlConnectorWorker) +from vllm.forward_context import ForwardContext from .utils import create_request, create_scheduler, create_vllm_config @@ -72,3 +87,160 @@ def test_prompt_less_than_block_size(): # This request should be scheduled regularly. assert len(scheduler_output.scheduled_new_reqs) == 1 + + +class FakeNixlWrapper: + """Mock implementation of NixlWrapper for testing. + + We don't inherit from NixlWrapper because NixlWrapper could be None. + """ + + AGENT_METADATA = b"fake_agent_metadata" + REMOTE_AGENT_NAME = "remote_agent" + + def __init__(self, agent_name: str, *args, **kwargs): + self._cycles_before_xfer_done = 0 + self._check_xfer_state_cycles: defaultdict[int, int] = defaultdict( + lambda: 0) + + def get_reg_descs(self, caches_data, memory_type: str) -> list: + return [str(uuid.uuid4()) for _ in caches_data] + + def register_memory(self, descs) -> None: + pass + + def get_xfer_descs(self, blocks_data, memory_type: str) -> list: + return [str(uuid.uuid4()) for _ in blocks_data] + + def prep_xfer_dlist(self, agent_name: str, descs: list) -> int: + return uuid.uuid4().int + + def get_agent_metadata(self) -> bytes: + return self.AGENT_METADATA + + def add_remote_agent(self, agent_metadata: bytes) -> str: + return self.REMOTE_AGENT_NAME + + def get_new_notifs(self) -> dict[str, list[bytes]]: + # Used to collect done_sending, which we don't test yet. + return {} + + def check_xfer_state(self, handle: int) -> str: + if self._check_xfer_state_cycles[ + handle] >= self._cycles_before_xfer_done: + return "DONE" + self._check_xfer_state_cycles[handle] += 1 + return "PROC" + + def release_xfer_handle(self, handle: int) -> None: + pass + + def send_notif(self, agent_name: str, notif_msg: bytes) -> None: + pass + + def make_prepped_xfer(self, + xfer_type: str, + local_xfer_side_handle: int, + local_block_descs_ids: list[int], + remote_xfer_side_handle: int, + remote_block_descs_ids: list[int], + notif_msg: Optional[bytes] = None) -> int: + return uuid.uuid4().int + + def transfer(self, handle: int) -> str: + return "PROC" + + ############################################################ + # Follow are for changing the behavior during testing. + ############################################################ + + def set_cycles_before_xfer_done(self, cycles: int): + """Set the number of cycles before a transfer is considered done.""" + self._cycles_before_xfer_done = cycles + + +class FakeNixlConnectorWorker(NixlConnectorWorker): + + REMOTE_ENGINE_ID = "remote_engine" + + def __init__(self, *args, hand_shake_latency: float = 1.8, **kwargs): + super().__init__(*args, **kwargs) + self._hand_shake_latency = hand_shake_latency + + def _nixl_handshake(self, host: str, port: int): + # Mimic slow _nixl_handshake, as well as bypass zmq communication. + time.sleep(self._hand_shake_latency) + # These should've been done in register_kv_caches(), called by + # gpu_model_runner. Here we just hardcode some dummy values. + self.slot_size_bytes = 4096 + self.block_len = self.slot_size_bytes * self.block_size + self.num_blocks = 1 + self.dst_num_blocks[self.engine_id] = self.num_blocks + + self.add_remote_agent( + NixlAgentMetadata( + engine_id=self.REMOTE_ENGINE_ID, + agent_metadata=FakeNixlWrapper.AGENT_METADATA, + kv_caches_base_addr=[0], + num_blocks=1, + tp_size=1, + block_len=self.block_len, + attn_backend_name=self.backend_name, + )) + + +@pytest.mark.skipif(NixlWrapper is None, reason="nixl not installed") +@patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FakeNixlWrapper) +def test_multi_xfer_one_engine( + # dist_init is a fixture that initializes the distributed environment. + dist_init): + """Test case where multiple xfers are initiated to the same engine. + + This test triggers the connector to load remote KV for the same + `request_id`. The transfer is not done immediately due to + `set_cycles_before_xfer_done`, so there is a state where there are multiple + transfer states for the same `request_id`, and `get_finished` should handle + it correctly (wait for all transfers to be done). + """ + vllm_config = create_vllm_config() + + request_id = "req_id" + + # Test worker role in decode server. + connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + connector.connector_worker = FakeNixlConnectorWorker(vllm_config, + connector.engine_id, + hand_shake_latency=0) + assert isinstance(connector.connector_worker.nixl_wrapper, FakeNixlWrapper) + connector.connector_worker.nixl_wrapper.set_cycles_before_xfer_done(3) + for i in range(4): + metadata = NixlConnectorMetadata() + metadata.add_new_req(request_id=request_id, + local_block_ids=[i + 1, i + 2, i + 3], + kv_transfer_params={ + "remote_block_ids": [i + 4, i + 5, i + 6], + "remote_engine_id": + FakeNixlConnectorWorker.REMOTE_ENGINE_ID, + "remote_host": "localhost", + "remote_port": 1234, + }) + connector.bind_connector_metadata(metadata) + + dummy_ctx = ForwardContext( + no_compile_layers={}, + attn_metadata={}, + virtual_engine=0, + ) + _before_load = time.perf_counter() + connector.start_load_kv(dummy_ctx) + _after_load = time.perf_counter() + assert _after_load - _before_load < 0.1, "start_load_kv took " \ + f"{_after_load - _before_load} seconds" + + while True: + _, done_recving = connector.get_finished(finished_req_ids=set()) + if len(done_recving) > 0: + assert request_id in done_recving + break diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 94f757e007af7..2d80cbf2b24f6 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -841,17 +841,20 @@ class NixlConnectorWorker: """ done_req_ids: set[str] = set() for req_id, handles in list(transfers.items()): - for handle, xfer_stime in handles: + in_progress = False + for handle, _xfer_stime in handles: xfer_state = self.nixl_wrapper.check_xfer_state(handle) if xfer_state == "DONE": self.nixl_wrapper.release_xfer_handle(handle) - done_req_ids.add(req_id) - del transfers[req_id] elif xfer_state == "PROC": + in_progress = True continue else: raise RuntimeError("Transfer failed with state %s", xfer_state) + if not in_progress: + done_req_ids.add(req_id) + del transfers[req_id] return done_req_ids def start_load_kv(self, metadata: NixlConnectorMetadata):