[NIXL] Terminate handshake listener thread in shutdown (#26404)

Signed-off-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
Mark McLoughlin 2025-10-22 15:59:53 +01:00 committed by GitHub
parent 675aa2ec64
commit 4ca13a8667
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 52 additions and 11 deletions

View File

@ -938,6 +938,13 @@ def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend):
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
monkeypatch.setenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", str(timeout))
def run_test_and_cleanup():
llm = LLM(**llm_kwargs)
try:
_run_abort_timeout_test(llm, timeout)
finally:
llm.llm_engine.engine_core.shutdown()
# Build runtime_env only if we're using Ray
if distributed_executor_backend == "ray":
with _make_fake_nixl_pkg() as working_dir:
@ -950,15 +957,16 @@ def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend):
},
}
ray.init(runtime_env=runtime_env)
_run_abort_timeout_test(llm_kwargs, timeout)
try:
run_test_and_cleanup()
finally:
ray.shutdown()
else:
_run_abort_timeout_test(llm_kwargs, timeout)
run_test_and_cleanup()
def _run_abort_timeout_test(llm_kwargs: dict, timeout: int):
def _run_abort_timeout_test(llm: LLM, timeout: int):
"""Helper function to run the abort timeout test logic."""
llm = LLM(**llm_kwargs)
remote_prefill_opts = {
"do_remote_decode": True,
"do_remote_prefill": False,
@ -1042,7 +1050,7 @@ def test_register_kv_caches(dist_init):
),
patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Thread"
),
) as mock_thread,
): # noqa: E501
# Create connector
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
@ -1054,6 +1062,9 @@ def test_register_kv_caches(dist_init):
mock_wrapper_instance = mock_nixl_wrapper.return_value
connector.connector_worker.nixl_wrapper = mock_wrapper_instance
# Reassure the shutdown() check that the thread is terminated
mock_thread.return_value.is_alive.return_value = False
# Execute register_kv_caches
connector.register_kv_caches(kv_caches)
@ -1171,6 +1182,7 @@ def test_shutdown_cleans_up_resources(dist_init):
with (
patch.object(worker, "_handshake_initiation_executor") as mock_exec,
patch.object(worker, "_nixl_handshake_listener_t") as mock_listener,
patch.object(worker, "_nixl_handshake_listener_stop_event") as mock_event,
patch.object(nixl_wrapper, "release_xfer_handle") as mock_rel_xfer,
patch.object(nixl_wrapper, "release_dlist_handle") as mock_rel_dlist,
patch.object(nixl_wrapper, "remove_remote_agent") as mock_rem_agent,
@ -1182,6 +1194,8 @@ def test_shutdown_cleans_up_resources(dist_init):
worker._remote_agents = {"engine1": {0: "agent1"}}
worker._registered_descs = ["desc1", "desc2"]
mock_listener.is_alive.return_value = False
worker.shutdown()
# Test idempotency
@ -1189,7 +1203,8 @@ def test_shutdown_cleans_up_resources(dist_init):
worker.shutdown()
mock_exec.shutdown.assert_called_with(wait=False)
mock_listener.join.assert_called_once_with(timeout=0)
mock_event.set.assert_called_once()
mock_listener.join.assert_called_once_with(timeout=1.0)
mock_rel_xfer.assert_called_once_with(123)
assert mock_rel_dlist.call_count == 2

View File

@ -520,6 +520,8 @@ class NixlConnectorScheduler:
class NixlConnectorWorker:
"""Implementation of Worker side methods"""
_POLL_TIMEOUT = 0.1 # Handshake thread polls for stop event every 100ms
@dataclass
class TpKVTopology:
"""
@ -719,6 +721,7 @@ class NixlConnectorWorker:
# Background thread for handling new handshake requests.
self._nixl_handshake_listener_t: threading.Thread | None = None
self._nixl_handshake_listener_stop_event: threading.Event | None = None
# Background thread for initializing new NIXL handshakes.
self._handshake_initiation_executor = ThreadPoolExecutor(
# NIXL is not guaranteed to be thread-safe, limit 1 worker.
@ -773,6 +776,7 @@ class NixlConnectorWorker:
def _nixl_handshake_listener(
metadata: NixlAgentMetadata,
ready_event: threading.Event,
stop_event: threading.Event,
base_port: int,
tp_rank: int,
):
@ -791,7 +795,14 @@ class NixlConnectorWorker:
logger.debug("Starting listening on path: %s", path)
with zmq_ctx(zmq.ROUTER, path) as sock:
ready_event.set()
while True:
poller = zmq.Poller()
poller.register(sock, zmq.POLLIN)
while not stop_event.is_set():
events = dict(
poller.poll(timeout=NixlConnectorWorker._POLL_TIMEOUT * 1000)
)
if sock not in events:
continue
identity, _, msg = sock.recv_multipart()
if msg != GET_META_MSG:
logger.warning("Connection listener got unexpected message %s", msg)
@ -1101,14 +1112,21 @@ class NixlConnectorWorker:
attn_backend_name=self.backend_name,
kv_cache_layout=self.kv_cache_layout,
)
ready_event = threading.Event()
ready_event, stop_event = threading.Event(), threading.Event()
self._nixl_handshake_listener_t = threading.Thread(
target=self._nixl_handshake_listener,
args=(metadata, ready_event, self.side_channel_port, self.tp_rank),
args=(
metadata,
ready_event,
stop_event,
self.side_channel_port,
self.tp_rank,
),
daemon=True,
name="nixl_handshake_listener",
)
self._nixl_handshake_listener_t.start()
self._nixl_handshake_listener_stop_event = stop_event
ready_event.wait() # Wait for listener ZMQ socket to be ready.
def add_remote_agent(
@ -1782,11 +1800,19 @@ class NixlConnectorWorker:
self._invalid_block_ids = set()
return result
def __del__(self):
self.shutdown()
def shutdown(self):
"""Shutdown the connector worker."""
self._handshake_initiation_executor.shutdown(wait=False)
if self._nixl_handshake_listener_stop_event is not None:
self._nixl_handshake_listener_stop_event.set()
self._nixl_handshake_listener_stop_event = None
if self._nixl_handshake_listener_t is not None:
self._nixl_handshake_listener_t.join(timeout=0)
# Generous timeout to allow the thread to exit
self._nixl_handshake_listener_t.join(timeout=self._POLL_TIMEOUT * 10)
assert not self._nixl_handshake_listener_t.is_alive()
self._nixl_handshake_listener_t = None
for handles in self._recving_transfers.values():
for handle, _ in handles: