From 08cd2efbb69c5cfe03cafddb99ba8dae18520ba0 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Thu, 20 Nov 2025 10:46:37 +0000 Subject: [PATCH] refine Signed-off-by: inkcherry --- .../kv_connector/v1/moriio_connector.py | 28 +++++++++---------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index 39d06b6e0bced..22ba1655eabd6 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -694,9 +694,9 @@ class MoRIIOWrapper: self._handle_completion_message(msg_str) handled = True except UnicodeDecodeError: - logger.warning(f"Received non-UTF8 message: {msg}") + logger.warning(f"Received non-UTF8 message: {msg_str}") if not handled: - raise MoRIIOError(f"Unhandled message format: {msg}") + raise MoRIIOError(f"Unhandled message format: {msg_str}") def _handle_structured_message(self, data: dict): req_id = data["req_id"] @@ -784,7 +784,7 @@ class ReqMeta: remote_host: str remote_port: int remote_handshake_port: int - remote_notify_port: int + remote_notify_port: int | None remote_engine_id: str tp_size: int remote_dp_size: int @@ -1011,7 +1011,7 @@ class MoRIIOConnectorScheduler: self._reqs_need_send: dict[ReqId, float] = {} self.sock = None self.is_producer = vllm_config.kv_transfer_config.kv_role == "kv_producer" - self.paths = {} + self.paths: dict[str, zmq.Socket] = {} def get_num_new_matched_tokens( self, @@ -1043,7 +1043,7 @@ class MoRIIOConnectorScheduler: return len(request.prompt_token_ids) - 1 - num_computed_tokens, False def send_notify_block( - self, req_id: str, block_notify_list: list[int] = None, host=None, port=None + self, req_id: str, block_notify_list: list[int] , host=None, port=None ): path = make_zmq_path("tcp", host, port) if path not in self.paths: @@ -1374,25 +1374,24 @@ class MoRIIOConnectorWorker: self.moriio_wrapper.set_moriio_engine(self.moriio_engine) self.moriio_wrapper.set_backend_type(BackendType.RDMA) self.moriio_wrapper.notify_port = self.moriio_config.notify_port - self.local_kv_cache_metadata = [] - self.local_kv_cache_size = [] - self.layer_name_to_local_kv_cache_metadata: dict[str, list[Any]] = dict() + self.local_kv_cache_metadata: list[bytes] = [] + self.local_kv_cache_size: list[int] = [] + self.layer_name_to_local_kv_cache_metadata: dict[str, list[bytes]] = {} - self.remote_kv_cache_metadata = [] - self.remote_kv_cache_size = [] + self.remote_kv_cache_metadata: list[bytes] = [] + self.remote_kv_cache_size: list[int] = [] self.layer_name_to_remote_kv_cache_metadata: dict[str, dict[str, list[Any]]] = ( dict() ) self.slot_size_bytes = 0 - self.load_ready_flag = {} - self.write_ready_flags = {} + self.load_ready_flag: dict[str, bool] = {} + self.write_ready_flags: dict[str, bool] = {} self.kv_cache_shape = None self.block_shape = None self.kv_element_size = 0 - self.done_sending_reqs = [] - self.done_send_threads = [] + # Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}. self._remote_agents: dict[EngineId, dict[int, str]] = defaultdict(dict) @@ -1452,7 +1451,6 @@ class MoRIIOConnectorWorker: self.use_mla = self.model_config.use_mla self.built_session = False self.built_write_session: defaultdict[str, list] = defaultdict(list) - self.debug_cache = [] backend = get_attn_backend( self.model_config.get_head_size(), self.model_config.dtype,