[P/D] Asynchronously do _nixl_handshake (#19836)

Signed-off-by: Linkun Chen <github@lkchen.net>
Signed-off-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
lkchen 2025-06-24 12:46:10 -07:00 committed by GitHub
parent 8619e7158c
commit 91f7d9d0b6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 259 additions and 91 deletions

View File

@ -7,13 +7,6 @@ from collections import defaultdict
from typing import Optional
from unittest.mock import patch
import pytest
try:
from nixl._api import nixl_agent as NixlWrapper
except ImportError:
NixlWrapper = None
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata,
NixlConnectorWorker)
@ -92,7 +85,8 @@ def test_prompt_less_than_block_size():
class FakeNixlWrapper:
"""Mock implementation of NixlWrapper for testing.
We don't inherit from NixlWrapper because NixlWrapper could be None.
We don't inherit from nixl._api.nixl_agent because nixl may not be
installed.
"""
AGENT_METADATA = b"fake_agent_metadata"
@ -167,7 +161,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
super().__init__(*args, **kwargs)
self._hand_shake_latency = hand_shake_latency
def _nixl_handshake(self, host: str, port: int):
def _nixl_handshake(self, host: str, port: int) -> dict[int, str]:
# Mimic slow _nixl_handshake, as well as bypass zmq communication.
time.sleep(self._hand_shake_latency)
# These should've been done in register_kv_caches(), called by
@ -177,7 +171,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
self.num_blocks = 1
self.dst_num_blocks[self.engine_id] = self.num_blocks
self.add_remote_agent(
remote_agent_name = self.add_remote_agent(
NixlAgentMetadata(
engine_id=self.REMOTE_ENGINE_ID,
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
@ -187,40 +181,101 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
block_len=self.block_len,
attn_backend_name=self.backend_name,
))
return {0: remote_agent_name}
@pytest.mark.skipif(NixlWrapper is None, reason="nixl not installed")
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper)
def test_multi_xfer_one_engine(
# dist_init is a fixture that initializes the distributed environment.
dist_init):
"""Test case where multiple xfers are initiated to the same engine.
This test triggers the connector to load remote KV for the same
`request_id`. The transfer is not done immediately due to
`set_cycles_before_xfer_done`, so there is a state where there are multiple
transfer states for the same `request_id`, and `get_finished` should handle
it correctly (wait for all transfers to be done).
"""
vllm_config = create_vllm_config()
class TestNixlHandshake:
request_id = "req_id"
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper)
def test_multi_xfer_one_engine(
self,
# dist_init is a fixture that initializes the distributed environment.
dist_init):
"""Test case where multiple xfers are initiated to the same engine.
This test triggers the connector to load remote KV for the same
`request_id`. The transfer is not done immediately due to
`set_cycles_before_xfer_done`, so there is a state where there are
multiple transfer states for the same `request_id`, and `get_finished`
should handle it correctly (wait for all transfers to be done).
"""
vllm_config = create_vllm_config()
# Test worker role in decode server.
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector.connector_worker = FakeNixlConnectorWorker(vllm_config,
connector.engine_id,
hand_shake_latency=0)
assert isinstance(connector.connector_worker.nixl_wrapper, FakeNixlWrapper)
connector.connector_worker.nixl_wrapper.set_cycles_before_xfer_done(3)
for i in range(4):
request_id = "req_id"
# Test worker role in decode server.
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0)
assert isinstance(connector.connector_worker.nixl_wrapper,
FakeNixlWrapper)
connector.connector_worker.nixl_wrapper.set_cycles_before_xfer_done(3)
num_xfers = 4
while True:
# For the same request_id, initiate multiple xfers across different
# round of `execute_model` calls.
metadata = NixlConnectorMetadata()
if num_xfers > 0:
num_xfers -= 1
metadata.add_new_req(
request_id=request_id,
local_block_ids=[
num_xfers + 1, num_xfers + 2, num_xfers + 3
],
kv_transfer_params={
"remote_block_ids":
[num_xfers + 4, num_xfers + 5, num_xfers + 6],
"remote_engine_id":
FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_host":
"localhost",
"remote_port":
1234,
})
connector.bind_connector_metadata(metadata)
# Mimic maybe_setup_kv_connector in gpu_model_runner.
dummy_ctx = ForwardContext(
no_compile_layers={},
attn_metadata={},
virtual_engine=0,
)
_before_load = time.perf_counter()
connector.start_load_kv(dummy_ctx)
_after_load = time.perf_counter()
assert _after_load - _before_load < 0.1, "start_load_kv took " \
f"{_after_load - _before_load} seconds"
# Mimic get_finished_kv_transfers in gpu_model_runner.
_, done_recving = connector.get_finished(finished_req_ids=set())
if len(done_recving) > 0:
assert request_id in done_recving
break
connector.clear_connector_metadata()
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper)
def test_async_load_kv(
self,
# dist_init is a fixture that initializes the distributed environment.
dist_init):
"""Test that NixlConnector's start_load_kv should be non-blocking."""
vllm_config = create_vllm_config()
# Test worker role in decode server.
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id)
metadata = NixlConnectorMetadata()
metadata.add_new_req(request_id=request_id,
local_block_ids=[i + 1, i + 2, i + 3],
metadata.add_new_req(request_id="id",
local_block_ids=[1, 2, 3],
kv_transfer_params={
"remote_block_ids": [i + 4, i + 5, i + 6],
"remote_block_ids": [4, 5, 6],
"remote_engine_id":
FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_host": "localhost",
@ -228,19 +283,74 @@ def test_multi_xfer_one_engine(
})
connector.bind_connector_metadata(metadata)
dummy_ctx = ForwardContext(
no_compile_layers={},
attn_metadata={},
virtual_engine=0,
)
_before_load = time.perf_counter()
connector.start_load_kv(dummy_ctx)
_after_load = time.perf_counter()
assert _after_load - _before_load < 0.1, "start_load_kv took " \
f"{_after_load - _before_load} seconds"
timeout = 2.5
start = time.perf_counter()
while time.perf_counter() - start < timeout:
dummy_ctx = ForwardContext(
no_compile_layers={},
attn_metadata={},
virtual_engine=0,
)
_before_load = time.perf_counter()
connector.start_load_kv(dummy_ctx)
_after_load = time.perf_counter()
assert _after_load - _before_load < 0.1, "start_load_kv took " \
f"{_after_load - _before_load} seconds"
time.sleep(0.5) # backoff for the async handshake to complete.
connector.bind_connector_metadata(NixlConnectorMetadata())
_, done_recving = connector.get_finished(finished_req_ids=set())
if len(done_recving) > 0:
return
raise TimeoutError("Took too long to complete async handshake.")
while True:
_, done_recving = connector.get_finished(finished_req_ids=set())
if len(done_recving) > 0:
assert request_id in done_recving
break
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper)
def test_concurrent_load_kv(
self,
# dist_init is a fixture that initializes the distributed environment.
dist_init):
"""Test that multiple start_load_kv calls should occur concurrently."""
vllm_config = create_vllm_config()
# Test worker role in decode server.
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id)
metadata = NixlConnectorMetadata()
total_reqs = 5
for i in range(total_reqs):
metadata.add_new_req(request_id=f"id_{i}",
local_block_ids=[1, 2, 3],
kv_transfer_params={
"remote_block_ids": [4, 5, 6],
"remote_engine_id":
FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_host": "localhost",
"remote_port": 1234,
})
connector.bind_connector_metadata(metadata)
timeout = 2.5 * total_reqs
cnt_finished_reqs = 0
start = time.perf_counter()
while time.perf_counter() - start < timeout:
dummy_ctx = ForwardContext(
no_compile_layers={},
attn_metadata={},
virtual_engine=0,
)
_before_load = time.perf_counter()
connector.start_load_kv(dummy_ctx)
_after_load = time.perf_counter()
assert _after_load - _before_load < 0.1, "start_load_kv took " \
f"{_after_load - _before_load} seconds"
time.sleep(0.5) # backoff for the async handshake to complete.
connector.bind_connector_metadata(NixlConnectorMetadata())
_, done_recving = connector.get_finished(finished_req_ids=set())
if len(done_recving) > 0:
cnt_finished_reqs += len(done_recving)
if cnt_finished_reqs == total_reqs:
return
raise TimeoutError("Took too long to complete async handshake.")

