mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:05:01 +08:00
[P/D] Support NIXL connector to disconnect during a clean shutdown (#24423)
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com> Co-authored-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
parent
61d1b35561
commit
f05a4f0e34
@ -60,6 +60,9 @@ class FakeNixlWrapper:
|
||||
def register_memory(self, descs, backends) -> None:
|
||||
pass
|
||||
|
||||
def deregister_memory(self, descs) -> None:
|
||||
pass
|
||||
|
||||
def get_xfer_descs(self, blocks_data, memory_type: str) -> list:
|
||||
return [str(uuid.uuid4()) for _ in blocks_data]
|
||||
|
||||
@ -86,6 +89,12 @@ class FakeNixlWrapper:
|
||||
def release_xfer_handle(self, handle: int) -> None:
|
||||
pass
|
||||
|
||||
def release_dlist_handle(self, handle: int) -> None:
|
||||
pass
|
||||
|
||||
def remove_remote_agent(self, agent: str) -> None:
|
||||
pass
|
||||
|
||||
def send_notif(self, agent_name: str, notif_msg: bytes) -> None:
|
||||
pass
|
||||
|
||||
@ -905,3 +914,46 @@ def test_kv_buffer_to_nixl_memory_types(dist_init, kv_buffer_device,
|
||||
# Verify get_reg_descs was called with the correct memory_type
|
||||
assert connector.connector_worker.kv_buffer_device == kv_buffer_device
|
||||
assert connector.connector_worker.nixl_memory_type == nixl_memory_type
|
||||
|
||||
|
||||
@patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||
FakeNixlWrapper)
|
||||
def test_shutdown_cleans_up_resources(dist_init):
|
||||
"""Test that shutdown() properly cleans up all resources."""
|
||||
vllm_config = create_vllm_config()
|
||||
|
||||
worker = NixlConnectorWorker(vllm_config,
|
||||
vllm_config.kv_transfer_config.engine_id)
|
||||
nixl_wrapper = worker.nixl_wrapper
|
||||
|
||||
with patch.object(worker, '_handshake_initiation_executor') as mock_exec, \
|
||||
patch.object(worker, '_nixl_handshake_listener_t') as mock_listener, \
|
||||
patch.object(nixl_wrapper, 'release_xfer_handle') as mock_rel_xfer, \
|
||||
patch.object(nixl_wrapper, 'release_dlist_handle') as mock_rel_dlist, \
|
||||
patch.object(nixl_wrapper, 'remove_remote_agent') as mock_rem_agent, \
|
||||
patch.object(nixl_wrapper, 'deregister_memory') as mock_dereg:
|
||||
|
||||
worker._recving_transfers = {"req1": [(123, time.perf_counter())]}
|
||||
worker.src_xfer_side_handle = 456
|
||||
worker.dst_xfer_side_handles = {"engine1": 789}
|
||||
worker._remote_agents = {"engine1": {0: "agent1"}}
|
||||
worker._registered_descs = ["desc1", "desc2"]
|
||||
|
||||
worker.shutdown()
|
||||
|
||||
# Test idempotency
|
||||
worker.shutdown()
|
||||
worker.shutdown()
|
||||
|
||||
mock_exec.shutdown.assert_called_with(wait=False)
|
||||
mock_listener.join.assert_called_once_with(timeout=0)
|
||||
|
||||
mock_rel_xfer.assert_called_once_with(123)
|
||||
assert mock_rel_dlist.call_count == 2
|
||||
mock_rel_dlist.assert_any_call(456) # src handle
|
||||
mock_rel_dlist.assert_any_call(789) # dst handle
|
||||
mock_rem_agent.assert_called_once_with("agent1")
|
||||
assert mock_dereg.call_count == 2
|
||||
mock_dereg.assert_any_call("desc1")
|
||||
mock_dereg.assert_any_call("desc2")
|
||||
|
||||
@ -250,6 +250,10 @@ class NixlConnector(KVConnectorBase_V1):
|
||||
self.connector_worker.copy_blocks:
|
||||
self.connector_worker.save_kv_to_host(self._connector_metadata)
|
||||
|
||||
def shutdown(self):
|
||||
if self.connector_worker is not None:
|
||||
self.connector_worker.shutdown()
|
||||
|
||||
|
||||
class NixlConnectorScheduler:
|
||||
"""Implementation of Scheduler side methods"""
|
||||
@ -586,13 +590,6 @@ class NixlConnectorWorker:
|
||||
self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int)
|
||||
self.xfer_stats = NixlKVConnectorStats()
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup background threads on destruction."""
|
||||
if executor := getattr(self, "_handshake_initiation_executor", None):
|
||||
executor.shutdown(wait=False)
|
||||
if listener_t := getattr(self, "_nixl_handshake_listener_t", None):
|
||||
listener_t.join(timeout=0)
|
||||
|
||||
@staticmethod
|
||||
def _nixl_handshake_listener(metadata: NixlAgentMetadata,
|
||||
ready_event: threading.Event, base_port: int,
|
||||
@ -1346,6 +1343,30 @@ class NixlConnectorWorker:
|
||||
return self.xfer_stats.clone_and_reset()
|
||||
return None
|
||||
|
||||
def shutdown(self):
|
||||
"""Shutdown the connector worker."""
|
||||
self._handshake_initiation_executor.shutdown(wait=False)
|
||||
if self._nixl_handshake_listener_t is not None:
|
||||
self._nixl_handshake_listener_t.join(timeout=0)
|
||||
self._nixl_handshake_listener_t = None
|
||||
for handles in self._recving_transfers.values():
|
||||
for handle, _ in handles:
|
||||
self.nixl_wrapper.release_xfer_handle(handle)
|
||||
self._recving_transfers.clear()
|
||||
if self.src_xfer_side_handle:
|
||||
self.nixl_wrapper.release_dlist_handle(self.src_xfer_side_handle)
|
||||
self.src_xfer_side_handle = 0
|
||||
for dst_xfer_side_handle in self.dst_xfer_side_handles.values():
|
||||
self.nixl_wrapper.release_dlist_handle(dst_xfer_side_handle)
|
||||
self.dst_xfer_side_handles.clear()
|
||||
for remote_agents in self._remote_agents.values():
|
||||
for agent_name in remote_agents.values():
|
||||
self.nixl_wrapper.remove_remote_agent(agent_name)
|
||||
self._remote_agents.clear()
|
||||
for desc in self._registered_descs:
|
||||
self.nixl_wrapper.deregister_memory(desc)
|
||||
self._registered_descs.clear()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user