mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 14:56:08 +08:00
[P/D] Asynchronously do _nixl_handshake (#19836)
Signed-off-by: Linkun Chen <github@lkchen.net> Signed-off-by: Nick Hill <nhill@redhat.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
8619e7158c
commit
91f7d9d0b6
@ -7,13 +7,6 @@ from collections import defaultdict
|
||||
from typing import Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
try:
|
||||
from nixl._api import nixl_agent as NixlWrapper
|
||||
except ImportError:
|
||||
NixlWrapper = None
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
|
||||
KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata,
|
||||
NixlConnectorWorker)
|
||||
@ -92,7 +85,8 @@ def test_prompt_less_than_block_size():
|
||||
class FakeNixlWrapper:
|
||||
"""Mock implementation of NixlWrapper for testing.
|
||||
|
||||
We don't inherit from NixlWrapper because NixlWrapper could be None.
|
||||
We don't inherit from nixl._api.nixl_agent because nixl may not be
|
||||
installed.
|
||||
"""
|
||||
|
||||
AGENT_METADATA = b"fake_agent_metadata"
|
||||
@ -167,7 +161,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._hand_shake_latency = hand_shake_latency
|
||||
|
||||
def _nixl_handshake(self, host: str, port: int):
|
||||
def _nixl_handshake(self, host: str, port: int) -> dict[int, str]:
|
||||
# Mimic slow _nixl_handshake, as well as bypass zmq communication.
|
||||
time.sleep(self._hand_shake_latency)
|
||||
# These should've been done in register_kv_caches(), called by
|
||||
@ -177,7 +171,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
|
||||
self.num_blocks = 1
|
||||
self.dst_num_blocks[self.engine_id] = self.num_blocks
|
||||
|
||||
self.add_remote_agent(
|
||||
remote_agent_name = self.add_remote_agent(
|
||||
NixlAgentMetadata(
|
||||
engine_id=self.REMOTE_ENGINE_ID,
|
||||
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
|
||||
@ -187,40 +181,101 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
|
||||
block_len=self.block_len,
|
||||
attn_backend_name=self.backend_name,
|
||||
))
|
||||
return {0: remote_agent_name}
|
||||
|
||||
|
||||
@pytest.mark.skipif(NixlWrapper is None, reason="nixl not installed")
|
||||
@patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||
FakeNixlWrapper)
|
||||
def test_multi_xfer_one_engine(
|
||||
# dist_init is a fixture that initializes the distributed environment.
|
||||
dist_init):
|
||||
"""Test case where multiple xfers are initiated to the same engine.
|
||||
|
||||
This test triggers the connector to load remote KV for the same
|
||||
`request_id`. The transfer is not done immediately due to
|
||||
`set_cycles_before_xfer_done`, so there is a state where there are multiple
|
||||
transfer states for the same `request_id`, and `get_finished` should handle
|
||||
it correctly (wait for all transfers to be done).
|
||||
"""
|
||||
vllm_config = create_vllm_config()
|
||||
class TestNixlHandshake:
|
||||
|
||||
request_id = "req_id"
|
||||
@patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||
FakeNixlWrapper)
|
||||
def test_multi_xfer_one_engine(
|
||||
self,
|
||||
# dist_init is a fixture that initializes the distributed environment.
|
||||
dist_init):
|
||||
"""Test case where multiple xfers are initiated to the same engine.
|
||||
|
||||
This test triggers the connector to load remote KV for the same
|
||||
`request_id`. The transfer is not done immediately due to
|
||||
`set_cycles_before_xfer_done`, so there is a state where there are
|
||||
multiple transfer states for the same `request_id`, and `get_finished`
|
||||
should handle it correctly (wait for all transfers to be done).
|
||||
"""
|
||||
vllm_config = create_vllm_config()
|
||||
|
||||
# Test worker role in decode server.
|
||||
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
connector.connector_worker = FakeNixlConnectorWorker(vllm_config,
|
||||
connector.engine_id,
|
||||
hand_shake_latency=0)
|
||||
assert isinstance(connector.connector_worker.nixl_wrapper, FakeNixlWrapper)
|
||||
connector.connector_worker.nixl_wrapper.set_cycles_before_xfer_done(3)
|
||||
for i in range(4):
|
||||
request_id = "req_id"
|
||||
|
||||
# Test worker role in decode server.
|
||||
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
connector.connector_worker = FakeNixlConnectorWorker(
|
||||
vllm_config, connector.engine_id, hand_shake_latency=0)
|
||||
assert isinstance(connector.connector_worker.nixl_wrapper,
|
||||
FakeNixlWrapper)
|
||||
connector.connector_worker.nixl_wrapper.set_cycles_before_xfer_done(3)
|
||||
num_xfers = 4
|
||||
while True:
|
||||
# For the same request_id, initiate multiple xfers across different
|
||||
# round of `execute_model` calls.
|
||||
metadata = NixlConnectorMetadata()
|
||||
if num_xfers > 0:
|
||||
num_xfers -= 1
|
||||
metadata.add_new_req(
|
||||
request_id=request_id,
|
||||
local_block_ids=[
|
||||
num_xfers + 1, num_xfers + 2, num_xfers + 3
|
||||
],
|
||||
kv_transfer_params={
|
||||
"remote_block_ids":
|
||||
[num_xfers + 4, num_xfers + 5, num_xfers + 6],
|
||||
"remote_engine_id":
|
||||
FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
|
||||
"remote_host":
|
||||
"localhost",
|
||||
"remote_port":
|
||||
1234,
|
||||
})
|
||||
connector.bind_connector_metadata(metadata)
|
||||
|
||||
# Mimic maybe_setup_kv_connector in gpu_model_runner.
|
||||
dummy_ctx = ForwardContext(
|
||||
no_compile_layers={},
|
||||
attn_metadata={},
|
||||
virtual_engine=0,
|
||||
)
|
||||
_before_load = time.perf_counter()
|
||||
connector.start_load_kv(dummy_ctx)
|
||||
_after_load = time.perf_counter()
|
||||
assert _after_load - _before_load < 0.1, "start_load_kv took " \
|
||||
f"{_after_load - _before_load} seconds"
|
||||
|
||||
# Mimic get_finished_kv_transfers in gpu_model_runner.
|
||||
_, done_recving = connector.get_finished(finished_req_ids=set())
|
||||
if len(done_recving) > 0:
|
||||
assert request_id in done_recving
|
||||
break
|
||||
|
||||
connector.clear_connector_metadata()
|
||||
|
||||
@patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||
FakeNixlWrapper)
|
||||
def test_async_load_kv(
|
||||
self,
|
||||
# dist_init is a fixture that initializes the distributed environment.
|
||||
dist_init):
|
||||
"""Test that NixlConnector's start_load_kv should be non-blocking."""
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
|
||||
# Test worker role in decode server.
|
||||
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
connector.connector_worker = FakeNixlConnectorWorker(
|
||||
vllm_config, connector.engine_id)
|
||||
metadata = NixlConnectorMetadata()
|
||||
metadata.add_new_req(request_id=request_id,
|
||||
local_block_ids=[i + 1, i + 2, i + 3],
|
||||
metadata.add_new_req(request_id="id",
|
||||
local_block_ids=[1, 2, 3],
|
||||
kv_transfer_params={
|
||||
"remote_block_ids": [i + 4, i + 5, i + 6],
|
||||
"remote_block_ids": [4, 5, 6],
|
||||
"remote_engine_id":
|
||||
FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
|
||||
"remote_host": "localhost",
|
||||
@ -228,19 +283,74 @@ def test_multi_xfer_one_engine(
|
||||
})
|
||||
connector.bind_connector_metadata(metadata)
|
||||
|
||||
dummy_ctx = ForwardContext(
|
||||
no_compile_layers={},
|
||||
attn_metadata={},
|
||||
virtual_engine=0,
|
||||
)
|
||||
_before_load = time.perf_counter()
|
||||
connector.start_load_kv(dummy_ctx)
|
||||
_after_load = time.perf_counter()
|
||||
assert _after_load - _before_load < 0.1, "start_load_kv took " \
|
||||
f"{_after_load - _before_load} seconds"
|
||||
timeout = 2.5
|
||||
start = time.perf_counter()
|
||||
while time.perf_counter() - start < timeout:
|
||||
dummy_ctx = ForwardContext(
|
||||
no_compile_layers={},
|
||||
attn_metadata={},
|
||||
virtual_engine=0,
|
||||
)
|
||||
_before_load = time.perf_counter()
|
||||
connector.start_load_kv(dummy_ctx)
|
||||
_after_load = time.perf_counter()
|
||||
assert _after_load - _before_load < 0.1, "start_load_kv took " \
|
||||
f"{_after_load - _before_load} seconds"
|
||||
time.sleep(0.5) # backoff for the async handshake to complete.
|
||||
connector.bind_connector_metadata(NixlConnectorMetadata())
|
||||
_, done_recving = connector.get_finished(finished_req_ids=set())
|
||||
if len(done_recving) > 0:
|
||||
return
|
||||
raise TimeoutError("Took too long to complete async handshake.")
|
||||
|
||||
while True:
|
||||
_, done_recving = connector.get_finished(finished_req_ids=set())
|
||||
if len(done_recving) > 0:
|
||||
assert request_id in done_recving
|
||||
break
|
||||
@patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||
FakeNixlWrapper)
|
||||
def test_concurrent_load_kv(
|
||||
self,
|
||||
# dist_init is a fixture that initializes the distributed environment.
|
||||
dist_init):
|
||||
"""Test that multiple start_load_kv calls should occur concurrently."""
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
|
||||
# Test worker role in decode server.
|
||||
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
connector.connector_worker = FakeNixlConnectorWorker(
|
||||
vllm_config, connector.engine_id)
|
||||
metadata = NixlConnectorMetadata()
|
||||
total_reqs = 5
|
||||
for i in range(total_reqs):
|
||||
metadata.add_new_req(request_id=f"id_{i}",
|
||||
local_block_ids=[1, 2, 3],
|
||||
kv_transfer_params={
|
||||
"remote_block_ids": [4, 5, 6],
|
||||
"remote_engine_id":
|
||||
FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
|
||||
"remote_host": "localhost",
|
||||
"remote_port": 1234,
|
||||
})
|
||||
connector.bind_connector_metadata(metadata)
|
||||
|
||||
timeout = 2.5 * total_reqs
|
||||
cnt_finished_reqs = 0
|
||||
start = time.perf_counter()
|
||||
while time.perf_counter() - start < timeout:
|
||||
dummy_ctx = ForwardContext(
|
||||
no_compile_layers={},
|
||||
attn_metadata={},
|
||||
virtual_engine=0,
|
||||
)
|
||||
_before_load = time.perf_counter()
|
||||
connector.start_load_kv(dummy_ctx)
|
||||
_after_load = time.perf_counter()
|
||||
assert _after_load - _before_load < 0.1, "start_load_kv took " \
|
||||
f"{_after_load - _before_load} seconds"
|
||||
time.sleep(0.5) # backoff for the async handshake to complete.
|
||||
connector.bind_connector_metadata(NixlConnectorMetadata())
|
||||
_, done_recving = connector.get_finished(finished_req_ids=set())
|
||||
if len(done_recving) > 0:
|
||||
cnt_finished_reqs += len(done_recving)
|
||||
if cnt_finished_reqs == total_reqs:
|
||||
return
|
||||
raise TimeoutError("Took too long to complete async handshake.")
|
||||
|
||||
@ -2,11 +2,13 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextlib
|
||||
import math
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterator
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
@ -23,6 +25,7 @@ from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
|
||||
get_tp_group)
|
||||
from vllm.distributed.utils import divide
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import _Backend
|
||||
from vllm.utils import make_zmq_path, make_zmq_socket, round_down
|
||||
@ -31,7 +34,6 @@ from vllm.v1.request import RequestStatus
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.request import Request
|
||||
|
||||
@ -71,7 +73,7 @@ class ReqMeta:
|
||||
remote_block_ids: list[int]
|
||||
remote_host: str
|
||||
remote_port: int
|
||||
remote_engine_id: str
|
||||
remote_engine_id: EngineId
|
||||
|
||||
|
||||
class NixlConnectorMetadata(KVConnectorMetadata):
|
||||
@ -81,7 +83,7 @@ class NixlConnectorMetadata(KVConnectorMetadata):
|
||||
|
||||
def add_new_req(
|
||||
self,
|
||||
request_id: str,
|
||||
request_id: ReqId,
|
||||
local_block_ids: list[int],
|
||||
kv_transfer_params: dict[str, Any],
|
||||
):
|
||||
@ -102,7 +104,7 @@ class NixlConnector(KVConnectorBase_V1):
|
||||
self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id
|
||||
|
||||
if role == KVConnectorRole.SCHEDULER:
|
||||
self.connector_scheduler : Optional[NixlConnectorScheduler] = \
|
||||
self.connector_scheduler: Optional[NixlConnectorScheduler] = \
|
||||
NixlConnectorScheduler(vllm_config, self.engine_id)
|
||||
self.connector_worker: Optional[NixlConnectorWorker] = None
|
||||
elif role == KVConnectorRole.WORKER:
|
||||
@ -186,7 +188,7 @@ class NixlConnectorScheduler:
|
||||
self.side_channel_host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
|
||||
self.side_channel_port = (
|
||||
envs.VLLM_NIXL_SIDE_CHANNEL_PORT +
|
||||
vllm_config.parallel_config.data_parallel_rank_local *
|
||||
vllm_config.parallel_config.data_parallel_rank *
|
||||
vllm_config.parallel_config.tensor_parallel_size)
|
||||
logger.info("Initializing NIXL Scheduler %s", engine_id)
|
||||
|
||||
@ -343,7 +345,7 @@ class NixlConnectorWorker:
|
||||
# Each TP rank listens/queries on the base_port + tp_rank.
|
||||
self.side_channel_port: int = (
|
||||
envs.VLLM_NIXL_SIDE_CHANNEL_PORT +
|
||||
vllm_config.parallel_config.data_parallel_rank_local *
|
||||
vllm_config.parallel_config.data_parallel_rank *
|
||||
vllm_config.parallel_config.tensor_parallel_size)
|
||||
|
||||
# Metadata.
|
||||
@ -386,8 +388,17 @@ class NixlConnectorWorker:
|
||||
self._done_sending_count: defaultdict[ReqId,
|
||||
int] = defaultdict(lambda: 0)
|
||||
|
||||
# Background thread for establishing new connections.
|
||||
# Background thread for handling new handshake requests.
|
||||
self._nixl_handshake_listener_t: Optional[threading.Thread] = None
|
||||
# Background thread for initializing new NIXL handshakes.
|
||||
self._handshake_initiation_executor = ThreadPoolExecutor(
|
||||
# NIXL is not guaranteed to be thread-safe, limit 1 worker.
|
||||
max_workers=1,
|
||||
thread_name_prefix="vllm-nixl-handshake-initiator")
|
||||
self._ready_requests = queue.Queue[tuple[ReqId, ReqMeta]]()
|
||||
self._handshake_futures: dict[EngineId, Future[dict[int, str]]] = {}
|
||||
# Protects _handshake_futures and _remote_agents.
|
||||
self._handshake_lock = threading.RLock()
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
@ -416,6 +427,12 @@ class NixlConnectorWorker:
|
||||
# finish reading before safely freeing the blocks.
|
||||
self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int)
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup background threads on destruction."""
|
||||
self._handshake_initiation_executor.shutdown(wait=False)
|
||||
if self._nixl_handshake_listener_t:
|
||||
self._nixl_handshake_listener_t.join(timeout=0)
|
||||
|
||||
@staticmethod
|
||||
def _nixl_handshake_listener(metadata: NixlAgentMetadata,
|
||||
ready_event: threading.Event, base_port: int,
|
||||
@ -443,7 +460,7 @@ class NixlConnectorWorker:
|
||||
"Connection listener got unexpected message %s", msg)
|
||||
sock.send_multipart((identity, b"", encoded_data))
|
||||
|
||||
def _nixl_handshake(self, host: str, port: int):
|
||||
def _nixl_handshake(self, host: str, port: int) -> dict[int, str]:
|
||||
"""Do a NIXL handshake with a remote instance."""
|
||||
|
||||
start_time = time.perf_counter()
|
||||
@ -452,7 +469,7 @@ class NixlConnectorWorker:
|
||||
# a hack to keep us moving. We will switch when moving to etcd
|
||||
# or where we have a single ZMQ socket in the scheduler.
|
||||
|
||||
def handshake(path: str, rank: int) -> NixlAgentMetadata:
|
||||
def handshake(path: str, rank: int) -> tuple[NixlAgentMetadata, str]:
|
||||
# Send query for the request.
|
||||
with zmq_ctx(zmq.REQ, path) as sock:
|
||||
sock.send(GET_META_MSG)
|
||||
@ -462,19 +479,20 @@ class NixlConnectorWorker:
|
||||
got_metadata_time = time.perf_counter()
|
||||
|
||||
# Register Remote agent.
|
||||
self.add_remote_agent(metadata, rank)
|
||||
remote_agent_name = self.add_remote_agent(metadata, rank)
|
||||
setup_agent_time = time.perf_counter()
|
||||
|
||||
logger.debug("NIXL handshake: get metadata took: %s",
|
||||
got_metadata_time - start_time)
|
||||
logger.debug("NIXL handshake: add agent took: %s",
|
||||
setup_agent_time - got_metadata_time)
|
||||
return metadata
|
||||
return metadata, remote_agent_name
|
||||
|
||||
# Handshake with remote agent-rank0 first to get the tp_size of remote
|
||||
path = make_zmq_path("tcp", host, port)
|
||||
logger.debug("Querying master rank metadata on path: %s", path)
|
||||
metadata = handshake(path, 0)
|
||||
rank_to_agent_name: dict[int, str] = {}
|
||||
metadata, rank_to_agent_name[0] = handshake(path, 0)
|
||||
|
||||
# Handshake only with the other TP remote the current local rank will
|
||||
# pull from. With homogeneous TP it happens to be the same rank_i.
|
||||
@ -484,7 +502,10 @@ class NixlConnectorWorker:
|
||||
path = make_zmq_path("tcp", host, port + p_remote_rank)
|
||||
logger.debug("Querying metadata on path: %s at remote rank %s",
|
||||
path, p_remote_rank)
|
||||
_ = handshake(path, p_remote_rank)
|
||||
_, rank_to_agent_name[p_remote_rank] = handshake(
|
||||
path, p_remote_rank)
|
||||
|
||||
return rank_to_agent_name
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
"""Register the KV Cache data in nixl."""
|
||||
@ -621,11 +642,11 @@ class NixlConnectorWorker:
|
||||
daemon=True,
|
||||
name="nixl_handshake_listener")
|
||||
self._nixl_handshake_listener_t.start()
|
||||
ready_event.wait()
|
||||
ready_event.wait() # Wait for listener ZMQ socket to be ready.
|
||||
|
||||
def add_remote_agent(self,
|
||||
nixl_agent_meta: NixlAgentMetadata,
|
||||
remote_tp_rank: int = 0):
|
||||
remote_tp_rank: int = 0) -> str:
|
||||
"""
|
||||
Add the remote NIXL agent and prepare the descriptors for reading cache
|
||||
blocks from remote.
|
||||
@ -666,8 +687,8 @@ class NixlConnectorWorker:
|
||||
""" # noqa: E501
|
||||
engine_id = nixl_agent_meta.engine_id
|
||||
# TODO re-evaluate refreshing for scaling/recovery
|
||||
if remote_tp_rank in self._remote_agents.get(engine_id, ()):
|
||||
return
|
||||
if remote_tp_rank in self._remote_agents.get(engine_id, {}):
|
||||
return self._remote_agents[engine_id][remote_tp_rank]
|
||||
|
||||
if engine_id in self._tp_size:
|
||||
assert self._tp_size[engine_id] == nixl_agent_meta.tp_size
|
||||
@ -677,9 +698,8 @@ class NixlConnectorWorker:
|
||||
# layout and close outputs.
|
||||
assert nixl_agent_meta.attn_backend_name == self.backend_name
|
||||
|
||||
self._remote_agents[engine_id][
|
||||
remote_tp_rank] = self.nixl_wrapper.add_remote_agent(
|
||||
nixl_agent_meta.agent_metadata)
|
||||
remote_agent_name = self.nixl_wrapper.add_remote_agent(
|
||||
nixl_agent_meta.agent_metadata)
|
||||
|
||||
# Number of D TP workers reading from a single P TP worker. This is
|
||||
# 1 when P and D `--tensor-parallel-size` match.
|
||||
@ -708,8 +728,9 @@ class NixlConnectorWorker:
|
||||
"local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
|
||||
)
|
||||
|
||||
assert self.block_size == remote_block_size, "Remote P worker with " \
|
||||
"different block size is not supported"
|
||||
assert self.block_size == remote_block_size, (
|
||||
"Remote P worker with different block size is not supported "
|
||||
f"{self.block_size=} {remote_block_size=}")
|
||||
|
||||
# Create dst descs and xfer side handles. TP workers have same #blocks.
|
||||
if engine_id in self.dst_num_blocks:
|
||||
@ -748,7 +769,9 @@ class NixlConnectorWorker:
|
||||
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
|
||||
self.dst_xfer_side_handles[
|
||||
engine_id] = self.nixl_wrapper.prep_xfer_dlist(
|
||||
self._remote_agents[engine_id][remote_tp_rank], descs)
|
||||
remote_agent_name, descs)
|
||||
|
||||
return remote_agent_name
|
||||
|
||||
def get_finished(self) -> tuple[set[str], set[str]]:
|
||||
"""
|
||||
@ -866,33 +889,68 @@ class NixlConnectorWorker:
|
||||
We check for these trnxs to complete in each step().
|
||||
"""
|
||||
for req_id, meta in metadata.requests.items():
|
||||
remote_engine_id = meta.remote_engine_id
|
||||
logger.debug(
|
||||
"start_load_kv for request %s from remote engine %s. "
|
||||
"Num local_block_ids: %s. Num remote_block_ids: %s. ", req_id,
|
||||
meta.remote_engine_id, len(meta.local_block_ids),
|
||||
remote_engine_id, len(meta.local_block_ids),
|
||||
len(meta.remote_block_ids))
|
||||
self._read_blocks(
|
||||
request_id=req_id,
|
||||
dst_engine_id=meta.remote_engine_id,
|
||||
local_block_ids=meta.local_block_ids,
|
||||
remote_block_ids=meta.remote_block_ids,
|
||||
remote_host=meta.remote_host,
|
||||
remote_port=meta.remote_port,
|
||||
)
|
||||
if remote_engine_id not in self._remote_agents:
|
||||
# Being optimistic to assume engine is usually ready, apply
|
||||
# lock only when the optimistic check fails.
|
||||
with self._handshake_lock:
|
||||
if remote_engine_id not in self._remote_agents:
|
||||
fut = self._handshake_futures.get(remote_engine_id)
|
||||
if fut is None:
|
||||
fut = self._handshake_initiation_executor.submit(
|
||||
self._nixl_handshake, meta.remote_host,
|
||||
meta.remote_port)
|
||||
self._handshake_futures[remote_engine_id] = fut
|
||||
|
||||
def done_callback(f: Future[dict[int, str]],
|
||||
eid=remote_engine_id):
|
||||
with self._handshake_lock:
|
||||
del self._handshake_futures[eid]
|
||||
try:
|
||||
self._remote_agents[eid] = f.result()
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Handshake with %s failed", eid)
|
||||
|
||||
fut.add_done_callback(done_callback)
|
||||
|
||||
# TODO: handle failure state of future in the
|
||||
# callback, we want to fail the request in this case.
|
||||
def request_ready(_f: Future[Any],
|
||||
entry=(req_id, meta)):
|
||||
self._ready_requests.put(entry)
|
||||
|
||||
fut.add_done_callback(request_ready)
|
||||
continue
|
||||
self._read_blocks_for_req(req_id, meta)
|
||||
|
||||
# Start transfers for requests whose handshakes have now finished.
|
||||
while not self._ready_requests.empty():
|
||||
self._read_blocks_for_req(*self._ready_requests.get_nowait())
|
||||
|
||||
def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
|
||||
logger.debug(
|
||||
"Remote agent %s available, calling _read_blocks for req %s",
|
||||
meta.remote_engine_id, req_id)
|
||||
self._read_blocks(
|
||||
request_id=req_id,
|
||||
dst_engine_id=meta.remote_engine_id,
|
||||
local_block_ids=meta.local_block_ids,
|
||||
remote_block_ids=meta.remote_block_ids,
|
||||
)
|
||||
|
||||
def _read_blocks(
|
||||
self,
|
||||
local_block_ids: list[int],
|
||||
remote_block_ids: list[int],
|
||||
remote_host: str,
|
||||
remote_port: int,
|
||||
dst_engine_id: str,
|
||||
request_id: str,
|
||||
):
|
||||
# NOTE(rob): this takes ~2s. We need to get this off the hotpath.
|
||||
if dst_engine_id not in self._remote_agents:
|
||||
self._nixl_handshake(remote_host, remote_port)
|
||||
|
||||
# NOTE(rob): having the staging blocks be on the READER side is
|
||||
# not going to work well (since we will have to call rearrange tensors).
|
||||
# after we detect the txn is complete (which means we cannot make the
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user