From f05a4f0e345bbfd4a7cb3f421bd9412e1cc53e74 Mon Sep 17 00:00:00 2001 From: Chauncey Date: Tue, 23 Sep 2025 22:08:02 +0800 Subject: [PATCH] [P/D] Support NIXL connector to disconnect during a clean shutdown (#24423) Signed-off-by: chaunceyjiang Co-authored-by: Mark McLoughlin --- .../kv_connector/unit/test_nixl_connector.py | 52 +++++++++++++++++++ .../kv_connector/v1/nixl_connector.py | 35 ++++++++++--- 2 files changed, 80 insertions(+), 7 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index fa698a2eabd9..24cc83c28614 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -60,6 +60,9 @@ class FakeNixlWrapper: def register_memory(self, descs, backends) -> None: pass + def deregister_memory(self, descs) -> None: + pass + def get_xfer_descs(self, blocks_data, memory_type: str) -> list: return [str(uuid.uuid4()) for _ in blocks_data] @@ -86,6 +89,12 @@ class FakeNixlWrapper: def release_xfer_handle(self, handle: int) -> None: pass + def release_dlist_handle(self, handle: int) -> None: + pass + + def remove_remote_agent(self, agent: str) -> None: + pass + def send_notif(self, agent_name: str, notif_msg: bytes) -> None: pass @@ -905,3 +914,46 @@ def test_kv_buffer_to_nixl_memory_types(dist_init, kv_buffer_device, # Verify get_reg_descs was called with the correct memory_type assert connector.connector_worker.kv_buffer_device == kv_buffer_device assert connector.connector_worker.nixl_memory_type == nixl_memory_type + + +@patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FakeNixlWrapper) +def test_shutdown_cleans_up_resources(dist_init): + """Test that shutdown() properly cleans up all resources.""" + vllm_config = create_vllm_config() + + worker = NixlConnectorWorker(vllm_config, + vllm_config.kv_transfer_config.engine_id) + nixl_wrapper = worker.nixl_wrapper + + with patch.object(worker, '_handshake_initiation_executor') as mock_exec, \ + patch.object(worker, '_nixl_handshake_listener_t') as mock_listener, \ + patch.object(nixl_wrapper, 'release_xfer_handle') as mock_rel_xfer, \ + patch.object(nixl_wrapper, 'release_dlist_handle') as mock_rel_dlist, \ + patch.object(nixl_wrapper, 'remove_remote_agent') as mock_rem_agent, \ + patch.object(nixl_wrapper, 'deregister_memory') as mock_dereg: + + worker._recving_transfers = {"req1": [(123, time.perf_counter())]} + worker.src_xfer_side_handle = 456 + worker.dst_xfer_side_handles = {"engine1": 789} + worker._remote_agents = {"engine1": {0: "agent1"}} + worker._registered_descs = ["desc1", "desc2"] + + worker.shutdown() + + # Test idempotency + worker.shutdown() + worker.shutdown() + + mock_exec.shutdown.assert_called_with(wait=False) + mock_listener.join.assert_called_once_with(timeout=0) + + mock_rel_xfer.assert_called_once_with(123) + assert mock_rel_dlist.call_count == 2 + mock_rel_dlist.assert_any_call(456) # src handle + mock_rel_dlist.assert_any_call(789) # dst handle + mock_rem_agent.assert_called_once_with("agent1") + assert mock_dereg.call_count == 2 + mock_dereg.assert_any_call("desc1") + mock_dereg.assert_any_call("desc2") diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 82b483447e33..64feddb591c2 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -250,6 +250,10 @@ class NixlConnector(KVConnectorBase_V1): self.connector_worker.copy_blocks: self.connector_worker.save_kv_to_host(self._connector_metadata) + def shutdown(self): + if self.connector_worker is not None: + self.connector_worker.shutdown() + class NixlConnectorScheduler: """Implementation of Scheduler side methods""" @@ -586,13 +590,6 @@ class NixlConnectorWorker: self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int) self.xfer_stats = NixlKVConnectorStats() - def __del__(self): - """Cleanup background threads on destruction.""" - if executor := getattr(self, "_handshake_initiation_executor", None): - executor.shutdown(wait=False) - if listener_t := getattr(self, "_nixl_handshake_listener_t", None): - listener_t.join(timeout=0) - @staticmethod def _nixl_handshake_listener(metadata: NixlAgentMetadata, ready_event: threading.Event, base_port: int, @@ -1346,6 +1343,30 @@ class NixlConnectorWorker: return self.xfer_stats.clone_and_reset() return None + def shutdown(self): + """Shutdown the connector worker.""" + self._handshake_initiation_executor.shutdown(wait=False) + if self._nixl_handshake_listener_t is not None: + self._nixl_handshake_listener_t.join(timeout=0) + self._nixl_handshake_listener_t = None + for handles in self._recving_transfers.values(): + for handle, _ in handles: + self.nixl_wrapper.release_xfer_handle(handle) + self._recving_transfers.clear() + if self.src_xfer_side_handle: + self.nixl_wrapper.release_dlist_handle(self.src_xfer_side_handle) + self.src_xfer_side_handle = 0 + for dst_xfer_side_handle in self.dst_xfer_side_handles.values(): + self.nixl_wrapper.release_dlist_handle(dst_xfer_side_handle) + self.dst_xfer_side_handles.clear() + for remote_agents in self._remote_agents.values(): + for agent_name in remote_agents.values(): + self.nixl_wrapper.remove_remote_agent(agent_name) + self._remote_agents.clear() + for desc in self._registered_descs: + self.nixl_wrapper.deregister_memory(desc) + self._registered_descs.clear() + @contextlib.contextmanager def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]: