[P/D] [NixlConnector] kv load recovery integration (#26171)

Signed-off-by: Will Eaton <weaton@redhat.com>
This commit is contained in:
Will Eaton 2025-10-13 11:48:04 -04:00 committed by GitHub
parent 0d21b9b51e
commit 53c9a7cee2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 252 additions and 26 deletions

View File

@ -190,7 +190,6 @@ def _make_fake_nixl_pkg():
# Copy of FakeNixlWrapper implementation for Ray workers
import uuid
from collections import defaultdict
from typing import Optional
{fake_nixl_source}
@ -1143,3 +1142,145 @@ def test_aborted_request_removed_from_worker_in_batch(dist_init):
# After abort, the worker should not keep tracking it as "in-batch"
assert req.request_id not in connector.connector_worker._reqs_to_process
#### Model Runner end ####
class FailingNixlWrapper(FakeNixlWrapper):
"""Mock NixlWrapper that fails on specific operations."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.fail_handshake = False
self.fail_transfer_setup = False
self.fail_send_notif = False
def add_remote_agent(self, agent_metadata: bytes) -> str:
if self.fail_handshake:
from zmq.error import Again
raise Again("Simulated timeout failure")
return super().add_remote_agent(agent_metadata)
def make_prepped_xfer(
self,
xfer_type: str,
local_xfer_side_handle: int,
local_block_descs_ids: list[int],
remote_xfer_side_handle: int,
remote_block_descs_ids: list[int],
notif_msg: bytes | None = None,
) -> int:
if self.fail_transfer_setup:
# classic RuntimeError to simulate failure
raise RuntimeError("BAD STATUS")
return super().make_prepped_xfer(
xfer_type,
local_xfer_side_handle,
local_block_descs_ids,
remote_xfer_side_handle,
remote_block_descs_ids,
notif_msg,
)
def send_notif(self, agent_name: str, notif_msg: bytes) -> None:
if self.fail_send_notif:
raise RuntimeError("Simulated send_notif failure")
return super().send_notif(agent_name, notif_msg)
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FailingNixlWrapper,
)
def test_handshake_failure_returns_finished(dist_init):
"""Test that handshake failures mark blocks invalid and return via get_finished."""
vllm_config = create_vllm_config()
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0.1
)
connector.connector_worker.nixl_wrapper.fail_handshake = True
request_id = "test_handshake_fail"
metadata = NixlConnectorMetadata()
metadata.add_new_req(
request_id=request_id,
local_block_ids=[1, 2, 3],
kv_transfer_params={
"remote_block_ids": [4, 5, 6],
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_host": "localhost",
"remote_port": 1234,
"remote_tp_size": 1,
},
)
connector.bind_connector_metadata(metadata)
dummy_ctx = ForwardContext(
no_compile_layers={},
attn_metadata={},
virtual_engine=0,
)
connector.start_load_kv(dummy_ctx)
# Wait for handshake to fail
time.sleep(0.3)
# Check that blocks were marked invalid
invalid_blocks = connector.get_block_ids_with_load_errors()
assert invalid_blocks == {1, 2, 3}
# Check that request appears in get_finished
_, done_recving = connector.get_finished(finished_req_ids=set())
assert request_id in done_recving
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FailingNixlWrapper,
)
def test_transfer_setup_failure_returns_finished(dist_init):
"""Test that transfer setup failures mark blocks invalid
and return via get_finished."""
vllm_config = create_vllm_config()
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0
)
connector.connector_worker.nixl_wrapper.fail_transfer_setup = True
request_id = "test_transfer_fail"
metadata = NixlConnectorMetadata()
metadata.add_new_req(
request_id=request_id,
local_block_ids=[7, 8, 9],
kv_transfer_params={
"remote_block_ids": [10, 11, 12],
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_host": "localhost",
"remote_port": 1234,
"remote_tp_size": 1,
},
)
connector.bind_connector_metadata(metadata)
dummy_ctx = ForwardContext(
no_compile_layers={},
attn_metadata={},
virtual_engine=0,
)
connector.start_load_kv(dummy_ctx)
# Wait for handshake to complete and process ready_requests
connector.bind_connector_metadata(NixlConnectorMetadata())
time.sleep(0.1)
connector.start_load_kv(dummy_ctx)
# check that blocks were marked invalid
invalid_blocks = connector.get_block_ids_with_load_errors()
assert invalid_blocks == {7, 8, 9}
# ensure request appears in get_finished
_, done_recving = connector.get_finished(finished_req_ids=set())
assert request_id in done_recving

View File

@ -68,6 +68,7 @@ except ImportError:
NixlWrapper = None
nixlXferTelemetry = None
try:
from nixl._api import nixl_agent_config
except ImportError:
@ -234,6 +235,11 @@ class NixlConnector(KVConnectorBase_V1):
assert self.connector_worker is not None
return self.connector_worker.get_finished()
def get_block_ids_with_load_errors(self) -> set[int]:
"""Get block IDs that failed to load via NIXL."""
assert self.connector_worker is not None
return self.connector_worker.get_block_ids_with_load_errors()
def get_kv_connector_stats(self) -> KVConnectorStats | None:
assert self.connector_worker is not None
return self.connector_worker.get_kv_connector_stats()
@ -614,6 +620,11 @@ class NixlConnectorWorker:
# Set of requests that have been part of a batch, regardless of status.
self._reqs_to_process: set[ReqId] = set()
# invalid blocks from failed NIXL operations
self._invalid_block_ids: set[int] = set()
# requests that skipped transfer (handshake or transfer failures)
self._failed_recv_reqs: set[ReqId] = set()
# Background thread for handling new handshake requests.
self._nixl_handshake_listener_t: threading.Thread | None = None
# Background thread for initializing new NIXL handshakes.
@ -713,6 +724,8 @@ class NixlConnectorWorker:
# Send query for the request.
with zmq_ctx(zmq.REQ, path) as sock:
# Set receive timeout to 5 seconds to avoid hanging on dead server
sock.setsockopt(zmq.RCVTIMEO, 5000) # milliseconds
sock.send(GET_META_MSG)
metadata_bytes = sock.recv()
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
@ -795,10 +808,20 @@ class NixlConnectorWorker:
fut.add_done_callback(done_callback)
# TODO: handle failure state of future in the
# callback, we want to fail the request in this case.
def request_ready(_f: Future[Any], entry=(req_id, meta)):
self._ready_requests.put(entry)
# check handshake success before proceeding with request
def request_ready(f: Future[Any], entry=(req_id, meta)):
try:
# check if handshake succeeded
f.result()
self._ready_requests.put(entry)
except Exception:
# handshake failed - mark blocks as invalid
logger.exception(
"Handshake failed for request %s, marking blocks as invalid", req_id
)
if req_meta := self._recving_metadata.get(req_id):
self._invalid_block_ids.update(req_meta.local_block_ids)
self._failed_recv_reqs.add(req_id)
fut.add_done_callback(request_ready)
@ -1205,6 +1228,11 @@ class NixlConnectorWorker:
"""
done_sending = self._get_new_notifs()
done_recving = self._pop_done_transfers(self._recving_transfers)
# add requests that skipped transfer to done_recving
done_recving.update(self._failed_recv_reqs)
self._failed_recv_reqs.clear()
if len(done_sending) > 0 or len(done_recving) > 0:
logger.debug(
"Rank %s, get_finished: %s requests done sending "
@ -1214,10 +1242,10 @@ class NixlConnectorWorker:
len(done_recving),
)
if self.use_host_buffer:
for req_id in done_recving:
meta = self._recving_metadata.pop(req_id)
assert meta, f"{req_id} not found in recving_metadata list"
# clean up metadata for completed requests
for req_id in done_recving:
meta = self._recving_metadata.pop(req_id, None)
if self.use_host_buffer and meta:
self.sync_recved_kv_to_device(req_id, meta)
# Handle timeout to avoid stranding blocks on remote.
@ -1296,7 +1324,19 @@ class NixlConnectorWorker:
in_progress = True
continue
else:
raise RuntimeError("Transfer failed with state %s", xfer_state)
# transfer failed - mark blocks as invalid
logger.error(
"NIXL transfer failed for request %s with state %s. "
"Marking blocks as invalid.",
req_id,
xfer_state,
)
# mark all blocks for this request as invalid
if meta := self._recving_metadata.pop(req_id, None):
self._invalid_block_ids.update(meta.local_block_ids)
self._recving_metadata.pop(req_id, None)
self.nixl_wrapper.release_xfer_handle(handle)
self.xfer_stats.record_failed_transfer()
if not in_progress:
done_req_ids.add(req_id)
del transfers[req_id]
@ -1317,8 +1357,8 @@ class NixlConnectorWorker:
len(meta.local_block_ids),
len(meta.remote_block_ids),
)
if self.use_host_buffer:
self._recving_metadata[req_id] = meta
# always store metadata for failure recovery
self._recving_metadata[req_id] = meta
if remote_engine_id not in self._remote_agents:
# Initiate handshake with remote engine to exchange metadata.
with self._handshake_lock:
@ -1394,7 +1434,16 @@ class NixlConnectorWorker:
if num_local_blocks == 0:
remote_rank = self.tp_rank // tp_ratio
agent_name = self._remote_agents[dst_engine_id][remote_rank]
self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id)
try:
self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id)
except Exception:
logger.exception(
"NIXL send_notif failed for request %s: "
"P worker blocks will be freed after timeout. "
"This may indicate network issues.",
request_id,
)
self.xfer_stats.record_failed_notification()
return
# Partial prefix cache hit: just read uncomputed blocks.
@ -1456,20 +1505,35 @@ class NixlConnectorWorker:
assert len(local_block_descs_ids) == len(remote_block_descs_ids)
# Prepare transfer with Nixl.
handle = self.nixl_wrapper.make_prepped_xfer(
"READ",
local_xfer_side_handle,
local_block_descs_ids,
remote_xfer_side_handle,
remote_block_descs_ids,
notif_msg=notif_id,
)
handle = None
try:
handle = self.nixl_wrapper.make_prepped_xfer(
"READ",
local_xfer_side_handle,
local_block_descs_ids,
remote_xfer_side_handle,
remote_block_descs_ids,
notif_msg=notif_id,
)
# Begin async xfer.
self.nixl_wrapper.transfer(handle)
# Begin async xfer.
self.nixl_wrapper.transfer(handle)
# Use handle to check completion in future step().
self._recving_transfers[request_id].append((handle, time.perf_counter()))
# Use handle to check completion in future step().
self._recving_transfers[request_id].append((handle, time.perf_counter()))
except Exception:
logger.exception(
"NIXL transfer setup/initiation failed for request %s. "
"Marking blocks as invalid.",
request_id,
)
# mark all blocks for this request as invalid
if meta := self._recving_metadata.get(request_id):
self._invalid_block_ids.update(meta.local_block_ids)
self.xfer_stats.record_failed_transfer()
if handle is not None:
self.nixl_wrapper.release_xfer_handle(handle)
self._failed_recv_reqs.add(request_id)
def _get_block_descs_ids(
self, engine_id: str, block_ids: list[int], layer_idx: int | None = None
@ -1527,6 +1591,17 @@ class NixlConnectorWorker:
return self.xfer_stats.clone_and_reset()
return None
def get_block_ids_with_load_errors(self) -> set[int]:
"""
Return and clear the set of block IDs that failed to load.
This is called by the scheduler to identify blocks that need
to be retried after a NIXL transfer failure.
"""
result = self._invalid_block_ids
self._invalid_block_ids = set()
return result
def shutdown(self):
"""Shutdown the connector worker."""
self._handshake_initiation_executor.shutdown(wait=False)
@ -1586,6 +1661,8 @@ class NixlKVConnectorStats(KVConnectorStats):
"post_duration": [],
"bytes_transferred": [],
"num_descriptors": [],
"num_failed_transfers": [],
"num_failed_notifications": [],
}
def record_transfer(self, res: nixlXferTelemetry):
@ -1595,6 +1672,14 @@ class NixlKVConnectorStats(KVConnectorStats):
self.data["bytes_transferred"].append(res.totalBytes)
self.data["num_descriptors"].append(res.descCount)
def record_failed_transfer(self):
"""Record a failed NIXL transfer operation."""
self.data["num_failed_transfers"].append(1.0)
def record_failed_notification(self):
"""Record a failed NIXL notification (send_notif)."""
self.data["num_failed_notifications"].append(1.0)
def clone_and_reset(self) -> "NixlKVConnectorStats":
old = copy.copy(self)
self.reset()

View File

@ -1487,7 +1487,7 @@ class Scheduler(SchedulerInterface):
total_tokens_to_reschedule += num_tokens_to_reschedule
# Mark requests with async KV load failures; they will be rescheduled
# once loading completes
# once loading completes.
self.failed_recving_kv_req_ids |= async_affected_req_ids
# --- Handle sync KV loads (running requests) ---