[Bugfix][core] replace heartbeat with pid check (#9818)

Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
This commit is contained in:
Joe Runde 2024-10-30 10:34:07 -06:00 committed by GitHub
parent 9ff4511e43
commit 3b3f1e7436
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 62 additions and 62 deletions

View File

@ -21,7 +21,7 @@ from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser
MODEL = "google/gemma-1.1-2b-it"
ENGINE_ARGS = AsyncEngineArgs(model=MODEL)
ENGINE_ARGS = AsyncEngineArgs(model=MODEL, enforce_eager=True)
RAISED_ERROR = KeyError
RAISED_VALUE = "foo"
@ -266,3 +266,28 @@ async def test_mp_cuda_init():
async with build_async_engine_client(args):
pass
@pytest.mark.asyncio
async def test_engine_process_death(tmp_socket):
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
ipc_path=tmp_socket) as engine:
client = await engine.make_client()
assert client.is_running
# kill the engine process
engine.proc.kill()
# Generate call should fail
with pytest.raises(MQEngineDeadError):
async for _ in client.generate(prompt="Hello my name is",
sampling_params=SamplingParams(),
request_id=uuid.uuid4()):
pass
# And the health check should show the engine is dead
with pytest.raises(RuntimeError, match="Engine process .* died"):
await client.check_health()
client.close()

View File

@ -68,7 +68,7 @@ class RemoteMQLLMEngine:
async def make_client(self) -> MQLLMEngineClient:
engine_config = self.engine_args.create_engine_config()
client = MQLLMEngineClient(self.ipc_path, engine_config)
client = MQLLMEngineClient(self.ipc_path, engine_config, self.proc.pid)
while True:
try:
await client.setup()

View File

