From 4776e2ddcf57af4a50f36122102c46e4a116c89c Mon Sep 17 00:00:00 2001 From: inkcherry Date: Fri, 21 Nov 2025 12:13:24 +0000 Subject: [PATCH] more Signed-off-by: inkcherry --- .../kv_connector/v1/moriio_connector.py | 71 +++++++------------ 1 file changed, 26 insertions(+), 45 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index 11552e5b460a3..7ac303331146f 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -37,7 +37,12 @@ from vllm.distributed.parallel_state import ( ) from vllm.forward_context import ForwardContext from vllm.logger import init_logger -from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket, get_open_port +from vllm.utils.network_utils import ( + get_ip, + get_open_port, + make_zmq_path, + make_zmq_socket, +) from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.request import RequestStatus @@ -70,7 +75,8 @@ class MoRIIOConstants: PING_INTERVAL = 5 MAX_PING_RETRIES = 100000 DEFAULT_HANDSHAKE_PORT = "6301" - DEFAULT_NOTIFY_PORT="61005" + DEFAULT_NOTIFY_PORT = "61005" + try: from mori.io import ( @@ -80,7 +86,7 @@ try: IOEngineConfig, MemoryDesc, PollCqMode, - RdmaBackendConfig + RdmaBackendConfig, ) logger.info("MoRIIO is available") @@ -200,7 +206,6 @@ def get_moriio_mode() -> MoRIIOMode: return MoRIIOMode.READ else: return MoRIIOMode.WRITE - def get_port_offset(dp_rank: int, tp_rank: int, tp_size: int = 1) -> int: @@ -224,7 +229,6 @@ class MoRIIOConfig: @classmethod def from_vllm_config(cls, vllm_config: VllmConfig) -> "MoRIIOConfig": - # Port Configuration: # local_ping_port -> Outgoing heartbeat to proxy # proxy_ping_port -> Remote proxy's heartbeat ingress port @@ -233,8 +237,8 @@ class MoRIIOConfig: # notify_port -> For synchronizing stages between prefill and decode # handshake_port -> For initial handshake between mori engine - #TODO : merge notify_port and handshake_port to simplify port management, supports non-contiguous ports - + # TODO : merge notify_port and handshake_port to simplify port management, supports non-contiguous ports + kv_transfer_config = vllm_config.kv_transfer_config extra_config = kv_transfer_config.kv_connector_extra_config tp_rank = get_tensor_model_parallel_rank() @@ -716,8 +720,7 @@ class MoRIIOWrapper: raise MoRIIOError(f"Unhandled message format: {msg_str}") def _handle_structured_message(self, data: dict): - - assert get_role()==ROLE.PRODUCER, "Only prefill can get block messages" + assert get_role() == ROLE.PRODUCER, "Only prefill can get block messages" req_id = data["req_id"] block_notify_list = data.get("block_notify_list", []) decode_dp_rank = data.get("decode_rank", 0) @@ -887,8 +890,9 @@ class MoRIIOConnector(KVConnectorBase_V1): self.connector_scheduler = None self.connector_worker = MoRIIOConnectorWorker(vllm_config, self.engine_id) logger.info( - "Initialized MoRIIO Connector,engine_id:%s,role: %s", - self.engine_id, role.value + "Initialized MoRIIO Connector,engine_id:%s,role: %s", + self.engine_id, + role.value, ) ############################################################ @@ -905,8 +909,6 @@ class MoRIIOConnector(KVConnectorBase_V1): if "notify_port" not in extra_config or not extra_config["notify_port"]: extra_config["notify_port"] = MoRIIOConstants.DEFAULT_NOTIFY_PORT - - def get_num_new_matched_tokens( self, request: "Request", num_computed_tokens: int ) -> tuple[int, bool]: @@ -1194,7 +1196,7 @@ class MoRIIOConnectorScheduler: if new_block_ids is not None: block_ids = new_block_ids[0] - #TODO : hybrid attn, etc + # TODO : hybrid attn, etc req, existing_blocks = self._reqs_need_pending_save[req_id] updated_blocks = list(existing_blocks) + (block_ids) self._reqs_need_pending_save[req_id] = (req, updated_blocks) @@ -1356,26 +1358,17 @@ class MoRIIOConnectorWorker: self._ping_thread = None self._writer = MoRIIOWriter(self) - engine_suffix = ( - f"{self.moriio_config.local_ip}:{self.moriio_config.handshake_port}" - f":tp {self.tp_rank}:dp {self.dp_rank}" + role = "producer" if self.is_producer else "consumer" + engine_suffix = f"{self.moriio_config.local_ip}:{self.moriio_config.handshake_port}:tp{self.tp_rank}:dp{self.dp_rank}" + self.moriio_engine = IOEngine( + f"{role}:{engine_suffix}", + IOEngineConfig( + self.moriio_config.local_ip, self.moriio_config.local_kv_port + ), ) - if not self.is_producer: - self.moriio_engine = IOEngine( - "consumer:" + engine_suffix, - IOEngineConfig( - self.moriio_config.local_ip, self.moriio_config.local_kv_port - ), - ) - else: - self.moriio_engine = IOEngine( - "producer:" + engine_suffix, - IOEngineConfig( - self.moriio_config.local_ip, self.moriio_config.local_kv_port - ), - ) logger.debug( - "build MORI IOEngine %s:%s", + "build MORI IOEngine %s (ip=%s port=%s)", + f"{role}:{engine_suffix}", self.moriio_config.local_ip, self.moriio_config.local_kv_port, ) @@ -1604,18 +1597,6 @@ class MoRIIOConnectorWorker: if should_break: break - # def handle_proxy_request(self): - # if self.is_producer: - # raise NotImplementedError( - # "prefill instance doesn't need to send kv cache in pull mode" - # ) - # while True: - # socks = dict(self.poller.poll()) - # logger.debug("handle_proxy_request: socks = %s", socks) - - # if self.metadata_socket not in socks: - # continue - def close(self): if hasattr(self, "_handshake_initiation_executor"): self._handshake_initiation_executor.shutdown(wait=False) @@ -2264,7 +2245,7 @@ class MoRIIOConnectorWorker: sess_idx = list(self.layer_name_to_local_kv_cache_metadata.keys()).index( layer_name ) - #TODO : apply multi-session batch-read when moriio support it + # TODO : apply multi-session batch-read when moriio support it transfer_status = self.moriio_wrapper.read_remote_data( offs[2], offs[0], offs[1], sessions[sess_idx] )