diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index f297d44ccf888..325c54eca07d6 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -38,6 +38,9 @@ if TYPE_CHECKING: Transfer = tuple[int, float] # (xfer_handle, start_time) GET_META_MSG = b"get_meta_msg" NIXL_MAX_DESCS = 1000 +import os + +USE_BATCHED = os.getenv("USE_BATCHED", "1") == "1" logger = init_logger(__name__) @@ -717,8 +720,6 @@ class NixlConnectorWorker: # Create dst descs and xfer side handles. TP workers have same #blocks. if engine_id in self.dst_num_blocks: - print(f"{self.dst_num_blocks[engine_id]=}") - print(f"{nixl_agent_meta.num_blocks=}") assert self.dst_num_blocks[engine_id] == nixl_agent_meta.num_blocks else: self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks @@ -1009,9 +1010,11 @@ class NixlConnectorWorker: # Begin async xfer. start = time.perf_counter() - for handle in handles: - self.nixl_wrapper.transfer(handle) - # self.nixl_wrapper.transfer_batched(handles) + if USE_BATCHED: + self.nixl_wrapper.transfer_batched(handles) + else: + for handle in handles: + self.nixl_wrapper.transfer(handle) end = time.perf_counter() logger.info("======== LAUNCH TIME: %s ========", end - start)