@ -6,6 +6,7 @@ from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping,
Optional, Union, cast, overload)
import cloudpickle
import psutil
import zmq
import zmq.asyncio
from zmq import Frame # type: ignore[attr-defined]
@ -77,7 +78,8 @@ class MQLLMEngineClient(EngineClient):
every N seconds, confirming the engine is healthy
"""
def __init__(self, ipc_path: str, engine_config: EngineConfig):
def __init__(self, ipc_path: str, engine_config: EngineConfig,
engine_pid: int):
self.context = zmq.asyncio.Context()
self._errored_with: Optional[BaseException] = None
@ -115,6 +117,7 @@ class MQLLMEngineClient(EngineClient):
# Loop to check health of the LLMEngine periodically.
# Started after the MQLLMEngine is ready.
self.health_loop: Optional[asyncio.Task] = None
self._engine_process = psutil.Process(engine_pid)
@staticmethod
def is_unsupported_config(engine_args: AsyncEngineArgs):
@ -131,21 +134,22 @@ class MQLLMEngineClient(EngineClient):
socket.close(linger=0)
async def run_heartbeat_loop(self, timeout: int):
"""Background loop that continually listens to the RPCServer for
heartbeats.
"""Background loop that continually checks to ensure the engine process
is still alive.
"""
try:
while True:
if await self.heartbeat_socket.poll(timeout=timeout) == 0:
# No heartbeat was received. Set error and exit the loop
# Check if the engine process is running:
if not self._engine_process.is_running() or (
self._engine_process.status() == psutil.STATUS_ZOMBIE):
# NB: is_running() returns True for zombies
self._set_errored(
TimeoutError("No heartbeat received "
"from MQLLMEngine"))
logger.debug("Shutting down MQLLMEngineClient check "
"health loop due to timeout")
RuntimeError(
f"Engine process (pid {self._engine_process.pid}) "
"died."))
break
else:
if await self.heartbeat_socket.poll(timeout=timeout):
# Heartbeat received- check the message
await self._check_success(
error_message="Heartbeat failed.",
@ -156,6 +160,11 @@ class MQLLMEngineClient(EngineClient):
except asyncio.CancelledError:
logger.debug("Shutting down MQLLMEngineClient check health loop.")
except psutil.NoSuchProcess:
self._set_errored(
RuntimeError(
f"Engine process (pid {self._engine_process.pid}) died."))
except Exception as e:
self._set_errored(e)

View File

@ -1,7 +1,5 @@
import pickle
import signal
import threading
import time
from contextlib import contextmanager
from typing import Iterator, List, Optional, Union
@ -21,7 +19,7 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
RPCStartupRequest, RPCStartupResponse,
RPCUProfileRequest)
# yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT, VLLM_USE_V1
from vllm.envs import VLLM_USE_V1
from vllm.executor.gpu_executor import GPUExecutor
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
@ -108,20 +106,6 @@ class MQLLMEngine:
# Error state.
self._errored_with: Optional[BaseException] = None
# Heartbeat thread
self.heartbeat_thread = threading.Thread(target=self._heartbeat_loop,
daemon=True)
self._heartbeat_stop_event = threading.Event()
# The heartbeat needs to be faster than what the client will wait for
# The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds
self.heartbeat_interval_seconds = VLLM_RPC_TIMEOUT / 5000.0
self._last_alive_time = time.time()
# The heartbeats can tolerate a long period of the engine chugging
# away at a generation request.
# The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds
self.last_alive_threshold = VLLM_RPC_TIMEOUT * 3.0 / 1000.0
@property
def dead_error(self) -> BaseException:
if self._errored_with is not None:
@ -157,8 +141,6 @@ class MQLLMEngine:
try:
logger.debug("Starting Startup Loop.")
self.run_startup_loop()
logger.debug("Starting heartbeat thread")
self.heartbeat_thread.start()
logger.debug("Starting Engine Loop.")
self.run_engine_loop()
except Exception as e:
@ -172,7 +154,6 @@ class MQLLMEngine:
def cleanup(self):
"""Cleanup zeromq state on shutdown."""
# Closes all sockets and destroys context.
self._heartbeat_stop_event.set()
self.ctx.destroy(linger=0)
del self.engine
@ -211,11 +192,12 @@ class MQLLMEngine:
"""Core busy loop of the LLMEngine."""
while True:
self._alive()
if not self.engine.has_unfinished_requests():
# Poll until there is work to do.
while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
self._alive()
# When there's no work, check on engine health and send
# health status back to client
self._health_check()
self.engine.do_log_stats()
logger.debug("Waiting for new requests in engine loop.")
@ -314,32 +296,16 @@ class MQLLMEngine:
if self.log_requests:
logger.info("Aborted request %s.", request.request_id)
def _heartbeat_loop(self):
while not self._heartbeat_stop_event.wait(
timeout=self.heartbeat_interval_seconds):
# Loops until the stop event is set
self._heartbeat()
logger.debug("Exiting MQLLMEngine heartbeat thread")
def _heartbeat(self):
def _health_check(self):
# Send unhealthy if engine has already errored
if self._errored_with is not None:
self._send_unhealthy(self._errored_with)
# Check for life of the main loop
elif time.time() - self._last_alive_time > self.last_alive_threshold:
self._send_unhealthy(RuntimeError("Engine loop has died"))
else:
# Otherwise- check health of the engine
# self.engine.check_health() raises on unhealthy
try:
self.engine.check_health()
self._send_healthy()
except Exception as e:
self._set_errored(e)
self._send_unhealthy(e)
try:
self.engine.check_health()
self._send_healthy()
except Exception as e:
self._set_errored(e)
self._send_unhealthy(e)
def _send_outputs(self, outputs: REQUEST_OUTPUTS_T):
"""Send List of RequestOutput to RPCClient."""
@ -369,9 +335,6 @@ class MQLLMEngine:
if self._errored_with is None:
self._errored_with = e
def _alive(self):
self._last_alive_time = time.time()
def start_profile(self) -> None:
if type(self.engine.model_executor) is GPUExecutor:
self.engine.model_executor.start_profile()

View File

@ -176,13 +176,16 @@ async def build_async_engine_client_from_engine_args(
UsageContext.OPENAI_API_SERVER,
ipc_path))
engine_process.start()
logger.info("Started engine process with PID %d", engine_process.pid)
engine_pid = engine_process.pid
assert engine_pid is not None, "Engine process failed to start"
logger.info("Started engine process with PID %d", engine_pid)
# Build RPCClient, which conforms to EngineClient Protocol.
# NOTE: Actually, this is not true yet. We still need to support
# embedding models via RPC (see TODO above)
engine_config = engine_args.create_engine_config()
mp_engine_client = MQLLMEngineClient(ipc_path, engine_config)
mp_engine_client = MQLLMEngineClient(ipc_path, engine_config,
engine_pid)
try:
while True: