diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 0c5986bfafaa0..5a49b8aa1244b 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -41,13 +41,13 @@ Transfer = tuple[int, float] # (xfer_handle, start_time) EngineId = str ReqId = str GET_META_MSG = b"get_meta_msg" +NIXL_NUM_WORKERS = 32 logger = init_logger(__name__) # Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used try: - from nixl._api import nixl_agent as NixlWrapper - logger.info("NIXL is available") + from nixl._api import nixl_agent as NixlWrapper, nixl_agent_config except ImportError: logger.warning("NIXL is not available") NixlWrapper = None @@ -361,7 +361,12 @@ class NixlConnectorWorker: self.block_size = vllm_config.cache_config.block_size # Agent. - self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None) + import os + if os.getenv("VLLM_USE_NIXL_WORKERS", "0") == "1": + config = nixl_agent_config(num_threads=NIXL_NUM_WORKERS) + else: + config = None + self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), config) # Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}. self._remote_agents: dict[EngineId, dict[int, str]] = defaultdict(dict) @@ -1019,10 +1024,14 @@ class NixlConnectorWorker: remote_xfer_side_handle, remote_block_descs_ids, notif_msg=notif_id, + skip_desc_merge=True, ) # Begin async xfer. + start = time.perf_counter() self.nixl_wrapper.transfer(handle) + end = time.perf_counter() + logger.info(f"TIME: {end - start}") # Use handle to check completion in future step(). # TODO (NickLucche) surface xfer elapsed time