From 650d5dbd04e92f5043a11e4a4d86d4f39ee1b694 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicol=C3=B2=20Lucchesi?= Date: Tue, 1 Jul 2025 13:40:14 +0200 Subject: [PATCH] [Misc] Minor refactor of NIXL background handshake (#20068) Signed-off-by: NickLucche --- .../kv_connector/v1/nixl_connector.py | 60 ++++++++++--------- 1 file changed, 32 insertions(+), 28 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 7a077dce7706c..56ae1acf8571f 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -515,6 +515,33 @@ class NixlConnectorWorker: # Remote rank -> agent name. return {p_remote_rank: handshake(path, p_remote_rank)} + def _background_nixl_handshake(self, req_id: str, + remote_engine_id: EngineId, meta: ReqMeta): + # Do NIXL handshake in background and add to _ready_requests when done. + fut = self._handshake_futures.get(remote_engine_id) + if fut is None: + fut = self._handshake_initiation_executor.submit( + self._nixl_handshake, meta.remote_host, meta.remote_port, + meta.tp_size) + self._handshake_futures[remote_engine_id] = fut + + def done_callback(f: Future[dict[int, str]], eid=remote_engine_id): + with self._handshake_lock: + del self._handshake_futures[eid] + try: + self._remote_agents[eid] = f.result() + except Exception: + logger.exception("Handshake with %s failed", eid) + + fut.add_done_callback(done_callback) + + # TODO: handle failure state of future in the + # callback, we want to fail the request in this case. + def request_ready(_f: Future[Any], entry=(req_id, meta)): + self._ready_requests.put(entry) + + fut.add_done_callback(request_ready) + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """Register the KV Cache data in nixl.""" @@ -902,37 +929,14 @@ class NixlConnectorWorker: remote_engine_id, len(meta.local_block_ids), len(meta.remote_block_ids)) if remote_engine_id not in self._remote_agents: - # Being optimistic to assume engine is usually ready, apply - # lock only when the optimistic check fails. + # Initiate handshake with remote engine to exchange metadata. with self._handshake_lock: if remote_engine_id not in self._remote_agents: - fut = self._handshake_futures.get(remote_engine_id) - if fut is None: - fut = self._handshake_initiation_executor.submit( - self._nixl_handshake, meta.remote_host, - meta.remote_port, meta.tp_size) - self._handshake_futures[remote_engine_id] = fut - - def done_callback(f: Future[dict[int, str]], - eid=remote_engine_id): - with self._handshake_lock: - del self._handshake_futures[eid] - try: - self._remote_agents[eid] = f.result() - except Exception: - logger.exception( - "Handshake with %s failed", eid) - - fut.add_done_callback(done_callback) - - # TODO: handle failure state of future in the - # callback, we want to fail the request in this case. - def request_ready(_f: Future[Any], - entry=(req_id, meta)): - self._ready_requests.put(entry) - - fut.add_done_callback(request_ready) + self._background_nixl_handshake( + req_id, remote_engine_id, meta) continue + + # Handshake already completed, start async read xfer. self._read_blocks_for_req(req_id, meta) # Start transfers for requests whose handshakes have now finished.