mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 11:55:00 +08:00
[P/D] Asynchronously do _nixl_handshake (#19836)
Signed-off-by: Linkun Chen <github@lkchen.net> Signed-off-by: Nick Hill <nhill@redhat.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
8619e7158c
commit
91f7d9d0b6
@ -7,13 +7,6 @@ from collections import defaultdict
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
try:
|
|
||||||
from nixl._api import nixl_agent as NixlWrapper
|
|
||||||
except ImportError:
|
|
||||||
NixlWrapper = None
|
|
||||||
|
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
|
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
|
||||||
KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata,
|
KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata,
|
||||||
NixlConnectorWorker)
|
NixlConnectorWorker)
|
||||||
@ -92,7 +85,8 @@ def test_prompt_less_than_block_size():
|
|||||||
class FakeNixlWrapper:
|
class FakeNixlWrapper:
|
||||||
"""Mock implementation of NixlWrapper for testing.
|
"""Mock implementation of NixlWrapper for testing.
|
||||||
|
|
||||||
We don't inherit from NixlWrapper because NixlWrapper could be None.
|
We don't inherit from nixl._api.nixl_agent because nixl may not be
|
||||||
|
installed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
AGENT_METADATA = b"fake_agent_metadata"
|
AGENT_METADATA = b"fake_agent_metadata"
|
||||||
@ -167,7 +161,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
|
|||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self._hand_shake_latency = hand_shake_latency
|
self._hand_shake_latency = hand_shake_latency
|
||||||
|
|
||||||
def _nixl_handshake(self, host: str, port: int):
|
def _nixl_handshake(self, host: str, port: int) -> dict[int, str]:
|
||||||
# Mimic slow _nixl_handshake, as well as bypass zmq communication.
|
# Mimic slow _nixl_handshake, as well as bypass zmq communication.
|
||||||
time.sleep(self._hand_shake_latency)
|
time.sleep(self._hand_shake_latency)
|
||||||
# These should've been done in register_kv_caches(), called by
|
# These should've been done in register_kv_caches(), called by
|
||||||
@ -177,7 +171,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
|
|||||||
self.num_blocks = 1
|
self.num_blocks = 1
|
||||||
self.dst_num_blocks[self.engine_id] = self.num_blocks
|
self.dst_num_blocks[self.engine_id] = self.num_blocks
|
||||||
|
|
||||||
self.add_remote_agent(
|
remote_agent_name = self.add_remote_agent(
|
||||||
NixlAgentMetadata(
|
NixlAgentMetadata(
|
||||||
engine_id=self.REMOTE_ENGINE_ID,
|
engine_id=self.REMOTE_ENGINE_ID,
|
||||||
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
|
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
|
||||||
@ -187,40 +181,101 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
|
|||||||
block_len=self.block_len,
|
block_len=self.block_len,
|
||||||
attn_backend_name=self.backend_name,
|
attn_backend_name=self.backend_name,
|
||||||
))
|
))
|
||||||
|
return {0: remote_agent_name}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(NixlWrapper is None, reason="nixl not installed")
|
class TestNixlHandshake:
|
||||||
@patch(
|
|
||||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
|
||||||
FakeNixlWrapper)
|
|
||||||
def test_multi_xfer_one_engine(
|
|
||||||
# dist_init is a fixture that initializes the distributed environment.
|
|
||||||
dist_init):
|
|
||||||
"""Test case where multiple xfers are initiated to the same engine.
|
|
||||||
|
|
||||||
This test triggers the connector to load remote KV for the same
|
|
||||||
`request_id`. The transfer is not done immediately due to
|
|
||||||
`set_cycles_before_xfer_done`, so there is a state where there are multiple
|
|
||||||
transfer states for the same `request_id`, and `get_finished` should handle
|
|
||||||
it correctly (wait for all transfers to be done).
|
|
||||||
"""
|
|
||||||
vllm_config = create_vllm_config()
|
|
||||||
|
|
||||||
request_id = "req_id"
|
@patch(
|
||||||
|
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||||
|
FakeNixlWrapper)
|
||||||
|
def test_multi_xfer_one_engine(
|
||||||
|
self,
|
||||||
|
# dist_init is a fixture that initializes the distributed environment.
|
||||||
|
dist_init):
|
||||||
|
"""Test case where multiple xfers are initiated to the same engine.
|
||||||
|
|
||||||
|
This test triggers the connector to load remote KV for the same
|
||||||
|
`request_id`. The transfer is not done immediately due to
|
||||||
|
`set_cycles_before_xfer_done`, so there is a state where there are
|
||||||
|
multiple transfer states for the same `request_id`, and `get_finished`
|
||||||
|
should handle it correctly (wait for all transfers to be done).
|
||||||
|
"""
|
||||||
|
vllm_config = create_vllm_config()
|
||||||
|
|
||||||
# Test worker role in decode server.
|
request_id = "req_id"
|
||||||
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
|
||||||
connector.connector_worker = FakeNixlConnectorWorker(vllm_config,
|
# Test worker role in decode server.
|
||||||
connector.engine_id,
|
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||||
hand_shake_latency=0)
|
connector.connector_worker = FakeNixlConnectorWorker(
|
||||||
assert isinstance(connector.connector_worker.nixl_wrapper, FakeNixlWrapper)
|
vllm_config, connector.engine_id, hand_shake_latency=0)
|
||||||
connector.connector_worker.nixl_wrapper.set_cycles_before_xfer_done(3)
|
assert isinstance(connector.connector_worker.nixl_wrapper,
|
||||||
for i in range(4):
|
FakeNixlWrapper)
|
||||||
|
connector.connector_worker.nixl_wrapper.set_cycles_before_xfer_done(3)
|
||||||
|
num_xfers = 4
|
||||||
|
while True:
|
||||||
|
# For the same request_id, initiate multiple xfers across different
|
||||||
|
# round of `execute_model` calls.
|
||||||
|
metadata = NixlConnectorMetadata()
|
||||||
|
if num_xfers > 0:
|
||||||
|
num_xfers -= 1
|
||||||
|
metadata.add_new_req(
|
||||||
|
request_id=request_id,
|
||||||
|
local_block_ids=[
|
||||||
|
num_xfers + 1, num_xfers + 2, num_xfers + 3
|
||||||
|
],
|
||||||
|
kv_transfer_params={
|
||||||
|
"remote_block_ids":
|
||||||
|
[num_xfers + 4, num_xfers + 5, num_xfers + 6],
|
||||||
|
"remote_engine_id":
|
||||||
|
FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
|
||||||
|
"remote_host":
|
||||||
|
"localhost",
|
||||||
|
"remote_port":
|
||||||
|
1234,
|
||||||
|
})
|
||||||
|
connector.bind_connector_metadata(metadata)
|
||||||
|
|
||||||
|
# Mimic maybe_setup_kv_connector in gpu_model_runner.
|
||||||
|
dummy_ctx = ForwardContext(
|
||||||
|
no_compile_layers={},
|
||||||
|
attn_metadata={},
|
||||||
|
virtual_engine=0,
|
||||||
|
)
|
||||||
|
_before_load = time.perf_counter()
|
||||||
|
connector.start_load_kv(dummy_ctx)
|
||||||
|
_after_load = time.perf_counter()
|
||||||
|
assert _after_load - _before_load < 0.1, "start_load_kv took " \
|
||||||
|
f"{_after_load - _before_load} seconds"
|
||||||
|
|
||||||
|
# Mimic get_finished_kv_transfers in gpu_model_runner.
|
||||||
|
_, done_recving = connector.get_finished(finished_req_ids=set())
|
||||||
|
if len(done_recving) > 0:
|
||||||
|
assert request_id in done_recving
|
||||||
|
break
|
||||||
|
|
||||||
|
connector.clear_connector_metadata()
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||||
|
FakeNixlWrapper)
|
||||||
|
def test_async_load_kv(
|
||||||
|
self,
|
||||||
|
# dist_init is a fixture that initializes the distributed environment.
|
||||||
|
dist_init):
|
||||||
|
"""Test that NixlConnector's start_load_kv should be non-blocking."""
|
||||||
|
|
||||||
|
vllm_config = create_vllm_config()
|
||||||
|
|
||||||
|
# Test worker role in decode server.
|
||||||
|
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||||
|
connector.connector_worker = FakeNixlConnectorWorker(
|
||||||
|
vllm_config, connector.engine_id)
|
||||||
metadata = NixlConnectorMetadata()
|
metadata = NixlConnectorMetadata()
|
||||||
metadata.add_new_req(request_id=request_id,
|
metadata.add_new_req(request_id="id",
|
||||||
local_block_ids=[i + 1, i + 2, i + 3],
|
local_block_ids=[1, 2, 3],
|
||||||
kv_transfer_params={
|
kv_transfer_params={
|
||||||
"remote_block_ids": [i + 4, i + 5, i + 6],
|
"remote_block_ids": [4, 5, 6],
|
||||||
"remote_engine_id":
|
"remote_engine_id":
|
||||||
FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
|
FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
|
||||||
"remote_host": "localhost",
|
"remote_host": "localhost",
|
||||||
@ -228,19 +283,74 @@ def test_multi_xfer_one_engine(
|
|||||||
})
|
})
|
||||||
connector.bind_connector_metadata(metadata)
|
connector.bind_connector_metadata(metadata)
|
||||||
|
|
||||||
dummy_ctx = ForwardContext(
|
timeout = 2.5
|
||||||
no_compile_layers={},
|
start = time.perf_counter()
|
||||||
attn_metadata={},
|
while time.perf_counter() - start < timeout:
|
||||||
virtual_engine=0,
|
dummy_ctx = ForwardContext(
|
||||||
)
|
no_compile_layers={},
|
||||||
_before_load = time.perf_counter()
|
attn_metadata={},
|
||||||
connector.start_load_kv(dummy_ctx)
|
virtual_engine=0,
|
||||||
_after_load = time.perf_counter()
|
)
|
||||||
assert _after_load - _before_load < 0.1, "start_load_kv took " \
|
_before_load = time.perf_counter()
|
||||||
f"{_after_load - _before_load} seconds"
|
connector.start_load_kv(dummy_ctx)
|
||||||
|
_after_load = time.perf_counter()
|
||||||
|
assert _after_load - _before_load < 0.1, "start_load_kv took " \
|
||||||
|
f"{_after_load - _before_load} seconds"
|
||||||
|
time.sleep(0.5) # backoff for the async handshake to complete.
|
||||||
|
connector.bind_connector_metadata(NixlConnectorMetadata())
|
||||||
|
_, done_recving = connector.get_finished(finished_req_ids=set())
|
||||||
|
if len(done_recving) > 0:
|
||||||
|
return
|
||||||
|
raise TimeoutError("Took too long to complete async handshake.")
|
||||||
|
|
||||||
while True:
|
@patch(
|
||||||
_, done_recving = connector.get_finished(finished_req_ids=set())
|
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||||
if len(done_recving) > 0:
|
FakeNixlWrapper)
|
||||||
assert request_id in done_recving
|
def test_concurrent_load_kv(
|
||||||
break
|
self,
|
||||||
|
# dist_init is a fixture that initializes the distributed environment.
|
||||||
|
dist_init):
|
||||||
|
"""Test that multiple start_load_kv calls should occur concurrently."""
|
||||||
|
|
||||||
|
vllm_config = create_vllm_config()
|
||||||
|
|
||||||
|
# Test worker role in decode server.
|
||||||
|
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||||
|
connector.connector_worker = FakeNixlConnectorWorker(
|
||||||
|
vllm_config, connector.engine_id)
|
||||||
|
metadata = NixlConnectorMetadata()
|
||||||
|
total_reqs = 5
|
||||||
|
for i in range(total_reqs):
|
||||||
|
metadata.add_new_req(request_id=f"id_{i}",
|
||||||
|
local_block_ids=[1, 2, 3],
|
||||||
|
kv_transfer_params={
|
||||||
|
"remote_block_ids": [4, 5, 6],
|
||||||
|
"remote_engine_id":
|
||||||
|
FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
|
||||||
|
"remote_host": "localhost",
|
||||||
|
"remote_port": 1234,
|
||||||
|
})
|
||||||
|
connector.bind_connector_metadata(metadata)
|
||||||
|
|
||||||
|
timeout = 2.5 * total_reqs
|
||||||
|
cnt_finished_reqs = 0
|
||||||
|
start = time.perf_counter()
|
||||||
|
while time.perf_counter() - start < timeout:
|
||||||
|
dummy_ctx = ForwardContext(
|
||||||
|
no_compile_layers={},
|
||||||
|
attn_metadata={},
|
||||||
|
virtual_engine=0,
|
||||||
|
)
|
||||||
|
_before_load = time.perf_counter()
|
||||||
|
connector.start_load_kv(dummy_ctx)
|
||||||
|
_after_load = time.perf_counter()
|
||||||
|
assert _after_load - _before_load < 0.1, "start_load_kv took " \
|
||||||
|
f"{_after_load - _before_load} seconds"
|
||||||
|
time.sleep(0.5) # backoff for the async handshake to complete.
|
||||||
|
connector.bind_connector_metadata(NixlConnectorMetadata())
|
||||||
|
_, done_recving = connector.get_finished(finished_req_ids=set())
|
||||||
|
if len(done_recving) > 0:
|
||||||
|
cnt_finished_reqs += len(done_recving)
|
||||||
|
if cnt_finished_reqs == total_reqs:
|
||||||
|
return
|
||||||
|
raise TimeoutError("Took too long to complete async handshake.")
|
||||||
|
|||||||
@ -2,11 +2,13 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import contextlib
|
import contextlib
|
||||||
import math
|
import math
|
||||||
|
import queue
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
|
from concurrent.futures import Future, ThreadPoolExecutor
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
@ -23,6 +25,7 @@ from vllm.distributed.parallel_state import (
|
|||||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
|
||||||
get_tp_group)
|
get_tp_group)
|
||||||
from vllm.distributed.utils import divide
|
from vllm.distributed.utils import divide
|
||||||
|
from vllm.forward_context import ForwardContext
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import _Backend
|
from vllm.platforms import _Backend
|
||||||
from vllm.utils import make_zmq_path, make_zmq_socket, round_down
|
from vllm.utils import make_zmq_path, make_zmq_socket, round_down
|
||||||
@ -31,7 +34,6 @@ from vllm.v1.request import RequestStatus
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
from vllm.forward_context import ForwardContext
|
|
||||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||||
from vllm.v1.request import Request
|
from vllm.v1.request import Request
|
||||||
|
|
||||||
@ -71,7 +73,7 @@ class ReqMeta:
|
|||||||
remote_block_ids: list[int]
|
remote_block_ids: list[int]
|
||||||
remote_host: str
|
remote_host: str
|
||||||
remote_port: int
|
remote_port: int
|
||||||
remote_engine_id: str
|
remote_engine_id: EngineId
|
||||||
|
|
||||||
|
|
||||||
class NixlConnectorMetadata(KVConnectorMetadata):
|
class NixlConnectorMetadata(KVConnectorMetadata):
|
||||||
@ -81,7 +83,7 @@ class NixlConnectorMetadata(KVConnectorMetadata):
|
|||||||
|
|
||||||
def add_new_req(
|
def add_new_req(
|
||||||
self,
|
self,
|
||||||
request_id: str,
|
request_id: ReqId,
|
||||||
local_block_ids: list[int],
|
local_block_ids: list[int],
|
||||||
kv_transfer_params: dict[str, Any],
|
kv_transfer_params: dict[str, Any],
|
||||||
):
|
):
|
||||||
@ -102,7 +104,7 @@ class NixlConnector(KVConnectorBase_V1):
|
|||||||
self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id
|
self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id
|
||||||
|
|
||||||
if role == KVConnectorRole.SCHEDULER:
|
if role == KVConnectorRole.SCHEDULER:
|
||||||
self.connector_scheduler : Optional[NixlConnectorScheduler] = \
|
self.connector_scheduler: Optional[NixlConnectorScheduler] = \
|
||||||
NixlConnectorScheduler(vllm_config, self.engine_id)
|
NixlConnectorScheduler(vllm_config, self.engine_id)
|
||||||
self.connector_worker: Optional[NixlConnectorWorker] = None
|
self.connector_worker: Optional[NixlConnectorWorker] = None
|
||||||
elif role == KVConnectorRole.WORKER:
|
elif role == KVConnectorRole.WORKER:
|
||||||
@ -186,7 +188,7 @@ class NixlConnectorScheduler:
|
|||||||
self.side_channel_host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
|
self.side_channel_host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
|
||||||
self.side_channel_port = (
|
self.side_channel_port = (
|
||||||
envs.VLLM_NIXL_SIDE_CHANNEL_PORT +
|
envs.VLLM_NIXL_SIDE_CHANNEL_PORT +
|
||||||
vllm_config.parallel_config.data_parallel_rank_local *
|
vllm_config.parallel_config.data_parallel_rank *
|
||||||
vllm_config.parallel_config.tensor_parallel_size)
|
vllm_config.parallel_config.tensor_parallel_size)
|
||||||
logger.info("Initializing NIXL Scheduler %s", engine_id)
|
logger.info("Initializing NIXL Scheduler %s", engine_id)
|
||||||
|
|
||||||
@ -343,7 +345,7 @@ class NixlConnectorWorker:
|
|||||||
# Each TP rank listens/queries on the base_port + tp_rank.
|
# Each TP rank listens/queries on the base_port + tp_rank.
|
||||||
self.side_channel_port: int = (
|
self.side_channel_port: int = (
|
||||||
envs.VLLM_NIXL_SIDE_CHANNEL_PORT +
|
envs.VLLM_NIXL_SIDE_CHANNEL_PORT +
|
||||||
vllm_config.parallel_config.data_parallel_rank_local *
|
vllm_config.parallel_config.data_parallel_rank *
|
||||||
vllm_config.parallel_config.tensor_parallel_size)
|
vllm_config.parallel_config.tensor_parallel_size)
|
||||||
|
|
||||||
# Metadata.
|
# Metadata.
|
||||||
@ -386,8 +388,17 @@ class NixlConnectorWorker:
|
|||||||
self._done_sending_count: defaultdict[ReqId,
|
self._done_sending_count: defaultdict[ReqId,
|
||||||
int] = defaultdict(lambda: 0)
|
int] = defaultdict(lambda: 0)
|
||||||
|
|
||||||
# Background thread for establishing new connections.
|
# Background thread for handling new handshake requests.
|
||||||
self._nixl_handshake_listener_t: Optional[threading.Thread] = None
|
self._nixl_handshake_listener_t: Optional[threading.Thread] = None
|
||||||
|
# Background thread for initializing new NIXL handshakes.
|
||||||
|
self._handshake_initiation_executor = ThreadPoolExecutor(
|
||||||
|
# NIXL is not guaranteed to be thread-safe, limit 1 worker.
|
||||||
|
max_workers=1,
|
||||||
|
thread_name_prefix="vllm-nixl-handshake-initiator")
|
||||||
|
self._ready_requests = queue.Queue[tuple[ReqId, ReqMeta]]()
|
||||||
|
self._handshake_futures: dict[EngineId, Future[dict[int, str]]] = {}
|
||||||
|
# Protects _handshake_futures and _remote_agents.
|
||||||
|
self._handshake_lock = threading.RLock()
|
||||||
|
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
self.block_size = vllm_config.cache_config.block_size
|
self.block_size = vllm_config.cache_config.block_size
|
||||||
@ -416,6 +427,12 @@ class NixlConnectorWorker:
|
|||||||
# finish reading before safely freeing the blocks.
|
# finish reading before safely freeing the blocks.
|
||||||
self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int)
|
self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int)
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
"""Cleanup background threads on destruction."""
|
||||||
|
self._handshake_initiation_executor.shutdown(wait=False)
|
||||||
|
if self._nixl_handshake_listener_t:
|
||||||
|
self._nixl_handshake_listener_t.join(timeout=0)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _nixl_handshake_listener(metadata: NixlAgentMetadata,
|
def _nixl_handshake_listener(metadata: NixlAgentMetadata,
|
||||||
ready_event: threading.Event, base_port: int,
|
ready_event: threading.Event, base_port: int,
|
||||||
@ -443,7 +460,7 @@ class NixlConnectorWorker:
|
|||||||
"Connection listener got unexpected message %s", msg)
|
"Connection listener got unexpected message %s", msg)
|
||||||
sock.send_multipart((identity, b"", encoded_data))
|
sock.send_multipart((identity, b"", encoded_data))
|
||||||
|
|
||||||
def _nixl_handshake(self, host: str, port: int):
|
def _nixl_handshake(self, host: str, port: int) -> dict[int, str]:
|
||||||
"""Do a NIXL handshake with a remote instance."""
|
"""Do a NIXL handshake with a remote instance."""
|
||||||
|
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
@ -452,7 +469,7 @@ class NixlConnectorWorker:
|
|||||||
# a hack to keep us moving. We will switch when moving to etcd
|
# a hack to keep us moving. We will switch when moving to etcd
|
||||||
# or where we have a single ZMQ socket in the scheduler.
|
# or where we have a single ZMQ socket in the scheduler.
|
||||||
|
|
||||||
def handshake(path: str, rank: int) -> NixlAgentMetadata:
|
def handshake(path: str, rank: int) -> tuple[NixlAgentMetadata, str]:
|
||||||
# Send query for the request.
|
# Send query for the request.
|
||||||
with zmq_ctx(zmq.REQ, path) as sock:
|
with zmq_ctx(zmq.REQ, path) as sock:
|
||||||
sock.send(GET_META_MSG)
|
sock.send(GET_META_MSG)
|
||||||
@ -462,19 +479,20 @@ class NixlConnectorWorker:
|
|||||||
got_metadata_time = time.perf_counter()
|
got_metadata_time = time.perf_counter()
|
||||||
|
|
||||||
# Register Remote agent.
|
# Register Remote agent.
|
||||||
self.add_remote_agent(metadata, rank)
|
remote_agent_name = self.add_remote_agent(metadata, rank)
|
||||||
setup_agent_time = time.perf_counter()
|
setup_agent_time = time.perf_counter()
|
||||||
|
|
||||||
logger.debug("NIXL handshake: get metadata took: %s",
|
logger.debug("NIXL handshake: get metadata took: %s",
|
||||||
got_metadata_time - start_time)
|
got_metadata_time - start_time)
|
||||||
logger.debug("NIXL handshake: add agent took: %s",
|
logger.debug("NIXL handshake: add agent took: %s",
|
||||||
setup_agent_time - got_metadata_time)
|
setup_agent_time - got_metadata_time)
|
||||||
return metadata
|
return metadata, remote_agent_name
|
||||||
|
|
||||||
# Handshake with remote agent-rank0 first to get the tp_size of remote
|
# Handshake with remote agent-rank0 first to get the tp_size of remote
|
||||||
path = make_zmq_path("tcp", host, port)
|
path = make_zmq_path("tcp", host, port)
|
||||||
logger.debug("Querying master rank metadata on path: %s", path)
|
logger.debug("Querying master rank metadata on path: %s", path)
|
||||||
metadata = handshake(path, 0)
|
rank_to_agent_name: dict[int, str] = {}
|
||||||
|
metadata, rank_to_agent_name[0] = handshake(path, 0)
|
||||||
|
|
||||||
# Handshake only with the other TP remote the current local rank will
|
# Handshake only with the other TP remote the current local rank will
|
||||||
# pull from. With homogeneous TP it happens to be the same rank_i.
|
# pull from. With homogeneous TP it happens to be the same rank_i.
|
||||||
@ -484,7 +502,10 @@ class NixlConnectorWorker:
|
|||||||
path = make_zmq_path("tcp", host, port + p_remote_rank)
|
path = make_zmq_path("tcp", host, port + p_remote_rank)
|
||||||
logger.debug("Querying metadata on path: %s at remote rank %s",
|
logger.debug("Querying metadata on path: %s at remote rank %s",
|
||||||
path, p_remote_rank)
|
path, p_remote_rank)
|
||||||
_ = handshake(path, p_remote_rank)
|
_, rank_to_agent_name[p_remote_rank] = handshake(
|
||||||
|
path, p_remote_rank)
|
||||||
|
|
||||||
|
return rank_to_agent_name
|
||||||
|
|
||||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||||
"""Register the KV Cache data in nixl."""
|
"""Register the KV Cache data in nixl."""
|
||||||
@ -621,11 +642,11 @@ class NixlConnectorWorker:
|
|||||||
daemon=True,
|
daemon=True,
|
||||||
name="nixl_handshake_listener")
|
name="nixl_handshake_listener")
|
||||||
self._nixl_handshake_listener_t.start()
|
self._nixl_handshake_listener_t.start()
|
||||||
ready_event.wait()
|
ready_event.wait() # Wait for listener ZMQ socket to be ready.
|
||||||
|
|
||||||
def add_remote_agent(self,
|
def add_remote_agent(self,
|
||||||
nixl_agent_meta: NixlAgentMetadata,
|
nixl_agent_meta: NixlAgentMetadata,
|
||||||
remote_tp_rank: int = 0):
|
remote_tp_rank: int = 0) -> str:
|
||||||
"""
|
"""
|
||||||
Add the remote NIXL agent and prepare the descriptors for reading cache
|
Add the remote NIXL agent and prepare the descriptors for reading cache
|
||||||
blocks from remote.
|
blocks from remote.
|
||||||
@ -666,8 +687,8 @@ class NixlConnectorWorker:
|
|||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
engine_id = nixl_agent_meta.engine_id
|
engine_id = nixl_agent_meta.engine_id
|
||||||
# TODO re-evaluate refreshing for scaling/recovery
|
# TODO re-evaluate refreshing for scaling/recovery
|
||||||
if remote_tp_rank in self._remote_agents.get(engine_id, ()):
|
if remote_tp_rank in self._remote_agents.get(engine_id, {}):
|
||||||
return
|
return self._remote_agents[engine_id][remote_tp_rank]
|
||||||
|
|
||||||
if engine_id in self._tp_size:
|
if engine_id in self._tp_size:
|
||||||
assert self._tp_size[engine_id] == nixl_agent_meta.tp_size
|
assert self._tp_size[engine_id] == nixl_agent_meta.tp_size
|
||||||
@ -677,9 +698,8 @@ class NixlConnectorWorker:
|
|||||||
# layout and close outputs.
|
# layout and close outputs.
|
||||||
assert nixl_agent_meta.attn_backend_name == self.backend_name
|
assert nixl_agent_meta.attn_backend_name == self.backend_name
|
||||||
|
|
||||||
self._remote_agents[engine_id][
|
remote_agent_name = self.nixl_wrapper.add_remote_agent(
|
||||||
remote_tp_rank] = self.nixl_wrapper.add_remote_agent(
|
nixl_agent_meta.agent_metadata)
|
||||||
nixl_agent_meta.agent_metadata)
|
|
||||||
|
|
||||||
# Number of D TP workers reading from a single P TP worker. This is
|
# Number of D TP workers reading from a single P TP worker. This is
|
||||||
# 1 when P and D `--tensor-parallel-size` match.
|
# 1 when P and D `--tensor-parallel-size` match.
|
||||||
@ -708,8 +728,9 @@ class NixlConnectorWorker:
|
|||||||
"local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
|
"local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
|
||||||
)
|
)
|
||||||
|
|
||||||
assert self.block_size == remote_block_size, "Remote P worker with " \
|
assert self.block_size == remote_block_size, (
|
||||||
"different block size is not supported"
|
"Remote P worker with different block size is not supported "
|
||||||
|
f"{self.block_size=} {remote_block_size=}")
|
||||||
|
|
||||||
# Create dst descs and xfer side handles. TP workers have same #blocks.
|
# Create dst descs and xfer side handles. TP workers have same #blocks.
|
||||||
if engine_id in self.dst_num_blocks:
|
if engine_id in self.dst_num_blocks:
|
||||||
@ -748,7 +769,9 @@ class NixlConnectorWorker:
|
|||||||
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
|
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
|
||||||
self.dst_xfer_side_handles[
|
self.dst_xfer_side_handles[
|
||||||
engine_id] = self.nixl_wrapper.prep_xfer_dlist(
|
engine_id] = self.nixl_wrapper.prep_xfer_dlist(
|
||||||
self._remote_agents[engine_id][remote_tp_rank], descs)
|
remote_agent_name, descs)
|
||||||
|
|
||||||
|
return remote_agent_name
|
||||||
|
|
||||||
def get_finished(self) -> tuple[set[str], set[str]]:
|
def get_finished(self) -> tuple[set[str], set[str]]:
|
||||||
"""
|
"""
|
||||||
@ -866,33 +889,68 @@ class NixlConnectorWorker:
|
|||||||
We check for these trnxs to complete in each step().
|
We check for these trnxs to complete in each step().
|
||||||
"""
|
"""
|
||||||
for req_id, meta in metadata.requests.items():
|
for req_id, meta in metadata.requests.items():
|
||||||
|
remote_engine_id = meta.remote_engine_id
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"start_load_kv for request %s from remote engine %s. "
|
"start_load_kv for request %s from remote engine %s. "
|
||||||
"Num local_block_ids: %s. Num remote_block_ids: %s. ", req_id,
|
"Num local_block_ids: %s. Num remote_block_ids: %s. ", req_id,
|
||||||
meta.remote_engine_id, len(meta.local_block_ids),
|
remote_engine_id, len(meta.local_block_ids),
|
||||||
len(meta.remote_block_ids))
|
len(meta.remote_block_ids))
|
||||||
self._read_blocks(
|
if remote_engine_id not in self._remote_agents:
|
||||||
request_id=req_id,
|
# Being optimistic to assume engine is usually ready, apply
|
||||||
dst_engine_id=meta.remote_engine_id,
|
# lock only when the optimistic check fails.
|
||||||
local_block_ids=meta.local_block_ids,
|
with self._handshake_lock:
|
||||||
remote_block_ids=meta.remote_block_ids,
|
if remote_engine_id not in self._remote_agents:
|
||||||
remote_host=meta.remote_host,
|
fut = self._handshake_futures.get(remote_engine_id)
|
||||||
remote_port=meta.remote_port,
|
if fut is None:
|
||||||
)
|
fut = self._handshake_initiation_executor.submit(
|
||||||
|
self._nixl_handshake, meta.remote_host,
|
||||||
|
meta.remote_port)
|
||||||
|
self._handshake_futures[remote_engine_id] = fut
|
||||||
|
|
||||||
|
def done_callback(f: Future[dict[int, str]],
|
||||||
|
eid=remote_engine_id):
|
||||||
|
with self._handshake_lock:
|
||||||
|
del self._handshake_futures[eid]
|
||||||
|
try:
|
||||||
|
self._remote_agents[eid] = f.result()
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"Handshake with %s failed", eid)
|
||||||
|
|
||||||
|
fut.add_done_callback(done_callback)
|
||||||
|
|
||||||
|
# TODO: handle failure state of future in the
|
||||||
|
# callback, we want to fail the request in this case.
|
||||||
|
def request_ready(_f: Future[Any],
|
||||||
|
entry=(req_id, meta)):
|
||||||
|
self._ready_requests.put(entry)
|
||||||
|
|
||||||
|
fut.add_done_callback(request_ready)
|
||||||
|
continue
|
||||||
|
self._read_blocks_for_req(req_id, meta)
|
||||||
|
|
||||||
|
# Start transfers for requests whose handshakes have now finished.
|
||||||
|
while not self._ready_requests.empty():
|
||||||
|
self._read_blocks_for_req(*self._ready_requests.get_nowait())
|
||||||
|
|
||||||
|
def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
|
||||||
|
logger.debug(
|
||||||
|
"Remote agent %s available, calling _read_blocks for req %s",
|
||||||
|
meta.remote_engine_id, req_id)
|
||||||
|
self._read_blocks(
|
||||||
|
request_id=req_id,
|
||||||
|
dst_engine_id=meta.remote_engine_id,
|
||||||
|
local_block_ids=meta.local_block_ids,
|
||||||
|
remote_block_ids=meta.remote_block_ids,
|
||||||
|
)
|
||||||
|
|
||||||
def _read_blocks(
|
def _read_blocks(
|
||||||
self,
|
self,
|
||||||
local_block_ids: list[int],
|
local_block_ids: list[int],
|
||||||
remote_block_ids: list[int],
|
remote_block_ids: list[int],
|
||||||
remote_host: str,
|
|
||||||
remote_port: int,
|
|
||||||
dst_engine_id: str,
|
dst_engine_id: str,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
):
|
):
|
||||||
# NOTE(rob): this takes ~2s. We need to get this off the hotpath.
|
|
||||||
if dst_engine_id not in self._remote_agents:
|
|
||||||
self._nixl_handshake(remote_host, remote_port)
|
|
||||||
|
|
||||||
# NOTE(rob): having the staging blocks be on the READER side is
|
# NOTE(rob): having the staging blocks be on the READER side is
|
||||||
# not going to work well (since we will have to call rearrange tensors).
|
# not going to work well (since we will have to call rearrange tensors).
|
||||||
# after we detect the txn is complete (which means we cannot make the
|
# after we detect the txn is complete (which means we cannot make the
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user