[Bugfix][core] replace heartbeat with pid check (#9818)

Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
This commit is contained in:
Joe Runde 2024-10-30 10:34:07 -06:00 committed by GitHub
parent 9ff4511e43
commit 3b3f1e7436
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 62 additions and 62 deletions

View File

@ -21,7 +21,7 @@ from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
MODEL = "google/gemma-1.1-2b-it" MODEL = "google/gemma-1.1-2b-it"
ENGINE_ARGS = AsyncEngineArgs(model=MODEL) ENGINE_ARGS = AsyncEngineArgs(model=MODEL, enforce_eager=True)
RAISED_ERROR = KeyError RAISED_ERROR = KeyError
RAISED_VALUE = "foo" RAISED_VALUE = "foo"
@ -266,3 +266,28 @@ async def test_mp_cuda_init():
async with build_async_engine_client(args): async with build_async_engine_client(args):
pass pass
@pytest.mark.asyncio
async def test_engine_process_death(tmp_socket):
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
ipc_path=tmp_socket) as engine:
client = await engine.make_client()
assert client.is_running
# kill the engine process
engine.proc.kill()
# Generate call should fail
with pytest.raises(MQEngineDeadError):
async for _ in client.generate(prompt="Hello my name is",
sampling_params=SamplingParams(),
request_id=uuid.uuid4()):
pass
# And the health check should show the engine is dead
with pytest.raises(RuntimeError, match="Engine process .* died"):
await client.check_health()
client.close()

View File

@ -68,7 +68,7 @@ class RemoteMQLLMEngine:
async def make_client(self) -> MQLLMEngineClient: async def make_client(self) -> MQLLMEngineClient:
engine_config = self.engine_args.create_engine_config() engine_config = self.engine_args.create_engine_config()
client = MQLLMEngineClient(self.ipc_path, engine_config) client = MQLLMEngineClient(self.ipc_path, engine_config, self.proc.pid)
while True: while True:
try: try:
await client.setup() await client.setup()

View File

