mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-02 20:17:53 +08:00
fix format
Signed-off-by: inkcherry <mingzhi.liu@amd.com>
This commit is contained in:
parent
e0885e52d9
commit
795a305b1b
@ -640,11 +640,11 @@ class MoRIIOWrapper:
|
|||||||
status.Wait()
|
status.Wait()
|
||||||
if not status.Succeeded():
|
if not status.Succeeded():
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Transfer failed: {status.Message()}, Code: {status.Code()}"
|
"Transfer failed: %s, Code: %s", status.Message(), status.Code()
|
||||||
)
|
)
|
||||||
raise TransferError("MoRIIO transfer failed!")
|
raise TransferError("MoRIIO transfer failed!")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Transfer {status} failed: {e}")
|
logger.error("Transfer %s failed: %s", status, e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def async_wait_reqid(self):
|
def async_wait_reqid(self):
|
||||||
@ -656,7 +656,7 @@ class MoRIIOWrapper:
|
|||||||
def _async_wait():
|
def _async_wait():
|
||||||
host = "*"
|
host = "*"
|
||||||
path = make_zmq_path("tcp", host, self.notify_port)
|
path = make_zmq_path("tcp", host, self.notify_port)
|
||||||
logger.info(f"Node starting to listen notify from path = {path}")
|
logger.info("Node starting to listen notify from path = %s", path)
|
||||||
|
|
||||||
with zmq_ctx(zmq.ROUTER, path) as sock:
|
with zmq_ctx(zmq.ROUTER, path) as sock:
|
||||||
while True:
|
while True:
|
||||||
@ -664,7 +664,7 @@ class MoRIIOWrapper:
|
|||||||
identity, msg = sock.recv_multipart()
|
identity, msg = sock.recv_multipart()
|
||||||
self._handle_message(msg)
|
self._handle_message(msg)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error processing message: {e}")
|
logger.error("Error processing message: %s", e)
|
||||||
raise HandshakeError(f"Error processing message: {e}") from e
|
raise HandshakeError(f"Error processing message: {e}") from e
|
||||||
|
|
||||||
self.notify_thread = threading.Thread(
|
self.notify_thread = threading.Thread(
|
||||||
@ -696,7 +696,7 @@ class MoRIIOWrapper:
|
|||||||
self._handle_completion_message(msg_str)
|
self._handle_completion_message(msg_str)
|
||||||
handled = True
|
handled = True
|
||||||
except UnicodeDecodeError:
|
except UnicodeDecodeError:
|
||||||
logger.warning(f"Received non-UTF8 message: {msg_str}")
|
logger.warning("Received non-UTF8 message: %s", msg_str)
|
||||||
if not handled:
|
if not handled:
|
||||||
raise MoRIIOError(f"Unhandled message format: {msg_str}")
|
raise MoRIIOError(f"Unhandled message format: {msg_str}")
|
||||||
|
|
||||||
@ -740,11 +740,13 @@ class MoRIIOWrapper:
|
|||||||
try:
|
try:
|
||||||
for req_id in req_list:
|
for req_id in req_list:
|
||||||
if not isinstance(req_id, str):
|
if not isinstance(req_id, str):
|
||||||
logger.warning(f"Invalid req_id type: {type(req_id)}, expected str")
|
logger.warning(
|
||||||
|
"Invalid req_id type: %s, expected str", type(req_id)
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
sock.send(req_id.encode("utf-8"))
|
sock.send(req_id.encode("utf-8"))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to send notification to {path}: {e}")
|
logger.error("Failed to send notification to %s: %s", path, e)
|
||||||
self.paths.pop(path, None)
|
self.paths.pop(path, None)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@ -936,9 +938,8 @@ class MoRIIOConnector(KVConnectorBase_V1):
|
|||||||
|
|
||||||
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
|
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
|
||||||
assert self.connector_worker is not None
|
assert self.connector_worker is not None
|
||||||
if self.mode == MoRIIOMode.WRITE:
|
if self.mode == MoRIIOMode.WRITE and get_role() == ROLE.CONSUMER:
|
||||||
if get_role() == ROLE.CONSUMER:
|
self.connector_worker.moriio_wrapper.async_wait_reqid()
|
||||||
self.connector_worker.moriio_wrapper.async_wait_reqid()
|
|
||||||
|
|
||||||
assert isinstance(self._connector_metadata, MoRIIOConnectorMetadata)
|
assert isinstance(self._connector_metadata, MoRIIOConnectorMetadata)
|
||||||
self.connector_worker.start_load_kv(self._connector_metadata)
|
self.connector_worker.start_load_kv(self._connector_metadata)
|
||||||
@ -999,7 +1000,7 @@ class MoRIIOConnectorScheduler:
|
|||||||
"handshake_port"
|
"handshake_port"
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
logger.info(f"Initializing MoRIIO Scheduler {engine_id = }")
|
logger.info("Initializing MoRIIO Scheduler engine_id = %s", engine_id)
|
||||||
|
|
||||||
self.side_notify_port = (
|
self.side_notify_port = (
|
||||||
self.vllm_config.kv_transfer_config.kv_connector_extra_config["notify_port"]
|
self.vllm_config.kv_transfer_config.kv_connector_extra_config["notify_port"]
|
||||||
@ -1383,8 +1384,9 @@ class MoRIIOConnectorWorker:
|
|||||||
self._ping_thread.start()
|
self._ping_thread.start()
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Initializing MoRIIO Engine ,engine = {self.moriio_engine},"
|
"Initializing MoRIIO Engine, engine = %s, role = %s",
|
||||||
f"role = {'producer' if self.is_producer else 'consumer'}"
|
self.moriio_engine,
|
||||||
|
"producer" if self.is_producer else "consumer",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Agent.
|
# Agent.
|
||||||
@ -1550,7 +1552,7 @@ class MoRIIOConnectorWorker:
|
|||||||
|
|
||||||
retry_count = 0
|
retry_count = 0
|
||||||
index = 1
|
index = 1
|
||||||
|
should_break = True
|
||||||
with zmq_context.socket(zmq.DEALER) as sock:
|
with zmq_context.socket(zmq.DEALER) as sock:
|
||||||
sock.connect(f"tcp://{self.proxy_ip}:{self.proxy_ping_port}")
|
sock.connect(f"tcp://{self.proxy_ip}:{self.proxy_ping_port}")
|
||||||
|
|
||||||
@ -1574,30 +1576,33 @@ class MoRIIOConnectorWorker:
|
|||||||
|
|
||||||
except ConnectionRefusedError:
|
except ConnectionRefusedError:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Connection refused: {self.local_ip}:"
|
"Connection refused: %s:%s -> %s:%s",
|
||||||
f"{self.local_ping_port} -> "
|
self.local_ip,
|
||||||
f"{self.proxy_ip}:{self.proxy_ping_port}"
|
self.local_ping_port,
|
||||||
|
self.proxy_ip,
|
||||||
|
self.proxy_ping_port,
|
||||||
)
|
)
|
||||||
retry_count += 1
|
retry_count += 1
|
||||||
|
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
logger.info(f"OS error when sending ping: {e}")
|
logger.info("OS error when sending ping: %s", e)
|
||||||
retry_count += 1
|
retry_count += 1
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.info(f"Unexpected error when sending ping: {e}")
|
logger.info("Unexpected error when sending ping: %s", e)
|
||||||
retry_count += 1
|
retry_count += 1
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
if retry_count >= MoRIIOConstants.MAX_PING_RETRIES:
|
if retry_count >= MoRIIOConstants.MAX_PING_RETRIES:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Max retries ({MoRIIOConstants.MAX_PING_RETRIES})"
|
"Max retries (%s) exceeded. Stopping ping loop.",
|
||||||
"exceeded. Stopping ping loop."
|
MoRIIOConstants.MAX_PING_RETRIES,
|
||||||
)
|
)
|
||||||
break
|
should_break = True
|
||||||
|
|
||||||
time.sleep(MoRIIOConstants.PING_INTERVAL)
|
time.sleep(MoRIIOConstants.PING_INTERVAL)
|
||||||
index += 1
|
index += 1
|
||||||
|
if should_break:
|
||||||
|
break
|
||||||
|
|
||||||
def handle_proxy_request(self):
|
def handle_proxy_request(self):
|
||||||
if self.is_producer:
|
if self.is_producer:
|
||||||
@ -1606,7 +1611,7 @@ class MoRIIOConnectorWorker:
|
|||||||
)
|
)
|
||||||
while True:
|
while True:
|
||||||
socks = dict(self.poller.poll())
|
socks = dict(self.poller.poll())
|
||||||
logger.debug(f"handle_proxy_request: {socks = }")
|
logger.debug("handle_proxy_request: socks = %s", socks)
|
||||||
|
|
||||||
if self.metadata_socket not in socks:
|
if self.metadata_socket not in socks:
|
||||||
continue
|
continue
|
||||||
@ -1650,7 +1655,7 @@ class MoRIIOConnectorWorker:
|
|||||||
host = "*"
|
host = "*"
|
||||||
|
|
||||||
path = make_zmq_path("tcp", host, base_port)
|
path = make_zmq_path("tcp", host, base_port)
|
||||||
logger.debug(f" mori handshake starting listening on path: {path}")
|
logger.debug("mori handshake starting listening on path: %s", path)
|
||||||
|
|
||||||
with zmq_ctx(zmq.ROUTER, path) as sock:
|
with zmq_ctx(zmq.ROUTER, path) as sock:
|
||||||
ready_event.set()
|
ready_event.set()
|
||||||
@ -1695,11 +1700,11 @@ class MoRIIOConnectorWorker:
|
|||||||
|
|
||||||
port_offset = get_port_offset(remote_dp_rank, self.tp_rank)
|
port_offset = get_port_offset(remote_dp_rank, self.tp_rank)
|
||||||
path = make_zmq_path("tcp", host, port + port_offset)
|
path = make_zmq_path("tcp", host, port + port_offset)
|
||||||
logger.debug(f"handshake Querying metadata on path:{path}")
|
logger.debug("handshake Querying metadata on path: %s", path)
|
||||||
|
|
||||||
# Send query for the request.
|
# Send query for the request.
|
||||||
with zmq_ctx(zmq.DEALER, path) as sock:
|
with zmq_ctx(zmq.DEALER, path) as sock:
|
||||||
logger.info(f"prepare send msg INSTAZNCE: {path}")
|
logger.debug("prepare send msg INSTAZNCE: %s", path)
|
||||||
sock.send(MoRIIOConstants.GET_META_MSG)
|
sock.send(MoRIIOConstants.GET_META_MSG)
|
||||||
received_frame = sock.recv_multipart()
|
received_frame = sock.recv_multipart()
|
||||||
if len(received_frame) != 2 or received_frame[0] != b"":
|
if len(received_frame) != 2 or received_frame[0] != b"":
|
||||||
@ -1719,21 +1724,26 @@ class MoRIIOConnectorWorker:
|
|||||||
metadata.agent_metadata
|
metadata.agent_metadata
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"MoRIIO handshake: registered remote agent "
|
"MoRIIO handshake: registered"
|
||||||
f"{remote_agent_name=} for engine ID "
|
"remote agent %s for engine ID %s, path = %s",
|
||||||
f"{expected_engine_id=},f{path= }"
|
remote_agent_name,
|
||||||
|
expected_engine_id,
|
||||||
|
path,
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(self.local_kv_cache_metadata) > 0:
|
if len(self.local_kv_cache_metadata) > 0:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"{len(self.local_kv_cache_metadata) = },"
|
"len(self.local_kv_cache_metadata) = %s,"
|
||||||
"maybe you didnt clear this buffer correctly"
|
"maybe you didnt clear this buffer correctly",
|
||||||
|
len(self.local_kv_cache_metadata),
|
||||||
)
|
)
|
||||||
self.local_kv_cache_metadata = []
|
self.local_kv_cache_metadata = []
|
||||||
if len(self.remote_kv_cache_metadata) > 0:
|
if len(self.remote_kv_cache_metadata) > 0:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f" {len(self.remote_kv_cache_metadata) = },"
|
"len(self.remote_kv_cache_metadata) = %s,"
|
||||||
"maybe you didnt clear this buffer correctly"
|
"maybe you didnt clear this buffer correctly",
|
||||||
|
len(self.remote_kv_cache_metadata),
|
||||||
)
|
)
|
||||||
self.remote_kv_cache_metadata = []
|
self.remote_kv_cache_metadata = []
|
||||||
|
|
||||||
@ -1995,7 +2005,6 @@ class MoRIIOConnectorWorker:
|
|||||||
# Initiate handshake with remote engine to exchange metadata.
|
# Initiate handshake with remote engine to exchange metadata.
|
||||||
with self._handshake_lock:
|
with self._handshake_lock:
|
||||||
if remote_engine_id not in self._remote_agents:
|
if remote_engine_id not in self._remote_agents:
|
||||||
logger.info(f"*****background moriio {remote_engine_id = }")
|
|
||||||
self._background_moriio_handshake(
|
self._background_moriio_handshake(
|
||||||
req_id, remote_engine_id, meta
|
req_id, remote_engine_id, meta
|
||||||
)
|
)
|
||||||
@ -2106,9 +2115,7 @@ class MoRIIOConnectorWorker:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _is_last_layer(self, layer_name):
|
def _is_last_layer(self, layer_name):
|
||||||
if layer_name == list(self.kv_caches.keys())[-1]:
|
return layer_name == list(self.kv_caches.keys())[-1]
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def merge_contiguous_blocks(
|
def merge_contiguous_blocks(
|
||||||
self,
|
self,
|
||||||
@ -2256,7 +2263,7 @@ class MoRIIOConnectorWorker:
|
|||||||
first_layer, local_block_ids, remote_block_ids
|
first_layer, local_block_ids, remote_block_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
for layer_name in self.layer_name_to_local_kv_cache_metadata.keys():
|
for layer_name in self.layer_name_to_local_kv_cache_metadata:
|
||||||
sess_idx = list(self.layer_name_to_local_kv_cache_metadata.keys()).index(
|
sess_idx = list(self.layer_name_to_local_kv_cache_metadata.keys()).index(
|
||||||
layer_name
|
layer_name
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user