[Bugfix][NIXL] Fix Async Scheduler timeout issue (#25808)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi 2025-09-27 21:17:35 +02:00 committed by GitHub
parent c216119d64
commit da63274d9f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -105,6 +105,7 @@ class NixlConnectorMetadata(KVConnectorMetadata):
self.reqs_to_recv: dict[ReqId, ReqMeta] = {}
self.reqs_to_save: dict[ReqId, ReqMeta] = {}
self.reqs_to_send: dict[ReqId, float] = {}
self.reqs_in_batch: set[ReqId] = set()
def add_new_req(
self,
@ -278,6 +279,7 @@ class NixlConnectorScheduler:
self._reqs_need_save: dict[ReqId, tuple[Request, list[int]]] = {}
# Reqs to send and their expiration time
self._reqs_need_send: dict[ReqId, float] = {}
self._reqs_in_batch: set[ReqId] = set()
def get_num_new_matched_tokens(
self, request: "Request",
@ -324,6 +326,9 @@ class NixlConnectorScheduler:
if not params:
return
if params.get("do_remote_decode"):
self._reqs_in_batch.add(request.request_id)
if self.use_host_buffer and params.get("do_remote_decode"):
# NOTE: when accelerator is not directly supported by Nixl,
# prefilled blocks need to be saved to host memory before transfer.
@ -373,6 +378,8 @@ class NixlConnectorScheduler:
request_id=req_id,
local_block_ids=block_ids,
kv_transfer_params=req.kv_transfer_params,
load_remote_cache=True,
save_to_host=False,
)
for req_id, (req, block_ids) in self._reqs_need_save.items():
@ -386,10 +393,12 @@ class NixlConnectorScheduler:
)
meta.reqs_to_send = self._reqs_need_send
meta.reqs_in_batch = self._reqs_in_batch
# Clear the list once workers start the transfers
self._reqs_need_recv.clear()
self._reqs_need_save.clear()
self._reqs_in_batch = set()
self._reqs_need_send = {}
return meta
@ -546,6 +555,8 @@ class NixlConnectorWorker:
self._recving_transfers = defaultdict[ReqId, list[Transfer]](list)
# Track the expiration time of requests that are waiting to be sent.
self._reqs_to_send: dict[ReqId, float] = {}
# Set of requests that have been part of a batch, regardless of status.
self._reqs_to_process: set[ReqId] = set()
# Background thread for handling new handshake requests.
self._nixl_handshake_listener_t: Optional[threading.Thread] = None
@ -1082,6 +1093,7 @@ class NixlConnectorWorker:
"Releasing expired KV blocks for request %s which were "
"retrieved by %d decode worker(s) within %d seconds.", req_id,
count, envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT)
self._reqs_to_process.remove(req_id)
del self._reqs_to_send[req_id]
done_sending.add(req_id)
@ -1097,7 +1109,8 @@ class NixlConnectorWorker:
for notifs in self.nixl_wrapper.get_new_notifs().values():
for notif in notifs:
req_id, tp_ratio = notif.decode("utf-8").rsplit(":", 1)
if req_id not in self._reqs_to_send:
if (req_id not in self._reqs_to_send
and req_id not in self._reqs_to_process):
logger.error(
"Potentially invalid KV blocks for "
"unrecognized request %s were retrieved by "
@ -1110,7 +1123,8 @@ class NixlConnectorWorker:
tp_ratio):
notified_req_ids.add(req_id)
del self.consumer_notification_counts_by_req[req_id]
del self._reqs_to_send[req_id]
self._reqs_to_process.remove(req_id)
self._reqs_to_send.pop(req_id, None)
return notified_req_ids
def _pop_done_transfers(
@ -1171,8 +1185,19 @@ class NixlConnectorWorker:
while not self._ready_requests.empty():
self._read_blocks_for_req(*self._ready_requests.get_nowait())
# Keep around the requests that have been part of a batch. This is
# needed because async scheduling pushes the misalignment between the
# moment in which requests expiration is set (P side) and the moment in
# which blocks are read from D. As P can now more easily lag behind D
# while processing the next batch, we make sure to only set an
# expiration for requests that have not been read from D yet.
for req_id in metadata.reqs_in_batch:
self._reqs_to_process.add(req_id)
# Add to requests that are waiting to be read and track expiration.
self._reqs_to_send.update(metadata.reqs_to_send)
for req_id, expiration_time in metadata.reqs_to_send.items():
if req_id in self._reqs_to_process:
self._reqs_to_send[req_id] = expiration_time
def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
logger.debug(