fix format

Signed-off-by: inkcherry <mingzhi.liu@amd.com>
This commit is contained in:
inkcherry 2025-11-21 04:49:34 +00:00
parent e0885e52d9
commit 795a305b1b

View File

@ -640,11 +640,11 @@ class MoRIIOWrapper:
status.Wait()
if not status.Succeeded():
logger.error(
f"Transfer failed: {status.Message()}, Code: {status.Code()}"
"Transfer failed: %s, Code: %s", status.Message(), status.Code()
)
raise TransferError("MoRIIO transfer failed!")
except Exception as e:
logger.error(f"Transfer {status} failed: {e}")
logger.error("Transfer %s failed: %s", status, e)
raise
def async_wait_reqid(self):
@ -656,7 +656,7 @@ class MoRIIOWrapper:
def _async_wait():
host = "*"
path = make_zmq_path("tcp", host, self.notify_port)
logger.info(f"Node starting to listen notify from path = {path}")
logger.info("Node starting to listen notify from path = %s", path)
with zmq_ctx(zmq.ROUTER, path) as sock:
while True:
@ -664,7 +664,7 @@ class MoRIIOWrapper:
identity, msg = sock.recv_multipart()
self._handle_message(msg)
except Exception as e:
logger.error(f"Error processing message: {e}")
logger.error("Error processing message: %s", e)
raise HandshakeError(f"Error processing message: {e}") from e
self.notify_thread = threading.Thread(
@ -696,7 +696,7 @@ class MoRIIOWrapper:
self._handle_completion_message(msg_str)
handled = True
except UnicodeDecodeError:
logger.warning(f"Received non-UTF8 message: {msg_str}")
logger.warning("Received non-UTF8 message: %s", msg_str)
if not handled:
raise MoRIIOError(f"Unhandled message format: {msg_str}")
@ -740,11 +740,13 @@ class MoRIIOWrapper:
try:
for req_id in req_list:
if not isinstance(req_id, str):
logger.warning(f"Invalid req_id type: {type(req_id)}, expected str")
logger.warning(
"Invalid req_id type: %s, expected str", type(req_id)
)
continue
sock.send(req_id.encode("utf-8"))
except Exception as e:
logger.error(f"Failed to send notification to {path}: {e}")
logger.error("Failed to send notification to %s: %s", path, e)
self.paths.pop(path, None)
raise
@ -936,9 +938,8 @@ class MoRIIOConnector(KVConnectorBase_V1):
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
assert self.connector_worker is not None
if self.mode == MoRIIOMode.WRITE:
if get_role() == ROLE.CONSUMER:
self.connector_worker.moriio_wrapper.async_wait_reqid()
if self.mode == MoRIIOMode.WRITE and get_role() == ROLE.CONSUMER:
self.connector_worker.moriio_wrapper.async_wait_reqid()
assert isinstance(self._connector_metadata, MoRIIOConnectorMetadata)
self.connector_worker.start_load_kv(self._connector_metadata)
@ -999,7 +1000,7 @@ class MoRIIOConnectorScheduler:
"handshake_port"
]
)
logger.info(f"Initializing MoRIIO Scheduler {engine_id = }")
logger.info("Initializing MoRIIO Scheduler engine_id = %s", engine_id)
self.side_notify_port = (
self.vllm_config.kv_transfer_config.kv_connector_extra_config["notify_port"]
@ -1383,8 +1384,9 @@ class MoRIIOConnectorWorker:
self._ping_thread.start()
logger.info(
f"Initializing MoRIIO Engine ,engine = {self.moriio_engine},"
f"role = {'producer' if self.is_producer else 'consumer'}"
"Initializing MoRIIO Engine, engine = %s, role = %s",
self.moriio_engine,
"producer" if self.is_producer else "consumer",
)
# Agent.
@ -1550,7 +1552,7 @@ class MoRIIOConnectorWorker:
retry_count = 0
index = 1
should_break = True
with zmq_context.socket(zmq.DEALER) as sock:
sock.connect(f"tcp://{self.proxy_ip}:{self.proxy_ping_port}")
@ -1574,30 +1576,33 @@ class MoRIIOConnectorWorker:
except ConnectionRefusedError:
logger.info(
f"Connection refused: {self.local_ip}:"
f"{self.local_ping_port} -> "
f"{self.proxy_ip}:{self.proxy_ping_port}"
"Connection refused: %s:%s -> %s:%s",
self.local_ip,
self.local_ping_port,
self.proxy_ip,
self.proxy_ping_port,
)
retry_count += 1
except OSError as e:
logger.info(f"OS error when sending ping: {e}")
logger.info("OS error when sending ping: %s", e)
retry_count += 1
except Exception as e:
logger.info(f"Unexpected error when sending ping: {e}")
logger.info("Unexpected error when sending ping: %s", e)
retry_count += 1
finally:
if retry_count >= MoRIIOConstants.MAX_PING_RETRIES:
logger.error(
f"Max retries ({MoRIIOConstants.MAX_PING_RETRIES})"
"exceeded. Stopping ping loop."
"Max retries (%s) exceeded. Stopping ping loop.",
MoRIIOConstants.MAX_PING_RETRIES,
)
break
should_break = True
time.sleep(MoRIIOConstants.PING_INTERVAL)
index += 1
if should_break:
break
def handle_proxy_request(self):
if self.is_producer:
@ -1606,7 +1611,7 @@ class MoRIIOConnectorWorker:
)
while True:
socks = dict(self.poller.poll())
logger.debug(f"handle_proxy_request: {socks = }")
logger.debug("handle_proxy_request: socks = %s", socks)
if self.metadata_socket not in socks:
continue
@ -1650,7 +1655,7 @@ class MoRIIOConnectorWorker:
host = "*"
path = make_zmq_path("tcp", host, base_port)
logger.debug(f" mori handshake starting listening on path: {path}")
logger.debug("mori handshake starting listening on path: %s", path)
with zmq_ctx(zmq.ROUTER, path) as sock:
ready_event.set()
@ -1695,11 +1700,11 @@ class MoRIIOConnectorWorker:
port_offset = get_port_offset(remote_dp_rank, self.tp_rank)
path = make_zmq_path("tcp", host, port + port_offset)
logger.debug(f"handshake Querying metadata on path:{path}")
logger.debug("handshake Querying metadata on path: %s", path)
# Send query for the request.
with zmq_ctx(zmq.DEALER, path) as sock:
logger.info(f"prepare send msg INSTAZNCE: {path}")
logger.debug("prepare send msg INSTAZNCE: %s", path)
sock.send(MoRIIOConstants.GET_META_MSG)
received_frame = sock.recv_multipart()
if len(received_frame) != 2 or received_frame[0] != b"":
@ -1719,21 +1724,26 @@ class MoRIIOConnectorWorker:
metadata.agent_metadata
)
logger.info(
f"MoRIIO handshake: registered remote agent "
f"{remote_agent_name=} for engine ID "
f"{expected_engine_id=},f{path= }"
logger.debug(
"MoRIIO handshake: registered"
"remote agent %s for engine ID %s, path = %s",
remote_agent_name,
expected_engine_id,
path,
)
if len(self.local_kv_cache_metadata) > 0:
logger.warning(
f"{len(self.local_kv_cache_metadata) = },"
"maybe you didnt clear this buffer correctly"
"len(self.local_kv_cache_metadata) = %s,"
"maybe you didnt clear this buffer correctly",
len(self.local_kv_cache_metadata),
)
self.local_kv_cache_metadata = []
if len(self.remote_kv_cache_metadata) > 0:
logger.warning(
f" {len(self.remote_kv_cache_metadata) = },"
"maybe you didnt clear this buffer correctly"
"len(self.remote_kv_cache_metadata) = %s,"
"maybe you didnt clear this buffer correctly",
len(self.remote_kv_cache_metadata),
)
self.remote_kv_cache_metadata = []
@ -1995,7 +2005,6 @@ class MoRIIOConnectorWorker:
# Initiate handshake with remote engine to exchange metadata.
with self._handshake_lock:
if remote_engine_id not in self._remote_agents:
logger.info(f"*****background moriio {remote_engine_id = }")
self._background_moriio_handshake(
req_id, remote_engine_id, meta
)
@ -2106,9 +2115,7 @@ class MoRIIOConnectorWorker:
)
def _is_last_layer(self, layer_name):
if layer_name == list(self.kv_caches.keys())[-1]:
return True
return False
return layer_name == list(self.kv_caches.keys())[-1]
def merge_contiguous_blocks(
self,
@ -2256,7 +2263,7 @@ class MoRIIOConnectorWorker:
first_layer, local_block_ids, remote_block_ids
)
for layer_name in self.layer_name_to_local_kv_cache_metadata.keys():
for layer_name in self.layer_name_to_local_kv_cache_metadata:
sess_idx = list(self.layer_name_to_local_kv_cache_metadata.keys()).index(
layer_name
)