[bugfix] avoid NIXL_ERR_REMOTE_DISCONNECT in nixl_connector when Prefill dies (#28120)

Signed-off-by: Mathis Felardos <mathis@mistral.ai>
This commit is contained in:
Mathis Felardos 2025-11-27 16:32:38 +01:00 committed by GitHub
parent 66d3d5422c
commit cd007a53b4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1832,35 +1832,55 @@ class NixlConnectorWorker:
done_req_ids: set[str] = set()
for req_id, handles in list(transfers.items()):
in_progress = False
for handle, _xfer_stime in handles:
xfer_state = self.nixl_wrapper.check_xfer_state(handle)
if xfer_state == "DONE":
# Get telemetry from NIXL
res = self.nixl_wrapper.get_xfer_telemetry(handle)
self.xfer_stats.record_transfer(res)
self.nixl_wrapper.release_xfer_handle(handle)
elif xfer_state == "PROC":
in_progress = True
continue
else:
# transfer failed - mark blocks as invalid
logger.error(
"NIXL transfer failed for request %s with state %s. "
for handle, xfer_start_time in handles:
try:
xfer_state = self.nixl_wrapper.check_xfer_state(handle)
if xfer_state == "DONE":
# Get telemetry from NIXL
res = self.nixl_wrapper.get_xfer_telemetry(handle)
self.xfer_stats.record_transfer(res)
self.nixl_wrapper.release_xfer_handle(handle)
elif xfer_state == "PROC":
in_progress = True
continue
else:
logger.error(
"NIXL transfer failed for request %s with state "
"%s. Marking blocks as invalid.",
req_id,
xfer_state,
)
self._handle_failed_transfer(req_id, handle)
in_progress = False
except Exception:
logger.exception(
"NIXL transfer exception for request %s. "
"Marking blocks as invalid.",
req_id,
xfer_state,
)
# mark all (logical)blocks for this request as invalid
if meta := self._recving_metadata.pop(req_id, None):
self._invalid_block_ids.update(meta.local_block_ids)
self._recving_metadata.pop(req_id, None)
self.nixl_wrapper.release_xfer_handle(handle)
self.xfer_stats.record_failed_transfer()
self._handle_failed_transfer(req_id, handle)
in_progress = False
if not in_progress:
done_req_ids.add(req_id)
del transfers[req_id]
return done_req_ids
def _handle_failed_transfer(self, req_id: str, handle: int):
"""
Handle a failed transfer by marking all (logical) blocks as invalid and
recording the failure.
Args:
req_id: The request ID.
handle: The transfer handle.
"""
if meta := self._recving_metadata.pop(req_id, None):
self._invalid_block_ids.update(meta.local_block_ids)
self._recving_metadata.pop(req_id, None)
self.nixl_wrapper.release_xfer_handle(handle)
self.xfer_stats.record_failed_transfer()
def start_load_kv(self, metadata: NixlConnectorMetadata):
"""
Start loading by triggering non-blocking nixl_xfer.