From 128eca2ce3bdfd455a25a01e5a768c6ed5350c8b Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 30 Jun 2025 19:48:33 +0000 Subject: [PATCH] update for use batched Signed-off-by: rshaw@neuralmagic.com --- .../kv_transfer/kv_connector/v1/nixl_connector.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index f297d44ccf888..325c54eca07d6 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -38,6 +38,9 @@ if TYPE_CHECKING: Transfer = tuple[int, float] # (xfer_handle, start_time) GET_META_MSG = b"get_meta_msg" NIXL_MAX_DESCS = 1000 +import os + +USE_BATCHED = os.getenv("USE_BATCHED", "1") == "1" logger = init_logger(__name__) @@ -717,8 +720,6 @@ class NixlConnectorWorker: # Create dst descs and xfer side handles. TP workers have same #blocks. if engine_id in self.dst_num_blocks: - print(f"{self.dst_num_blocks[engine_id]=}") - print(f"{nixl_agent_meta.num_blocks=}") assert self.dst_num_blocks[engine_id] == nixl_agent_meta.num_blocks else: self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks @@ -1009,9 +1010,11 @@ class NixlConnectorWorker: # Begin async xfer. start = time.perf_counter() - for handle in handles: - self.nixl_wrapper.transfer(handle) - # self.nixl_wrapper.transfer_batched(handles) + if USE_BATCHED: + self.nixl_wrapper.transfer_batched(handles) + else: + for handle in handles: + self.nixl_wrapper.transfer(handle) end = time.perf_counter() logger.info("======== LAUNCH TIME: %s ========", end - start)