[Misc] Minor refactor of NIXL background handshake (#20068)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi 2025-07-01 13:40:14 +02:00 committed by GitHub
parent 9025a9a705
commit 650d5dbd04
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -515,6 +515,33 @@ class NixlConnectorWorker:
# Remote rank -> agent name.
return {p_remote_rank: handshake(path, p_remote_rank)}
def _background_nixl_handshake(self, req_id: str,
remote_engine_id: EngineId, meta: ReqMeta):
# Do NIXL handshake in background and add to _ready_requests when done.
fut = self._handshake_futures.get(remote_engine_id)
if fut is None:
fut = self._handshake_initiation_executor.submit(
self._nixl_handshake, meta.remote_host, meta.remote_port,
meta.tp_size)
self._handshake_futures[remote_engine_id] = fut
def done_callback(f: Future[dict[int, str]], eid=remote_engine_id):
with self._handshake_lock:
del self._handshake_futures[eid]
try:
self._remote_agents[eid] = f.result()
except Exception:
logger.exception("Handshake with %s failed", eid)
fut.add_done_callback(done_callback)
# TODO: handle failure state of future in the
# callback, we want to fail the request in this case.
def request_ready(_f: Future[Any], entry=(req_id, meta)):
self._ready_requests.put(entry)
fut.add_done_callback(request_ready)
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""Register the KV Cache data in nixl."""
@ -902,37 +929,14 @@ class NixlConnectorWorker:
remote_engine_id, len(meta.local_block_ids),
len(meta.remote_block_ids))
if remote_engine_id not in self._remote_agents:
# Being optimistic to assume engine is usually ready, apply
# lock only when the optimistic check fails.
# Initiate handshake with remote engine to exchange metadata.
with self._handshake_lock:
if remote_engine_id not in self._remote_agents:
fut = self._handshake_futures.get(remote_engine_id)
if fut is None:
fut = self._handshake_initiation_executor.submit(
self._nixl_handshake, meta.remote_host,
meta.remote_port, meta.tp_size)
self._handshake_futures[remote_engine_id] = fut
def done_callback(f: Future[dict[int, str]],
eid=remote_engine_id):
with self._handshake_lock:
del self._handshake_futures[eid]
try:
self._remote_agents[eid] = f.result()
except Exception:
logger.exception(
"Handshake with %s failed", eid)
fut.add_done_callback(done_callback)
# TODO: handle failure state of future in the
# callback, we want to fail the request in this case.
def request_ready(_f: Future[Any],
entry=(req_id, meta)):
self._ready_requests.put(entry)
fut.add_done_callback(request_ready)
self._background_nixl_handshake(
req_id, remote_engine_id, meta)
continue
# Handshake already completed, start async read xfer.
self._read_blocks_for_req(req_id, meta)
# Start transfers for requests whose handshakes have now finished.