diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index c04d6055fd040..0b37f412fbfe4 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -68,8 +68,6 @@ class MoRIIOConstants: OVER = b"OVER" COMPLETION_PREFIX = "cmpl" - # Default GPU count per node for standard configurations - RANK_PER_NODE = 8 PING_INTERVAL = 5 MAX_PING_RETRIES = 100000 @@ -204,7 +202,7 @@ def get_moriio_mode() -> MoRIIOMode: def get_port_offset(dp_rank: int, tp_rank: int, tp_size: int = 1) -> int: - return ((dp_rank) * tp_size + tp_rank) % MoRIIOConstants.RANK_PER_NODE + return ((dp_rank) * tp_size + tp_rank) @dataclass @@ -405,8 +403,8 @@ class MoRIIOWriter: return # Wait for CUDA event - #The attention computation of the current layer cannot overlap with the kv transfer task, otherwise it will cause precision issues. - #This event is used to synchronize the kv transfer and computation tasks. + # The attention computation of the current layer cannot overlap with the kv transfer task, otherwise it will cause precision issues. + # This event is used to synchronize the kv transfer and computation tasks. task.event.synchronize() # Update engine ID with DP rank @@ -546,8 +544,7 @@ class MoRIIOWrapper: self.done_write_cache_req_ids = [] self.notify_thread = None self.sock = None - self.sessions = [] - self.kv_caches = None + self.sessions: list["IOEngine.Session"] = [] self.paths = {} def set_moriio_engine(self, moriio_engine): @@ -730,7 +727,7 @@ class MoRIIOWrapper: path = make_zmq_path("tcp", remote_ip, str(remote_port)) if path not in self.paths: - ctx = zmq.Context() + ctx = zmq.Context.instance() sock = make_zmq_socket( ctx=ctx, path=path, socket_type=zmq.DEALER, bind=False ) @@ -1033,7 +1030,7 @@ class MoRIIOConnectorScheduler: ): path = make_zmq_path("tcp", host, port) if path not in self.paths: - ctx = zmq.Context() + ctx = zmq.Context.instance() sock = make_zmq_socket( ctx=ctx, path=path, socket_type=zmq.DEALER, bind=False ) @@ -1268,7 +1265,7 @@ class MoRIIOConnectorWorker: self.mode = get_moriio_mode() logger.info("Initializing MoRIIO worker %s", engine_id) - # for debug + logging.getLogger("aiter").disabled = True # Config. @@ -1377,7 +1374,7 @@ class MoRIIOConnectorWorker: ) self.slot_size_bytes = 0 - self.load_ready_flag = False + self.load_ready_flag = {} self.write_ready_flags = {} self.kv_cache_shape = None self.block_shape = None @@ -1569,9 +1566,9 @@ class MoRIIOConnectorWorker: retry_count += 1 finally: - if retry_count >= MoRIIOConstants.MAX_RETRIES: + if retry_count >= MoRIIOConstants.MAX_PING_RETRIES: logger.error( - f"Max retries ({MoRIIOConstants.MAX_RETRIES}) exceeded. Stopping ping loop." + f"Max retries ({MoRIIOConstants.MAX_PING_RETRIES}) exceeded. Stopping ping loop." ) break @@ -1590,12 +1587,19 @@ class MoRIIOConnectorWorker: if self.metadata_socket not in socks: continue + def close(self): + if hasattr(self, '_handshake_initiation_executor'): + self._handshake_initiation_executor.shutdown(wait=False) + + if hasattr(self, '_moriio_handshake_listener_t') and self._moriio_handshake_listener_t: + self._moriio_handshake_listener_t.join(timeout=0) + + if hasattr(self, 'zmq_context') and self.zmq_context: + self.zmq_context.destroy(linger=0) + self.zmq_context = None def __del__(self): - """Cleanup background threads on destruction.""" - self._handshake_initiation_executor.shutdown(wait=False) - if self._moriio_handshake_listener_t: - self._moriio_handshake_listener_t.join(timeout=0) + self.close() @staticmethod def _moriio_handshake_listener( @@ -1744,7 +1748,7 @@ class MoRIIOConnectorWorker: def request_ready(_f: Future[Any], entry=(req_id, meta)): logger.info("MoRIIO handshake done for request %s", req_id) self._ready_requests.put(entry) - self.load_ready_flag = True + self.load_ready_flag [remote_engine_id] = True self.write_ready_flags[remote_engine_id] = True fut_list = [] @@ -1826,15 +1830,6 @@ class MoRIIOConnectorWorker: kv_caches_base_addr = [] caches_data = [] - # Note(tms): I modified this from the original region setup code. - # K and V are now in different regions. Advantage is that we can - # elegantly support MLA and any cases where the K and V tensors - # are non-contiguous (it's not locally guaranteed that they will be) - # Disadvantage is that the encoded MoRIIOAgentMetadata is now larger - # (roughly 8KB vs 5KB). - # Conversely for FlashInfer, K and V are transferred in the same tensor - # to better exploit the memory layout (ie num_blocks is the first dim). - for cache_or_caches in kv_caches.values(): cache_list = ( [cache_or_caches] @@ -2043,18 +2038,20 @@ class MoRIIOConnectorWorker: self._read_blocks_for_req(req_id, meta) # Start transfers for requests whose handshakes have now finished. - while True: # TODO + while True: if ( self._ready_requests.empty() - and not self.load_ready_flag + and remote_engine_id not in self.load_ready_flag and wait_handshake_readd_req ): continue - elif not self._ready_requests.empty() and self.load_ready_flag: + elif not self._ready_requests.empty() and remote_engine_id in self.load_ready_flag: self._read_blocks_for_req(*self._ready_requests.get_nowait()) break else: break + + self._reqs_to_send.update(metadata.reqs_to_send) @@ -2233,8 +2230,8 @@ class MoRIIOConnectorWorker: ) -> None: if self.mode == MoRIIOMode.WRITE: return - # we only test TP<->TP in read mode - # assert self.dp_rank>0, "only test TP<->TP in read mode" + + dp0_engine_id=self.get_engine_name_with_dp(dst_engine_id,0) sessions = self._get_built_session(dp0_engine_id) @@ -2248,7 +2245,7 @@ class MoRIIOConnectorWorker: layer_name ) transfer_status = self.moriio_wrapper.read_remote_data( - offs[0], offs[1], offs[2], sessions[sess_idx] + offs[2], offs[0], offs[1], sessions[sess_idx] ) self._recving_transfers[request_id].append(transfer_status)