mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 03:44:58 +08:00
[P/D] [NixlConnector] kv load recovery integration (#26171)
Signed-off-by: Will Eaton <weaton@redhat.com>
This commit is contained in:
parent
0d21b9b51e
commit
53c9a7cee2
@ -190,7 +190,6 @@ def _make_fake_nixl_pkg():
|
||||
# Copy of FakeNixlWrapper implementation for Ray workers
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
|
||||
{fake_nixl_source}
|
||||
|
||||
@ -1143,3 +1142,145 @@ def test_aborted_request_removed_from_worker_in_batch(dist_init):
|
||||
# After abort, the worker should not keep tracking it as "in-batch"
|
||||
assert req.request_id not in connector.connector_worker._reqs_to_process
|
||||
#### Model Runner end ####
|
||||
|
||||
|
||||
class FailingNixlWrapper(FakeNixlWrapper):
|
||||
"""Mock NixlWrapper that fails on specific operations."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.fail_handshake = False
|
||||
self.fail_transfer_setup = False
|
||||
self.fail_send_notif = False
|
||||
|
||||
def add_remote_agent(self, agent_metadata: bytes) -> str:
|
||||
if self.fail_handshake:
|
||||
from zmq.error import Again
|
||||
|
||||
raise Again("Simulated timeout failure")
|
||||
return super().add_remote_agent(agent_metadata)
|
||||
|
||||
def make_prepped_xfer(
|
||||
self,
|
||||
xfer_type: str,
|
||||
local_xfer_side_handle: int,
|
||||
local_block_descs_ids: list[int],
|
||||
remote_xfer_side_handle: int,
|
||||
remote_block_descs_ids: list[int],
|
||||
notif_msg: bytes | None = None,
|
||||
) -> int:
|
||||
if self.fail_transfer_setup:
|
||||
# classic RuntimeError to simulate failure
|
||||
raise RuntimeError("BAD STATUS")
|
||||
return super().make_prepped_xfer(
|
||||
xfer_type,
|
||||
local_xfer_side_handle,
|
||||
local_block_descs_ids,
|
||||
remote_xfer_side_handle,
|
||||
remote_block_descs_ids,
|
||||
notif_msg,
|
||||
)
|
||||
|
||||
def send_notif(self, agent_name: str, notif_msg: bytes) -> None:
|
||||
if self.fail_send_notif:
|
||||
raise RuntimeError("Simulated send_notif failure")
|
||||
return super().send_notif(agent_name, notif_msg)
|
||||
|
||||
|
||||
@patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||
FailingNixlWrapper,
|
||||
)
|
||||
def test_handshake_failure_returns_finished(dist_init):
|
||||
"""Test that handshake failures mark blocks invalid and return via get_finished."""
|
||||
vllm_config = create_vllm_config()
|
||||
|
||||
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
connector.connector_worker = FakeNixlConnectorWorker(
|
||||
vllm_config, connector.engine_id, hand_shake_latency=0.1
|
||||
)
|
||||
connector.connector_worker.nixl_wrapper.fail_handshake = True
|
||||
|
||||
request_id = "test_handshake_fail"
|
||||
metadata = NixlConnectorMetadata()
|
||||
metadata.add_new_req(
|
||||
request_id=request_id,
|
||||
local_block_ids=[1, 2, 3],
|
||||
kv_transfer_params={
|
||||
"remote_block_ids": [4, 5, 6],
|
||||
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
|
||||
"remote_host": "localhost",
|
||||
"remote_port": 1234,
|
||||
"remote_tp_size": 1,
|
||||
},
|
||||
)
|
||||
connector.bind_connector_metadata(metadata)
|
||||
|
||||
dummy_ctx = ForwardContext(
|
||||
no_compile_layers={},
|
||||
attn_metadata={},
|
||||
virtual_engine=0,
|
||||
)
|
||||
connector.start_load_kv(dummy_ctx)
|
||||
|
||||
# Wait for handshake to fail
|
||||
time.sleep(0.3)
|
||||
|
||||
# Check that blocks were marked invalid
|
||||
invalid_blocks = connector.get_block_ids_with_load_errors()
|
||||
assert invalid_blocks == {1, 2, 3}
|
||||
|
||||
# Check that request appears in get_finished
|
||||
_, done_recving = connector.get_finished(finished_req_ids=set())
|
||||
assert request_id in done_recving
|
||||
|
||||
|
||||
@patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||
FailingNixlWrapper,
|
||||
)
|
||||
def test_transfer_setup_failure_returns_finished(dist_init):
|
||||
"""Test that transfer setup failures mark blocks invalid
|
||||
and return via get_finished."""
|
||||
vllm_config = create_vllm_config()
|
||||
|
||||
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
connector.connector_worker = FakeNixlConnectorWorker(
|
||||
vllm_config, connector.engine_id, hand_shake_latency=0
|
||||
)
|
||||
connector.connector_worker.nixl_wrapper.fail_transfer_setup = True
|
||||
|
||||
request_id = "test_transfer_fail"
|
||||
metadata = NixlConnectorMetadata()
|
||||
metadata.add_new_req(
|
||||
request_id=request_id,
|
||||
local_block_ids=[7, 8, 9],
|
||||
kv_transfer_params={
|
||||
"remote_block_ids": [10, 11, 12],
|
||||
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
|
||||
"remote_host": "localhost",
|
||||
"remote_port": 1234,
|
||||
"remote_tp_size": 1,
|
||||
},
|
||||
)
|
||||
connector.bind_connector_metadata(metadata)
|
||||
|
||||
dummy_ctx = ForwardContext(
|
||||
no_compile_layers={},
|
||||
attn_metadata={},
|
||||
virtual_engine=0,
|
||||
)
|
||||
connector.start_load_kv(dummy_ctx)
|
||||
|
||||
# Wait for handshake to complete and process ready_requests
|
||||
connector.bind_connector_metadata(NixlConnectorMetadata())
|
||||
time.sleep(0.1)
|
||||
connector.start_load_kv(dummy_ctx)
|
||||
|
||||
# check that blocks were marked invalid
|
||||
invalid_blocks = connector.get_block_ids_with_load_errors()
|
||||
assert invalid_blocks == {7, 8, 9}
|
||||
|
||||
# ensure request appears in get_finished
|
||||
_, done_recving = connector.get_finished(finished_req_ids=set())
|
||||
assert request_id in done_recving
|
||||
|
||||
@ -68,6 +68,7 @@ except ImportError:
|
||||
NixlWrapper = None
|
||||
nixlXferTelemetry = None
|
||||
|
||||
|
||||
try:
|
||||
from nixl._api import nixl_agent_config
|
||||
except ImportError:
|
||||
@ -234,6 +235,11 @@ class NixlConnector(KVConnectorBase_V1):
|
||||
assert self.connector_worker is not None
|
||||
return self.connector_worker.get_finished()
|
||||
|
||||
def get_block_ids_with_load_errors(self) -> set[int]:
|
||||
"""Get block IDs that failed to load via NIXL."""
|
||||
assert self.connector_worker is not None
|
||||
return self.connector_worker.get_block_ids_with_load_errors()
|
||||
|
||||
def get_kv_connector_stats(self) -> KVConnectorStats | None:
|
||||
assert self.connector_worker is not None
|
||||
return self.connector_worker.get_kv_connector_stats()
|
||||
@ -614,6 +620,11 @@ class NixlConnectorWorker:
|
||||
# Set of requests that have been part of a batch, regardless of status.
|
||||
self._reqs_to_process: set[ReqId] = set()
|
||||
|
||||
# invalid blocks from failed NIXL operations
|
||||
self._invalid_block_ids: set[int] = set()
|
||||
# requests that skipped transfer (handshake or transfer failures)
|
||||
self._failed_recv_reqs: set[ReqId] = set()
|
||||
|
||||
# Background thread for handling new handshake requests.
|
||||
self._nixl_handshake_listener_t: threading.Thread | None = None
|
||||
# Background thread for initializing new NIXL handshakes.
|
||||
@ -713,6 +724,8 @@ class NixlConnectorWorker:
|
||||
|
||||
# Send query for the request.
|
||||
with zmq_ctx(zmq.REQ, path) as sock:
|
||||
# Set receive timeout to 5 seconds to avoid hanging on dead server
|
||||
sock.setsockopt(zmq.RCVTIMEO, 5000) # milliseconds
|
||||
sock.send(GET_META_MSG)
|
||||
metadata_bytes = sock.recv()
|
||||
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
|
||||
@ -795,10 +808,20 @@ class NixlConnectorWorker:
|
||||
|
||||
fut.add_done_callback(done_callback)
|
||||
|
||||
# TODO: handle failure state of future in the
|
||||
# callback, we want to fail the request in this case.
|
||||
def request_ready(_f: Future[Any], entry=(req_id, meta)):
|
||||
self._ready_requests.put(entry)
|
||||
# check handshake success before proceeding with request
|
||||
def request_ready(f: Future[Any], entry=(req_id, meta)):
|
||||
try:
|
||||
# check if handshake succeeded
|
||||
f.result()
|
||||
self._ready_requests.put(entry)
|
||||
except Exception:
|
||||
# handshake failed - mark blocks as invalid
|
||||
logger.exception(
|
||||
"Handshake failed for request %s, marking blocks as invalid", req_id
|
||||
)
|
||||
if req_meta := self._recving_metadata.get(req_id):
|
||||
self._invalid_block_ids.update(req_meta.local_block_ids)
|
||||
self._failed_recv_reqs.add(req_id)
|
||||
|
||||
fut.add_done_callback(request_ready)
|
||||
|
||||
@ -1205,6 +1228,11 @@ class NixlConnectorWorker:
|
||||
"""
|
||||
done_sending = self._get_new_notifs()
|
||||
done_recving = self._pop_done_transfers(self._recving_transfers)
|
||||
|
||||
# add requests that skipped transfer to done_recving
|
||||
done_recving.update(self._failed_recv_reqs)
|
||||
self._failed_recv_reqs.clear()
|
||||
|
||||
if len(done_sending) > 0 or len(done_recving) > 0:
|
||||
logger.debug(
|
||||
"Rank %s, get_finished: %s requests done sending "
|
||||
@ -1214,10 +1242,10 @@ class NixlConnectorWorker:
|
||||
len(done_recving),
|
||||
)
|
||||
|
||||
if self.use_host_buffer:
|
||||
for req_id in done_recving:
|
||||
meta = self._recving_metadata.pop(req_id)
|
||||
assert meta, f"{req_id} not found in recving_metadata list"
|
||||
# clean up metadata for completed requests
|
||||
for req_id in done_recving:
|
||||
meta = self._recving_metadata.pop(req_id, None)
|
||||
if self.use_host_buffer and meta:
|
||||
self.sync_recved_kv_to_device(req_id, meta)
|
||||
|
||||
# Handle timeout to avoid stranding blocks on remote.
|
||||
@ -1296,7 +1324,19 @@ class NixlConnectorWorker:
|
||||
in_progress = True
|
||||
continue
|
||||
else:
|
||||
raise RuntimeError("Transfer failed with state %s", xfer_state)
|
||||
# transfer failed - mark blocks as invalid
|
||||
logger.error(
|
||||
"NIXL transfer failed for request %s with state %s. "
|
||||
"Marking blocks as invalid.",
|
||||
req_id,
|
||||
xfer_state,
|
||||
)
|
||||
# mark all blocks for this request as invalid
|
||||
if meta := self._recving_metadata.pop(req_id, None):
|
||||
self._invalid_block_ids.update(meta.local_block_ids)
|
||||
self._recving_metadata.pop(req_id, None)
|
||||
self.nixl_wrapper.release_xfer_handle(handle)
|
||||
self.xfer_stats.record_failed_transfer()
|
||||
if not in_progress:
|
||||
done_req_ids.add(req_id)
|
||||
del transfers[req_id]
|
||||
@ -1317,8 +1357,8 @@ class NixlConnectorWorker:
|
||||
len(meta.local_block_ids),
|
||||
len(meta.remote_block_ids),
|
||||
)
|
||||
if self.use_host_buffer:
|
||||
self._recving_metadata[req_id] = meta
|
||||
# always store metadata for failure recovery
|
||||
self._recving_metadata[req_id] = meta
|
||||
if remote_engine_id not in self._remote_agents:
|
||||
# Initiate handshake with remote engine to exchange metadata.
|
||||
with self._handshake_lock:
|
||||
@ -1394,7 +1434,16 @@ class NixlConnectorWorker:
|
||||
if num_local_blocks == 0:
|
||||
remote_rank = self.tp_rank // tp_ratio
|
||||
agent_name = self._remote_agents[dst_engine_id][remote_rank]
|
||||
self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id)
|
||||
try:
|
||||
self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"NIXL send_notif failed for request %s: "
|
||||
"P worker blocks will be freed after timeout. "
|
||||
"This may indicate network issues.",
|
||||
request_id,
|
||||
)
|
||||
self.xfer_stats.record_failed_notification()
|
||||
return
|
||||
|
||||
# Partial prefix cache hit: just read uncomputed blocks.
|
||||
@ -1456,20 +1505,35 @@ class NixlConnectorWorker:
|
||||
assert len(local_block_descs_ids) == len(remote_block_descs_ids)
|
||||
|
||||
# Prepare transfer with Nixl.
|
||||
handle = self.nixl_wrapper.make_prepped_xfer(
|
||||
"READ",
|
||||
local_xfer_side_handle,
|
||||
local_block_descs_ids,
|
||||
remote_xfer_side_handle,
|
||||
remote_block_descs_ids,
|
||||
notif_msg=notif_id,
|
||||
)
|
||||
handle = None
|
||||
try:
|
||||
handle = self.nixl_wrapper.make_prepped_xfer(
|
||||
"READ",
|
||||
local_xfer_side_handle,
|
||||
local_block_descs_ids,
|
||||
remote_xfer_side_handle,
|
||||
remote_block_descs_ids,
|
||||
notif_msg=notif_id,
|
||||
)
|
||||
|
||||
# Begin async xfer.
|
||||
self.nixl_wrapper.transfer(handle)
|
||||
# Begin async xfer.
|
||||
self.nixl_wrapper.transfer(handle)
|
||||
|
||||
# Use handle to check completion in future step().
|
||||
self._recving_transfers[request_id].append((handle, time.perf_counter()))
|
||||
# Use handle to check completion in future step().
|
||||
self._recving_transfers[request_id].append((handle, time.perf_counter()))
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"NIXL transfer setup/initiation failed for request %s. "
|
||||
"Marking blocks as invalid.",
|
||||
request_id,
|
||||
)
|
||||
# mark all blocks for this request as invalid
|
||||
if meta := self._recving_metadata.get(request_id):
|
||||
self._invalid_block_ids.update(meta.local_block_ids)
|
||||
self.xfer_stats.record_failed_transfer()
|
||||
if handle is not None:
|
||||
self.nixl_wrapper.release_xfer_handle(handle)
|
||||
self._failed_recv_reqs.add(request_id)
|
||||
|
||||
def _get_block_descs_ids(
|
||||
self, engine_id: str, block_ids: list[int], layer_idx: int | None = None
|
||||
@ -1527,6 +1591,17 @@ class NixlConnectorWorker:
|
||||
return self.xfer_stats.clone_and_reset()
|
||||
return None
|
||||
|
||||
def get_block_ids_with_load_errors(self) -> set[int]:
|
||||
"""
|
||||
Return and clear the set of block IDs that failed to load.
|
||||
|
||||
This is called by the scheduler to identify blocks that need
|
||||
to be retried after a NIXL transfer failure.
|
||||
"""
|
||||
result = self._invalid_block_ids
|
||||
self._invalid_block_ids = set()
|
||||
return result
|
||||
|
||||
def shutdown(self):
|
||||
"""Shutdown the connector worker."""
|
||||
self._handshake_initiation_executor.shutdown(wait=False)
|
||||
@ -1586,6 +1661,8 @@ class NixlKVConnectorStats(KVConnectorStats):
|
||||
"post_duration": [],
|
||||
"bytes_transferred": [],
|
||||
"num_descriptors": [],
|
||||
"num_failed_transfers": [],
|
||||
"num_failed_notifications": [],
|
||||
}
|
||||
|
||||
def record_transfer(self, res: nixlXferTelemetry):
|
||||
@ -1595,6 +1672,14 @@ class NixlKVConnectorStats(KVConnectorStats):
|
||||
self.data["bytes_transferred"].append(res.totalBytes)
|
||||
self.data["num_descriptors"].append(res.descCount)
|
||||
|
||||
def record_failed_transfer(self):
|
||||
"""Record a failed NIXL transfer operation."""
|
||||
self.data["num_failed_transfers"].append(1.0)
|
||||
|
||||
def record_failed_notification(self):
|
||||
"""Record a failed NIXL notification (send_notif)."""
|
||||
self.data["num_failed_notifications"].append(1.0)
|
||||
|
||||
def clone_and_reset(self) -> "NixlKVConnectorStats":
|
||||
old = copy.copy(self)
|
||||
self.reset()
|
||||
|
||||
@ -1487,7 +1487,7 @@ class Scheduler(SchedulerInterface):
|
||||
total_tokens_to_reschedule += num_tokens_to_reschedule
|
||||
|
||||
# Mark requests with async KV load failures; they will be rescheduled
|
||||
# once loading completes
|
||||
# once loading completes.
|
||||
self.failed_recving_kv_req_ids |= async_affected_req_ids
|
||||
|
||||
# --- Handle sync KV loads (running requests) ---
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user