update lock

Signed-off-by: inkcherry <mingzhi.liu@amd.com>
This commit is contained in:
inkcherry 2025-11-24 06:53:45 +00:00
parent 374cc25e0f
commit 77321502e7

View File

@ -520,7 +520,8 @@ class MoRIIOWriter:
task.request_id, task.remote_ip, remote_port
)
# mark request as done, then we can free the blocks
self.worker.moriio_wrapper.done_req_ids.append(task.request_id)
with self.worker.moriio_wrapper.lock:
self.worker.moriio_wrapper.done_req_ids.append(task.request_id)
del self.worker.moriio_wrapper.done_remote_allocate_req_dict[
task.request_id
]
@ -1559,7 +1560,6 @@ class MoRIIOConnectorWorker:
retry_count = 0
index = 1
should_break = True
with zmq_context.socket(zmq.DEALER) as sock:
sock.connect(f"tcp://{self.proxy_ip}:{self.proxy_ping_port}")
@ -1598,18 +1598,17 @@ class MoRIIOConnectorWorker:
except Exception as e:
logger.info("Unexpected error when sending ping: %s", e)
retry_count += 1
finally:
if retry_count >= MoRIIOConstants.MAX_PING_RETRIES:
logger.error(
"Max retries (%s) exceeded. Stopping ping loop.",
MoRIIOConstants.MAX_PING_RETRIES,
)
should_break = True
raise RuntimeError(f"Ping failed after {retry_count} retries")
finally:
time.sleep(MoRIIOConstants.PING_INTERVAL)
index += 1
if should_break:
break
def close(self):
if hasattr(self, "_handshake_initiation_executor"):
@ -1951,19 +1950,23 @@ class MoRIIOConnectorWorker:
def _pop_done_transfers(self) -> set[str]:
done_req_ids: set[str] = set()
for req_id, status_list in self._recving_transfers.items():
if status_list[-1].Succeeded():
done_req_ids.add(req_id)
with self.moriio_wrapper.lock:
to_remove = []
for req_id, status_list in self._recving_transfers.items():
if status_list[-1].Succeeded():
done_req_ids.add(req_id)
self.moriio_wrapper.send_notify(
req_id,
self._recving_transfers_callback_addr[req_id][0],
self._recving_transfers_callback_addr[req_id][1],
)
self.moriio_wrapper.send_notify(
req_id,
self._recving_transfers_callback_addr[req_id][0],
self._recving_transfers_callback_addr[req_id][1],
)
to_remove.append(req_id)
for req_id in to_remove:
del self._recving_transfers[req_id]
del self._recving_transfers_callback_addr[req_id]
return done_req_ids
return done_req_ids
def save_kv_layer(
self,
@ -2258,12 +2261,13 @@ class MoRIIOConnectorWorker:
transfer_status = self.moriio_wrapper.read_remote_data(
offs[2], offs[0], offs[1], sessions[sess_idx]
)
with self.moriio_wrapper.lock:
self._recving_transfers[request_id].append(transfer_status)
self._recving_transfers_callback_addr[request_id] = (
remote_host,
str(remote_notify_port + self.tp_rank),
)
self._recving_transfers[request_id].append(transfer_status)
self._recving_transfers_callback_addr[request_id] = (
remote_host,
str(remote_notify_port + self.tp_rank),
)
@contextlib.contextmanager