From 17546dc79f5089565c22b6431e7324a998cc1461 Mon Sep 17 00:00:00 2001 From: Pravein Govindan Kannan Date: Mon, 30 Jun 2025 14:40:18 +0530 Subject: [PATCH] Add threading for load-balancing to different workers --- .../kv_connector/v1/nixl_connector.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index c1f1db7e567cf..13c7806dbb7ac 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -13,6 +13,7 @@ from typing import TYPE_CHECKING, Any, Optional import msgspec import torch import zmq +from concurrent.futures import ThreadPoolExecutor, as_completed from vllm import envs from vllm.attention.selector import backend_name_to_enum, get_attn_backend @@ -986,16 +987,23 @@ class NixlConnectorWorker: # Prepare transfer with Nixl. CHUNK_SIZE = 100 handles = [] - for i in range(0, len(local_block_descs_ids), CHUNK_SIZE): - handles.append( - self.nixl_wrapper.make_prepped_xfer( + futures = [] + with ThreadPoolExecutor() as executor: + for i in range(0, len(local_block_descs_ids), CHUNK_SIZE): + future = executor.submit( + self.nixl_wrapper.make_prepped_xfer, "READ", local_xfer_side_handle, local_block_descs_ids[i:i + CHUNK_SIZE], remote_xfer_side_handle, remote_block_descs_ids[i:i + CHUNK_SIZE], skip_desc_merge=True, - )) + ) + futures.append(future) + + for future in futures: + handles.append(future.result()) + # Begin async xfer. start = time.perf_counter()