mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-22 06:17:51 +08:00
[P/D] [NixlConnector] kv load recovery integration (#26171)
Signed-off-by: Will Eaton <weaton@redhat.com>
This commit is contained in:
parent
0d21b9b51e
commit
53c9a7cee2
@ -190,7 +190,6 @@ def _make_fake_nixl_pkg():
|
|||||||
# Copy of FakeNixlWrapper implementation for Ray workers
|
# Copy of FakeNixlWrapper implementation for Ray workers
|
||||||
import uuid
|
import uuid
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
{fake_nixl_source}
|
{fake_nixl_source}
|
||||||
|
|
||||||
@ -1143,3 +1142,145 @@ def test_aborted_request_removed_from_worker_in_batch(dist_init):
|
|||||||
# After abort, the worker should not keep tracking it as "in-batch"
|
# After abort, the worker should not keep tracking it as "in-batch"
|
||||||
assert req.request_id not in connector.connector_worker._reqs_to_process
|
assert req.request_id not in connector.connector_worker._reqs_to_process
|
||||||
#### Model Runner end ####
|
#### Model Runner end ####
|
||||||
|
|
||||||
|
|
||||||
|
class FailingNixlWrapper(FakeNixlWrapper):
|
||||||
|
"""Mock NixlWrapper that fails on specific operations."""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.fail_handshake = False
|
||||||
|
self.fail_transfer_setup = False
|
||||||
|
self.fail_send_notif = False
|
||||||
|
|
||||||
|
def add_remote_agent(self, agent_metadata: bytes) -> str:
|
||||||
|
if self.fail_handshake:
|
||||||
|
from zmq.error import Again
|
||||||
|
|
||||||
|
raise Again("Simulated timeout failure")
|
||||||
|
return super().add_remote_agent(agent_metadata)
|
||||||
|
|
||||||
|
def make_prepped_xfer(
|
||||||
|
self,
|
||||||
|
xfer_type: str,
|
||||||
|
local_xfer_side_handle: int,
|
||||||
|
local_block_descs_ids: list[int],
|
||||||
|
remote_xfer_side_handle: int,
|
||||||
|
remote_block_descs_ids: list[int],
|
||||||
|
notif_msg: bytes | None = None,
|
||||||
|
) -> int:
|
||||||
|
if self.fail_transfer_setup:
|
||||||
|
# classic RuntimeError to simulate failure
|
||||||
|
raise RuntimeError("BAD STATUS")
|
||||||
|
return super().make_prepped_xfer(
|
||||||
|
xfer_type,
|
||||||
|
local_xfer_side_handle,
|
||||||
|
local_block_descs_ids,
|
||||||
|
remote_xfer_side_handle,
|
||||||
|
remote_block_descs_ids,
|
||||||
|
notif_msg,
|
||||||
|
)
|
||||||
|
|
||||||
|
def send_notif(self, agent_name: str, notif_msg: bytes) -> None:
|
||||||
|
if self.fail_send_notif:
|
||||||
|
raise RuntimeError("Simulated send_notif failure")
|
||||||
|
return super().send_notif(agent_name, notif_msg)
|
||||||
|
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||||
|
FailingNixlWrapper,
|
||||||
|
)
|
||||||
|
def test_handshake_failure_returns_finished(dist_init):
|
||||||
|
"""Test that handshake failures mark blocks invalid and return via get_finished."""
|
||||||
|
vllm_config = create_vllm_config()
|
||||||
|
|
||||||
|
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||||
|
connector.connector_worker = FakeNixlConnectorWorker(
|
||||||
|
vllm_config, connector.engine_id, hand_shake_latency=0.1
|
||||||
|
)
|
||||||
|
connector.connector_worker.nixl_wrapper.fail_handshake = True
|
||||||
|
|
||||||
|
request_id = "test_handshake_fail"
|
||||||
|
metadata = NixlConnectorMetadata()
|
||||||
|
metadata.add_new_req(
|
||||||
|
request_id=request_id,
|
||||||
|
local_block_ids=[1, 2, 3],
|
||||||
|
kv_transfer_params={
|
||||||
|
"remote_block_ids": [4, 5, 6],
|
||||||
|
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
|
||||||
|
"remote_host": "localhost",
|
||||||
|
"remote_port": 1234,
|
||||||
|
"remote_tp_size": 1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
connector.bind_connector_metadata(metadata)
|
||||||
|
|
||||||
|
dummy_ctx = ForwardContext(
|
||||||
|
no_compile_layers={},
|
||||||
|
attn_metadata={},
|
||||||
|
virtual_engine=0,
|
||||||
|
)
|
||||||
|
connector.start_load_kv(dummy_ctx)
|
||||||
|
|
||||||
|
# Wait for handshake to fail
|
||||||
|
time.sleep(0.3)
|
||||||
|
|
||||||
|
# Check that blocks were marked invalid
|
||||||
|
invalid_blocks = connector.get_block_ids_with_load_errors()
|
||||||
|
assert invalid_blocks == {1, 2, 3}
|
||||||
|
|
||||||
|
# Check that request appears in get_finished
|
||||||
|
_, done_recving = connector.get_finished(finished_req_ids=set())
|
||||||
|
assert request_id in done_recving
|
||||||
|
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||||
|
FailingNixlWrapper,
|
||||||
|
)
|
||||||
|
def test_transfer_setup_failure_returns_finished(dist_init):
|
||||||
|
"""Test that transfer setup failures mark blocks invalid
|
||||||
|
and return via get_finished."""
|
||||||
|
vllm_config = create_vllm_config()
|
||||||
|
|
||||||
|
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||||
|
connector.connector_worker = FakeNixlConnectorWorker(
|
||||||
|
vllm_config, connector.engine_id, hand_shake_latency=0
|
||||||
|
)
|
||||||
|
connector.connector_worker.nixl_wrapper.fail_transfer_setup = True
|
||||||
|
|
||||||
|
request_id = "test_transfer_fail"
|
||||||
|
metadata = NixlConnectorMetadata()
|
||||||
|
metadata.add_new_req(
|
||||||
|
request_id=request_id,
|
||||||
|
local_block_ids=[7, 8, 9],
|
||||||
|
kv_transfer_params={
|
||||||
|
"remote_block_ids": [10, 11, 12],
|
||||||
|
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
|
||||||
|
"remote_host": "localhost",
|
||||||
|
"remote_port": 1234,
|
||||||
|
"remote_tp_size": 1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
connector.bind_connector_metadata(metadata)
|
||||||
|
|
||||||
|
dummy_ctx = ForwardContext(
|
||||||
|
no_compile_layers={},
|
||||||
|
attn_metadata={},
|
||||||
|
virtual_engine=0,
|
||||||
|
)
|
||||||
|
connector.start_load_kv(dummy_ctx)
|
||||||
|
|
||||||
|
# Wait for handshake to complete and process ready_requests
|
||||||
|
connector.bind_connector_metadata(NixlConnectorMetadata())
|
||||||
|
time.sleep(0.1)
|
||||||
|
connector.start_load_kv(dummy_ctx)
|
||||||
|
|
||||||
|
# check that blocks were marked invalid
|
||||||
|
invalid_blocks = connector.get_block_ids_with_load_errors()
|
||||||
|
assert invalid_blocks == {7, 8, 9}
|
||||||
|
|
||||||
|
# ensure request appears in get_finished
|
||||||
|
_, done_recving = connector.get_finished(finished_req_ids=set())
|
||||||
|
assert request_id in done_recving
|
||||||
|
|||||||
@ -68,6 +68,7 @@ except ImportError:
|
|||||||
NixlWrapper = None
|
NixlWrapper = None
|
||||||
nixlXferTelemetry = None
|
nixlXferTelemetry = None
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from nixl._api import nixl_agent_config
|
from nixl._api import nixl_agent_config
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -234,6 +235,11 @@ class NixlConnector(KVConnectorBase_V1):
|
|||||||
assert self.connector_worker is not None
|
assert self.connector_worker is not None
|
||||||
return self.connector_worker.get_finished()
|
return self.connector_worker.get_finished()
|
||||||
|
|
||||||
|
def get_block_ids_with_load_errors(self) -> set[int]:
|
||||||
|
"""Get block IDs that failed to load via NIXL."""
|
||||||
|
assert self.connector_worker is not None
|
||||||
|
return self.connector_worker.get_block_ids_with_load_errors()
|
||||||
|
|
||||||
def get_kv_connector_stats(self) -> KVConnectorStats | None:
|
def get_kv_connector_stats(self) -> KVConnectorStats | None:
|
||||||
assert self.connector_worker is not None
|
assert self.connector_worker is not None
|
||||||
return self.connector_worker.get_kv_connector_stats()
|
return self.connector_worker.get_kv_connector_stats()
|
||||||
@ -614,6 +620,11 @@ class NixlConnectorWorker:
|
|||||||
# Set of requests that have been part of a batch, regardless of status.
|
# Set of requests that have been part of a batch, regardless of status.
|
||||||
self._reqs_to_process: set[ReqId] = set()
|
self._reqs_to_process: set[ReqId] = set()
|
||||||
|
|
||||||
|
# invalid blocks from failed NIXL operations
|
||||||
|
self._invalid_block_ids: set[int] = set()
|
||||||
|
# requests that skipped transfer (handshake or transfer failures)
|
||||||
|
self._failed_recv_reqs: set[ReqId] = set()
|
||||||
|
|
||||||
# Background thread for handling new handshake requests.
|
# Background thread for handling new handshake requests.
|
||||||
self._nixl_handshake_listener_t: threading.Thread | None = None
|
self._nixl_handshake_listener_t: threading.Thread | None = None
|
||||||
# Background thread for initializing new NIXL handshakes.
|
# Background thread for initializing new NIXL handshakes.
|
||||||
@ -713,6 +724,8 @@ class NixlConnectorWorker:
|
|||||||
|
|
||||||
# Send query for the request.
|
# Send query for the request.
|
||||||
with zmq_ctx(zmq.REQ, path) as sock:
|
with zmq_ctx(zmq.REQ, path) as sock:
|
||||||
|
# Set receive timeout to 5 seconds to avoid hanging on dead server
|
||||||
|
sock.setsockopt(zmq.RCVTIMEO, 5000) # milliseconds
|
||||||
sock.send(GET_META_MSG)
|
sock.send(GET_META_MSG)
|
||||||
metadata_bytes = sock.recv()
|
metadata_bytes = sock.recv()
|
||||||
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
|
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
|
||||||
@ -795,10 +808,20 @@ class NixlConnectorWorker:
|
|||||||
|
|
||||||
fut.add_done_callback(done_callback)
|
fut.add_done_callback(done_callback)
|
||||||
|
|
||||||
# TODO: handle failure state of future in the
|
# check handshake success before proceeding with request
|
||||||
# callback, we want to fail the request in this case.
|
def request_ready(f: Future[Any], entry=(req_id, meta)):
|
||||||
def request_ready(_f: Future[Any], entry=(req_id, meta)):
|
try:
|
||||||
self._ready_requests.put(entry)
|
# check if handshake succeeded
|
||||||
|
f.result()
|
||||||
|
self._ready_requests.put(entry)
|
||||||
|
except Exception:
|
||||||
|
# handshake failed - mark blocks as invalid
|
||||||
|
logger.exception(
|
||||||
|
"Handshake failed for request %s, marking blocks as invalid", req_id
|
||||||
|
)
|
||||||
|
if req_meta := self._recving_metadata.get(req_id):
|
||||||
|
self._invalid_block_ids.update(req_meta.local_block_ids)
|
||||||
|
self._failed_recv_reqs.add(req_id)
|
||||||
|
|
||||||
fut.add_done_callback(request_ready)
|
fut.add_done_callback(request_ready)
|
||||||
|
|
||||||
@ -1205,6 +1228,11 @@ class NixlConnectorWorker:
|
|||||||
"""
|
"""
|
||||||
done_sending = self._get_new_notifs()
|
done_sending = self._get_new_notifs()
|
||||||
done_recving = self._pop_done_transfers(self._recving_transfers)
|
done_recving = self._pop_done_transfers(self._recving_transfers)
|
||||||
|
|
||||||
|
# add requests that skipped transfer to done_recving
|
||||||
|
done_recving.update(self._failed_recv_reqs)
|
||||||
|
self._failed_recv_reqs.clear()
|
||||||
|
|
||||||
if len(done_sending) > 0 or len(done_recving) > 0:
|
if len(done_sending) > 0 or len(done_recving) > 0:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Rank %s, get_finished: %s requests done sending "
|
"Rank %s, get_finished: %s requests done sending "
|
||||||
@ -1214,10 +1242,10 @@ class NixlConnectorWorker:
|
|||||||
len(done_recving),
|
len(done_recving),
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.use_host_buffer:
|
# clean up metadata for completed requests
|
||||||
for req_id in done_recving:
|
for req_id in done_recving:
|
||||||
meta = self._recving_metadata.pop(req_id)
|
meta = self._recving_metadata.pop(req_id, None)
|
||||||
assert meta, f"{req_id} not found in recving_metadata list"
|
if self.use_host_buffer and meta:
|
||||||
self.sync_recved_kv_to_device(req_id, meta)
|
self.sync_recved_kv_to_device(req_id, meta)
|
||||||
|
|
||||||
# Handle timeout to avoid stranding blocks on remote.
|
# Handle timeout to avoid stranding blocks on remote.
|
||||||
@ -1296,7 +1324,19 @@ class NixlConnectorWorker:
|
|||||||
in_progress = True
|
in_progress = True
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Transfer failed with state %s", xfer_state)
|
# transfer failed - mark blocks as invalid
|
||||||
|
logger.error(
|
||||||
|
"NIXL transfer failed for request %s with state %s. "
|
||||||
|
"Marking blocks as invalid.",
|
||||||
|
req_id,
|
||||||
|
xfer_state,
|
||||||
|
)
|
||||||
|
# mark all blocks for this request as invalid
|
||||||
|
if meta := self._recving_metadata.pop(req_id, None):
|
||||||
|
self._invalid_block_ids.update(meta.local_block_ids)
|
||||||
|
self._recving_metadata.pop(req_id, None)
|
||||||
|
self.nixl_wrapper.release_xfer_handle(handle)
|
||||||
|
self.xfer_stats.record_failed_transfer()
|
||||||
if not in_progress:
|
if not in_progress:
|
||||||
done_req_ids.add(req_id)
|
done_req_ids.add(req_id)
|
||||||
del transfers[req_id]
|
del transfers[req_id]
|
||||||
@ -1317,8 +1357,8 @@ class NixlConnectorWorker:
|
|||||||
len(meta.local_block_ids),
|
len(meta.local_block_ids),
|
||||||
len(meta.remote_block_ids),
|
len(meta.remote_block_ids),
|
||||||
)
|
)
|
||||||
if self.use_host_buffer:
|
# always store metadata for failure recovery
|
||||||
self._recving_metadata[req_id] = meta
|
self._recving_metadata[req_id] = meta
|
||||||
if remote_engine_id not in self._remote_agents:
|
if remote_engine_id not in self._remote_agents:
|
||||||
# Initiate handshake with remote engine to exchange metadata.
|
# Initiate handshake with remote engine to exchange metadata.
|
||||||
with self._handshake_lock:
|
with self._handshake_lock:
|
||||||
@ -1394,7 +1434,16 @@ class NixlConnectorWorker:
|
|||||||
if num_local_blocks == 0:
|
if num_local_blocks == 0:
|
||||||
remote_rank = self.tp_rank // tp_ratio
|
remote_rank = self.tp_rank // tp_ratio
|
||||||
agent_name = self._remote_agents[dst_engine_id][remote_rank]
|
agent_name = self._remote_agents[dst_engine_id][remote_rank]
|
||||||
self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id)
|
try:
|
||||||
|
self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"NIXL send_notif failed for request %s: "
|
||||||
|
"P worker blocks will be freed after timeout. "
|
||||||
|
"This may indicate network issues.",
|
||||||
|
request_id,
|
||||||
|
)
|
||||||
|
self.xfer_stats.record_failed_notification()
|
||||||
return
|
return
|
||||||
|
|
||||||
# Partial prefix cache hit: just read uncomputed blocks.
|
# Partial prefix cache hit: just read uncomputed blocks.
|
||||||
@ -1456,20 +1505,35 @@ class NixlConnectorWorker:
|
|||||||
assert len(local_block_descs_ids) == len(remote_block_descs_ids)
|
assert len(local_block_descs_ids) == len(remote_block_descs_ids)
|
||||||
|
|
||||||
# Prepare transfer with Nixl.
|
# Prepare transfer with Nixl.
|
||||||
handle = self.nixl_wrapper.make_prepped_xfer(
|
handle = None
|
||||||
"READ",
|
try:
|
||||||
local_xfer_side_handle,
|
handle = self.nixl_wrapper.make_prepped_xfer(
|
||||||
local_block_descs_ids,
|
"READ",
|
||||||
remote_xfer_side_handle,
|
local_xfer_side_handle,
|
||||||
remote_block_descs_ids,
|
local_block_descs_ids,
|
||||||
notif_msg=notif_id,
|
remote_xfer_side_handle,
|
||||||
)
|
remote_block_descs_ids,
|
||||||
|
notif_msg=notif_id,
|
||||||
|
)
|
||||||
|
|
||||||
# Begin async xfer.
|
# Begin async xfer.
|
||||||
self.nixl_wrapper.transfer(handle)
|
self.nixl_wrapper.transfer(handle)
|
||||||
|
|
||||||
# Use handle to check completion in future step().
|
# Use handle to check completion in future step().
|
||||||
self._recving_transfers[request_id].append((handle, time.perf_counter()))
|
self._recving_transfers[request_id].append((handle, time.perf_counter()))
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"NIXL transfer setup/initiation failed for request %s. "
|
||||||
|
"Marking blocks as invalid.",
|
||||||
|
request_id,
|
||||||
|
)
|
||||||
|
# mark all blocks for this request as invalid
|
||||||
|
if meta := self._recving_metadata.get(request_id):
|
||||||
|
self._invalid_block_ids.update(meta.local_block_ids)
|
||||||
|
self.xfer_stats.record_failed_transfer()
|
||||||
|
if handle is not None:
|
||||||
|
self.nixl_wrapper.release_xfer_handle(handle)
|
||||||
|
self._failed_recv_reqs.add(request_id)
|
||||||
|
|
||||||
def _get_block_descs_ids(
|
def _get_block_descs_ids(
|
||||||
self, engine_id: str, block_ids: list[int], layer_idx: int | None = None
|
self, engine_id: str, block_ids: list[int], layer_idx: int | None = None
|
||||||
@ -1527,6 +1591,17 @@ class NixlConnectorWorker:
|
|||||||
return self.xfer_stats.clone_and_reset()
|
return self.xfer_stats.clone_and_reset()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_block_ids_with_load_errors(self) -> set[int]:
|
||||||
|
"""
|
||||||
|
Return and clear the set of block IDs that failed to load.
|
||||||
|
|
||||||
|
This is called by the scheduler to identify blocks that need
|
||||||
|
to be retried after a NIXL transfer failure.
|
||||||
|
"""
|
||||||
|
result = self._invalid_block_ids
|
||||||
|
self._invalid_block_ids = set()
|
||||||
|
return result
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
"""Shutdown the connector worker."""
|
"""Shutdown the connector worker."""
|
||||||
self._handshake_initiation_executor.shutdown(wait=False)
|
self._handshake_initiation_executor.shutdown(wait=False)
|
||||||
@ -1586,6 +1661,8 @@ class NixlKVConnectorStats(KVConnectorStats):
|
|||||||
"post_duration": [],
|
"post_duration": [],
|
||||||
"bytes_transferred": [],
|
"bytes_transferred": [],
|
||||||
"num_descriptors": [],
|
"num_descriptors": [],
|
||||||
|
"num_failed_transfers": [],
|
||||||
|
"num_failed_notifications": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
def record_transfer(self, res: nixlXferTelemetry):
|
def record_transfer(self, res: nixlXferTelemetry):
|
||||||
@ -1595,6 +1672,14 @@ class NixlKVConnectorStats(KVConnectorStats):
|
|||||||
self.data["bytes_transferred"].append(res.totalBytes)
|
self.data["bytes_transferred"].append(res.totalBytes)
|
||||||
self.data["num_descriptors"].append(res.descCount)
|
self.data["num_descriptors"].append(res.descCount)
|
||||||
|
|
||||||
|
def record_failed_transfer(self):
|
||||||
|
"""Record a failed NIXL transfer operation."""
|
||||||
|
self.data["num_failed_transfers"].append(1.0)
|
||||||
|
|
||||||
|
def record_failed_notification(self):
|
||||||
|
"""Record a failed NIXL notification (send_notif)."""
|
||||||
|
self.data["num_failed_notifications"].append(1.0)
|
||||||
|
|
||||||
def clone_and_reset(self) -> "NixlKVConnectorStats":
|
def clone_and_reset(self) -> "NixlKVConnectorStats":
|
||||||
old = copy.copy(self)
|
old = copy.copy(self)
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|||||||
@ -1487,7 +1487,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
total_tokens_to_reschedule += num_tokens_to_reschedule
|
total_tokens_to_reschedule += num_tokens_to_reschedule
|
||||||
|
|
||||||
# Mark requests with async KV load failures; they will be rescheduled
|
# Mark requests with async KV load failures; they will be rescheduled
|
||||||
# once loading completes
|
# once loading completes.
|
||||||
self.failed_recving_kv_req_ids |= async_affected_req_ids
|
self.failed_recving_kv_req_ids |= async_affected_req_ids
|
||||||
|
|
||||||
# --- Handle sync KV loads (running requests) ---
|
# --- Handle sync KV loads (running requests) ---
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user