mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-23 10:06:19 +08:00
[Bugfix][NIXL] Fix Async Scheduler timeout issue (#25808)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
parent
c216119d64
commit
da63274d9f
@ -105,6 +105,7 @@ class NixlConnectorMetadata(KVConnectorMetadata):
|
|||||||
self.reqs_to_recv: dict[ReqId, ReqMeta] = {}
|
self.reqs_to_recv: dict[ReqId, ReqMeta] = {}
|
||||||
self.reqs_to_save: dict[ReqId, ReqMeta] = {}
|
self.reqs_to_save: dict[ReqId, ReqMeta] = {}
|
||||||
self.reqs_to_send: dict[ReqId, float] = {}
|
self.reqs_to_send: dict[ReqId, float] = {}
|
||||||
|
self.reqs_in_batch: set[ReqId] = set()
|
||||||
|
|
||||||
def add_new_req(
|
def add_new_req(
|
||||||
self,
|
self,
|
||||||
@ -278,6 +279,7 @@ class NixlConnectorScheduler:
|
|||||||
self._reqs_need_save: dict[ReqId, tuple[Request, list[int]]] = {}
|
self._reqs_need_save: dict[ReqId, tuple[Request, list[int]]] = {}
|
||||||
# Reqs to send and their expiration time
|
# Reqs to send and their expiration time
|
||||||
self._reqs_need_send: dict[ReqId, float] = {}
|
self._reqs_need_send: dict[ReqId, float] = {}
|
||||||
|
self._reqs_in_batch: set[ReqId] = set()
|
||||||
|
|
||||||
def get_num_new_matched_tokens(
|
def get_num_new_matched_tokens(
|
||||||
self, request: "Request",
|
self, request: "Request",
|
||||||
@ -324,6 +326,9 @@ class NixlConnectorScheduler:
|
|||||||
|
|
||||||
if not params:
|
if not params:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if params.get("do_remote_decode"):
|
||||||
|
self._reqs_in_batch.add(request.request_id)
|
||||||
if self.use_host_buffer and params.get("do_remote_decode"):
|
if self.use_host_buffer and params.get("do_remote_decode"):
|
||||||
# NOTE: when accelerator is not directly supported by Nixl,
|
# NOTE: when accelerator is not directly supported by Nixl,
|
||||||
# prefilled blocks need to be saved to host memory before transfer.
|
# prefilled blocks need to be saved to host memory before transfer.
|
||||||
@ -373,6 +378,8 @@ class NixlConnectorScheduler:
|
|||||||
request_id=req_id,
|
request_id=req_id,
|
||||||
local_block_ids=block_ids,
|
local_block_ids=block_ids,
|
||||||
kv_transfer_params=req.kv_transfer_params,
|
kv_transfer_params=req.kv_transfer_params,
|
||||||
|
load_remote_cache=True,
|
||||||
|
save_to_host=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
for req_id, (req, block_ids) in self._reqs_need_save.items():
|
for req_id, (req, block_ids) in self._reqs_need_save.items():
|
||||||
@ -386,10 +393,12 @@ class NixlConnectorScheduler:
|
|||||||
)
|
)
|
||||||
|
|
||||||
meta.reqs_to_send = self._reqs_need_send
|
meta.reqs_to_send = self._reqs_need_send
|
||||||
|
meta.reqs_in_batch = self._reqs_in_batch
|
||||||
|
|
||||||
# Clear the list once workers start the transfers
|
# Clear the list once workers start the transfers
|
||||||
self._reqs_need_recv.clear()
|
self._reqs_need_recv.clear()
|
||||||
self._reqs_need_save.clear()
|
self._reqs_need_save.clear()
|
||||||
|
self._reqs_in_batch = set()
|
||||||
self._reqs_need_send = {}
|
self._reqs_need_send = {}
|
||||||
|
|
||||||
return meta
|
return meta
|
||||||
@ -546,6 +555,8 @@ class NixlConnectorWorker:
|
|||||||
self._recving_transfers = defaultdict[ReqId, list[Transfer]](list)
|
self._recving_transfers = defaultdict[ReqId, list[Transfer]](list)
|
||||||
# Track the expiration time of requests that are waiting to be sent.
|
# Track the expiration time of requests that are waiting to be sent.
|
||||||
self._reqs_to_send: dict[ReqId, float] = {}
|
self._reqs_to_send: dict[ReqId, float] = {}
|
||||||
|
# Set of requests that have been part of a batch, regardless of status.
|
||||||
|
self._reqs_to_process: set[ReqId] = set()
|
||||||
|
|
||||||
# Background thread for handling new handshake requests.
|
# Background thread for handling new handshake requests.
|
||||||
self._nixl_handshake_listener_t: Optional[threading.Thread] = None
|
self._nixl_handshake_listener_t: Optional[threading.Thread] = None
|
||||||
@ -1082,6 +1093,7 @@ class NixlConnectorWorker:
|
|||||||
"Releasing expired KV blocks for request %s which were "
|
"Releasing expired KV blocks for request %s which were "
|
||||||
"retrieved by %d decode worker(s) within %d seconds.", req_id,
|
"retrieved by %d decode worker(s) within %d seconds.", req_id,
|
||||||
count, envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT)
|
count, envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT)
|
||||||
|
self._reqs_to_process.remove(req_id)
|
||||||
del self._reqs_to_send[req_id]
|
del self._reqs_to_send[req_id]
|
||||||
done_sending.add(req_id)
|
done_sending.add(req_id)
|
||||||
|
|
||||||
@ -1097,7 +1109,8 @@ class NixlConnectorWorker:
|
|||||||
for notifs in self.nixl_wrapper.get_new_notifs().values():
|
for notifs in self.nixl_wrapper.get_new_notifs().values():
|
||||||
for notif in notifs:
|
for notif in notifs:
|
||||||
req_id, tp_ratio = notif.decode("utf-8").rsplit(":", 1)
|
req_id, tp_ratio = notif.decode("utf-8").rsplit(":", 1)
|
||||||
if req_id not in self._reqs_to_send:
|
if (req_id not in self._reqs_to_send
|
||||||
|
and req_id not in self._reqs_to_process):
|
||||||
logger.error(
|
logger.error(
|
||||||
"Potentially invalid KV blocks for "
|
"Potentially invalid KV blocks for "
|
||||||
"unrecognized request %s were retrieved by "
|
"unrecognized request %s were retrieved by "
|
||||||
@ -1110,7 +1123,8 @@ class NixlConnectorWorker:
|
|||||||
tp_ratio):
|
tp_ratio):
|
||||||
notified_req_ids.add(req_id)
|
notified_req_ids.add(req_id)
|
||||||
del self.consumer_notification_counts_by_req[req_id]
|
del self.consumer_notification_counts_by_req[req_id]
|
||||||
del self._reqs_to_send[req_id]
|
self._reqs_to_process.remove(req_id)
|
||||||
|
self._reqs_to_send.pop(req_id, None)
|
||||||
return notified_req_ids
|
return notified_req_ids
|
||||||
|
|
||||||
def _pop_done_transfers(
|
def _pop_done_transfers(
|
||||||
@ -1171,8 +1185,19 @@ class NixlConnectorWorker:
|
|||||||
while not self._ready_requests.empty():
|
while not self._ready_requests.empty():
|
||||||
self._read_blocks_for_req(*self._ready_requests.get_nowait())
|
self._read_blocks_for_req(*self._ready_requests.get_nowait())
|
||||||
|
|
||||||
|
# Keep around the requests that have been part of a batch. This is
|
||||||
|
# needed because async scheduling pushes the misalignment between the
|
||||||
|
# moment in which requests expiration is set (P side) and the moment in
|
||||||
|
# which blocks are read from D. As P can now more easily lag behind D
|
||||||
|
# while processing the next batch, we make sure to only set an
|
||||||
|
# expiration for requests that have not been read from D yet.
|
||||||
|
for req_id in metadata.reqs_in_batch:
|
||||||
|
self._reqs_to_process.add(req_id)
|
||||||
|
|
||||||
# Add to requests that are waiting to be read and track expiration.
|
# Add to requests that are waiting to be read and track expiration.
|
||||||
self._reqs_to_send.update(metadata.reqs_to_send)
|
for req_id, expiration_time in metadata.reqs_to_send.items():
|
||||||
|
if req_id in self._reqs_to_process:
|
||||||
|
self._reqs_to_send[req_id] = expiration_time
|
||||||
|
|
||||||
def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
|
def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user