Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
This commit is contained in:
rshaw@neuralmagic.com 2025-06-19 13:10:52 +00:00
parent dac8cc49f4
commit 489e5ba5ce

View File

@ -921,6 +921,15 @@ class NixlConnectorWorker:
# corresponding rank. With heterogeneous TP, fixing D>P, the D tp
# workers will issue xfers to parts of the P worker remote kv caches.
# Sort block ids to ensure nixl can merge contiguous blocks.
start = time.perf_counter()
sorted_idx = sorted(range(len(local_block_ids)),
key=local_block_ids.__getitem__)
local_block_ids = [local_block_ids[i] for i in sorted_idx]
remote_block_ids = [remote_block_ids[i] for i in sorted_idx]
end = time.perf_counter()
print(f"REORDER took: {end - start}")
# Get descs ids.
local_block_descs_ids: list[int] = []
remote_block_descs_ids: list[int] = []
@ -965,10 +974,17 @@ class NixlConnectorWorker:
remote_xfer_side_handle,
remote_block_descs_ids,
notif_msg=notif_id,
# skip_desc_merge=True,
)
# Begin async xfer.
start = time.perf_counter()
self.nixl_wrapper.transfer(handle)
end = time.perf_counter()
print(f"self.nixl_wrapper.transfer() TIME: {end-start}")
if end - start > 0.2:
print(f"{local_block_ids=}")
print(f"{remote_block_ids=}")
# Use handle to check completion in future step().
# TODO (NickLucche) surface xfer elapsed time