diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 325530cab69aa..3d3c393eb3992 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -37,6 +37,9 @@ if TYPE_CHECKING: Transfer = tuple[int, float] # (xfer_handle, start_time) GET_META_MSG = b"get_meta_msg" +import os +LOG_XFER_TIME = os.getenv("VLLM_LOG_XFER_TIME", "0") == "1" + logger = init_logger(__name__) # Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used @@ -764,8 +767,11 @@ class NixlConnectorWorker: to Rank 0 once their transaction is done + Rank 0 returns finished sets to Scheduler only once all ranks are done. """ + + start = time.perf_counter() done_sending = self._get_new_notifs() done_recving = self._pop_done_transfers(self._recving_transfers) + if len(done_sending) > 0 or len(done_recving) > 0: logger.debug( "Rank %s, get_finished: %s requests done sending " @@ -806,6 +812,10 @@ class NixlConnectorWorker: if self._done_sending_count[req_id] == self.world_size: del self._done_sending_count[req_id] all_done_sending.add(req_id) + + end = time.perf_counter() + if LOG_XFER_TIME: + logger.info("========== .get_finished(): %s ==========", end - start) return all_done_sending, all_done_recving @@ -815,6 +825,10 @@ class NixlConnectorWorker: self.tp_group.send_object(finished_req_ids, dst=0) # Unused as only Rank 0 results are sent to scheduler. + end = time.perf_counter() + if LOG_XFER_TIME: + logger.info("========== .get_finished(): %s ==========", end - start) + return done_sending, done_recving def _get_new_notifs(self) -> set[str]: @@ -845,8 +859,8 @@ class NixlConnectorWorker: Returns: set of req_ids that have all done xfers """ - done_req_ids: set[str] = set() - for req_id, (handles, agent_name, notif_id) in list(transfers.items()): + done_req_ids: set[str, float] = set() + for req_id, (handles, agent_name, notif_id, start_time) in list(transfers.items()): new_handles = [] for handle in handles: xfer_state = self.nixl_wrapper.check_xfer_state(handle) @@ -863,6 +877,9 @@ class NixlConnectorWorker: self.nixl_wrapper.send_notif(agent_name, notif_id) del transfers[req_id] done_req_ids.add(req_id) + if LOG_XFER_TIME: + logger.info("========== transmission time: %s ==========", time.perf_counter() - start_time) + else: transfers[req_id] = (new_handles, agent_name, notif_id) @@ -998,13 +1015,14 @@ class NixlConnectorWorker: start = time.perf_counter() self.nixl_wrapper.transfer_batched(handles) end = time.perf_counter() - logger.info("========== TRANSFER BATCHED: %s ==========", end - start) + if LOG_XFER_TIME: + logger.info("========== .transfer_batched(): %s ==========", end - start) # Keep track of ongoing transfers. remote_rank = self.tp_rank // tp_ratio agent_name = self._remote_agents[dst_engine_id][remote_rank] assert request_id not in self._recving_transfers - self._recving_transfers[request_id] = (handles, agent_name, notif_id) + self._recving_transfers[request_id] = (handles, agent_name, notif_id, time.perf_counter()) def _get_block_descs_ids(self, engine_id: str,