mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-05 23:02:17 +08:00
parent
f9617c75ad
commit
ee2a4b0889
@ -36,6 +36,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
Transfer = tuple[int, float] # (xfer_handle, start_time)
|
Transfer = tuple[int, float] # (xfer_handle, start_time)
|
||||||
GET_META_MSG = b"get_meta_msg"
|
GET_META_MSG = b"get_meta_msg"
|
||||||
|
NIXL_MAX_DESCS = 1000
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -371,8 +372,8 @@ class NixlConnectorWorker:
|
|||||||
self._registered_descs: list[Any] = []
|
self._registered_descs: list[Any] = []
|
||||||
|
|
||||||
# In progress transfers.
|
# In progress transfers.
|
||||||
# [req_id -> list[handle]]
|
# [req_id -> list[handles], agent_name, notif_id]
|
||||||
self._recving_transfers = defaultdict[str, list[Transfer]](list)
|
self._recving_transfers: dict[str, tuple[list[int], str, str]] = {}
|
||||||
|
|
||||||
# Complete transfer tracker. Used by the rank 0 to track finished
|
# Complete transfer tracker. Used by the rank 0 to track finished
|
||||||
# transactions on ranks 1 to N-1.
|
# transactions on ranks 1 to N-1.
|
||||||
@ -826,7 +827,8 @@ class NixlConnectorWorker:
|
|||||||
return notified_req_ids
|
return notified_req_ids
|
||||||
|
|
||||||
def _pop_done_transfers(
|
def _pop_done_transfers(
|
||||||
self, transfers: dict[str, list[tuple[int, float]]]) -> set[str]:
|
self, transfers: dict[str, tuple[list[int], str,
|
||||||
|
str]]) -> set[str]:
|
||||||
"""
|
"""
|
||||||
Pop completed xfers by checking for DONE state.
|
Pop completed xfers by checking for DONE state.
|
||||||
Args:
|
Args:
|
||||||
@ -835,18 +837,29 @@ class NixlConnectorWorker:
|
|||||||
set of req_ids that have all done xfers
|
set of req_ids that have all done xfers
|
||||||
"""
|
"""
|
||||||
done_req_ids: set[str] = set()
|
done_req_ids: set[str] = set()
|
||||||
for req_id, handles in list(transfers.items()):
|
for req_id, (handles, agent_name, notif_id) in list(transfers.items()):
|
||||||
for handle, xfer_stime in handles:
|
new_handles = []
|
||||||
|
for handle in handles:
|
||||||
xfer_state = self.nixl_wrapper.check_xfer_state(handle)
|
xfer_state = self.nixl_wrapper.check_xfer_state(handle)
|
||||||
if xfer_state == "DONE":
|
if xfer_state == "DONE":
|
||||||
self.nixl_wrapper.release_xfer_handle(handle)
|
self.nixl_wrapper.release_xfer_handle(handle)
|
||||||
done_req_ids.add(req_id)
|
|
||||||
del transfers[req_id]
|
|
||||||
elif xfer_state == "PROC":
|
elif xfer_state == "PROC":
|
||||||
continue
|
new_handles.append(handle)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Transfer failed with state %s",
|
raise RuntimeError("Transfer failed with state %s",
|
||||||
xfer_state)
|
xfer_state)
|
||||||
|
|
||||||
|
# Done.
|
||||||
|
if len(new_handles) == 0:
|
||||||
|
start = time.perf_counter()
|
||||||
|
self.nixl_wrapper.send_notif(agent_name, notif_id)
|
||||||
|
del transfers[req_id]
|
||||||
|
done_req_ids.add(req_id)
|
||||||
|
end = time.perf_counter()
|
||||||
|
print(f"========= SEND NOTIF TIME: {end - start} =========")
|
||||||
|
else:
|
||||||
|
transfers[req_id] = (new_handles, notif_id, agent_name)
|
||||||
|
|
||||||
return done_req_ids
|
return done_req_ids
|
||||||
|
|
||||||
def start_load_kv(self, metadata: NixlConnectorMetadata):
|
def start_load_kv(self, metadata: NixlConnectorMetadata):
|
||||||
@ -958,25 +971,32 @@ class NixlConnectorWorker:
|
|||||||
assert len(local_block_descs_ids) == len(remote_block_descs_ids)
|
assert len(local_block_descs_ids) == len(remote_block_descs_ids)
|
||||||
|
|
||||||
# Prepare transfer with Nixl.
|
# Prepare transfer with Nixl.
|
||||||
handle = self.nixl_wrapper.make_prepped_xfer(
|
CHUNK_SIZE = 1000
|
||||||
"READ",
|
handles = []
|
||||||
local_xfer_side_handle,
|
for i in range(0, len(local_block_descs_ids), CHUNK_SIZE):
|
||||||
local_block_descs_ids,
|
handles.append(
|
||||||
remote_xfer_side_handle,
|
self.nixl_wrapper.make_prepped_xfer(
|
||||||
remote_block_descs_ids,
|
"READ",
|
||||||
notif_msg=notif_id,
|
local_xfer_side_handle,
|
||||||
)
|
local_block_descs_ids[i:i + CHUNK_SIZE],
|
||||||
|
remote_xfer_side_handle,
|
||||||
|
remote_block_descs_ids[i:i + CHUNK_SIZE],
|
||||||
|
skip_desc_merge=True,
|
||||||
|
))
|
||||||
|
|
||||||
# Begin async xfer.
|
# Begin async xfer.
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
self.nixl_wrapper.transfer(handle)
|
# for handle in handles:
|
||||||
|
# self.nixl_wrapper.transfer(handle)
|
||||||
|
self.nixl_wrapper.transfer_batched(handles)
|
||||||
end = time.perf_counter()
|
end = time.perf_counter()
|
||||||
logger.info("======== LAUNCH TIME: %s ========", end - start)
|
logger.info("======== LAUNCH TIME: %s ========", end - start)
|
||||||
|
|
||||||
# Use handle to check completion in future step().
|
# Keep track of ongoing transfers.
|
||||||
# TODO (NickLucche) surface xfer elapsed time
|
remote_rank = self.tp_rank // tp_ratio
|
||||||
self._recving_transfers[request_id].append(
|
agent_name = self._remote_agents[dst_engine_id][remote_rank]
|
||||||
(handle, time.perf_counter()))
|
assert request_id not in self._recving_transfers
|
||||||
|
self._recving_transfers[request_id] = (handles, agent_name, notif_id)
|
||||||
|
|
||||||
def _get_block_descs_ids(self,
|
def _get_block_descs_ids(self,
|
||||||
engine_id: str,
|
engine_id: str,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user