From 70ea1b2460d67767d83f6a7f194e231c932f1fcf Mon Sep 17 00:00:00 2001 From: inkcherry Date: Thu, 20 Nov 2025 04:36:19 +0000 Subject: [PATCH] refine code Signed-off-by: inkcherry --- .../kv_connector/v1/moriio_connector.py | 80 +++++++------------ 1 file changed, 28 insertions(+), 52 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index 697ae3a786271..603efc0507bed 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -7,6 +7,13 @@ import pickle import queue import threading import time +import msgpack +import msgspec +import numpy as np +import torch +import zmq +import logging + from collections import defaultdict from collections.abc import Iterator from concurrent.futures import Future, ThreadPoolExecutor @@ -14,12 +21,6 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Optional from weakref import ref as weakref_ref -import msgpack -import msgspec -import numpy as np -import torch -import zmq - from vllm import envs from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.selector import get_attn_backend @@ -50,7 +51,6 @@ if TYPE_CHECKING: from dataclasses import field from enum import Enum from queue import Empty, Queue -import logging logger = init_logger(__name__) @@ -70,7 +70,9 @@ class MoRIIOConstants: # Default GPU count per node for standard configurations RANK_PER_NODE = 8 - + + PING_INTERVAL = 5 + MAX_PING_RETRIES = 100000 try: import mori @@ -80,7 +82,6 @@ try: IOEngine, IOEngineConfig, MemoryDesc, - StatusCode, ) logger.info("MoRIIO is available") @@ -404,10 +405,12 @@ class MoRIIOWriter: return # Wait for CUDA event + #The attention computation of the current layer cannot overlap with the kv transfer task, otherwise it will cause precision issues. + #This event is used to synchronize the kv transfer and computation tasks. task.event.synchronize() # Update engine ID with DP rank - task.dst_engine_id = f"{task.dst_engine_id}_dp{request_info.decode_dp_rank}" + task.dst_engine_id= self.worker.get_engine_name_with_dp(task.dst_engine_id,request_info.decode_dp_rank) # Get or create sessions sessions = self.worker._get_built_session(task.dst_engine_id) @@ -665,7 +668,6 @@ class MoRIIOWrapper: except Exception as e: logger.error(f"Error processing message: {e}") raise HandshakeError(f"Error processing message: {e}") from e - continue self.notify_thread = threading.Thread( target=_async_wait, daemon=True, name="moriio-notify-listener" @@ -685,7 +687,6 @@ class MoRIIOWrapper: data = msgpack.loads(msg) if isinstance(data, dict) and "req_id" in data: self._handle_structured_message(data) - handled = True return except (msgpack.exceptions.ExtraData, msgpack.exceptions.UnpackException): @@ -841,6 +842,7 @@ class MoRIIOConnector(KVConnectorBase_V1): role: KVConnectorRole, kv_cache_config: Optional["KVCacheConfig"] = None, ): + super().__init__(vllm_config, role) assert vllm_config.kv_transfer_config is not None # assert vllm_config.kv_transfer_config.engine_id is not None self.engine_id = ( @@ -1149,10 +1151,8 @@ class MoRIIOConnectorScheduler: ) self._reqs_need_pending_save[req_id] = (req, updated_blocks) if ( - len( - self._reqs_need_pending_save[req_id][1] + len(self._reqs_need_pending_save[req_id][1]) * self.block_size - ) >= req.num_prompt_tokens ): meta.add_new_req( @@ -1363,13 +1363,9 @@ class MoRIIOConnectorWorker: ) # Agent. self.moriio_wrapper = MoRIIOWrapper(tp_rank=self.tp_rank, dp_rank=self.dp_rank) - self.moriio_wrapper.set_moriio_engine(self.moriio_engine) - self.moriio_wrapper.set_backend_type(BackendType.RDMA) - self.moriio_wrapper.notify_port = self.moriio_config.notify_port - self.local_kv_cache_metadata = [] self.local_kv_cache_size = [] self.layer_name_to_local_kv_cache_metadata: dict[str, list[Any]] = dict() @@ -1450,7 +1446,6 @@ class MoRIIOConnectorWorker: self.use_mla = self.model_config.use_mla self.built_session = False self.built_write_session: defaultdict[str, list] = defaultdict(list) - self._write_session_lock = threading.Lock() self.debug_cache = [] backend = get_attn_backend( self.model_config.get_head_size(), @@ -1467,15 +1462,6 @@ class MoRIIOConnectorWorker: # self._use_flashinfer = attn_backend == _Backend.FLASHINFER logger.debug("Detected attention backend %s", self.backend_name) - self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size} - - ####write worker### - # self._write_task_q: Queue[WriteTask] = Queue() - # self._write_worker_started = False - # self._write_worker_lock = threading.Lock() - # self._deferred_tasks: list[WriteTask] = [] - # ####write worker### - def schedule_write_blocks( self, request_id: str, @@ -1539,8 +1525,6 @@ class MoRIIOConnectorWorker: return self.built_write_session[remote_engine_id] def _ping(self, zmq_context): - PING_INTERVAL = 5 - MAX_RETRIES = 100000 http_request_address = f"http://{self.request_address}/v1/completions" role = "P" if self.is_producer else "D" @@ -1585,13 +1569,13 @@ class MoRIIOConnectorWorker: retry_count += 1 finally: - if retry_count >= MAX_RETRIES: + if retry_count >= MoRIIOConstants.MAX_RETRIES: logger.error( - f"Max retries ({MAX_RETRIES}) exceeded. Stopping ping loop." + f"Max retries ({MoRIIOConstants.MAX_RETRIES}) exceeded. Stopping ping loop." ) break - time.sleep(PING_INTERVAL) + time.sleep(MoRIIOConstants.PING_INTERVAL) index += 1 def handle_proxy_request(self): @@ -1603,7 +1587,6 @@ class MoRIIOConnectorWorker: socks = dict(self.poller.poll()) logger.debug(f"handle_proxy_request: {socks = }") - # TODO: inkcherry , check here? if self.metadata_socket not in socks: continue @@ -1665,8 +1648,6 @@ class MoRIIOConnectorWorker: "MoRIIO handshake listener received done recv for req %s", req_id.decode(), ) - else: - pass def _moriio_handshake( self, @@ -1709,9 +1690,7 @@ class MoRIIOConnectorWorker: ) self.moriio_wrapper.remote_engine_ip = host - remote_agent_name = self.moriio_wrapper.register_remote_engine( - metadata.agent_metadata - ) + remote_agent_name = EngineDesc.unpack(metadata.agent_metadata).key logger.info( @@ -1873,8 +1852,6 @@ class MoRIIOConnectorWorker: if layer_name not in self.layer_name_to_local_kv_cache_metadata: self.layer_name_to_local_kv_cache_metadata[layer_name] = [] - # for cache in cache_list: - # moriio_mem_metadata = self.moriio_wrapper.register_local_tensor(cache) moriio_mem_metadata = self.moriio_wrapper.register_local_tensor(kv_cache) self.layer_name_to_local_kv_cache_metadata[layer_name].append( moriio_mem_metadata @@ -1942,7 +1919,7 @@ class MoRIIOConnectorWorker: to track which workers are done. """ - done_sending, done_recving = set(), set() + done_sending = set() if self.is_producer: done_sending = self.moriio_wrapper.pop_finished_req_ids() @@ -1991,7 +1968,6 @@ class MoRIIOConnectorWorker: remote_engine_id = None for req_id, meta in metadata.reqs_to_save.items(): - remote_engine_id = meta.remote_engine_id # we only need to check if dp0 in rank remote_engine_id = ( str(meta.remote_host) + ":" + str(meta.remote_handshake_port) @@ -2030,13 +2006,15 @@ class MoRIIOConnectorWorker: break else: break - + + + def get_engine_name_with_dp(self, engine_name,dp_rank): + return f"{engine_name}_dp{dp_rank}" def start_load_kv(self, metadata: MoRIIOConnectorMetadata): """ Start loading by triggering non-blocking moriio_xfer. We check for these trnxs to complete in each step(). """ - # print("start load kv") if self.is_producer: self.moriio_wrapper.async_wait_reqid() return @@ -2051,7 +2029,7 @@ class MoRIIOConnectorWorker: str(meta.remote_host) + ":" + str(meta.remote_handshake_port) ) meta.remote_engine_id = remote_engine_id - dp0_remote_engine_id = f"{remote_engine_id}_dp0" + dp0_remote_engine_id = self.get_engine_name_with_dp(remote_engine_id,0) if dp0_remote_engine_id not in self._remote_agents: # Initiate handshake with remote engine to exchange metadata. with self._handshake_lock: @@ -2216,7 +2194,6 @@ class MoRIIOConnectorWorker: blknum, blksize, hs = self.kv_cache_shape hn = 1 block_stride = stride[0] - ktov_stride = None else: _, blknum, blksize, hn, hs = self.kv_cache_shape ktov_stride = stride[0] @@ -2260,21 +2237,20 @@ class MoRIIOConnectorWorker: return # we only test TP<->TP in read mode # assert self.dp_rank>0, "only test TP<->TP in read mode" - dst_engine_id += "_dp0" - sessions = self._get_built_session(dst_engine_id) + dp0_engine_id=self.get_engine_name_with_dp(dst_engine_id,0) + sessions = self._get_built_session(dp0_engine_id) first_layer = list(self.layer_name_to_local_kv_cache_metadata.keys())[0] offs = self._compute_block_transfer_offsets( first_layer, local_block_ids, remote_block_ids ) - a, b, c = offs[0], offs[1], offs[2] for layer_name in self.layer_name_to_local_kv_cache_metadata.keys(): sess_idx = list(self.layer_name_to_local_kv_cache_metadata.keys()).index( layer_name ) transfer_status = self.moriio_wrapper.read_remote_data( - c, a, b, sessions[sess_idx] + offs[0], offs[1], offs[2], sessions[sess_idx] ) self._recving_transfers[request_id].append(transfer_status)