[bugfix] avoid NIXL_ERR_REMOTE_DISCONNECT in nixl_connector when Prefill dies (#28120)

Signed-off-by: Mathis Felardos <mathis@mistral.ai>
This commit is contained in:
Mathis Felardos 2025-11-27 16:32:38 +01:00 committed by GitHub
parent 66d3d5422c
commit cd007a53b4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1832,35 +1832,55 @@ class NixlConnectorWorker:
done_req_ids: set[str] = set() done_req_ids: set[str] = set()
for req_id, handles in list(transfers.items()): for req_id, handles in list(transfers.items()):
in_progress = False in_progress = False
for handle, _xfer_stime in handles: for handle, xfer_start_time in handles:
xfer_state = self.nixl_wrapper.check_xfer_state(handle) try:
if xfer_state == "DONE": xfer_state = self.nixl_wrapper.check_xfer_state(handle)
# Get telemetry from NIXL if xfer_state == "DONE":
res = self.nixl_wrapper.get_xfer_telemetry(handle) # Get telemetry from NIXL
self.xfer_stats.record_transfer(res) res = self.nixl_wrapper.get_xfer_telemetry(handle)
self.nixl_wrapper.release_xfer_handle(handle) self.xfer_stats.record_transfer(res)
elif xfer_state == "PROC": self.nixl_wrapper.release_xfer_handle(handle)
in_progress = True elif xfer_state == "PROC":
continue in_progress = True
else: continue
# transfer failed - mark blocks as invalid else:
logger.error( logger.error(
"NIXL transfer failed for request %s with state %s. " "NIXL transfer failed for request %s with state "
"%s. Marking blocks as invalid.",
req_id,
xfer_state,
)
self._handle_failed_transfer(req_id, handle)
in_progress = False
except Exception:
logger.exception(
"NIXL transfer exception for request %s. "
"Marking blocks as invalid.", "Marking blocks as invalid.",
req_id, req_id,
xfer_state,
) )
# mark all (logical)blocks for this request as invalid self._handle_failed_transfer(req_id, handle)
if meta := self._recving_metadata.pop(req_id, None): in_progress = False
self._invalid_block_ids.update(meta.local_block_ids)
self._recving_metadata.pop(req_id, None)
self.nixl_wrapper.release_xfer_handle(handle)
self.xfer_stats.record_failed_transfer()
if not in_progress: if not in_progress:
done_req_ids.add(req_id) done_req_ids.add(req_id)
del transfers[req_id] del transfers[req_id]
return done_req_ids return done_req_ids
def _handle_failed_transfer(self, req_id: str, handle: int):
"""
Handle a failed transfer by marking all (logical) blocks as invalid and
recording the failure.
Args:
req_id: The request ID.
handle: The transfer handle.
"""
if meta := self._recving_metadata.pop(req_id, None):
self._invalid_block_ids.update(meta.local_block_ids)
self._recving_metadata.pop(req_id, None)
self.nixl_wrapper.release_xfer_handle(handle)
self.xfer_stats.record_failed_transfer()
def start_load_kv(self, metadata: NixlConnectorMetadata): def start_load_kv(self, metadata: NixlConnectorMetadata):
""" """
Start loading by triggering non-blocking nixl_xfer. Start loading by triggering non-blocking nixl_xfer.