[V1][BugFix] Fix remaining sync engine client shutdown errors/hangs (#13869)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-03-04 07:06:47 -08:00 committed by GitHub
parent 6247bae6c6
commit 5db6b2c961
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 68 additions and 40 deletions

View File

@ -15,8 +15,6 @@ DTYPE = "half"
def _vllm_model(apc: bool, vllm_runner, monkeypatch):
"""Set up VllmRunner instance."""
monkeypatch.setenv("VLLM_USE_V1", "1")
# TODO(nick): Single-proc to work around a ZMQ shutdown hang for now.
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
return vllm_runner(
MODEL,
dtype=DTYPE,

View File

@ -500,6 +500,10 @@ def get_open_zmq_ipc_path() -> str:
return f"ipc://{base_rpc_path}/{uuid4()}"
def get_open_zmq_inproc_path() -> str:
return f"inproc://{uuid4()}"
def get_open_port() -> int:
"""
Get an open port for the vLLM process to listen on.
@ -2108,12 +2112,12 @@ def get_exception_traceback():
def make_zmq_socket(
ctx: Union[zmq.asyncio.Context, zmq.Context], # type: ignore[name-defined]
path: str,
type: Any,
socket_type: Any,
) -> Union[zmq.Socket, zmq.asyncio.Socket]: # type: ignore[name-defined]
"""Make a ZMQ socket with the proper bind/connect semantics."""
mem = psutil.virtual_memory()
socket = ctx.socket(type)
socket = ctx.socket(socket_type)
# Calculate buffer size based on system memory
total_mem = mem.total / 1024**3
@ -2127,29 +2131,27 @@ def make_zmq_socket(
else:
buf_size = -1 # Use system default buffer size
if type == zmq.constants.PULL:
if socket_type == zmq.constants.PULL:
socket.setsockopt(zmq.constants.RCVHWM, 0)
socket.setsockopt(zmq.constants.RCVBUF, buf_size)
socket.connect(path)
elif type == zmq.constants.PUSH:
elif socket_type == zmq.constants.PUSH:
socket.setsockopt(zmq.constants.SNDHWM, 0)
socket.setsockopt(zmq.constants.SNDBUF, buf_size)
socket.bind(path)
else:
raise ValueError(f"Unknown Socket Type: {type}")
raise ValueError(f"Unknown Socket Type: {socket_type}")
return socket
@contextlib.contextmanager
def zmq_socket_ctx(
path: str,
type: Any) -> Iterator[zmq.Socket]: # type: ignore[name-defined]
def zmq_socket_ctx(path: str, socket_type: Any) -> Iterator[zmq.Socket]:
"""Context manager for a ZMQ socket"""
ctx = zmq.Context(io_threads=2) # type: ignore[attr-defined]
ctx = zmq.Context() # type: ignore[attr-defined]
try:
yield make_zmq_socket(ctx, path, type)
yield make_zmq_socket(ctx, path, socket_type)
except KeyboardInterrupt:
logger.debug("Got Keyboard Interrupt.")

View File

@ -18,8 +18,8 @@ import zmq.asyncio
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.utils import (get_open_zmq_ipc_path, kill_process_tree,
make_zmq_socket)
from vllm.utils import (get_open_zmq_inproc_path, get_open_zmq_ipc_path,
kill_process_tree, make_zmq_socket)
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType, UtilityOutput)
from vllm.v1.engine.core import EngineCore, EngineCoreProc
@ -202,10 +202,11 @@ class BackgroundResources:
"""Used as a finalizer for clean shutdown, avoiding
circular reference back to the client object."""
ctx: Union[zmq.Context, zmq.asyncio.Context] = None
ctx: Union[zmq.Context] = None
output_socket: Union[zmq.Socket, zmq.asyncio.Socket] = None
input_socket: Union[zmq.Socket, zmq.asyncio.Socket] = None
proc_handle: Optional[BackgroundProcHandle] = None
shutdown_path: Optional[str] = None
def __call__(self):
"""Clean up background resources."""
@ -218,8 +219,13 @@ class BackgroundResources:
self.output_socket.close(linger=0)
if self.input_socket is not None:
self.input_socket.close(linger=0)
if self.ctx is not None:
self.ctx.destroy(linger=0)
if self.shutdown_path is not None:
# We must ensure that the sync output socket is
# closed cleanly in its own thread.
with self.ctx.socket(zmq.PAIR) as shutdown_sender:
shutdown_sender.connect(self.shutdown_path)
# Send shutdown signal.
shutdown_sender.send(b'')
class MPClient(EngineCoreClient):
@ -261,28 +267,23 @@ class MPClient(EngineCoreClient):
self.decoder = MsgpackDecoder(EngineCoreOutputs)
# ZMQ setup.
self.ctx = (
zmq.asyncio.Context() # type: ignore[attr-defined]
if asyncio_mode else zmq.Context()) # type: ignore[attr-defined]
sync_ctx = zmq.Context()
self.ctx = zmq.asyncio.Context(sync_ctx) if asyncio_mode else sync_ctx
# This will ensure resources created so far are closed
# when the client is garbage collected, even if an
# exception is raised mid-construction.
resources = BackgroundResources(ctx=self.ctx)
self._finalizer = weakref.finalize(self, resources)
self.resources = BackgroundResources(ctx=sync_ctx)
self._finalizer = weakref.finalize(self, self.resources)
# Paths and sockets for IPC.
output_path = get_open_zmq_ipc_path()
# Paths for IPC.
self.output_path = get_open_zmq_ipc_path()
input_path = get_open_zmq_ipc_path()
resources.output_socket = make_zmq_socket(self.ctx, output_path,
zmq.constants.PULL)
resources.input_socket = make_zmq_socket(self.ctx, input_path,
zmq.constants.PUSH)
# Start EngineCore in background process.
resources.proc_handle = BackgroundProcHandle(
self.resources.proc_handle = BackgroundProcHandle(
input_path=input_path,
output_path=output_path,
output_path=self.output_path,
process_name="EngineCore",
target_fn=EngineCoreProc.run_engine_core,
process_kwargs={
@ -291,8 +292,10 @@ class MPClient(EngineCoreClient):
"log_stats": log_stats,
})
self.output_socket = resources.output_socket
self.input_socket = resources.input_socket
# Create input socket.
self.resources.input_socket = make_zmq_socket(self.ctx, input_path,
zmq.constants.PUSH)
self.input_socket = self.resources.input_socket
self.utility_results: dict[int, AnyFuture] = {}
def shutdown(self):
@ -325,27 +328,48 @@ class SyncMPClient(MPClient):
# Ensure that the outputs socket processing thread does not have
# a ref to the client which prevents gc.
output_socket = self.output_socket
ctx = self.ctx
output_path = self.output_path
decoder = self.decoder
utility_results = self.utility_results
outputs_queue = self.outputs_queue
shutdown_path = get_open_zmq_inproc_path()
self.resources.shutdown_path = shutdown_path
def process_outputs_socket():
shutdown_socket = ctx.socket(zmq.PAIR)
shutdown_socket.bind(shutdown_path)
out_socket = make_zmq_socket(ctx, output_path, zmq.constants.PULL)
try:
poller = zmq.Poller()
poller.register(shutdown_socket)
poller.register(out_socket)
while True:
(frame, ) = output_socket.recv_multipart(copy=False)
socks = poller.poll()
if not socks:
continue
if len(socks) == 2 or socks[0][0] == shutdown_socket:
# shutdown signal, exit thread.
break
(frame, ) = out_socket.recv_multipart(copy=False)
outputs = decoder.decode(frame.buffer)
if outputs.utility_output:
_process_utility_output(outputs.utility_output,
utility_results)
else:
outputs_queue.put_nowait(outputs)
except zmq.error.ContextTerminated:
# Expected when the class is GC'd / during process termination.
pass
finally:
# Close sockets.
shutdown_socket.close(linger=0)
out_socket.close(linger=0)
# Process outputs from engine in separate thread.
Thread(target=process_outputs_socket, daemon=True).start()
self.output_queue_thread = Thread(target=process_outputs_socket,
name="EngineCoreOutputQueueThread",
daemon=True)
self.output_queue_thread.start()
def get_output(self) -> EngineCoreOutputs:
return self.outputs_queue.get()
@ -424,10 +448,13 @@ class AsyncMPClient(MPClient):
# Perform IO in separate task to parallelize as much as possible.
# Avoid task having direct reference back to the client.
self.outputs_queue = asyncio.Queue()
output_socket = self.output_socket
decoder = self.decoder
utility_results = self.utility_results
outputs_queue = self.outputs_queue
output_path = self.output_path
output_socket = make_zmq_socket(self.ctx, output_path,
zmq.constants.PULL)
self.resources.output_socket = output_socket
async def process_outputs_socket():
while True:
@ -439,7 +466,8 @@ class AsyncMPClient(MPClient):
else:
outputs_queue.put_nowait(outputs)
self.queue_task = asyncio.create_task(process_outputs_socket())
self.queue_task = asyncio.create_task(process_outputs_socket(),
name="EngineCoreOutputQueueTask")
async def get_output_async(self) -> EngineCoreOutputs:
if self.outputs_queue is None: