diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index a80abf10a8a15..bb77c4f2b62a6 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -938,6 +938,13 @@ def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend): monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") monkeypatch.setenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", str(timeout)) + def run_test_and_cleanup(): + llm = LLM(**llm_kwargs) + try: + _run_abort_timeout_test(llm, timeout) + finally: + llm.llm_engine.engine_core.shutdown() + # Build runtime_env only if we're using Ray if distributed_executor_backend == "ray": with _make_fake_nixl_pkg() as working_dir: @@ -950,15 +957,16 @@ def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend): }, } ray.init(runtime_env=runtime_env) - - _run_abort_timeout_test(llm_kwargs, timeout) + try: + run_test_and_cleanup() + finally: + ray.shutdown() else: - _run_abort_timeout_test(llm_kwargs, timeout) + run_test_and_cleanup() -def _run_abort_timeout_test(llm_kwargs: dict, timeout: int): +def _run_abort_timeout_test(llm: LLM, timeout: int): """Helper function to run the abort timeout test logic.""" - llm = LLM(**llm_kwargs) remote_prefill_opts = { "do_remote_decode": True, "do_remote_prefill": False, @@ -1042,7 +1050,7 @@ def test_register_kv_caches(dist_init): ), patch( "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Thread" - ), + ) as mock_thread, ): # noqa: E501 # Create connector connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) @@ -1054,6 +1062,9 @@ def test_register_kv_caches(dist_init): mock_wrapper_instance = mock_nixl_wrapper.return_value connector.connector_worker.nixl_wrapper = mock_wrapper_instance + # Reassure the shutdown() check that the thread is terminated + mock_thread.return_value.is_alive.return_value = False + # Execute register_kv_caches connector.register_kv_caches(kv_caches) @@ -1171,6 +1182,7 @@ def test_shutdown_cleans_up_resources(dist_init): with ( patch.object(worker, "_handshake_initiation_executor") as mock_exec, patch.object(worker, "_nixl_handshake_listener_t") as mock_listener, + patch.object(worker, "_nixl_handshake_listener_stop_event") as mock_event, patch.object(nixl_wrapper, "release_xfer_handle") as mock_rel_xfer, patch.object(nixl_wrapper, "release_dlist_handle") as mock_rel_dlist, patch.object(nixl_wrapper, "remove_remote_agent") as mock_rem_agent, @@ -1182,6 +1194,8 @@ def test_shutdown_cleans_up_resources(dist_init): worker._remote_agents = {"engine1": {0: "agent1"}} worker._registered_descs = ["desc1", "desc2"] + mock_listener.is_alive.return_value = False + worker.shutdown() # Test idempotency @@ -1189,7 +1203,8 @@ def test_shutdown_cleans_up_resources(dist_init): worker.shutdown() mock_exec.shutdown.assert_called_with(wait=False) - mock_listener.join.assert_called_once_with(timeout=0) + mock_event.set.assert_called_once() + mock_listener.join.assert_called_once_with(timeout=1.0) mock_rel_xfer.assert_called_once_with(123) assert mock_rel_dlist.call_count == 2 diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index d73e05562951d..ae7144cf78472 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -520,6 +520,8 @@ class NixlConnectorScheduler: class NixlConnectorWorker: """Implementation of Worker side methods""" + _POLL_TIMEOUT = 0.1 # Handshake thread polls for stop event every 100ms + @dataclass class TpKVTopology: """ @@ -719,6 +721,7 @@ class NixlConnectorWorker: # Background thread for handling new handshake requests. self._nixl_handshake_listener_t: threading.Thread | None = None + self._nixl_handshake_listener_stop_event: threading.Event | None = None # Background thread for initializing new NIXL handshakes. self._handshake_initiation_executor = ThreadPoolExecutor( # NIXL is not guaranteed to be thread-safe, limit 1 worker. @@ -773,6 +776,7 @@ class NixlConnectorWorker: def _nixl_handshake_listener( metadata: NixlAgentMetadata, ready_event: threading.Event, + stop_event: threading.Event, base_port: int, tp_rank: int, ): @@ -791,7 +795,14 @@ class NixlConnectorWorker: logger.debug("Starting listening on path: %s", path) with zmq_ctx(zmq.ROUTER, path) as sock: ready_event.set() - while True: + poller = zmq.Poller() + poller.register(sock, zmq.POLLIN) + while not stop_event.is_set(): + events = dict( + poller.poll(timeout=NixlConnectorWorker._POLL_TIMEOUT * 1000) + ) + if sock not in events: + continue identity, _, msg = sock.recv_multipart() if msg != GET_META_MSG: logger.warning("Connection listener got unexpected message %s", msg) @@ -1101,14 +1112,21 @@ class NixlConnectorWorker: attn_backend_name=self.backend_name, kv_cache_layout=self.kv_cache_layout, ) - ready_event = threading.Event() + ready_event, stop_event = threading.Event(), threading.Event() self._nixl_handshake_listener_t = threading.Thread( target=self._nixl_handshake_listener, - args=(metadata, ready_event, self.side_channel_port, self.tp_rank), + args=( + metadata, + ready_event, + stop_event, + self.side_channel_port, + self.tp_rank, + ), daemon=True, name="nixl_handshake_listener", ) self._nixl_handshake_listener_t.start() + self._nixl_handshake_listener_stop_event = stop_event ready_event.wait() # Wait for listener ZMQ socket to be ready. def add_remote_agent( @@ -1782,11 +1800,19 @@ class NixlConnectorWorker: self._invalid_block_ids = set() return result + def __del__(self): + self.shutdown() + def shutdown(self): """Shutdown the connector worker.""" self._handshake_initiation_executor.shutdown(wait=False) + if self._nixl_handshake_listener_stop_event is not None: + self._nixl_handshake_listener_stop_event.set() + self._nixl_handshake_listener_stop_event = None if self._nixl_handshake_listener_t is not None: - self._nixl_handshake_listener_t.join(timeout=0) + # Generous timeout to allow the thread to exit + self._nixl_handshake_listener_t.join(timeout=self._POLL_TIMEOUT * 10) + assert not self._nixl_handshake_listener_t.is_alive() self._nixl_handshake_listener_t = None for handles in self._recving_transfers.values(): for handle, _ in handles: