diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 095829db8394..cae4eecc0dee 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -709,8 +709,28 @@ class AsyncMicrobatchTokenizer: def cancel_task_threadsafe(task: Task): - if task and not task.done() and not (loop := task.get_loop()).is_closed(): - loop.call_soon_threadsafe(task.cancel) + if task and not task.done(): + run_in_loop(task.get_loop(), task.cancel) + + +def close_sockets(sockets: Sequence[Union[zmq.Socket, zmq.asyncio.Socket]]): + for sock in sockets: + if sock is not None: + sock.close(linger=0) + + +def run_in_loop(loop: AbstractEventLoop, function: Callable, *args): + if in_loop(loop): + function(*args) + elif not loop.is_closed(): + loop.call_soon_threadsafe(function, *args) + + +def in_loop(event_loop: AbstractEventLoop) -> bool: + try: + return asyncio.get_running_loop() == event_loop + except RuntimeError: + return False def make_async( diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 05b4d7260896..5ffa555570a2 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -23,8 +23,8 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.tasks import SupportedTask -from vllm.utils import (cancel_task_threadsafe, get_open_port, - get_open_zmq_inproc_path, make_zmq_socket) +from vllm.utils import (close_sockets, get_open_port, get_open_zmq_inproc_path, + in_loop, make_zmq_socket) from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType, ReconfigureDistributedRequest, ReconfigureRankType, @@ -317,7 +317,7 @@ class BackgroundResources: """Used as a finalizer for clean shutdown, avoiding circular reference back to the client object.""" - ctx: Union[zmq.Context] + ctx: zmq.Context # If CoreEngineProcManager, it manages local engines; # if CoreEngineActorManager, it manages all engines. engine_manager: Optional[Union[CoreEngineProcManager, @@ -326,6 +326,8 @@ class BackgroundResources: output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None first_req_send_socket: Optional[zmq.asyncio.Socket] = None + first_req_rcv_socket: Optional[zmq.asyncio.Socket] = None + stats_update_socket: Optional[zmq.asyncio.Socket] = None output_queue_task: Optional[asyncio.Task] = None stats_update_task: Optional[asyncio.Task] = None shutdown_path: Optional[str] = None @@ -343,23 +345,47 @@ class BackgroundResources: if self.coordinator is not None: self.coordinator.close() - cancel_task_threadsafe(self.output_queue_task) - cancel_task_threadsafe(self.stats_update_task) + if isinstance(self.output_socket, zmq.asyncio.Socket): + # Async case. + loop = self.output_socket._get_loop() + asyncio.get_running_loop() + sockets = (self.output_socket, self.input_socket, + self.first_req_send_socket, self.first_req_rcv_socket, + self.stats_update_socket) - # ZMQ context termination can hang if the sockets - # aren't explicitly closed first. - for socket in (self.output_socket, self.input_socket, - self.first_req_send_socket): - if socket is not None: - socket.close(linger=0) + tasks = (self.output_queue_task, self.stats_update_task) - if self.shutdown_path is not None: - # We must ensure that the sync output socket is - # closed cleanly in its own thread. - with self.ctx.socket(zmq.PAIR) as shutdown_sender: - shutdown_sender.connect(self.shutdown_path) - # Send shutdown signal. - shutdown_sender.send(b'') + def close_sockets_and_tasks(): + close_sockets(sockets) + for task in tasks: + if task is not None and not task.done(): + task.cancel() + + if in_loop(loop): + close_sockets_and_tasks() + elif not loop.is_closed(): + loop.call_soon_threadsafe(close_sockets_and_tasks) + else: + # Loop has been closed, try to clean up directly. + del tasks + del close_sockets_and_tasks + close_sockets(sockets) + del self.output_queue_task + del self.stats_update_task + else: + # Sync case. + + # ZMQ context termination can hang if the sockets + # aren't explicitly closed first. + close_sockets((self.output_socket, self.input_socket)) + + if self.shutdown_path is not None: + # We must ensure that the sync output socket is + # closed cleanly in its own thread. + with self.ctx.socket(zmq.PAIR) as shutdown_sender: + shutdown_sender.connect(self.shutdown_path) + # Send shutdown signal. + shutdown_sender.send(b'') def validate_alive(self, frames: Sequence[zmq.Frame]): if len(frames) == 1 and (frames[0].buffer @@ -969,14 +995,19 @@ class DPAsyncMPClient(AsyncMPClient): self.engine_ranks_managed[-1] + 1) async def run_engine_stats_update_task(): - with make_zmq_socket(self.ctx, self.stats_update_address, - zmq.XSUB) as socket, make_zmq_socket( - self.ctx, - self.first_req_sock_addr, - zmq.PAIR, - bind=False) as first_req_rcv_socket: + with (make_zmq_socket(self.ctx, + self.stats_update_address, + zmq.XSUB, + linger=0) as socket, + make_zmq_socket(self.ctx, + self.first_req_sock_addr, + zmq.PAIR, + bind=False, + linger=0) as first_req_rcv_socket): assert isinstance(socket, zmq.asyncio.Socket) assert isinstance(first_req_rcv_socket, zmq.asyncio.Socket) + self.resources.stats_update_socket = socket + self.resources.first_req_rcv_socket = first_req_rcv_socket # Send subscription message. await socket.send(b'\x01')