From b60ee86585d34ef97c66dd6864b036d963f22fc6 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Thu, 20 Nov 2025 07:31:29 +0000 Subject: [PATCH] format Signed-off-by: inkcherry --- .../kv_connector/v1/moriio_connector.py | 109 ++++++++---------- 1 file changed, 47 insertions(+), 62 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index 0b37f412fbfe4..cc9cc28a9980b 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -1,19 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import contextlib +import logging import math import os import pickle import queue import threading import time -import msgpack -import msgspec -import numpy as np -import torch -import zmq -import logging - from collections import defaultdict from collections.abc import Iterator from concurrent.futures import Future, ThreadPoolExecutor @@ -21,6 +15,12 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Optional from weakref import ref as weakref_ref +import msgpack +import msgspec +import numpy as np +import torch +import zmq + from vllm import envs from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.selector import get_attn_backend @@ -68,10 +68,10 @@ class MoRIIOConstants: OVER = b"OVER" COMPLETION_PREFIX = "cmpl" - PING_INTERVAL = 5 MAX_PING_RETRIES = 100000 + try: import mori from mori.io import ( @@ -202,7 +202,7 @@ def get_moriio_mode() -> MoRIIOMode: def get_port_offset(dp_rank: int, tp_rank: int, tp_size: int = 1) -> int: - return ((dp_rank) * tp_size + tp_rank) + return (dp_rank) * tp_size + tp_rank @dataclass @@ -316,7 +316,6 @@ class MoRIIOWriter: def _write_worker_loop(self) -> None: """Main loop for the write worker thread.""" - logger.info("Write worker loop started") while True: # Process deferred tasks first @@ -408,7 +407,9 @@ class MoRIIOWriter: task.event.synchronize() # Update engine ID with DP rank - task.dst_engine_id= self.worker.get_engine_name_with_dp(task.dst_engine_id,request_info.decode_dp_rank) + task.dst_engine_id = self.worker.get_engine_name_with_dp( + task.dst_engine_id, request_info.decode_dp_rank + ) # Get or create sessions sessions = self.worker._get_built_session(task.dst_engine_id) @@ -544,7 +545,7 @@ class MoRIIOWrapper: self.done_write_cache_req_ids = [] self.notify_thread = None self.sock = None - self.sessions: list["IOEngine.Session"] = [] + self.sessions: list[IOEngine.Session] = [] self.paths = {} def set_moriio_engine(self, moriio_engine): @@ -968,7 +969,7 @@ class MoRIIOConnectorScheduler: "handshake_port" ] ) - logger.info(f"==========> Initializing MoRIIO Scheduler {engine_id = }") + logger.info(f"Initializing MoRIIO Scheduler {engine_id = }") self.side_notify_port = ( self.vllm_config.kv_transfer_config.kv_connector_extra_config["notify_port"] @@ -1149,7 +1150,7 @@ class MoRIIOConnectorScheduler: self._reqs_need_pending_save[req_id] = (req, updated_blocks) if ( len(self._reqs_need_pending_save[req_id][1]) - * self.block_size + * self.block_size >= req.num_prompt_tokens ): meta.add_new_req( @@ -1171,7 +1172,7 @@ class MoRIIOConnectorScheduler: for req_id, (req, block_ids) in self._reqs_need_save.items(): assert req.kv_transfer_params is not None - if req.num_prompt_tokens > len(block_ids)*self.block_size: + if req.num_prompt_tokens > len(block_ids) * self.block_size: # not last chunk prefill self._reqs_need_pending_save[req_id] = (req, block_ids) continue @@ -1265,7 +1266,7 @@ class MoRIIOConnectorWorker: self.mode = get_moriio_mode() logger.info("Initializing MoRIIO worker %s", engine_id) - + logging.getLogger("aiter").disabled = True # Config. @@ -1283,11 +1284,6 @@ class MoRIIOConnectorWorker: self.tp_rank = self.moriio_config.tp_rank self.dp_rank = self.moriio_config.dp_rank - logger.info( - f"MoRIIO Worker init {self.tp_rank = },{self.dp_rank= }" - f",{self.is_producer= }" - ) - self.local_ip = self.moriio_config.local_ip self.local_kv_port = self.moriio_config.local_kv_port self.proxy_ip = self.moriio_config.proxy_ip @@ -1340,8 +1336,8 @@ class MoRIIOConnectorWorker: ), ) - logger.info( - "build IOEngine %s:%s", + logger.debug( + "build MORI IOEngine %s:%s", self.moriio_config.local_ip, self.moriio_config.local_kv_port, ) @@ -1390,8 +1386,6 @@ class MoRIIOConnectorWorker: self.moriio_config.handshake_port + get_port_offset(self.dp_rank, self.tp_rank) ) - logger.info(f"MoRIIO Worker init {self.tp_rank = },{self.dp_rank= }") - logger.info(f"MoRIIO side channel_port port: {self.side_channel_port}") self.engine_id: EngineId = engine_id self.world_size = get_tensor_model_parallel_world_size() @@ -1522,7 +1516,6 @@ class MoRIIOConnectorWorker: return self.built_write_session[remote_engine_id] def _ping(self, zmq_context): - http_request_address = f"http://{self.request_address}/v1/completions" role = "P" if self.is_producer else "D" @@ -1586,15 +1579,18 @@ class MoRIIOConnectorWorker: if self.metadata_socket not in socks: continue - + def close(self): - if hasattr(self, '_handshake_initiation_executor'): + if hasattr(self, "_handshake_initiation_executor"): self._handshake_initiation_executor.shutdown(wait=False) - - if hasattr(self, '_moriio_handshake_listener_t') and self._moriio_handshake_listener_t: + + if ( + hasattr(self, "_moriio_handshake_listener_t") + and self._moriio_handshake_listener_t + ): self._moriio_handshake_listener_t.join(timeout=0) - - if hasattr(self, 'zmq_context') and self.zmq_context: + + if hasattr(self, "zmq_context") and self.zmq_context: self.zmq_context.destroy(linger=0) self.zmq_context = None @@ -1621,12 +1617,9 @@ class MoRIIOConnectorWorker: # Listen for new requests for metadata. host = "*" - logger.info( - f"======> mori handshake starting listening on baseport: {base_port}" - ) path = make_zmq_path("tcp", host, base_port) - logger.info(f"======> mori handshake starting listening on path: {path}") + logger.debug(f" mori handshake starting listening on path: {path}") with zmq_ctx(zmq.ROUTER, path) as sock: ready_event.set() @@ -1636,20 +1629,20 @@ class MoRIIOConnectorWorker: msg != MoRIIOConstants.GET_META_MSG and msg != MoRIIOConstants.POP_DONE_RECV ): - logger.error("Connection listener got unexpected message %s", msg) + logger.error("Connection listener got unexpected message") raise HandshakeError("handshake failed, unexpected msg type") elif msg == MoRIIOConstants.GET_META_MSG: sock.send_multipart( (identity, b"", encoded_data) ) # send local mori io engine meta data - logger.info("MoRIIO handshake listener sent metadata to %s") + logger.debug("MoRIIO handshake listener sent metadata") # now we send tensor meta data for each block buf = pickle.dumps(layer_name_to_local_kv_cache_metadata) sock.send_multipart((identity, b"", buf)) elif msg == MoRIIOConstants.POP_DONE_RECV: _, req_id = sock.recv_multipart() - logger.info( - "MoRIIO handshake listener received done recv for req %s", + logger.debug( + "MoRIIO handshake listener received done recv for req", req_id.decode(), ) @@ -1671,10 +1664,7 @@ class MoRIIOConnectorWorker: port_offset = get_port_offset(remote_dp_rank, self.tp_rank) path = make_zmq_path("tcp", host, port + port_offset) - logger.info( - "handshake Querying metadata on path: %s at remote rank %s", - path, - ) + logger.debug(f"handshake Querying metadata on path:{path}") # Send query for the request. with zmq_ctx(zmq.DEALER, path) as sock: @@ -1694,7 +1684,7 @@ class MoRIIOConnectorWorker: ) self.moriio_wrapper.remote_engine_ip = host - remote_agent_name=self.moriio_wrapper.register_remote_engine( + remote_agent_name = self.moriio_wrapper.register_remote_engine( metadata.agent_metadata ) @@ -1748,7 +1738,7 @@ class MoRIIOConnectorWorker: def request_ready(_f: Future[Any], entry=(req_id, meta)): logger.info("MoRIIO handshake done for request %s", req_id) self._ready_requests.put(entry) - self.load_ready_flag [remote_engine_id] = True + self.load_ready_flag[remote_engine_id] = True self.write_ready_flags[remote_engine_id] = True fut_list = [] @@ -1999,10 +1989,10 @@ class MoRIIOConnectorWorker: break else: break - - - def get_engine_name_with_dp(self, engine_name,dp_rank): + + def get_engine_name_with_dp(self, engine_name, dp_rank): return f"{engine_name}_dp{dp_rank}" + def start_load_kv(self, metadata: MoRIIOConnectorMetadata): """ Start loading by triggering non-blocking moriio_xfer. @@ -2022,7 +2012,7 @@ class MoRIIOConnectorWorker: str(meta.remote_host) + ":" + str(meta.remote_handshake_port) ) meta.remote_engine_id = remote_engine_id - dp0_remote_engine_id = self.get_engine_name_with_dp(remote_engine_id,0) + dp0_remote_engine_id = self.get_engine_name_with_dp(remote_engine_id, 0) if dp0_remote_engine_id not in self._remote_agents: # Initiate handshake with remote engine to exchange metadata. with self._handshake_lock: @@ -2038,20 +2028,21 @@ class MoRIIOConnectorWorker: self._read_blocks_for_req(req_id, meta) # Start transfers for requests whose handshakes have now finished. - while True: + while True: if ( self._ready_requests.empty() and remote_engine_id not in self.load_ready_flag and wait_handshake_readd_req ): continue - elif not self._ready_requests.empty() and remote_engine_id in self.load_ready_flag: + elif ( + not self._ready_requests.empty() + and remote_engine_id in self.load_ready_flag + ): self._read_blocks_for_req(*self._ready_requests.get_nowait()) break else: break - - self._reqs_to_send.update(metadata.reqs_to_send) @@ -2089,11 +2080,6 @@ class MoRIIOConnectorWorker: return True return False - def _is_first_layer(self, layer_name): - if layer_name == list(self.kv_caches.keys())[0]: - return True - return False - def merge_contiguous_blocks( self, offsets_local: list[int], @@ -2230,9 +2216,8 @@ class MoRIIOConnectorWorker: ) -> None: if self.mode == MoRIIOMode.WRITE: return - - - dp0_engine_id=self.get_engine_name_with_dp(dst_engine_id,0) + + dp0_engine_id = self.get_engine_name_with_dp(dst_engine_id, 0) sessions = self._get_built_session(dp0_engine_id) first_layer = list(self.layer_name_to_local_kv_cache_metadata.keys())[0]