[BugFix] Threadsafe close async zmq sockets (#22877)

Signed-off-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Nick Hill 2025-08-14 03:44:29 -07:00 committed by GitHub
parent 7c3a0741c6
commit eb08487b18
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 77 additions and 26 deletions

View File

@ -709,8 +709,28 @@ class AsyncMicrobatchTokenizer:
def cancel_task_threadsafe(task: Task): def cancel_task_threadsafe(task: Task):
if task and not task.done() and not (loop := task.get_loop()).is_closed(): if task and not task.done():
loop.call_soon_threadsafe(task.cancel) run_in_loop(task.get_loop(), task.cancel)
def close_sockets(sockets: Sequence[Union[zmq.Socket, zmq.asyncio.Socket]]):
for sock in sockets:
if sock is not None:
sock.close(linger=0)
def run_in_loop(loop: AbstractEventLoop, function: Callable, *args):
if in_loop(loop):
function(*args)
elif not loop.is_closed():
loop.call_soon_threadsafe(function, *args)
def in_loop(event_loop: AbstractEventLoop) -> bool:
try:
return asyncio.get_running_loop() == event_loop
except RuntimeError:
return False
def make_async( def make_async(

View File

@ -23,8 +23,8 @@ from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.tasks import SupportedTask from vllm.tasks import SupportedTask
from vllm.utils import (cancel_task_threadsafe, get_open_port, from vllm.utils import (close_sockets, get_open_port, get_open_zmq_inproc_path,
get_open_zmq_inproc_path, make_zmq_socket) in_loop, make_zmq_socket)
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType, EngineCoreRequestType,
ReconfigureDistributedRequest, ReconfigureRankType, ReconfigureDistributedRequest, ReconfigureRankType,
@ -317,7 +317,7 @@ class BackgroundResources:
"""Used as a finalizer for clean shutdown, avoiding """Used as a finalizer for clean shutdown, avoiding
circular reference back to the client object.""" circular reference back to the client object."""
ctx: Union[zmq.Context] ctx: zmq.Context
# If CoreEngineProcManager, it manages local engines; # If CoreEngineProcManager, it manages local engines;
# if CoreEngineActorManager, it manages all engines. # if CoreEngineActorManager, it manages all engines.
engine_manager: Optional[Union[CoreEngineProcManager, engine_manager: Optional[Union[CoreEngineProcManager,
@ -326,6 +326,8 @@ class BackgroundResources:
output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
first_req_send_socket: Optional[zmq.asyncio.Socket] = None first_req_send_socket: Optional[zmq.asyncio.Socket] = None
first_req_rcv_socket: Optional[zmq.asyncio.Socket] = None
stats_update_socket: Optional[zmq.asyncio.Socket] = None
output_queue_task: Optional[asyncio.Task] = None output_queue_task: Optional[asyncio.Task] = None
stats_update_task: Optional[asyncio.Task] = None stats_update_task: Optional[asyncio.Task] = None
shutdown_path: Optional[str] = None shutdown_path: Optional[str] = None
@ -343,23 +345,47 @@ class BackgroundResources:
if self.coordinator is not None: if self.coordinator is not None:
self.coordinator.close() self.coordinator.close()
cancel_task_threadsafe(self.output_queue_task) if isinstance(self.output_socket, zmq.asyncio.Socket):
cancel_task_threadsafe(self.stats_update_task) # Async case.
loop = self.output_socket._get_loop()
asyncio.get_running_loop()
sockets = (self.output_socket, self.input_socket,
self.first_req_send_socket, self.first_req_rcv_socket,
self.stats_update_socket)
# ZMQ context termination can hang if the sockets tasks = (self.output_queue_task, self.stats_update_task)
# aren't explicitly closed first.
for socket in (self.output_socket, self.input_socket,
self.first_req_send_socket):
if socket is not None:
socket.close(linger=0)
if self.shutdown_path is not None: def close_sockets_and_tasks():
# We must ensure that the sync output socket is close_sockets(sockets)
# closed cleanly in its own thread. for task in tasks:
with self.ctx.socket(zmq.PAIR) as shutdown_sender: if task is not None and not task.done():
shutdown_sender.connect(self.shutdown_path) task.cancel()
# Send shutdown signal.
shutdown_sender.send(b'') if in_loop(loop):
close_sockets_and_tasks()
elif not loop.is_closed():
loop.call_soon_threadsafe(close_sockets_and_tasks)
else:
# Loop has been closed, try to clean up directly.
del tasks
del close_sockets_and_tasks
close_sockets(sockets)
del self.output_queue_task
del self.stats_update_task
else:
# Sync case.
# ZMQ context termination can hang if the sockets
# aren't explicitly closed first.
close_sockets((self.output_socket, self.input_socket))
if self.shutdown_path is not None:
# We must ensure that the sync output socket is
# closed cleanly in its own thread.
with self.ctx.socket(zmq.PAIR) as shutdown_sender:
shutdown_sender.connect(self.shutdown_path)
# Send shutdown signal.
shutdown_sender.send(b'')
def validate_alive(self, frames: Sequence[zmq.Frame]): def validate_alive(self, frames: Sequence[zmq.Frame]):
if len(frames) == 1 and (frames[0].buffer if len(frames) == 1 and (frames[0].buffer
@ -969,14 +995,19 @@ class DPAsyncMPClient(AsyncMPClient):
self.engine_ranks_managed[-1] + 1) self.engine_ranks_managed[-1] + 1)
async def run_engine_stats_update_task(): async def run_engine_stats_update_task():
with make_zmq_socket(self.ctx, self.stats_update_address, with (make_zmq_socket(self.ctx,
zmq.XSUB) as socket, make_zmq_socket( self.stats_update_address,
self.ctx, zmq.XSUB,
self.first_req_sock_addr, linger=0) as socket,
zmq.PAIR, make_zmq_socket(self.ctx,
bind=False) as first_req_rcv_socket: self.first_req_sock_addr,
zmq.PAIR,
bind=False,
linger=0) as first_req_rcv_socket):
assert isinstance(socket, zmq.asyncio.Socket) assert isinstance(socket, zmq.asyncio.Socket)
assert isinstance(first_req_rcv_socket, zmq.asyncio.Socket) assert isinstance(first_req_rcv_socket, zmq.asyncio.Socket)
self.resources.stats_update_socket = socket
self.resources.first_req_rcv_socket = first_req_rcv_socket
# Send subscription message. # Send subscription message.
await socket.send(b'\x01') await socket.send(b'\x01')