From 6fbeee78d1d43f57c8a546a9ca8896f29b38d50b Mon Sep 17 00:00:00 2001 From: inkcherry Date: Mon, 1 Dec 2025 08:56:53 +0000 Subject: [PATCH] improve shutdown Signed-off-by: inkcherry --- .../kv_connector/v1/moriio_connector.py | 38 +++++++++++++++++-- 1 file changed, 34 insertions(+), 4 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index 5b8370a56e488..781a76baab2bf 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -726,6 +726,7 @@ class MoRIIOWrapper: return except (msgpack.exceptions.ExtraData, msgpack.exceptions.UnpackException): + logger.debug("Failed to decode msgpack message, will try as string. Error: %s") pass try: @@ -802,6 +803,16 @@ class MoRIIOWrapper: done_write_cache = set(self.done_write_cache_req_ids) self.done_write_cache_req_ids = [] return done_write_cache + + def shutdown(self): + logger.debug("Closing MoRIIOWrapper and cleaning up ZMQ sockets") + for path, sock in self.paths.items(): + try: + sock.close(linger=0) + logger.debug(f"Closed ZMQ socket for path: {path}") + except Exception as e: + logger.warning(f"Error closing ZMQ socket for path {path}: {e}") + self.paths.clear() class MoRIIOAgentMetadata( @@ -1011,6 +1022,12 @@ class MoRIIOConnector(KVConnectorBase_V1): def wait_for_save(self): pass + def shutdown(self): + if self.connector_worker is not None: + self.connector_worker.shutdown() + if self.connector_scheduler is not None: + self.connector_scheduler.shutdown() + def has_connector_metadata(self) -> bool: """Check whether the connector metadata is currently set. @@ -1351,6 +1368,8 @@ class MoRIIOConnectorWorker: logger.info("Initializing MoRIIO worker %s", engine_id) + logging.getLogger("aiter").disabled = True + # Config. self.vllm_config = vllm_config assert vllm_config.kv_transfer_config is not None, ( @@ -1634,7 +1653,18 @@ class MoRIIOConnectorWorker: time.sleep(MoRIIOConstants.PING_INTERVAL) index += 1 - def close(self): + def shutdown(self): + if hasattr(self, 'moriio_wrapper') and self.moriio_wrapper: + self.moriio_wrapper.shutdown() + + for path, sock in self.paths.items(): + try: + sock.close(linger=0) + logger.debug(f"Closed ZMQ socket for path: {path}") + except Exception as e: + logger.warning(f"Error closing ZMQ socket for path {path}: {e}") + self.paths.clear() + if hasattr(self, "_handshake_initiation_executor"): self._handshake_initiation_executor.shutdown(wait=False) @@ -1649,7 +1679,7 @@ class MoRIIOConnectorWorker: self.zmq_context = None def __del__(self): - self.close() + self.shutdown() @staticmethod def _moriio_handshake_listener( @@ -1726,7 +1756,7 @@ class MoRIIOConnectorWorker: sock.send(MoRIIOConstants.GET_META_MSG) received_frame = sock.recv_multipart() if len(received_frame) != 2 or received_frame[0] != b"": - assert 0, f"unexpected frame! {received_frame = }" + raise HandshakeError(f"Unexpected frame! {received_frame = }") metadata_bytes = received_frame[1] decoder = msgspec.msgpack.Decoder(MoRIIOAgentMetadata) @@ -1767,7 +1797,7 @@ class MoRIIOConnectorWorker: received_frame = sock.recv_multipart() if len(received_frame) != 2 or received_frame[0] != b"": - assert 0, f"Unexpected frame! {received_frame = }" + raise HandshakeError(f"unexpected frame! {received_frame = }") buf = received_frame[1] self.layer_name_to_remote_kv_cache_metadata[expected_engine_id] = ( msgpack.loads(buf)