mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:35:28 +08:00
[Bugfix][NIXL] Fix Async Scheduler timeout issue (#25808)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
parent
c216119d64
commit
da63274d9f
@ -105,6 +105,7 @@ class NixlConnectorMetadata(KVConnectorMetadata):
|
||||
self.reqs_to_recv: dict[ReqId, ReqMeta] = {}
|
||||
self.reqs_to_save: dict[ReqId, ReqMeta] = {}
|
||||
self.reqs_to_send: dict[ReqId, float] = {}
|
||||
self.reqs_in_batch: set[ReqId] = set()
|
||||
|
||||
def add_new_req(
|
||||
self,
|
||||
@ -278,6 +279,7 @@ class NixlConnectorScheduler:
|
||||
self._reqs_need_save: dict[ReqId, tuple[Request, list[int]]] = {}
|
||||
# Reqs to send and their expiration time
|
||||
self._reqs_need_send: dict[ReqId, float] = {}
|
||||
self._reqs_in_batch: set[ReqId] = set()
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request",
|
||||
@ -324,6 +326,9 @@ class NixlConnectorScheduler:
|
||||
|
||||
if not params:
|
||||
return
|
||||
|
||||
if params.get("do_remote_decode"):
|
||||
self._reqs_in_batch.add(request.request_id)
|
||||
if self.use_host_buffer and params.get("do_remote_decode"):
|
||||
# NOTE: when accelerator is not directly supported by Nixl,
|
||||
# prefilled blocks need to be saved to host memory before transfer.
|
||||
@ -373,6 +378,8 @@ class NixlConnectorScheduler:
|
||||
request_id=req_id,
|
||||
local_block_ids=block_ids,
|
||||
kv_transfer_params=req.kv_transfer_params,
|
||||
load_remote_cache=True,
|
||||
save_to_host=False,
|
||||
)
|
||||
|
||||
for req_id, (req, block_ids) in self._reqs_need_save.items():
|
||||
@ -386,10 +393,12 @@ class NixlConnectorScheduler:
|
||||
)
|
||||
|
||||
meta.reqs_to_send = self._reqs_need_send
|
||||
meta.reqs_in_batch = self._reqs_in_batch
|
||||
|
||||
# Clear the list once workers start the transfers
|
||||
self._reqs_need_recv.clear()
|
||||
self._reqs_need_save.clear()
|
||||
self._reqs_in_batch = set()
|
||||
self._reqs_need_send = {}
|
||||
|
||||
return meta
|
||||
@ -546,6 +555,8 @@ class NixlConnectorWorker:
|
||||
self._recving_transfers = defaultdict[ReqId, list[Transfer]](list)
|
||||
# Track the expiration time of requests that are waiting to be sent.
|
||||
self._reqs_to_send: dict[ReqId, float] = {}
|
||||
# Set of requests that have been part of a batch, regardless of status.
|
||||
self._reqs_to_process: set[ReqId] = set()
|
||||
|
||||
# Background thread for handling new handshake requests.
|
||||
self._nixl_handshake_listener_t: Optional[threading.Thread] = None
|
||||
@ -1082,6 +1093,7 @@ class NixlConnectorWorker:
|
||||
"Releasing expired KV blocks for request %s which were "
|
||||
"retrieved by %d decode worker(s) within %d seconds.", req_id,
|
||||
count, envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT)
|
||||
self._reqs_to_process.remove(req_id)
|
||||
del self._reqs_to_send[req_id]
|
||||
done_sending.add(req_id)
|
||||
|
||||
@ -1097,7 +1109,8 @@ class NixlConnectorWorker:
|
||||
for notifs in self.nixl_wrapper.get_new_notifs().values():
|
||||
for notif in notifs:
|
||||
req_id, tp_ratio = notif.decode("utf-8").rsplit(":", 1)
|
||||
if req_id not in self._reqs_to_send:
|
||||
if (req_id not in self._reqs_to_send
|
||||
and req_id not in self._reqs_to_process):
|
||||
logger.error(
|
||||
"Potentially invalid KV blocks for "
|
||||
"unrecognized request %s were retrieved by "
|
||||
@ -1110,7 +1123,8 @@ class NixlConnectorWorker:
|
||||
tp_ratio):
|
||||
notified_req_ids.add(req_id)
|
||||
del self.consumer_notification_counts_by_req[req_id]
|
||||
del self._reqs_to_send[req_id]
|
||||
self._reqs_to_process.remove(req_id)
|
||||
self._reqs_to_send.pop(req_id, None)
|
||||
return notified_req_ids
|
||||
|
||||
def _pop_done_transfers(
|
||||
@ -1171,8 +1185,19 @@ class NixlConnectorWorker:
|
||||
while not self._ready_requests.empty():
|
||||
self._read_blocks_for_req(*self._ready_requests.get_nowait())
|
||||
|
||||
# Keep around the requests that have been part of a batch. This is
|
||||
# needed because async scheduling pushes the misalignment between the
|
||||
# moment in which requests expiration is set (P side) and the moment in
|
||||
# which blocks are read from D. As P can now more easily lag behind D
|
||||
# while processing the next batch, we make sure to only set an
|
||||
# expiration for requests that have not been read from D yet.
|
||||
for req_id in metadata.reqs_in_batch:
|
||||
self._reqs_to_process.add(req_id)
|
||||
|
||||
# Add to requests that are waiting to be read and track expiration.
|
||||
self._reqs_to_send.update(metadata.reqs_to_send)
|
||||
for req_id, expiration_time in metadata.reqs_to_send.items():
|
||||
if req_id in self._reqs_to_process:
|
||||
self._reqs_to_send[req_id] = expiration_time
|
||||
|
||||
def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
|
||||
logger.debug(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user