diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index c205501e6c98..5af2b33f029c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -105,6 +105,7 @@ class NixlConnectorMetadata(KVConnectorMetadata): self.reqs_to_recv: dict[ReqId, ReqMeta] = {} self.reqs_to_save: dict[ReqId, ReqMeta] = {} self.reqs_to_send: dict[ReqId, float] = {} + self.reqs_in_batch: set[ReqId] = set() def add_new_req( self, @@ -278,6 +279,7 @@ class NixlConnectorScheduler: self._reqs_need_save: dict[ReqId, tuple[Request, list[int]]] = {} # Reqs to send and their expiration time self._reqs_need_send: dict[ReqId, float] = {} + self._reqs_in_batch: set[ReqId] = set() def get_num_new_matched_tokens( self, request: "Request", @@ -324,6 +326,9 @@ class NixlConnectorScheduler: if not params: return + + if params.get("do_remote_decode"): + self._reqs_in_batch.add(request.request_id) if self.use_host_buffer and params.get("do_remote_decode"): # NOTE: when accelerator is not directly supported by Nixl, # prefilled blocks need to be saved to host memory before transfer. @@ -373,6 +378,8 @@ class NixlConnectorScheduler: request_id=req_id, local_block_ids=block_ids, kv_transfer_params=req.kv_transfer_params, + load_remote_cache=True, + save_to_host=False, ) for req_id, (req, block_ids) in self._reqs_need_save.items(): @@ -386,10 +393,12 @@ class NixlConnectorScheduler: ) meta.reqs_to_send = self._reqs_need_send + meta.reqs_in_batch = self._reqs_in_batch # Clear the list once workers start the transfers self._reqs_need_recv.clear() self._reqs_need_save.clear() + self._reqs_in_batch = set() self._reqs_need_send = {} return meta @@ -546,6 +555,8 @@ class NixlConnectorWorker: self._recving_transfers = defaultdict[ReqId, list[Transfer]](list) # Track the expiration time of requests that are waiting to be sent. self._reqs_to_send: dict[ReqId, float] = {} + # Set of requests that have been part of a batch, regardless of status. + self._reqs_to_process: set[ReqId] = set() # Background thread for handling new handshake requests. self._nixl_handshake_listener_t: Optional[threading.Thread] = None @@ -1082,6 +1093,7 @@ class NixlConnectorWorker: "Releasing expired KV blocks for request %s which were " "retrieved by %d decode worker(s) within %d seconds.", req_id, count, envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT) + self._reqs_to_process.remove(req_id) del self._reqs_to_send[req_id] done_sending.add(req_id) @@ -1097,7 +1109,8 @@ class NixlConnectorWorker: for notifs in self.nixl_wrapper.get_new_notifs().values(): for notif in notifs: req_id, tp_ratio = notif.decode("utf-8").rsplit(":", 1) - if req_id not in self._reqs_to_send: + if (req_id not in self._reqs_to_send + and req_id not in self._reqs_to_process): logger.error( "Potentially invalid KV blocks for " "unrecognized request %s were retrieved by " @@ -1110,7 +1123,8 @@ class NixlConnectorWorker: tp_ratio): notified_req_ids.add(req_id) del self.consumer_notification_counts_by_req[req_id] - del self._reqs_to_send[req_id] + self._reqs_to_process.remove(req_id) + self._reqs_to_send.pop(req_id, None) return notified_req_ids def _pop_done_transfers( @@ -1171,8 +1185,19 @@ class NixlConnectorWorker: while not self._ready_requests.empty(): self._read_blocks_for_req(*self._ready_requests.get_nowait()) + # Keep around the requests that have been part of a batch. This is + # needed because async scheduling pushes the misalignment between the + # moment in which requests expiration is set (P side) and the moment in + # which blocks are read from D. As P can now more easily lag behind D + # while processing the next batch, we make sure to only set an + # expiration for requests that have not been read from D yet. + for req_id in metadata.reqs_in_batch: + self._reqs_to_process.add(req_id) + # Add to requests that are waiting to be read and track expiration. - self._reqs_to_send.update(metadata.reqs_to_send) + for req_id, expiration_time in metadata.reqs_to_send.items(): + if req_id in self._reqs_to_process: + self._reqs_to_send[req_id] = expiration_time def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): logger.debug(