improve shutdown

Signed-off-by: inkcherry <mingzhi.liu@amd.com>
This commit is contained in:
inkcherry 2025-12-01 08:56:53 +00:00
parent bbe6dad401
commit 6fbeee78d1

View File

@ -726,6 +726,7 @@ class MoRIIOWrapper:
return
except (msgpack.exceptions.ExtraData, msgpack.exceptions.UnpackException):
logger.debug("Failed to decode msgpack message, will try as string. Error: %s")
pass
try:
@ -802,6 +803,16 @@ class MoRIIOWrapper:
done_write_cache = set(self.done_write_cache_req_ids)
self.done_write_cache_req_ids = []
return done_write_cache
def shutdown(self):
logger.debug("Closing MoRIIOWrapper and cleaning up ZMQ sockets")
for path, sock in self.paths.items():
try:
sock.close(linger=0)
logger.debug(f"Closed ZMQ socket for path: {path}")
except Exception as e:
logger.warning(f"Error closing ZMQ socket for path {path}: {e}")
self.paths.clear()
class MoRIIOAgentMetadata(
@ -1011,6 +1022,12 @@ class MoRIIOConnector(KVConnectorBase_V1):
def wait_for_save(self):
pass
def shutdown(self):
if self.connector_worker is not None:
self.connector_worker.shutdown()
if self.connector_scheduler is not None:
self.connector_scheduler.shutdown()
def has_connector_metadata(self) -> bool:
"""Check whether the connector metadata is currently set.
@ -1351,6 +1368,8 @@ class MoRIIOConnectorWorker:
logger.info("Initializing MoRIIO worker %s", engine_id)
logging.getLogger("aiter").disabled = True
# Config.
self.vllm_config = vllm_config
assert vllm_config.kv_transfer_config is not None, (
@ -1634,7 +1653,18 @@ class MoRIIOConnectorWorker:
time.sleep(MoRIIOConstants.PING_INTERVAL)
index += 1
def close(self):
def shutdown(self):
if hasattr(self, 'moriio_wrapper') and self.moriio_wrapper:
self.moriio_wrapper.shutdown()
for path, sock in self.paths.items():
try:
sock.close(linger=0)
logger.debug(f"Closed ZMQ socket for path: {path}")
except Exception as e:
logger.warning(f"Error closing ZMQ socket for path {path}: {e}")
self.paths.clear()
if hasattr(self, "_handshake_initiation_executor"):
self._handshake_initiation_executor.shutdown(wait=False)
@ -1649,7 +1679,7 @@ class MoRIIOConnectorWorker:
self.zmq_context = None
def __del__(self):
self.close()
self.shutdown()
@staticmethod
def _moriio_handshake_listener(
@ -1726,7 +1756,7 @@ class MoRIIOConnectorWorker:
sock.send(MoRIIOConstants.GET_META_MSG)
received_frame = sock.recv_multipart()
if len(received_frame) != 2 or received_frame[0] != b"":
assert 0, f"unexpected frame! {received_frame = }"
raise HandshakeError(f"Unexpected frame! {received_frame = }")
metadata_bytes = received_frame[1]
decoder = msgspec.msgpack.Decoder(MoRIIOAgentMetadata)
@ -1767,7 +1797,7 @@ class MoRIIOConnectorWorker:
received_frame = sock.recv_multipart()
if len(received_frame) != 2 or received_frame[0] != b"":
assert 0, f"Unexpected frame! {received_frame = }"
raise HandshakeError(f"unexpected frame! {received_frame = }")
buf = received_frame[1]
self.layer_name_to_remote_kv_cache_metadata[expected_engine_id] = (
msgpack.loads(buf)