[P/D] Asynchronously do _nixl_handshake (#19836)

Signed-off-by: Linkun Chen <github@lkchen.net>
Signed-off-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
lkchen 2025-06-24 12:46:10 -07:00 committed by GitHub
parent 8619e7158c
commit 91f7d9d0b6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 259 additions and 91 deletions

View File

@ -7,13 +7,6 @@ from collections import defaultdict
from typing import Optional from typing import Optional
from unittest.mock import patch from unittest.mock import patch
import pytest
try:
from nixl._api import nixl_agent as NixlWrapper
except ImportError:
NixlWrapper = None
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata, KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata,
NixlConnectorWorker) NixlConnectorWorker)
@ -92,7 +85,8 @@ def test_prompt_less_than_block_size():
class FakeNixlWrapper: class FakeNixlWrapper:
"""Mock implementation of NixlWrapper for testing. """Mock implementation of NixlWrapper for testing.
We don't inherit from NixlWrapper because NixlWrapper could be None. We don't inherit from nixl._api.nixl_agent because nixl may not be
installed.
""" """
AGENT_METADATA = b"fake_agent_metadata" AGENT_METADATA = b"fake_agent_metadata"
@ -167,7 +161,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._hand_shake_latency = hand_shake_latency self._hand_shake_latency = hand_shake_latency
def _nixl_handshake(self, host: str, port: int): def _nixl_handshake(self, host: str, port: int) -> dict[int, str]:
# Mimic slow _nixl_handshake, as well as bypass zmq communication. # Mimic slow _nixl_handshake, as well as bypass zmq communication.
time.sleep(self._hand_shake_latency) time.sleep(self._hand_shake_latency)
# These should've been done in register_kv_caches(), called by # These should've been done in register_kv_caches(), called by
@ -177,7 +171,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
self.num_blocks = 1 self.num_blocks = 1
self.dst_num_blocks[self.engine_id] = self.num_blocks self.dst_num_blocks[self.engine_id] = self.num_blocks
self.add_remote_agent( remote_agent_name = self.add_remote_agent(
NixlAgentMetadata( NixlAgentMetadata(
engine_id=self.REMOTE_ENGINE_ID, engine_id=self.REMOTE_ENGINE_ID,
agent_metadata=FakeNixlWrapper.AGENT_METADATA, agent_metadata=FakeNixlWrapper.AGENT_METADATA,
@ -187,40 +181,101 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
block_len=self.block_len, block_len=self.block_len,
attn_backend_name=self.backend_name, attn_backend_name=self.backend_name,
)) ))
return {0: remote_agent_name}
@pytest.mark.skipif(NixlWrapper is None, reason="nixl not installed") class TestNixlHandshake:
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper)
def test_multi_xfer_one_engine(
# dist_init is a fixture that initializes the distributed environment.
dist_init):
"""Test case where multiple xfers are initiated to the same engine.
This test triggers the connector to load remote KV for the same
`request_id`. The transfer is not done immediately due to
`set_cycles_before_xfer_done`, so there is a state where there are multiple
transfer states for the same `request_id`, and `get_finished` should handle
it correctly (wait for all transfers to be done).
"""
vllm_config = create_vllm_config()
request_id = "req_id" @patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper)
def test_multi_xfer_one_engine(
self,
# dist_init is a fixture that initializes the distributed environment.
dist_init):
"""Test case where multiple xfers are initiated to the same engine.
This test triggers the connector to load remote KV for the same
`request_id`. The transfer is not done immediately due to
`set_cycles_before_xfer_done`, so there is a state where there are
multiple transfer states for the same `request_id`, and `get_finished`
should handle it correctly (wait for all transfers to be done).
"""
vllm_config = create_vllm_config()
# Test worker role in decode server. request_id = "req_id"
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector.connector_worker = FakeNixlConnectorWorker(vllm_config, # Test worker role in decode server.
connector.engine_id, connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
hand_shake_latency=0) connector.connector_worker = FakeNixlConnectorWorker(
assert isinstance(connector.connector_worker.nixl_wrapper, FakeNixlWrapper) vllm_config, connector.engine_id, hand_shake_latency=0)
connector.connector_worker.nixl_wrapper.set_cycles_before_xfer_done(3) assert isinstance(connector.connector_worker.nixl_wrapper,
for i in range(4): FakeNixlWrapper)
connector.connector_worker.nixl_wrapper.set_cycles_before_xfer_done(3)
num_xfers = 4
while True:
# For the same request_id, initiate multiple xfers across different
# round of `execute_model` calls.
metadata = NixlConnectorMetadata()
if num_xfers > 0:
num_xfers -= 1
metadata.add_new_req(
request_id=request_id,
local_block_ids=[
num_xfers + 1, num_xfers + 2, num_xfers + 3
],
kv_transfer_params={
"remote_block_ids":
[num_xfers + 4, num_xfers + 5, num_xfers + 6],
"remote_engine_id":
FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_host":
"localhost",
"remote_port":
1234,
})
connector.bind_connector_metadata(metadata)
# Mimic maybe_setup_kv_connector in gpu_model_runner.
dummy_ctx = ForwardContext(
no_compile_layers={},
attn_metadata={},
virtual_engine=0,
)
_before_load = time.perf_counter()
connector.start_load_kv(dummy_ctx)
_after_load = time.perf_counter()
assert _after_load - _before_load < 0.1, "start_load_kv took " \
f"{_after_load - _before_load} seconds"
# Mimic get_finished_kv_transfers in gpu_model_runner.
_, done_recving = connector.get_finished(finished_req_ids=set())
if len(done_recving) > 0:
assert request_id in done_recving
break
connector.clear_connector_metadata()
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper)
def test_async_load_kv(
self,
# dist_init is a fixture that initializes the distributed environment.
dist_init):
"""Test that NixlConnector's start_load_kv should be non-blocking."""
vllm_config = create_vllm_config()
# Test worker role in decode server.
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id)
metadata = NixlConnectorMetadata() metadata = NixlConnectorMetadata()
metadata.add_new_req(request_id=request_id, metadata.add_new_req(request_id="id",
local_block_ids=[i + 1, i + 2, i + 3], local_block_ids=[1, 2, 3],
kv_transfer_params={ kv_transfer_params={
"remote_block_ids": [i + 4, i + 5, i + 6], "remote_block_ids": [4, 5, 6],
"remote_engine_id": "remote_engine_id":
FakeNixlConnectorWorker.REMOTE_ENGINE_ID, FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_host": "localhost", "remote_host": "localhost",
@ -228,19 +283,74 @@ def test_multi_xfer_one_engine(
}) })
connector.bind_connector_metadata(metadata) connector.bind_connector_metadata(metadata)
dummy_ctx = ForwardContext( timeout = 2.5
no_compile_layers={}, start = time.perf_counter()
attn_metadata={}, while time.perf_counter() - start < timeout:
virtual_engine=0, dummy_ctx = ForwardContext(
) no_compile_layers={},
_before_load = time.perf_counter() attn_metadata={},
connector.start_load_kv(dummy_ctx) virtual_engine=0,
_after_load = time.perf_counter() )
assert _after_load - _before_load < 0.1, "start_load_kv took " \ _before_load = time.perf_counter()
f"{_after_load - _before_load} seconds" connector.start_load_kv(dummy_ctx)
_after_load = time.perf_counter()
assert _after_load - _before_load < 0.1, "start_load_kv took " \
f"{_after_load - _before_load} seconds"
time.sleep(0.5) # backoff for the async handshake to complete.
connector.bind_connector_metadata(NixlConnectorMetadata())
_, done_recving = connector.get_finished(finished_req_ids=set())
if len(done_recving) > 0:
return
raise TimeoutError("Took too long to complete async handshake.")
while True: @patch(
_, done_recving = connector.get_finished(finished_req_ids=set()) "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
if len(done_recving) > 0: FakeNixlWrapper)
assert request_id in done_recving def test_concurrent_load_kv(
break self,
# dist_init is a fixture that initializes the distributed environment.
dist_init):
"""Test that multiple start_load_kv calls should occur concurrently."""
vllm_config = create_vllm_config()
# Test worker role in decode server.
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id)
metadata = NixlConnectorMetadata()
total_reqs = 5
for i in range(total_reqs):
metadata.add_new_req(request_id=f"id_{i}",
local_block_ids=[1, 2, 3],
kv_transfer_params={
"remote_block_ids": [4, 5, 6],
"remote_engine_id":
FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_host": "localhost",
"remote_port": 1234,
})
connector.bind_connector_metadata(metadata)
timeout = 2.5 * total_reqs
cnt_finished_reqs = 0
start = time.perf_counter()
while time.perf_counter() - start < timeout:
dummy_ctx = ForwardContext(
no_compile_layers={},
attn_metadata={},
virtual_engine=0,
)
_before_load = time.perf_counter()
connector.start_load_kv(dummy_ctx)
_after_load = time.perf_counter()
assert _after_load - _before_load < 0.1, "start_load_kv took " \
f"{_after_load - _before_load} seconds"
time.sleep(0.5) # backoff for the async handshake to complete.
connector.bind_connector_metadata(NixlConnectorMetadata())
_, done_recving = connector.get_finished(finished_req_ids=set())
if len(done_recving) > 0:
cnt_finished_reqs += len(done_recving)
if cnt_finished_reqs == total_reqs:
return
raise TimeoutError("Took too long to complete async handshake.")

