Signed-off-by: inkcherry <mingzhi.liu@amd.com>
This commit is contained in:
inkcherry 2025-11-20 07:17:34 +00:00
parent 64694c3e76
commit 245b71a891

View File

@ -68,8 +68,6 @@ class MoRIIOConstants:
OVER = b"OVER"
COMPLETION_PREFIX = "cmpl"
# Default GPU count per node for standard configurations
RANK_PER_NODE = 8
PING_INTERVAL = 5
MAX_PING_RETRIES = 100000
@ -204,7 +202,7 @@ def get_moriio_mode() -> MoRIIOMode:
def get_port_offset(dp_rank: int, tp_rank: int, tp_size: int = 1) -> int:
return ((dp_rank) * tp_size + tp_rank) % MoRIIOConstants.RANK_PER_NODE
return ((dp_rank) * tp_size + tp_rank)
@dataclass
@ -405,8 +403,8 @@ class MoRIIOWriter:
return
# Wait for CUDA event
#The attention computation of the current layer cannot overlap with the kv transfer task, otherwise it will cause precision issues.
#This event is used to synchronize the kv transfer and computation tasks.
# The attention computation of the current layer cannot overlap with the kv transfer task, otherwise it will cause precision issues.
# This event is used to synchronize the kv transfer and computation tasks.
task.event.synchronize()
# Update engine ID with DP rank
@ -546,8 +544,7 @@ class MoRIIOWrapper:
self.done_write_cache_req_ids = []
self.notify_thread = None
self.sock = None
self.sessions = []
self.kv_caches = None
self.sessions: list["IOEngine.Session"] = []
self.paths = {}
def set_moriio_engine(self, moriio_engine):
@ -730,7 +727,7 @@ class MoRIIOWrapper:
path = make_zmq_path("tcp", remote_ip, str(remote_port))
if path not in self.paths:
ctx = zmq.Context()
ctx = zmq.Context.instance()
sock = make_zmq_socket(
ctx=ctx, path=path, socket_type=zmq.DEALER, bind=False
)
@ -1033,7 +1030,7 @@ class MoRIIOConnectorScheduler:
):
path = make_zmq_path("tcp", host, port)
if path not in self.paths:
ctx = zmq.Context()
ctx = zmq.Context.instance()
sock = make_zmq_socket(
ctx=ctx, path=path, socket_type=zmq.DEALER, bind=False
)
@ -1268,7 +1265,7 @@ class MoRIIOConnectorWorker:
self.mode = get_moriio_mode()
logger.info("Initializing MoRIIO worker %s", engine_id)
# for debug
logging.getLogger("aiter").disabled = True
# Config.
@ -1377,7 +1374,7 @@ class MoRIIOConnectorWorker:
)
self.slot_size_bytes = 0
self.load_ready_flag = False
self.load_ready_flag = {}
self.write_ready_flags = {}
self.kv_cache_shape = None
self.block_shape = None
@ -1569,9 +1566,9 @@ class MoRIIOConnectorWorker:
retry_count += 1
finally:
if retry_count >= MoRIIOConstants.MAX_RETRIES:
if retry_count >= MoRIIOConstants.MAX_PING_RETRIES:
logger.error(
f"Max retries ({MoRIIOConstants.MAX_RETRIES}) exceeded. Stopping ping loop."
f"Max retries ({MoRIIOConstants.MAX_PING_RETRIES}) exceeded. Stopping ping loop."
)
break
@ -1590,12 +1587,19 @@ class MoRIIOConnectorWorker:
if self.metadata_socket not in socks:
continue
def close(self):
if hasattr(self, '_handshake_initiation_executor'):
self._handshake_initiation_executor.shutdown(wait=False)
if hasattr(self, '_moriio_handshake_listener_t') and self._moriio_handshake_listener_t:
self._moriio_handshake_listener_t.join(timeout=0)
if hasattr(self, 'zmq_context') and self.zmq_context:
self.zmq_context.destroy(linger=0)
self.zmq_context = None
def __del__(self):
"""Cleanup background threads on destruction."""
self._handshake_initiation_executor.shutdown(wait=False)
if self._moriio_handshake_listener_t:
self._moriio_handshake_listener_t.join(timeout=0)
self.close()
@staticmethod
def _moriio_handshake_listener(
@ -1744,7 +1748,7 @@ class MoRIIOConnectorWorker:
def request_ready(_f: Future[Any], entry=(req_id, meta)):
logger.info("MoRIIO handshake done for request %s", req_id)
self._ready_requests.put(entry)
self.load_ready_flag = True
self.load_ready_flag [remote_engine_id] = True
self.write_ready_flags[remote_engine_id] = True
fut_list = []
@ -1826,15 +1830,6 @@ class MoRIIOConnectorWorker:
kv_caches_base_addr = []
caches_data = []
# Note(tms): I modified this from the original region setup code.
# K and V are now in different regions. Advantage is that we can
# elegantly support MLA and any cases where the K and V tensors
# are non-contiguous (it's not locally guaranteed that they will be)
# Disadvantage is that the encoded MoRIIOAgentMetadata is now larger
# (roughly 8KB vs 5KB).
# Conversely for FlashInfer, K and V are transferred in the same tensor
# to better exploit the memory layout (ie num_blocks is the first dim).
for cache_or_caches in kv_caches.values():
cache_list = (
[cache_or_caches]
@ -2043,18 +2038,20 @@ class MoRIIOConnectorWorker:
self._read_blocks_for_req(req_id, meta)
# Start transfers for requests whose handshakes have now finished.
while True: # TODO
while True:
if (
self._ready_requests.empty()
and not self.load_ready_flag
and remote_engine_id not in self.load_ready_flag
and wait_handshake_readd_req
):
continue
elif not self._ready_requests.empty() and self.load_ready_flag:
elif not self._ready_requests.empty() and remote_engine_id in self.load_ready_flag:
self._read_blocks_for_req(*self._ready_requests.get_nowait())
break
else:
break
self._reqs_to_send.update(metadata.reqs_to_send)
@ -2233,8 +2230,8 @@ class MoRIIOConnectorWorker:
) -> None:
if self.mode == MoRIIOMode.WRITE:
return
# we only test TP<->TP in read mode
# assert self.dp_rank>0, "only test TP<->TP in read mode"
dp0_engine_id=self.get_engine_name_with_dp(dst_engine_id,0)
sessions = self._get_built_session(dp0_engine_id)
@ -2248,7 +2245,7 @@ class MoRIIOConnectorWorker:
layer_name
)
transfer_status = self.moriio_wrapper.read_remote_data(
offs[0], offs[1], offs[2], sessions[sess_idx]
offs[2], offs[0], offs[1], sessions[sess_idx]
)
self._recving_transfers[request_id].append(transfer_status)