update for use batched

Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
This commit is contained in:
rshaw@neuralmagic.com 2025-06-30 19:48:33 +00:00
parent 6babd39366
commit 128eca2ce3

View File

@ -38,6 +38,9 @@ if TYPE_CHECKING:
Transfer = tuple[int, float] # (xfer_handle, start_time)
GET_META_MSG = b"get_meta_msg"
NIXL_MAX_DESCS = 1000
import os
USE_BATCHED = os.getenv("USE_BATCHED", "1") == "1"
logger = init_logger(__name__)
@ -717,8 +720,6 @@ class NixlConnectorWorker:
# Create dst descs and xfer side handles. TP workers have same #blocks.
if engine_id in self.dst_num_blocks:
print(f"{self.dst_num_blocks[engine_id]=}")
print(f"{nixl_agent_meta.num_blocks=}")
assert self.dst_num_blocks[engine_id] == nixl_agent_meta.num_blocks
else:
self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks
@ -1009,9 +1010,11 @@ class NixlConnectorWorker:
# Begin async xfer.
start = time.perf_counter()
for handle in handles:
self.nixl_wrapper.transfer(handle)
# self.nixl_wrapper.transfer_batched(handles)
if USE_BATCHED:
self.nixl_wrapper.transfer_batched(handles)
else:
for handle in handles:
self.nixl_wrapper.transfer(handle)
end = time.perf_counter()
logger.info("======== LAUNCH TIME: %s ========", end - start)