View File

@ -2,11 +2,13 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib import contextlib
import math import math
import queue
import threading import threading
import time import time
import uuid import uuid
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterator from collections.abc import Iterator
from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
@ -23,6 +25,7 @@ from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
get_tp_group) get_tp_group)
from vllm.distributed.utils import divide from vllm.distributed.utils import divide
from vllm.forward_context import ForwardContext
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import _Backend from vllm.platforms import _Backend
from vllm.utils import make_zmq_path, make_zmq_socket, round_down from vllm.utils import make_zmq_path, make_zmq_socket, round_down
@ -31,7 +34,6 @@ from vllm.v1.request import RequestStatus
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request from vllm.v1.request import Request
@ -71,7 +73,7 @@ class ReqMeta:
remote_block_ids: list[int] remote_block_ids: list[int]
remote_host: str remote_host: str
remote_port: int remote_port: int
remote_engine_id: str remote_engine_id: EngineId
class NixlConnectorMetadata(KVConnectorMetadata): class NixlConnectorMetadata(KVConnectorMetadata):
@ -81,7 +83,7 @@ class NixlConnectorMetadata(KVConnectorMetadata):
def add_new_req( def add_new_req(
self, self,
request_id: str, request_id: ReqId,
local_block_ids: list[int], local_block_ids: list[int],
kv_transfer_params: dict[str, Any], kv_transfer_params: dict[str, Any],
): ):
@ -102,7 +104,7 @@ class NixlConnector(KVConnectorBase_V1):
self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id
if role == KVConnectorRole.SCHEDULER: if role == KVConnectorRole.SCHEDULER:
self.connector_scheduler : Optional[NixlConnectorScheduler] = \ self.connector_scheduler: Optional[NixlConnectorScheduler] = \
NixlConnectorScheduler(vllm_config, self.engine_id) NixlConnectorScheduler(vllm_config, self.engine_id)
self.connector_worker: Optional[NixlConnectorWorker] = None self.connector_worker: Optional[NixlConnectorWorker] = None
elif role == KVConnectorRole.WORKER: elif role == KVConnectorRole.WORKER:
@ -186,7 +188,7 @@ class NixlConnectorScheduler:
self.side_channel_host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST self.side_channel_host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
self.side_channel_port = ( self.side_channel_port = (
envs.VLLM_NIXL_SIDE_CHANNEL_PORT + envs.VLLM_NIXL_SIDE_CHANNEL_PORT +
vllm_config.parallel_config.data_parallel_rank_local * vllm_config.parallel_config.data_parallel_rank *
vllm_config.parallel_config.tensor_parallel_size) vllm_config.parallel_config.tensor_parallel_size)
logger.info("Initializing NIXL Scheduler %s", engine_id) logger.info("Initializing NIXL Scheduler %s", engine_id)
@ -343,7 +345,7 @@ class NixlConnectorWorker:
# Each TP rank listens/queries on the base_port + tp_rank. # Each TP rank listens/queries on the base_port + tp_rank.
self.side_channel_port: int = ( self.side_channel_port: int = (
envs.VLLM_NIXL_SIDE_CHANNEL_PORT + envs.VLLM_NIXL_SIDE_CHANNEL_PORT +
vllm_config.parallel_config.data_parallel_rank_local * vllm_config.parallel_config.data_parallel_rank *
vllm_config.parallel_config.tensor_parallel_size) vllm_config.parallel_config.tensor_parallel_size)
# Metadata. # Metadata.
@ -386,8 +388,17 @@ class NixlConnectorWorker:
self._done_sending_count: defaultdict[ReqId, self._done_sending_count: defaultdict[ReqId,
int] = defaultdict(lambda: 0) int] = defaultdict(lambda: 0)
# Background thread for establishing new connections. # Background thread for handling new handshake requests.
self._nixl_handshake_listener_t: Optional[threading.Thread] = None self._nixl_handshake_listener_t: Optional[threading.Thread] = None
# Background thread for initializing new NIXL handshakes.
self._handshake_initiation_executor = ThreadPoolExecutor(
# NIXL is not guaranteed to be thread-safe, limit 1 worker.
max_workers=1,
thread_name_prefix="vllm-nixl-handshake-initiator")
self._ready_requests = queue.Queue[tuple[ReqId, ReqMeta]]()
self._handshake_futures: dict[EngineId, Future[dict[int, str]]] = {}
# Protects _handshake_futures and _remote_agents.
self._handshake_lock = threading.RLock()
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size self.block_size = vllm_config.cache_config.block_size
@ -416,6 +427,12 @@ class NixlConnectorWorker:
# finish reading before safely freeing the blocks. # finish reading before safely freeing the blocks.
self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int) self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int)
def __del__(self):
"""Cleanup background threads on destruction."""
self._handshake_initiation_executor.shutdown(wait=False)
if self._nixl_handshake_listener_t:
self._nixl_handshake_listener_t.join(timeout=0)
@staticmethod @staticmethod
def _nixl_handshake_listener(metadata: NixlAgentMetadata, def _nixl_handshake_listener(metadata: NixlAgentMetadata,
ready_event: threading.Event, base_port: int, ready_event: threading.Event, base_port: int,
@ -443,7 +460,7 @@ class NixlConnectorWorker:
"Connection listener got unexpected message %s", msg) "Connection listener got unexpected message %s", msg)
sock.send_multipart((identity, b"", encoded_data)) sock.send_multipart((identity, b"", encoded_data))
def _nixl_handshake(self, host: str, port: int): def _nixl_handshake(self, host: str, port: int) -> dict[int, str]:
"""Do a NIXL handshake with a remote instance.""" """Do a NIXL handshake with a remote instance."""
start_time = time.perf_counter() start_time = time.perf_counter()
@ -452,7 +469,7 @@ class NixlConnectorWorker:
# a hack to keep us moving. We will switch when moving to etcd # a hack to keep us moving. We will switch when moving to etcd
# or where we have a single ZMQ socket in the scheduler. # or where we have a single ZMQ socket in the scheduler.
def handshake(path: str, rank: int) -> NixlAgentMetadata: def handshake(path: str, rank: int) -> tuple[NixlAgentMetadata, str]:
# Send query for the request. # Send query for the request.
with zmq_ctx(zmq.REQ, path) as sock: with zmq_ctx(zmq.REQ, path) as sock:
sock.send(GET_META_MSG) sock.send(GET_META_MSG)
@ -462,19 +479,20 @@ class NixlConnectorWorker:
got_metadata_time = time.perf_counter() got_metadata_time = time.perf_counter()
# Register Remote agent. # Register Remote agent.
self.add_remote_agent(metadata, rank) remote_agent_name = self.add_remote_agent(metadata, rank)
setup_agent_time = time.perf_counter() setup_agent_time = time.perf_counter()
logger.debug("NIXL handshake: get metadata took: %s", logger.debug("NIXL handshake: get metadata took: %s",
got_metadata_time - start_time) got_metadata_time - start_time)
logger.debug("NIXL handshake: add agent took: %s", logger.debug("NIXL handshake: add agent took: %s",
setup_agent_time - got_metadata_time) setup_agent_time - got_metadata_time)
return metadata return metadata, remote_agent_name
# Handshake with remote agent-rank0 first to get the tp_size of remote # Handshake with remote agent-rank0 first to get the tp_size of remote
path = make_zmq_path("tcp", host, port) path = make_zmq_path("tcp", host, port)
logger.debug("Querying master rank metadata on path: %s", path) logger.debug("Querying master rank metadata on path: %s", path)
metadata = handshake(path, 0) rank_to_agent_name: dict[int, str] = {}
metadata, rank_to_agent_name[0] = handshake(path, 0)
# Handshake only with the other TP remote the current local rank will # Handshake only with the other TP remote the current local rank will
# pull from. With homogeneous TP it happens to be the same rank_i. # pull from. With homogeneous TP it happens to be the same rank_i.
@ -484,7 +502,10 @@ class NixlConnectorWorker:
path = make_zmq_path("tcp", host, port + p_remote_rank) path = make_zmq_path("tcp", host, port + p_remote_rank)
logger.debug("Querying metadata on path: %s at remote rank %s", logger.debug("Querying metadata on path: %s at remote rank %s",
path, p_remote_rank) path, p_remote_rank)
_ = handshake(path, p_remote_rank) _, rank_to_agent_name[p_remote_rank] = handshake(
path, p_remote_rank)
return rank_to_agent_name
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""Register the KV Cache data in nixl.""" """Register the KV Cache data in nixl."""
@ -621,11 +642,11 @@ class NixlConnectorWorker:
daemon=True, daemon=True,
name="nixl_handshake_listener") name="nixl_handshake_listener")
self._nixl_handshake_listener_t.start() self._nixl_handshake_listener_t.start()
ready_event.wait() ready_event.wait() # Wait for listener ZMQ socket to be ready.
def add_remote_agent(self, def add_remote_agent(self,
nixl_agent_meta: NixlAgentMetadata, nixl_agent_meta: NixlAgentMetadata,
remote_tp_rank: int = 0): remote_tp_rank: int = 0) -> str:
""" """
Add the remote NIXL agent and prepare the descriptors for reading cache Add the remote NIXL agent and prepare the descriptors for reading cache
blocks from remote. blocks from remote.
@ -666,8 +687,8 @@ class NixlConnectorWorker:
""" # noqa: E501 """ # noqa: E501
engine_id = nixl_agent_meta.engine_id engine_id = nixl_agent_meta.engine_id
# TODO re-evaluate refreshing for scaling/recovery # TODO re-evaluate refreshing for scaling/recovery
if remote_tp_rank in self._remote_agents.get(engine_id, ()): if remote_tp_rank in self._remote_agents.get(engine_id, {}):
return return self._remote_agents[engine_id][remote_tp_rank]
if engine_id in self._tp_size: if engine_id in self._tp_size:
assert self._tp_size[engine_id] == nixl_agent_meta.tp_size assert self._tp_size[engine_id] == nixl_agent_meta.tp_size
@ -677,9 +698,8 @@ class NixlConnectorWorker:
# layout and close outputs. # layout and close outputs.
assert nixl_agent_meta.attn_backend_name == self.backend_name assert nixl_agent_meta.attn_backend_name == self.backend_name
self._remote_agents[engine_id][ remote_agent_name = self.nixl_wrapper.add_remote_agent(
remote_tp_rank] = self.nixl_wrapper.add_remote_agent( nixl_agent_meta.agent_metadata)
nixl_agent_meta.agent_metadata)
# Number of D TP workers reading from a single P TP worker. This is # Number of D TP workers reading from a single P TP worker. This is
# 1 when P and D `--tensor-parallel-size` match. # 1 when P and D `--tensor-parallel-size` match.
@ -708,8 +728,9 @@ class NixlConnectorWorker:
"local_kv_heads*tp_ratio, block_size, head_dim] and same dtype." "local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
) )
assert self.block_size == remote_block_size, "Remote P worker with " \ assert self.block_size == remote_block_size, (
"different block size is not supported" "Remote P worker with different block size is not supported "
f"{self.block_size=} {remote_block_size=}")
# Create dst descs and xfer side handles. TP workers have same #blocks. # Create dst descs and xfer side handles. TP workers have same #blocks.
if engine_id in self.dst_num_blocks: if engine_id in self.dst_num_blocks:
@ -748,7 +769,9 @@ class NixlConnectorWorker:
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
self.dst_xfer_side_handles[ self.dst_xfer_side_handles[
engine_id] = self.nixl_wrapper.prep_xfer_dlist( engine_id] = self.nixl_wrapper.prep_xfer_dlist(
self._remote_agents[engine_id][remote_tp_rank], descs) remote_agent_name, descs)
return remote_agent_name
def get_finished(self) -> tuple[set[str], set[str]]: def get_finished(self) -> tuple[set[str], set[str]]:
""" """
@ -866,33 +889,68 @@ class NixlConnectorWorker:
We check for these trnxs to complete in each step(). We check for these trnxs to complete in each step().
""" """
for req_id, meta in metadata.requests.items(): for req_id, meta in metadata.requests.items():
remote_engine_id = meta.remote_engine_id
logger.debug( logger.debug(
"start_load_kv for request %s from remote engine %s. " "start_load_kv for request %s from remote engine %s. "
"Num local_block_ids: %s. Num remote_block_ids: %s. ", req_id, "Num local_block_ids: %s. Num remote_block_ids: %s. ", req_id,
meta.remote_engine_id, len(meta.local_block_ids), remote_engine_id, len(meta.local_block_ids),
len(meta.remote_block_ids)) len(meta.remote_block_ids))
self._read_blocks( if remote_engine_id not in self._remote_agents:
request_id=req_id, # Being optimistic to assume engine is usually ready, apply
dst_engine_id=meta.remote_engine_id, # lock only when the optimistic check fails.
local_block_ids=meta.local_block_ids, with self._handshake_lock:
remote_block_ids=meta.remote_block_ids, if remote_engine_id not in self._remote_agents:
remote_host=meta.remote_host, fut = self._handshake_futures.get(remote_engine_id)
remote_port=meta.remote_port, if fut is None:
) fut = self._handshake_initiation_executor.submit(
self._nixl_handshake, meta.remote_host,
meta.remote_port)
self._handshake_futures[remote_engine_id] = fut
def done_callback(f: Future[dict[int, str]],
eid=remote_engine_id):
with self._handshake_lock:
del self._handshake_futures[eid]
try:
self._remote_agents[eid] = f.result()
except Exception:
logger.exception(
"Handshake with %s failed", eid)
fut.add_done_callback(done_callback)
# TODO: handle failure state of future in the
# callback, we want to fail the request in this case.
def request_ready(_f: Future[Any],
entry=(req_id, meta)):
self._ready_requests.put(entry)
fut.add_done_callback(request_ready)
continue
self._read_blocks_for_req(req_id, meta)
# Start transfers for requests whose handshakes have now finished.
while not self._ready_requests.empty():
self._read_blocks_for_req(*self._ready_requests.get_nowait())
def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
logger.debug(
"Remote agent %s available, calling _read_blocks for req %s",
meta.remote_engine_id, req_id)
self._read_blocks(
request_id=req_id,
dst_engine_id=meta.remote_engine_id,
local_block_ids=meta.local_block_ids,
remote_block_ids=meta.remote_block_ids,
)
def _read_blocks( def _read_blocks(
self, self,
local_block_ids: list[int], local_block_ids: list[int],
remote_block_ids: list[int], remote_block_ids: list[int],
remote_host: str,
remote_port: int,
dst_engine_id: str, dst_engine_id: str,
request_id: str, request_id: str,
): ):
# NOTE(rob): this takes ~2s. We need to get this off the hotpath.
if dst_engine_id not in self._remote_agents:
self._nixl_handshake(remote_host, remote_port)
# NOTE(rob): having the staging blocks be on the READER side is # NOTE(rob): having the staging blocks be on the READER side is
# not going to work well (since we will have to call rearrange tensors). # not going to work well (since we will have to call rearrange tensors).
# after we detect the txn is complete (which means we cannot make the # after we detect the txn is complete (which means we cannot make the