diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index c0039a668763c..0abfa489e1312 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -343,8 +343,8 @@ class NixlConnectorWorker: print(f"NUM_WORKERS: {num_workers=}") self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None, - num_workers=num_workers, - num_shared_workers=None) + num_workers=None, + num_shared_workers=num_workers) # Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}. self._remote_agents: dict[str, dict[int, str]] = defaultdict(dict) @@ -987,32 +987,28 @@ class NixlConnectorWorker: # Prepare transfer with Nixl. CHUNK_SIZE = 1000 handles = [] - futures = [] # NOTE: this is a hack to make make_prepped_xfer into threads so that # different workers are allocated for each chuck. Without this change, # nixl was allocating the same worker (0) for all the chunks and the # overall launch time was >300 ms. - with ThreadPoolExecutor() as executor: - for i in range(0, len(local_block_descs_ids), CHUNK_SIZE): - future = executor.submit( - self.nixl_wrapper.make_prepped_xfer, - "READ", - local_xfer_side_handle, - local_block_descs_ids[i:i + CHUNK_SIZE], - remote_xfer_side_handle, - remote_block_descs_ids[i:i + CHUNK_SIZE], - skip_desc_merge=True, - ) - futures.append(future) - - for future in futures: - handles.append(future.result()) + for i in range(0, len(local_block_descs_ids), CHUNK_SIZE): + handle = self.nixl_wrapper.make_prepped_xfer( + "READ", + local_xfer_side_handle, + local_block_descs_ids[i:i + CHUNK_SIZE], + remote_xfer_side_handle, + remote_block_descs_ids[i:i + CHUNK_SIZE], + skip_desc_merge=True, + ) + handles.append(handle) # Begin async xfer. start = time.perf_counter() if USE_BATCHED: + print("BATCHED!") self.nixl_wrapper.transfer_batched(handles) else: + print("NON BATCHED!") for handle in handles: self.nixl_wrapper.transfer(handle) end = time.perf_counter()