Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
This commit is contained in:
rshaw@neuralmagic.com 2025-06-30 01:11:22 +00:00
parent f9617c75ad
commit ee2a4b0889

View File

@ -36,6 +36,7 @@ if TYPE_CHECKING:
Transfer = tuple[int, float] # (xfer_handle, start_time)
GET_META_MSG = b"get_meta_msg"
NIXL_MAX_DESCS = 1000
logger = init_logger(__name__)
@ -371,8 +372,8 @@ class NixlConnectorWorker:
self._registered_descs: list[Any] = []
# In progress transfers.
# [req_id -> list[handle]]
self._recving_transfers = defaultdict[str, list[Transfer]](list)
# [req_id -> list[handles], agent_name, notif_id]
self._recving_transfers: dict[str, tuple[list[int], str, str]] = {}
# Complete transfer tracker. Used by the rank 0 to track finished
# transactions on ranks 1 to N-1.
@ -826,7 +827,8 @@ class NixlConnectorWorker:
return notified_req_ids
def _pop_done_transfers(
self, transfers: dict[str, list[tuple[int, float]]]) -> set[str]:
self, transfers: dict[str, tuple[list[int], str,
str]]) -> set[str]:
"""
Pop completed xfers by checking for DONE state.
Args:
@ -835,18 +837,29 @@ class NixlConnectorWorker:
set of req_ids that have all done xfers
"""
done_req_ids: set[str] = set()
for req_id, handles in list(transfers.items()):
for handle, xfer_stime in handles:
for req_id, (handles, agent_name, notif_id) in list(transfers.items()):
new_handles = []
for handle in handles:
xfer_state = self.nixl_wrapper.check_xfer_state(handle)
if xfer_state == "DONE":
self.nixl_wrapper.release_xfer_handle(handle)
done_req_ids.add(req_id)
del transfers[req_id]
elif xfer_state == "PROC":
continue
new_handles.append(handle)
else:
raise RuntimeError("Transfer failed with state %s",
xfer_state)
# Done.
if len(new_handles) == 0:
start = time.perf_counter()
self.nixl_wrapper.send_notif(agent_name, notif_id)
del transfers[req_id]
done_req_ids.add(req_id)
end = time.perf_counter()
print(f"========= SEND NOTIF TIME: {end - start} =========")
else:
transfers[req_id] = (new_handles, notif_id, agent_name)
return done_req_ids
def start_load_kv(self, metadata: NixlConnectorMetadata):
@ -958,25 +971,32 @@ class NixlConnectorWorker:
assert len(local_block_descs_ids) == len(remote_block_descs_ids)
# Prepare transfer with Nixl.
handle = self.nixl_wrapper.make_prepped_xfer(
"READ",
local_xfer_side_handle,
local_block_descs_ids,
remote_xfer_side_handle,
remote_block_descs_ids,
notif_msg=notif_id,
)
CHUNK_SIZE = 1000
handles = []
for i in range(0, len(local_block_descs_ids), CHUNK_SIZE):
handles.append(
self.nixl_wrapper.make_prepped_xfer(
"READ",
local_xfer_side_handle,
local_block_descs_ids[i:i + CHUNK_SIZE],
remote_xfer_side_handle,
remote_block_descs_ids[i:i + CHUNK_SIZE],
skip_desc_merge=True,
))
# Begin async xfer.
start = time.perf_counter()
self.nixl_wrapper.transfer(handle)
# for handle in handles:
# self.nixl_wrapper.transfer(handle)
self.nixl_wrapper.transfer_batched(handles)
end = time.perf_counter()
logger.info("======== LAUNCH TIME: %s ========", end - start)
# Use handle to check completion in future step().
# TODO (NickLucche) surface xfer elapsed time
self._recving_transfers[request_id].append(
(handle, time.perf_counter()))
# Keep track of ongoing transfers.
remote_rank = self.tp_rank // tp_ratio
agent_name = self._remote_agents[dst_engine_id][remote_rank]
assert request_id not in self._recving_transfers
self._recving_transfers[request_id] = (handles, agent_name, notif_id)
def _get_block_descs_ids(self,
engine_id: str,