@ -6,6 +6,7 @@ from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping,
Optional, Union, cast, overload) Optional, Union, cast, overload)
import cloudpickle import cloudpickle
import psutil
import zmq import zmq
import zmq.asyncio import zmq.asyncio
from zmq import Frame # type: ignore[attr-defined] from zmq import Frame # type: ignore[attr-defined]
@ -77,7 +78,8 @@ class MQLLMEngineClient(EngineClient):
every N seconds, confirming the engine is healthy every N seconds, confirming the engine is healthy
""" """
def __init__(self, ipc_path: str, engine_config: EngineConfig): def __init__(self, ipc_path: str, engine_config: EngineConfig,
engine_pid: int):
self.context = zmq.asyncio.Context() self.context = zmq.asyncio.Context()
self._errored_with: Optional[BaseException] = None self._errored_with: Optional[BaseException] = None
@ -115,6 +117,7 @@ class MQLLMEngineClient(EngineClient):
# Loop to check health of the LLMEngine periodically. # Loop to check health of the LLMEngine periodically.
# Started after the MQLLMEngine is ready. # Started after the MQLLMEngine is ready.
self.health_loop: Optional[asyncio.Task] = None self.health_loop: Optional[asyncio.Task] = None
self._engine_process = psutil.Process(engine_pid)
@staticmethod @staticmethod
def is_unsupported_config(engine_args: AsyncEngineArgs): def is_unsupported_config(engine_args: AsyncEngineArgs):
@ -131,21 +134,22 @@ class MQLLMEngineClient(EngineClient):
socket.close(linger=0) socket.close(linger=0)
async def run_heartbeat_loop(self, timeout: int): async def run_heartbeat_loop(self, timeout: int):
"""Background loop that continually listens to the RPCServer for """Background loop that continually checks to ensure the engine process
heartbeats. is still alive.
""" """
try: try:
while True: while True:
if await self.heartbeat_socket.poll(timeout=timeout) == 0: # Check if the engine process is running:
# No heartbeat was received. Set error and exit the loop if not self._engine_process.is_running() or (
self._engine_process.status() == psutil.STATUS_ZOMBIE):
# NB: is_running() returns True for zombies
self._set_errored( self._set_errored(
TimeoutError("No heartbeat received " RuntimeError(
"from MQLLMEngine")) f"Engine process (pid {self._engine_process.pid}) "
logger.debug("Shutting down MQLLMEngineClient check " "died."))
"health loop due to timeout")
break break
else: if await self.heartbeat_socket.poll(timeout=timeout):
# Heartbeat received- check the message # Heartbeat received- check the message
await self._check_success( await self._check_success(
error_message="Heartbeat failed.", error_message="Heartbeat failed.",
@ -156,6 +160,11 @@ class MQLLMEngineClient(EngineClient):
except asyncio.CancelledError: except asyncio.CancelledError:
logger.debug("Shutting down MQLLMEngineClient check health loop.") logger.debug("Shutting down MQLLMEngineClient check health loop.")
except psutil.NoSuchProcess:
self._set_errored(
RuntimeError(
f"Engine process (pid {self._engine_process.pid}) died."))
except Exception as e: except Exception as e:
self._set_errored(e) self._set_errored(e)

View File

@ -1,7 +1,5 @@
import pickle import pickle
import signal import signal
import threading
import time
from contextlib import contextmanager from contextlib import contextmanager
from typing import Iterator, List, Optional, Union from typing import Iterator, List, Optional, Union
@ -21,7 +19,7 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
RPCStartupRequest, RPCStartupResponse, RPCStartupRequest, RPCStartupResponse,
RPCUProfileRequest) RPCUProfileRequest)
# yapf: enable # yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT, VLLM_USE_V1 from vllm.envs import VLLM_USE_V1
from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.gpu_executor import GPUExecutor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
@ -108,20 +106,6 @@ class MQLLMEngine:
# Error state. # Error state.
self._errored_with: Optional[BaseException] = None self._errored_with: Optional[BaseException] = None
# Heartbeat thread
self.heartbeat_thread = threading.Thread(target=self._heartbeat_loop,
daemon=True)
self._heartbeat_stop_event = threading.Event()
# The heartbeat needs to be faster than what the client will wait for
# The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds
self.heartbeat_interval_seconds = VLLM_RPC_TIMEOUT / 5000.0
self._last_alive_time = time.time()
# The heartbeats can tolerate a long period of the engine chugging
# away at a generation request.
# The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds
self.last_alive_threshold = VLLM_RPC_TIMEOUT * 3.0 / 1000.0
@property @property
def dead_error(self) -> BaseException: def dead_error(self) -> BaseException:
if self._errored_with is not None: if self._errored_with is not None:
@ -157,8 +141,6 @@ class MQLLMEngine:
try: try:
logger.debug("Starting Startup Loop.") logger.debug("Starting Startup Loop.")
self.run_startup_loop() self.run_startup_loop()
logger.debug("Starting heartbeat thread")
self.heartbeat_thread.start()
logger.debug("Starting Engine Loop.") logger.debug("Starting Engine Loop.")
self.run_engine_loop() self.run_engine_loop()
except Exception as e: except Exception as e:
@ -172,7 +154,6 @@ class MQLLMEngine:
def cleanup(self): def cleanup(self):
"""Cleanup zeromq state on shutdown.""" """Cleanup zeromq state on shutdown."""
# Closes all sockets and destroys context. # Closes all sockets and destroys context.
self._heartbeat_stop_event.set()
self.ctx.destroy(linger=0) self.ctx.destroy(linger=0)
del self.engine del self.engine
@ -211,11 +192,12 @@ class MQLLMEngine:
"""Core busy loop of the LLMEngine.""" """Core busy loop of the LLMEngine."""
while True: while True:
self._alive()
if not self.engine.has_unfinished_requests(): if not self.engine.has_unfinished_requests():
# Poll until there is work to do. # Poll until there is work to do.
while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
self._alive() # When there's no work, check on engine health and send
# health status back to client
self._health_check()
self.engine.do_log_stats() self.engine.do_log_stats()
logger.debug("Waiting for new requests in engine loop.") logger.debug("Waiting for new requests in engine loop.")
@ -314,32 +296,16 @@ class MQLLMEngine:
if self.log_requests: if self.log_requests:
logger.info("Aborted request %s.", request.request_id) logger.info("Aborted request %s.", request.request_id)
def _heartbeat_loop(self): def _health_check(self):
while not self._heartbeat_stop_event.wait(
timeout=self.heartbeat_interval_seconds):
# Loops until the stop event is set
self._heartbeat()
logger.debug("Exiting MQLLMEngine heartbeat thread")
def _heartbeat(self):
# Send unhealthy if engine has already errored # Send unhealthy if engine has already errored
if self._errored_with is not None: if self._errored_with is not None:
self._send_unhealthy(self._errored_with) self._send_unhealthy(self._errored_with)
try:
# Check for life of the main loop self.engine.check_health()
elif time.time() - self._last_alive_time > self.last_alive_threshold: self._send_healthy()
self._send_unhealthy(RuntimeError("Engine loop has died")) except Exception as e:
self._set_errored(e)
else: self._send_unhealthy(e)
# Otherwise- check health of the engine
# self.engine.check_health() raises on unhealthy
try:
self.engine.check_health()
self._send_healthy()
except Exception as e:
self._set_errored(e)
self._send_unhealthy(e)
def _send_outputs(self, outputs: REQUEST_OUTPUTS_T): def _send_outputs(self, outputs: REQUEST_OUTPUTS_T):
"""Send List of RequestOutput to RPCClient.""" """Send List of RequestOutput to RPCClient."""
@ -369,9 +335,6 @@ class MQLLMEngine:
if self._errored_with is None: if self._errored_with is None:
self._errored_with = e self._errored_with = e
def _alive(self):
self._last_alive_time = time.time()
def start_profile(self) -> None: def start_profile(self) -> None:
if type(self.engine.model_executor) is GPUExecutor: if type(self.engine.model_executor) is GPUExecutor:
self.engine.model_executor.start_profile() self.engine.model_executor.start_profile()

View File

@ -176,13 +176,16 @@ async def build_async_engine_client_from_engine_args(
UsageContext.OPENAI_API_SERVER, UsageContext.OPENAI_API_SERVER,
ipc_path)) ipc_path))
engine_process.start() engine_process.start()
logger.info("Started engine process with PID %d", engine_process.pid) engine_pid = engine_process.pid
assert engine_pid is not None, "Engine process failed to start"
logger.info("Started engine process with PID %d", engine_pid)
# Build RPCClient, which conforms to EngineClient Protocol. # Build RPCClient, which conforms to EngineClient Protocol.
# NOTE: Actually, this is not true yet. We still need to support # NOTE: Actually, this is not true yet. We still need to support
# embedding models via RPC (see TODO above) # embedding models via RPC (see TODO above)
engine_config = engine_args.create_engine_config() engine_config = engine_args.create_engine_config()
mp_engine_client = MQLLMEngineClient(ipc_path, engine_config) mp_engine_client = MQLLMEngineClient(ipc_path, engine_config,
engine_pid)
try: try:
while True: while True: