mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:34:27 +08:00
[V1][BugFix] Fix remaining sync engine client shutdown errors/hangs (#13869)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
6247bae6c6
commit
5db6b2c961
@ -15,8 +15,6 @@ DTYPE = "half"
|
||||
def _vllm_model(apc: bool, vllm_runner, monkeypatch):
|
||||
"""Set up VllmRunner instance."""
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||
# TODO(nick): Single-proc to work around a ZMQ shutdown hang for now.
|
||||
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||
return vllm_runner(
|
||||
MODEL,
|
||||
dtype=DTYPE,
|
||||
|
||||
@ -500,6 +500,10 @@ def get_open_zmq_ipc_path() -> str:
|
||||
return f"ipc://{base_rpc_path}/{uuid4()}"
|
||||
|
||||
|
||||
def get_open_zmq_inproc_path() -> str:
|
||||
return f"inproc://{uuid4()}"
|
||||
|
||||
|
||||
def get_open_port() -> int:
|
||||
"""
|
||||
Get an open port for the vLLM process to listen on.
|
||||
@ -2108,12 +2112,12 @@ def get_exception_traceback():
|
||||
def make_zmq_socket(
|
||||
ctx: Union[zmq.asyncio.Context, zmq.Context], # type: ignore[name-defined]
|
||||
path: str,
|
||||
type: Any,
|
||||
socket_type: Any,
|
||||
) -> Union[zmq.Socket, zmq.asyncio.Socket]: # type: ignore[name-defined]
|
||||
"""Make a ZMQ socket with the proper bind/connect semantics."""
|
||||
|
||||
mem = psutil.virtual_memory()
|
||||
socket = ctx.socket(type)
|
||||
socket = ctx.socket(socket_type)
|
||||
|
||||
# Calculate buffer size based on system memory
|
||||
total_mem = mem.total / 1024**3
|
||||
@ -2127,29 +2131,27 @@ def make_zmq_socket(
|
||||
else:
|
||||
buf_size = -1 # Use system default buffer size
|
||||
|
||||
if type == zmq.constants.PULL:
|
||||
if socket_type == zmq.constants.PULL:
|
||||
socket.setsockopt(zmq.constants.RCVHWM, 0)
|
||||
socket.setsockopt(zmq.constants.RCVBUF, buf_size)
|
||||
socket.connect(path)
|
||||
elif type == zmq.constants.PUSH:
|
||||
elif socket_type == zmq.constants.PUSH:
|
||||
socket.setsockopt(zmq.constants.SNDHWM, 0)
|
||||
socket.setsockopt(zmq.constants.SNDBUF, buf_size)
|
||||
socket.bind(path)
|
||||
else:
|
||||
raise ValueError(f"Unknown Socket Type: {type}")
|
||||
raise ValueError(f"Unknown Socket Type: {socket_type}")
|
||||
|
||||
return socket
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def zmq_socket_ctx(
|
||||
path: str,
|
||||
type: Any) -> Iterator[zmq.Socket]: # type: ignore[name-defined]
|
||||
def zmq_socket_ctx(path: str, socket_type: Any) -> Iterator[zmq.Socket]:
|
||||
"""Context manager for a ZMQ socket"""
|
||||
|
||||
ctx = zmq.Context(io_threads=2) # type: ignore[attr-defined]
|
||||
ctx = zmq.Context() # type: ignore[attr-defined]
|
||||
try:
|
||||
yield make_zmq_socket(ctx, path, type)
|
||||
yield make_zmq_socket(ctx, path, socket_type)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.debug("Got Keyboard Interrupt.")
|
||||
|
||||
@ -18,8 +18,8 @@ import zmq.asyncio
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.utils import (get_open_zmq_ipc_path, kill_process_tree,
|
||||
make_zmq_socket)
|
||||
from vllm.utils import (get_open_zmq_inproc_path, get_open_zmq_ipc_path,
|
||||
kill_process_tree, make_zmq_socket)
|
||||
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
|
||||
EngineCoreRequestType, UtilityOutput)
|
||||
from vllm.v1.engine.core import EngineCore, EngineCoreProc
|
||||
@ -202,10 +202,11 @@ class BackgroundResources:
|
||||
"""Used as a finalizer for clean shutdown, avoiding
|
||||
circular reference back to the client object."""
|
||||
|
||||
ctx: Union[zmq.Context, zmq.asyncio.Context] = None
|
||||
ctx: Union[zmq.Context] = None
|
||||
output_socket: Union[zmq.Socket, zmq.asyncio.Socket] = None
|
||||
input_socket: Union[zmq.Socket, zmq.asyncio.Socket] = None
|
||||
proc_handle: Optional[BackgroundProcHandle] = None
|
||||
shutdown_path: Optional[str] = None
|
||||
|
||||
def __call__(self):
|
||||
"""Clean up background resources."""
|
||||
@ -218,8 +219,13 @@ class BackgroundResources:
|
||||
self.output_socket.close(linger=0)
|
||||
if self.input_socket is not None:
|
||||
self.input_socket.close(linger=0)
|
||||
if self.ctx is not None:
|
||||
self.ctx.destroy(linger=0)
|
||||
if self.shutdown_path is not None:
|
||||
# We must ensure that the sync output socket is
|
||||
# closed cleanly in its own thread.
|
||||
with self.ctx.socket(zmq.PAIR) as shutdown_sender:
|
||||
shutdown_sender.connect(self.shutdown_path)
|
||||
# Send shutdown signal.
|
||||
shutdown_sender.send(b'')
|
||||
|
||||
|
||||
class MPClient(EngineCoreClient):
|
||||
@ -261,28 +267,23 @@ class MPClient(EngineCoreClient):
|
||||
self.decoder = MsgpackDecoder(EngineCoreOutputs)
|
||||
|
||||
# ZMQ setup.
|
||||
self.ctx = (
|
||||
zmq.asyncio.Context() # type: ignore[attr-defined]
|
||||
if asyncio_mode else zmq.Context()) # type: ignore[attr-defined]
|
||||
sync_ctx = zmq.Context()
|
||||
self.ctx = zmq.asyncio.Context(sync_ctx) if asyncio_mode else sync_ctx
|
||||
|
||||
# This will ensure resources created so far are closed
|
||||
# when the client is garbage collected, even if an
|
||||
# exception is raised mid-construction.
|
||||
resources = BackgroundResources(ctx=self.ctx)
|
||||
self._finalizer = weakref.finalize(self, resources)
|
||||
self.resources = BackgroundResources(ctx=sync_ctx)
|
||||
self._finalizer = weakref.finalize(self, self.resources)
|
||||
|
||||
# Paths and sockets for IPC.
|
||||
output_path = get_open_zmq_ipc_path()
|
||||
# Paths for IPC.
|
||||
self.output_path = get_open_zmq_ipc_path()
|
||||
input_path = get_open_zmq_ipc_path()
|
||||
resources.output_socket = make_zmq_socket(self.ctx, output_path,
|
||||
zmq.constants.PULL)
|
||||
resources.input_socket = make_zmq_socket(self.ctx, input_path,
|
||||
zmq.constants.PUSH)
|
||||
|
||||
# Start EngineCore in background process.
|
||||
resources.proc_handle = BackgroundProcHandle(
|
||||
self.resources.proc_handle = BackgroundProcHandle(
|
||||
input_path=input_path,
|
||||
output_path=output_path,
|
||||
output_path=self.output_path,
|
||||
process_name="EngineCore",
|
||||
target_fn=EngineCoreProc.run_engine_core,
|
||||
process_kwargs={
|
||||
@ -291,8 +292,10 @@ class MPClient(EngineCoreClient):
|
||||
"log_stats": log_stats,
|
||||
})
|
||||
|
||||
self.output_socket = resources.output_socket
|
||||
self.input_socket = resources.input_socket
|
||||
# Create input socket.
|
||||
self.resources.input_socket = make_zmq_socket(self.ctx, input_path,
|
||||
zmq.constants.PUSH)
|
||||
self.input_socket = self.resources.input_socket
|
||||
self.utility_results: dict[int, AnyFuture] = {}
|
||||
|
||||
def shutdown(self):
|
||||
@ -325,27 +328,48 @@ class SyncMPClient(MPClient):
|
||||
|
||||
# Ensure that the outputs socket processing thread does not have
|
||||
# a ref to the client which prevents gc.
|
||||
output_socket = self.output_socket
|
||||
ctx = self.ctx
|
||||
output_path = self.output_path
|
||||
decoder = self.decoder
|
||||
utility_results = self.utility_results
|
||||
outputs_queue = self.outputs_queue
|
||||
|
||||
shutdown_path = get_open_zmq_inproc_path()
|
||||
self.resources.shutdown_path = shutdown_path
|
||||
|
||||
def process_outputs_socket():
|
||||
shutdown_socket = ctx.socket(zmq.PAIR)
|
||||
shutdown_socket.bind(shutdown_path)
|
||||
out_socket = make_zmq_socket(ctx, output_path, zmq.constants.PULL)
|
||||
try:
|
||||
poller = zmq.Poller()
|
||||
poller.register(shutdown_socket)
|
||||
poller.register(out_socket)
|
||||
while True:
|
||||
(frame, ) = output_socket.recv_multipart(copy=False)
|
||||
socks = poller.poll()
|
||||
if not socks:
|
||||
continue
|
||||
if len(socks) == 2 or socks[0][0] == shutdown_socket:
|
||||
# shutdown signal, exit thread.
|
||||
break
|
||||
|
||||
(frame, ) = out_socket.recv_multipart(copy=False)
|
||||
outputs = decoder.decode(frame.buffer)
|
||||
if outputs.utility_output:
|
||||
_process_utility_output(outputs.utility_output,
|
||||
utility_results)
|
||||
else:
|
||||
outputs_queue.put_nowait(outputs)
|
||||
except zmq.error.ContextTerminated:
|
||||
# Expected when the class is GC'd / during process termination.
|
||||
pass
|
||||
finally:
|
||||
# Close sockets.
|
||||
shutdown_socket.close(linger=0)
|
||||
out_socket.close(linger=0)
|
||||
|
||||
# Process outputs from engine in separate thread.
|
||||
Thread(target=process_outputs_socket, daemon=True).start()
|
||||
self.output_queue_thread = Thread(target=process_outputs_socket,
|
||||
name="EngineCoreOutputQueueThread",
|
||||
daemon=True)
|
||||
self.output_queue_thread.start()
|
||||
|
||||
def get_output(self) -> EngineCoreOutputs:
|
||||
return self.outputs_queue.get()
|
||||
@ -424,10 +448,13 @@ class AsyncMPClient(MPClient):
|
||||
# Perform IO in separate task to parallelize as much as possible.
|
||||
# Avoid task having direct reference back to the client.
|
||||
self.outputs_queue = asyncio.Queue()
|
||||
output_socket = self.output_socket
|
||||
decoder = self.decoder
|
||||
utility_results = self.utility_results
|
||||
outputs_queue = self.outputs_queue
|
||||
output_path = self.output_path
|
||||
output_socket = make_zmq_socket(self.ctx, output_path,
|
||||
zmq.constants.PULL)
|
||||
self.resources.output_socket = output_socket
|
||||
|
||||
async def process_outputs_socket():
|
||||
while True:
|
||||
@ -439,7 +466,8 @@ class AsyncMPClient(MPClient):
|
||||
else:
|
||||
outputs_queue.put_nowait(outputs)
|
||||
|
||||
self.queue_task = asyncio.create_task(process_outputs_socket())
|
||||
self.queue_task = asyncio.create_task(process_outputs_socket(),
|
||||
name="EngineCoreOutputQueueTask")
|
||||
|
||||
async def get_output_async(self) -> EngineCoreOutputs:
|
||||
if self.outputs_queue is None:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user