View File

@ -2,11 +2,13 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import math
import queue
import threading
import time
import uuid
from collections import defaultdict
from collections.abc import Iterator
from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
@ -23,6 +25,7 @@ from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
get_tp_group)
from vllm.distributed.utils import divide
from vllm.forward_context import ForwardContext
from vllm.logger import init_logger
from vllm.platforms import _Backend
from vllm.utils import make_zmq_path, make_zmq_socket, round_down
@ -31,7 +34,6 @@ from vllm.v1.request import RequestStatus
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request
@ -71,7 +73,7 @@ class ReqMeta:
remote_block_ids: list[int]
remote_host: str
remote_port: int
remote_engine_id: str
remote_engine_id: EngineId
class NixlConnectorMetadata(KVConnectorMetadata):
@ -81,7 +83,7 @@ class NixlConnectorMetadata(KVConnectorMetadata):
def add_new_req(
self,
request_id: str,
request_id: ReqId,
local_block_ids: list[int],
kv_transfer_params: dict[str, Any],
):
@ -102,7 +104,7 @@ class NixlConnector(KVConnectorBase_V1):
self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id
if role == KVConnectorRole.SCHEDULER:
self.connector_scheduler : Optional[NixlConnectorScheduler] = \
self.connector_scheduler: Optional[NixlConnectorScheduler] = \
NixlConnectorScheduler(vllm_config, self.engine_id)
self.connector_worker: Optional[NixlConnectorWorker] = None
elif role == KVConnectorRole.WORKER:
@ -186,7 +188,7 @@ class NixlConnectorScheduler:
self.side_channel_host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
self.side_channel_port = (
envs.VLLM_NIXL_SIDE_CHANNEL_PORT +
vllm_config.parallel_config.data_parallel_rank_local *
vllm_config.parallel_config.data_parallel_rank *
vllm_config.parallel_config.tensor_parallel_size)
logger.info("Initializing NIXL Scheduler %s", engine_id)
@ -343,7 +345,7 @@ class NixlConnectorWorker:
# Each TP rank listens/queries on the base_port + tp_rank.
self.side_channel_port: int = (
envs.VLLM_NIXL_SIDE_CHANNEL_PORT +
vllm_config.parallel_config.data_parallel_rank_local *
vllm_config.parallel_config.data_parallel_rank *
vllm_config.parallel_config.tensor_parallel_size)
# Metadata.
@ -386,8 +388,17 @@ class NixlConnectorWorker:
self._done_sending_count: defaultdict[ReqId,
int] = defaultdict(lambda: 0)
# Background thread for establishing new connections.
# Background thread for handling new handshake requests.
self._nixl_handshake_listener_t: Optional[threading.Thread] = None
# Background thread for initializing new NIXL handshakes.
self._handshake_initiation_executor = ThreadPoolExecutor(
# NIXL is not guaranteed to be thread-safe, limit 1 worker.
max_workers=1,
thread_name_prefix="vllm-nixl-handshake-initiator")
self._ready_requests = queue.Queue[tuple[ReqId, ReqMeta]]()
self._handshake_futures: dict[EngineId, Future[dict[int, str]]] = {}
# Protects _handshake_futures and _remote_agents.
self._handshake_lock = threading.RLock()
self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size
@ -416,6 +427,12 @@ class NixlConnectorWorker:
# finish reading before safely freeing the blocks.
self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int)
def __del__(self):
"""Cleanup background threads on destruction."""
self._handshake_initiation_executor.shutdown(wait=False)
if self._nixl_handshake_listener_t:
self._nixl_handshake_listener_t.join(timeout=0)
@staticmethod
def _nixl_handshake_listener(metadata: NixlAgentMetadata,
ready_event: threading.Event, base_port: int,
@ -443,7 +460,7 @@ class NixlConnectorWorker:
"Connection listener got unexpected message %s", msg)
sock.send_multipart((identity, b"", encoded_data))
def _nixl_handshake(self, host: str, port: int):
def _nixl_handshake(self, host: str, port: int) -> dict[int, str]:
"""Do a NIXL handshake with a remote instance."""
start_time = time.perf_counter()
@ -452,7 +469,7 @@ class NixlConnectorWorker:
# a hack to keep us moving. We will switch when moving to etcd
# or where we have a single ZMQ socket in the scheduler.
def handshake(path: str, rank: int) -> NixlAgentMetadata:
def handshake(path: str, rank: int) -> tuple[NixlAgentMetadata, str]:
# Send query for the request.
with zmq_ctx(zmq.REQ, path) as sock:
sock.send(GET_META_MSG)
@ -462,19 +479,20 @@ class NixlConnectorWorker:
got_metadata_time = time.perf_counter()
# Register Remote agent.
self.add_remote_agent(metadata, rank)
remote_agent_name = self.add_remote_agent(metadata, rank)
setup_agent_time = time.perf_counter()
logger.debug("NIXL handshake: get metadata took: %s",
got_metadata_time - start_time)
logger.debug("NIXL handshake: add agent took: %s",
setup_agent_time - got_metadata_time)
return metadata
return metadata, remote_agent_name
# Handshake with remote agent-rank0 first to get the tp_size of remote
path = make_zmq_path("tcp", host, port)
logger.debug("Querying master rank metadata on path: %s", path)
metadata = handshake(path, 0)
rank_to_agent_name: dict[int, str] = {}
metadata, rank_to_agent_name[0] = handshake(path, 0)
# Handshake only with the other TP remote the current local rank will
# pull from. With homogeneous TP it happens to be the same rank_i.
@ -484,7 +502,10 @@ class NixlConnectorWorker:
path = make_zmq_path("tcp", host, port + p_remote_rank)
logger.debug("Querying metadata on path: %s at remote rank %s",
path, p_remote_rank)
_ = handshake(path, p_remote_rank)
_, rank_to_agent_name[p_remote_rank] = handshake(
path, p_remote_rank)
return rank_to_agent_name
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""Register the KV Cache data in nixl."""
@ -621,11 +642,11 @@ class NixlConnectorWorker:
daemon=True,
name="nixl_handshake_listener")
self._nixl_handshake_listener_t.start()
ready_event.wait()
ready_event.wait() # Wait for listener ZMQ socket to be ready.
def add_remote_agent(self,
nixl_agent_meta: NixlAgentMetadata,
remote_tp_rank: int = 0):
remote_tp_rank: int = 0) -> str:
"""
Add the remote NIXL agent and prepare the descriptors for reading cache
blocks from remote.
@ -666,8 +687,8 @@ class NixlConnectorWorker:
""" # noqa: E501
engine_id = nixl_agent_meta.engine_id
# TODO re-evaluate refreshing for scaling/recovery
if remote_tp_rank in self._remote_agents.get(engine_id, ()):
return
if remote_tp_rank in self._remote_agents.get(engine_id, {}):
return self._remote_agents[engine_id][remote_tp_rank]
if engine_id in self._tp_size:
assert self._tp_size[engine_id] == nixl_agent_meta.tp_size
@ -677,9 +698,8 @@ class NixlConnectorWorker:
# layout and close outputs.
assert nixl_agent_meta.attn_backend_name == self.backend_name
self._remote_agents[engine_id][
remote_tp_rank] = self.nixl_wrapper.add_remote_agent(
nixl_agent_meta.agent_metadata)
remote_agent_name = self.nixl_wrapper.add_remote_agent(
nixl_agent_meta.agent_metadata)
# Number of D TP workers reading from a single P TP worker. This is
# 1 when P and D `--tensor-parallel-size` match.
@ -708,8 +728,9 @@ class NixlConnectorWorker:
"local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
)
assert self.block_size == remote_block_size, "Remote P worker with " \
"different block size is not supported"
assert self.block_size == remote_block_size, (
"Remote P worker with different block size is not supported "
f"{self.block_size=} {remote_block_size=}")
# Create dst descs and xfer side handles. TP workers have same #blocks.
if engine_id in self.dst_num_blocks:
@ -748,7 +769,9 @@ class NixlConnectorWorker:
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
self.dst_xfer_side_handles[
engine_id] = self.nixl_wrapper.prep_xfer_dlist(
self._remote_agents[engine_id][remote_tp_rank], descs)
remote_agent_name, descs)
return remote_agent_name
def get_finished(self) -> tuple[set[str], set[str]]:
"""
@ -866,33 +889,68 @@ class NixlConnectorWorker:
We check for these trnxs to complete in each step().
"""
for req_id, meta in metadata.requests.items():
remote_engine_id = meta.remote_engine_id
logger.debug(
"start_load_kv for request %s from remote engine %s. "
"Num local_block_ids: %s. Num remote_block_ids: %s. ", req_id,
meta.remote_engine_id, len(meta.local_block_ids),
remote_engine_id, len(meta.local_block_ids),
len(meta.remote_block_ids))
self._read_blocks(
request_id=req_id,
dst_engine_id=meta.remote_engine_id,
local_block_ids=meta.local_block_ids,
remote_block_ids=meta.remote_block_ids,
remote_host=meta.remote_host,
remote_port=meta.remote_port,
)
if remote_engine_id not in self._remote_agents:
# Being optimistic to assume engine is usually ready, apply
# lock only when the optimistic check fails.
with self._handshake_lock:
if remote_engine_id not in self._remote_agents:
fut = self._handshake_futures.get(remote_engine_id)
if fut is None:
fut = self._handshake_initiation_executor.submit(
self._nixl_handshake, meta.remote_host,
meta.remote_port)
self._handshake_futures[remote_engine_id] = fut
def done_callback(f: Future[dict[int, str]],
eid=remote_engine_id):
with self._handshake_lock:
del self._handshake_futures[eid]
try:
self._remote_agents[eid] = f.result()
except Exception:
logger.exception(
"Handshake with %s failed", eid)
fut.add_done_callback(done_callback)
# TODO: handle failure state of future in the
# callback, we want to fail the request in this case.
def request_ready(_f: Future[Any],
entry=(req_id, meta)):
self._ready_requests.put(entry)
fut.add_done_callback(request_ready)
continue
self._read_blocks_for_req(req_id, meta)
# Start transfers for requests whose handshakes have now finished.
while not self._ready_requests.empty():
self._read_blocks_for_req(*self._ready_requests.get_nowait())
def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
logger.debug(
"Remote agent %s available, calling _read_blocks for req %s",
meta.remote_engine_id, req_id)
self._read_blocks(
request_id=req_id,
dst_engine_id=meta.remote_engine_id,
local_block_ids=meta.local_block_ids,
remote_block_ids=meta.remote_block_ids,
)
def _read_blocks(
self,
local_block_ids: list[int],
remote_block_ids: list[int],
remote_host: str,
remote_port: int,
dst_engine_id: str,
request_id: str,
):
# NOTE(rob): this takes ~2s. We need to get this off the hotpath.
if dst_engine_id not in self._remote_agents:
self._nixl_handshake(remote_host, remote_port)
# NOTE(rob): having the staging blocks be on the READER side is
# not going to work well (since we will have to call rearrange tensors).
# after we detect the txn is complete (which means we cannot make the