diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index aff27c4555907..46d7c157ef581 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -36,6 +36,7 @@ if TYPE_CHECKING: Transfer = tuple[int, float] # (xfer_handle, start_time) GET_META_MSG = b"get_meta_msg" +NIXL_MAX_DESCS = 1000 logger = init_logger(__name__) @@ -371,8 +372,8 @@ class NixlConnectorWorker: self._registered_descs: list[Any] = [] # In progress transfers. - # [req_id -> list[handle]] - self._recving_transfers = defaultdict[str, list[Transfer]](list) + # [req_id -> list[handles], agent_name, notif_id] + self._recving_transfers: dict[str, tuple[list[int], str, str]] = {} # Complete transfer tracker. Used by the rank 0 to track finished # transactions on ranks 1 to N-1. @@ -826,7 +827,8 @@ class NixlConnectorWorker: return notified_req_ids def _pop_done_transfers( - self, transfers: dict[str, list[tuple[int, float]]]) -> set[str]: + self, transfers: dict[str, tuple[list[int], str, + str]]) -> set[str]: """ Pop completed xfers by checking for DONE state. Args: @@ -835,18 +837,29 @@ class NixlConnectorWorker: set of req_ids that have all done xfers """ done_req_ids: set[str] = set() - for req_id, handles in list(transfers.items()): - for handle, xfer_stime in handles: + for req_id, (handles, agent_name, notif_id) in list(transfers.items()): + new_handles = [] + for handle in handles: xfer_state = self.nixl_wrapper.check_xfer_state(handle) if xfer_state == "DONE": self.nixl_wrapper.release_xfer_handle(handle) - done_req_ids.add(req_id) - del transfers[req_id] elif xfer_state == "PROC": - continue + new_handles.append(handle) else: raise RuntimeError("Transfer failed with state %s", xfer_state) + + # Done. + if len(new_handles) == 0: + start = time.perf_counter() + self.nixl_wrapper.send_notif(agent_name, notif_id) + del transfers[req_id] + done_req_ids.add(req_id) + end = time.perf_counter() + print(f"========= SEND NOTIF TIME: {end - start} =========") + else: + transfers[req_id] = (new_handles, notif_id, agent_name) + return done_req_ids def start_load_kv(self, metadata: NixlConnectorMetadata): @@ -958,25 +971,32 @@ class NixlConnectorWorker: assert len(local_block_descs_ids) == len(remote_block_descs_ids) # Prepare transfer with Nixl. - handle = self.nixl_wrapper.make_prepped_xfer( - "READ", - local_xfer_side_handle, - local_block_descs_ids, - remote_xfer_side_handle, - remote_block_descs_ids, - notif_msg=notif_id, - ) + CHUNK_SIZE = 1000 + handles = [] + for i in range(0, len(local_block_descs_ids), CHUNK_SIZE): + handles.append( + self.nixl_wrapper.make_prepped_xfer( + "READ", + local_xfer_side_handle, + local_block_descs_ids[i:i + CHUNK_SIZE], + remote_xfer_side_handle, + remote_block_descs_ids[i:i + CHUNK_SIZE], + skip_desc_merge=True, + )) # Begin async xfer. start = time.perf_counter() - self.nixl_wrapper.transfer(handle) + # for handle in handles: + # self.nixl_wrapper.transfer(handle) + self.nixl_wrapper.transfer_batched(handles) end = time.perf_counter() logger.info("======== LAUNCH TIME: %s ========", end - start) - # Use handle to check completion in future step(). - # TODO (NickLucche) surface xfer elapsed time - self._recving_transfers[request_id].append( - (handle, time.perf_counter())) + # Keep track of ongoing transfers. + remote_rank = self.tp_rank // tp_ratio + agent_name = self._remote_agents[dst_engine_id][remote_rank] + assert request_id not in self._recving_transfers + self._recving_transfers[request_id] = (handles, agent_name, notif_id) def _get_block_descs_ids(self, engine_id: str,