mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-19 04:37:02 +08:00
improve shutdown
Signed-off-by: inkcherry <mingzhi.liu@amd.com>
This commit is contained in:
parent
bbe6dad401
commit
6fbeee78d1
@ -726,6 +726,7 @@ class MoRIIOWrapper:
|
|||||||
|
|
||||||
return
|
return
|
||||||
except (msgpack.exceptions.ExtraData, msgpack.exceptions.UnpackException):
|
except (msgpack.exceptions.ExtraData, msgpack.exceptions.UnpackException):
|
||||||
|
logger.debug("Failed to decode msgpack message, will try as string. Error: %s")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -803,6 +804,16 @@ class MoRIIOWrapper:
|
|||||||
self.done_write_cache_req_ids = []
|
self.done_write_cache_req_ids = []
|
||||||
return done_write_cache
|
return done_write_cache
|
||||||
|
|
||||||
|
def shutdown(self):
|
||||||
|
logger.debug("Closing MoRIIOWrapper and cleaning up ZMQ sockets")
|
||||||
|
for path, sock in self.paths.items():
|
||||||
|
try:
|
||||||
|
sock.close(linger=0)
|
||||||
|
logger.debug(f"Closed ZMQ socket for path: {path}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error closing ZMQ socket for path {path}: {e}")
|
||||||
|
self.paths.clear()
|
||||||
|
|
||||||
|
|
||||||
class MoRIIOAgentMetadata(
|
class MoRIIOAgentMetadata(
|
||||||
msgspec.Struct,
|
msgspec.Struct,
|
||||||
@ -1011,6 +1022,12 @@ class MoRIIOConnector(KVConnectorBase_V1):
|
|||||||
def wait_for_save(self):
|
def wait_for_save(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def shutdown(self):
|
||||||
|
if self.connector_worker is not None:
|
||||||
|
self.connector_worker.shutdown()
|
||||||
|
if self.connector_scheduler is not None:
|
||||||
|
self.connector_scheduler.shutdown()
|
||||||
|
|
||||||
def has_connector_metadata(self) -> bool:
|
def has_connector_metadata(self) -> bool:
|
||||||
"""Check whether the connector metadata is currently set.
|
"""Check whether the connector metadata is currently set.
|
||||||
|
|
||||||
@ -1351,6 +1368,8 @@ class MoRIIOConnectorWorker:
|
|||||||
|
|
||||||
logger.info("Initializing MoRIIO worker %s", engine_id)
|
logger.info("Initializing MoRIIO worker %s", engine_id)
|
||||||
|
|
||||||
|
logging.getLogger("aiter").disabled = True
|
||||||
|
|
||||||
# Config.
|
# Config.
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
assert vllm_config.kv_transfer_config is not None, (
|
assert vllm_config.kv_transfer_config is not None, (
|
||||||
@ -1634,7 +1653,18 @@ class MoRIIOConnectorWorker:
|
|||||||
time.sleep(MoRIIOConstants.PING_INTERVAL)
|
time.sleep(MoRIIOConstants.PING_INTERVAL)
|
||||||
index += 1
|
index += 1
|
||||||
|
|
||||||
def close(self):
|
def shutdown(self):
|
||||||
|
if hasattr(self, 'moriio_wrapper') and self.moriio_wrapper:
|
||||||
|
self.moriio_wrapper.shutdown()
|
||||||
|
|
||||||
|
for path, sock in self.paths.items():
|
||||||
|
try:
|
||||||
|
sock.close(linger=0)
|
||||||
|
logger.debug(f"Closed ZMQ socket for path: {path}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error closing ZMQ socket for path {path}: {e}")
|
||||||
|
self.paths.clear()
|
||||||
|
|
||||||
if hasattr(self, "_handshake_initiation_executor"):
|
if hasattr(self, "_handshake_initiation_executor"):
|
||||||
self._handshake_initiation_executor.shutdown(wait=False)
|
self._handshake_initiation_executor.shutdown(wait=False)
|
||||||
|
|
||||||
@ -1649,7 +1679,7 @@ class MoRIIOConnectorWorker:
|
|||||||
self.zmq_context = None
|
self.zmq_context = None
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
self.close()
|
self.shutdown()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _moriio_handshake_listener(
|
def _moriio_handshake_listener(
|
||||||
@ -1726,7 +1756,7 @@ class MoRIIOConnectorWorker:
|
|||||||
sock.send(MoRIIOConstants.GET_META_MSG)
|
sock.send(MoRIIOConstants.GET_META_MSG)
|
||||||
received_frame = sock.recv_multipart()
|
received_frame = sock.recv_multipart()
|
||||||
if len(received_frame) != 2 or received_frame[0] != b"":
|
if len(received_frame) != 2 or received_frame[0] != b"":
|
||||||
assert 0, f"unexpected frame! {received_frame = }"
|
raise HandshakeError(f"Unexpected frame! {received_frame = }")
|
||||||
|
|
||||||
metadata_bytes = received_frame[1]
|
metadata_bytes = received_frame[1]
|
||||||
decoder = msgspec.msgpack.Decoder(MoRIIOAgentMetadata)
|
decoder = msgspec.msgpack.Decoder(MoRIIOAgentMetadata)
|
||||||
@ -1767,7 +1797,7 @@ class MoRIIOConnectorWorker:
|
|||||||
|
|
||||||
received_frame = sock.recv_multipart()
|
received_frame = sock.recv_multipart()
|
||||||
if len(received_frame) != 2 or received_frame[0] != b"":
|
if len(received_frame) != 2 or received_frame[0] != b"":
|
||||||
assert 0, f"Unexpected frame! {received_frame = }"
|
raise HandshakeError(f"unexpected frame! {received_frame = }")
|
||||||
buf = received_frame[1]
|
buf = received_frame[1]
|
||||||
self.layer_name_to_remote_kv_cache_metadata[expected_engine_id] = (
|
self.layer_name_to_remote_kv_cache_metadata[expected_engine_id] = (
|
||||||
msgpack.loads(buf)
|
msgpack.loads(buf)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user