[P/D] Support NIXL connector to disconnect during a clean shutdown (#24423)

Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
Co-authored-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
Chauncey 2025-09-23 22:08:02 +08:00 committed by GitHub
parent 61d1b35561
commit f05a4f0e34
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 80 additions and 7 deletions

View File

@ -60,6 +60,9 @@ class FakeNixlWrapper:
def register_memory(self, descs, backends) -> None: def register_memory(self, descs, backends) -> None:
pass pass
def deregister_memory(self, descs) -> None:
pass
def get_xfer_descs(self, blocks_data, memory_type: str) -> list: def get_xfer_descs(self, blocks_data, memory_type: str) -> list:
return [str(uuid.uuid4()) for _ in blocks_data] return [str(uuid.uuid4()) for _ in blocks_data]
@ -86,6 +89,12 @@ class FakeNixlWrapper:
def release_xfer_handle(self, handle: int) -> None: def release_xfer_handle(self, handle: int) -> None:
pass pass
def release_dlist_handle(self, handle: int) -> None:
pass
def remove_remote_agent(self, agent: str) -> None:
pass
def send_notif(self, agent_name: str, notif_msg: bytes) -> None: def send_notif(self, agent_name: str, notif_msg: bytes) -> None:
pass pass
@ -905,3 +914,46 @@ def test_kv_buffer_to_nixl_memory_types(dist_init, kv_buffer_device,
# Verify get_reg_descs was called with the correct memory_type # Verify get_reg_descs was called with the correct memory_type
assert connector.connector_worker.kv_buffer_device == kv_buffer_device assert connector.connector_worker.kv_buffer_device == kv_buffer_device
assert connector.connector_worker.nixl_memory_type == nixl_memory_type assert connector.connector_worker.nixl_memory_type == nixl_memory_type
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper)
def test_shutdown_cleans_up_resources(dist_init):
"""Test that shutdown() properly cleans up all resources."""
vllm_config = create_vllm_config()
worker = NixlConnectorWorker(vllm_config,
vllm_config.kv_transfer_config.engine_id)
nixl_wrapper = worker.nixl_wrapper
with patch.object(worker, '_handshake_initiation_executor') as mock_exec, \
patch.object(worker, '_nixl_handshake_listener_t') as mock_listener, \
patch.object(nixl_wrapper, 'release_xfer_handle') as mock_rel_xfer, \
patch.object(nixl_wrapper, 'release_dlist_handle') as mock_rel_dlist, \
patch.object(nixl_wrapper, 'remove_remote_agent') as mock_rem_agent, \
patch.object(nixl_wrapper, 'deregister_memory') as mock_dereg:
worker._recving_transfers = {"req1": [(123, time.perf_counter())]}
worker.src_xfer_side_handle = 456
worker.dst_xfer_side_handles = {"engine1": 789}
worker._remote_agents = {"engine1": {0: "agent1"}}
worker._registered_descs = ["desc1", "desc2"]
worker.shutdown()
# Test idempotency
worker.shutdown()
worker.shutdown()
mock_exec.shutdown.assert_called_with(wait=False)
mock_listener.join.assert_called_once_with(timeout=0)
mock_rel_xfer.assert_called_once_with(123)
assert mock_rel_dlist.call_count == 2
mock_rel_dlist.assert_any_call(456) # src handle
mock_rel_dlist.assert_any_call(789) # dst handle
mock_rem_agent.assert_called_once_with("agent1")
assert mock_dereg.call_count == 2
mock_dereg.assert_any_call("desc1")
mock_dereg.assert_any_call("desc2")

View File

@ -250,6 +250,10 @@ class NixlConnector(KVConnectorBase_V1):
self.connector_worker.copy_blocks: self.connector_worker.copy_blocks:
self.connector_worker.save_kv_to_host(self._connector_metadata) self.connector_worker.save_kv_to_host(self._connector_metadata)
def shutdown(self):
if self.connector_worker is not None:
self.connector_worker.shutdown()
class NixlConnectorScheduler: class NixlConnectorScheduler:
"""Implementation of Scheduler side methods""" """Implementation of Scheduler side methods"""
@ -586,13 +590,6 @@ class NixlConnectorWorker:
self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int) self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int)
self.xfer_stats = NixlKVConnectorStats() self.xfer_stats = NixlKVConnectorStats()
def __del__(self):
"""Cleanup background threads on destruction."""
if executor := getattr(self, "_handshake_initiation_executor", None):
executor.shutdown(wait=False)
if listener_t := getattr(self, "_nixl_handshake_listener_t", None):
listener_t.join(timeout=0)
@staticmethod @staticmethod
def _nixl_handshake_listener(metadata: NixlAgentMetadata, def _nixl_handshake_listener(metadata: NixlAgentMetadata,
ready_event: threading.Event, base_port: int, ready_event: threading.Event, base_port: int,
@ -1346,6 +1343,30 @@ class NixlConnectorWorker:
return self.xfer_stats.clone_and_reset() return self.xfer_stats.clone_and_reset()
return None return None
def shutdown(self):
"""Shutdown the connector worker."""
self._handshake_initiation_executor.shutdown(wait=False)
if self._nixl_handshake_listener_t is not None:
self._nixl_handshake_listener_t.join(timeout=0)
self._nixl_handshake_listener_t = None
for handles in self._recving_transfers.values():
for handle, _ in handles:
self.nixl_wrapper.release_xfer_handle(handle)
self._recving_transfers.clear()
if self.src_xfer_side_handle:
self.nixl_wrapper.release_dlist_handle(self.src_xfer_side_handle)
self.src_xfer_side_handle = 0
for dst_xfer_side_handle in self.dst_xfer_side_handles.values():
self.nixl_wrapper.release_dlist_handle(dst_xfer_side_handle)
self.dst_xfer_side_handles.clear()
for remote_agents in self._remote_agents.values():
for agent_name in remote_agents.values():
self.nixl_wrapper.remove_remote_agent(agent_name)
self._remote_agents.clear()
for desc in self._registered_descs:
self.nixl_wrapper.deregister_memory(desc)
self._registered_descs.clear()
@contextlib.contextmanager @contextlib.contextmanager
def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]: def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]: