mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 14:06:32 +08:00
[NIXL] Terminate handshake listener thread in shutdown (#26404)
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
parent
675aa2ec64
commit
4ca13a8667
@ -938,6 +938,13 @@ def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend):
|
||||
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||
monkeypatch.setenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", str(timeout))
|
||||
|
||||
def run_test_and_cleanup():
|
||||
llm = LLM(**llm_kwargs)
|
||||
try:
|
||||
_run_abort_timeout_test(llm, timeout)
|
||||
finally:
|
||||
llm.llm_engine.engine_core.shutdown()
|
||||
|
||||
# Build runtime_env only if we're using Ray
|
||||
if distributed_executor_backend == "ray":
|
||||
with _make_fake_nixl_pkg() as working_dir:
|
||||
@ -950,15 +957,16 @@ def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend):
|
||||
},
|
||||
}
|
||||
ray.init(runtime_env=runtime_env)
|
||||
|
||||
_run_abort_timeout_test(llm_kwargs, timeout)
|
||||
try:
|
||||
run_test_and_cleanup()
|
||||
finally:
|
||||
ray.shutdown()
|
||||
else:
|
||||
_run_abort_timeout_test(llm_kwargs, timeout)
|
||||
run_test_and_cleanup()
|
||||
|
||||
|
||||
def _run_abort_timeout_test(llm_kwargs: dict, timeout: int):
|
||||
def _run_abort_timeout_test(llm: LLM, timeout: int):
|
||||
"""Helper function to run the abort timeout test logic."""
|
||||
llm = LLM(**llm_kwargs)
|
||||
remote_prefill_opts = {
|
||||
"do_remote_decode": True,
|
||||
"do_remote_prefill": False,
|
||||
@ -1042,7 +1050,7 @@ def test_register_kv_caches(dist_init):
|
||||
),
|
||||
patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Thread"
|
||||
),
|
||||
) as mock_thread,
|
||||
): # noqa: E501
|
||||
# Create connector
|
||||
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
@ -1054,6 +1062,9 @@ def test_register_kv_caches(dist_init):
|
||||
mock_wrapper_instance = mock_nixl_wrapper.return_value
|
||||
connector.connector_worker.nixl_wrapper = mock_wrapper_instance
|
||||
|
||||
# Reassure the shutdown() check that the thread is terminated
|
||||
mock_thread.return_value.is_alive.return_value = False
|
||||
|
||||
# Execute register_kv_caches
|
||||
connector.register_kv_caches(kv_caches)
|
||||
|
||||
@ -1171,6 +1182,7 @@ def test_shutdown_cleans_up_resources(dist_init):
|
||||
with (
|
||||
patch.object(worker, "_handshake_initiation_executor") as mock_exec,
|
||||
patch.object(worker, "_nixl_handshake_listener_t") as mock_listener,
|
||||
patch.object(worker, "_nixl_handshake_listener_stop_event") as mock_event,
|
||||
patch.object(nixl_wrapper, "release_xfer_handle") as mock_rel_xfer,
|
||||
patch.object(nixl_wrapper, "release_dlist_handle") as mock_rel_dlist,
|
||||
patch.object(nixl_wrapper, "remove_remote_agent") as mock_rem_agent,
|
||||
@ -1182,6 +1194,8 @@ def test_shutdown_cleans_up_resources(dist_init):
|
||||
worker._remote_agents = {"engine1": {0: "agent1"}}
|
||||
worker._registered_descs = ["desc1", "desc2"]
|
||||
|
||||
mock_listener.is_alive.return_value = False
|
||||
|
||||
worker.shutdown()
|
||||
|
||||
# Test idempotency
|
||||
@ -1189,7 +1203,8 @@ def test_shutdown_cleans_up_resources(dist_init):
|
||||
worker.shutdown()
|
||||
|
||||
mock_exec.shutdown.assert_called_with(wait=False)
|
||||
mock_listener.join.assert_called_once_with(timeout=0)
|
||||
mock_event.set.assert_called_once()
|
||||
mock_listener.join.assert_called_once_with(timeout=1.0)
|
||||
|
||||
mock_rel_xfer.assert_called_once_with(123)
|
||||
assert mock_rel_dlist.call_count == 2
|
||||
|
||||
@ -520,6 +520,8 @@ class NixlConnectorScheduler:
|
||||
class NixlConnectorWorker:
|
||||
"""Implementation of Worker side methods"""
|
||||
|
||||
_POLL_TIMEOUT = 0.1 # Handshake thread polls for stop event every 100ms
|
||||
|
||||
@dataclass
|
||||
class TpKVTopology:
|
||||
"""
|
||||
@ -719,6 +721,7 @@ class NixlConnectorWorker:
|
||||
|
||||
# Background thread for handling new handshake requests.
|
||||
self._nixl_handshake_listener_t: threading.Thread | None = None
|
||||
self._nixl_handshake_listener_stop_event: threading.Event | None = None
|
||||
# Background thread for initializing new NIXL handshakes.
|
||||
self._handshake_initiation_executor = ThreadPoolExecutor(
|
||||
# NIXL is not guaranteed to be thread-safe, limit 1 worker.
|
||||
@ -773,6 +776,7 @@ class NixlConnectorWorker:
|
||||
def _nixl_handshake_listener(
|
||||
metadata: NixlAgentMetadata,
|
||||
ready_event: threading.Event,
|
||||
stop_event: threading.Event,
|
||||
base_port: int,
|
||||
tp_rank: int,
|
||||
):
|
||||
@ -791,7 +795,14 @@ class NixlConnectorWorker:
|
||||
logger.debug("Starting listening on path: %s", path)
|
||||
with zmq_ctx(zmq.ROUTER, path) as sock:
|
||||
ready_event.set()
|
||||
while True:
|
||||
poller = zmq.Poller()
|
||||
poller.register(sock, zmq.POLLIN)
|
||||
while not stop_event.is_set():
|
||||
events = dict(
|
||||
poller.poll(timeout=NixlConnectorWorker._POLL_TIMEOUT * 1000)
|
||||
)
|
||||
if sock not in events:
|
||||
continue
|
||||
identity, _, msg = sock.recv_multipart()
|
||||
if msg != GET_META_MSG:
|
||||
logger.warning("Connection listener got unexpected message %s", msg)
|
||||
@ -1101,14 +1112,21 @@ class NixlConnectorWorker:
|
||||
attn_backend_name=self.backend_name,
|
||||
kv_cache_layout=self.kv_cache_layout,
|
||||
)
|
||||
ready_event = threading.Event()
|
||||
ready_event, stop_event = threading.Event(), threading.Event()
|
||||
self._nixl_handshake_listener_t = threading.Thread(
|
||||
target=self._nixl_handshake_listener,
|
||||
args=(metadata, ready_event, self.side_channel_port, self.tp_rank),
|
||||
args=(
|
||||
metadata,
|
||||
ready_event,
|
||||
stop_event,
|
||||
self.side_channel_port,
|
||||
self.tp_rank,
|
||||
),
|
||||
daemon=True,
|
||||
name="nixl_handshake_listener",
|
||||
)
|
||||
self._nixl_handshake_listener_t.start()
|
||||
self._nixl_handshake_listener_stop_event = stop_event
|
||||
ready_event.wait() # Wait for listener ZMQ socket to be ready.
|
||||
|
||||
def add_remote_agent(
|
||||
@ -1782,11 +1800,19 @@ class NixlConnectorWorker:
|
||||
self._invalid_block_ids = set()
|
||||
return result
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
|
||||
def shutdown(self):
|
||||
"""Shutdown the connector worker."""
|
||||
self._handshake_initiation_executor.shutdown(wait=False)
|
||||
if self._nixl_handshake_listener_stop_event is not None:
|
||||
self._nixl_handshake_listener_stop_event.set()
|
||||
self._nixl_handshake_listener_stop_event = None
|
||||
if self._nixl_handshake_listener_t is not None:
|
||||
self._nixl_handshake_listener_t.join(timeout=0)
|
||||
# Generous timeout to allow the thread to exit
|
||||
self._nixl_handshake_listener_t.join(timeout=self._POLL_TIMEOUT * 10)
|
||||
assert not self._nixl_handshake_listener_t.is_alive()
|
||||
self._nixl_handshake_listener_t = None
|
||||
for handles in self._recving_transfers.values():
|
||||
for handle, _ in handles:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user