diff --git a/tests/mq_llm_engine/test_error_handling.py b/tests/mq_llm_engine/test_error_handling.py index 205ab00aa6b1..83bc4e7cf847 100644 --- a/tests/mq_llm_engine/test_error_handling.py +++ b/tests/mq_llm_engine/test_error_handling.py @@ -21,7 +21,7 @@ from vllm.usage.usage_lib import UsageContext from vllm.utils import FlexibleArgumentParser MODEL = "google/gemma-1.1-2b-it" -ENGINE_ARGS = AsyncEngineArgs(model=MODEL) +ENGINE_ARGS = AsyncEngineArgs(model=MODEL, enforce_eager=True) RAISED_ERROR = KeyError RAISED_VALUE = "foo" @@ -266,3 +266,28 @@ async def test_mp_cuda_init(): async with build_async_engine_client(args): pass + + +@pytest.mark.asyncio +async def test_engine_process_death(tmp_socket): + with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, + ipc_path=tmp_socket) as engine: + + client = await engine.make_client() + assert client.is_running + + # kill the engine process + engine.proc.kill() + + # Generate call should fail + with pytest.raises(MQEngineDeadError): + async for _ in client.generate(prompt="Hello my name is", + sampling_params=SamplingParams(), + request_id=uuid.uuid4()): + pass + + # And the health check should show the engine is dead + with pytest.raises(RuntimeError, match="Engine process .* died"): + await client.check_health() + + client.close() diff --git a/tests/mq_llm_engine/utils.py b/tests/mq_llm_engine/utils.py index 3ffa126070ca..f717c1355431 100644 --- a/tests/mq_llm_engine/utils.py +++ b/tests/mq_llm_engine/utils.py @@ -68,7 +68,7 @@ class RemoteMQLLMEngine: async def make_client(self) -> MQLLMEngineClient: engine_config = self.engine_args.create_engine_config() - client = MQLLMEngineClient(self.ipc_path, engine_config) + client = MQLLMEngineClient(self.ipc_path, engine_config, self.proc.pid) while True: try: await client.setup() diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 9e5a6b21f4c1..6e6630b3ff55 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -6,6 +6,7 @@ from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping, Optional, Union, cast, overload) import cloudpickle +import psutil import zmq import zmq.asyncio from zmq import Frame # type: ignore[attr-defined] @@ -77,7 +78,8 @@ class MQLLMEngineClient(EngineClient): every N seconds, confirming the engine is healthy """ - def __init__(self, ipc_path: str, engine_config: EngineConfig): + def __init__(self, ipc_path: str, engine_config: EngineConfig, + engine_pid: int): self.context = zmq.asyncio.Context() self._errored_with: Optional[BaseException] = None @@ -115,6 +117,7 @@ class MQLLMEngineClient(EngineClient): # Loop to check health of the LLMEngine periodically. # Started after the MQLLMEngine is ready. self.health_loop: Optional[asyncio.Task] = None + self._engine_process = psutil.Process(engine_pid) @staticmethod def is_unsupported_config(engine_args: AsyncEngineArgs): @@ -131,21 +134,22 @@ class MQLLMEngineClient(EngineClient): socket.close(linger=0) async def run_heartbeat_loop(self, timeout: int): - """Background loop that continually listens to the RPCServer for - heartbeats. + """Background loop that continually checks to ensure the engine process + is still alive. """ try: while True: - if await self.heartbeat_socket.poll(timeout=timeout) == 0: - # No heartbeat was received. Set error and exit the loop + # Check if the engine process is running: + if not self._engine_process.is_running() or ( + self._engine_process.status() == psutil.STATUS_ZOMBIE): + # NB: is_running() returns True for zombies self._set_errored( - TimeoutError("No heartbeat received " - "from MQLLMEngine")) - logger.debug("Shutting down MQLLMEngineClient check " - "health loop due to timeout") + RuntimeError( + f"Engine process (pid {self._engine_process.pid}) " + "died.")) break - else: + if await self.heartbeat_socket.poll(timeout=timeout): # Heartbeat received- check the message await self._check_success( error_message="Heartbeat failed.", @@ -156,6 +160,11 @@ class MQLLMEngineClient(EngineClient): except asyncio.CancelledError: logger.debug("Shutting down MQLLMEngineClient check health loop.") + except psutil.NoSuchProcess: + self._set_errored( + RuntimeError( + f"Engine process (pid {self._engine_process.pid}) died.")) + except Exception as e: self._set_errored(e) diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index f67acdf66075..0a7f430eca48 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -1,7 +1,5 @@ import pickle import signal -import threading -import time from contextlib import contextmanager from typing import Iterator, List, Optional, Union @@ -21,7 +19,7 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, RPCStartupRequest, RPCStartupResponse, RPCUProfileRequest) # yapf: enable -from vllm.envs import VLLM_RPC_TIMEOUT, VLLM_USE_V1 +from vllm.envs import VLLM_USE_V1 from vllm.executor.gpu_executor import GPUExecutor from vllm.logger import init_logger from vllm.outputs import RequestOutput @@ -108,20 +106,6 @@ class MQLLMEngine: # Error state. self._errored_with: Optional[BaseException] = None - # Heartbeat thread - self.heartbeat_thread = threading.Thread(target=self._heartbeat_loop, - daemon=True) - self._heartbeat_stop_event = threading.Event() - # The heartbeat needs to be faster than what the client will wait for - # The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds - self.heartbeat_interval_seconds = VLLM_RPC_TIMEOUT / 5000.0 - - self._last_alive_time = time.time() - # The heartbeats can tolerate a long period of the engine chugging - # away at a generation request. - # The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds - self.last_alive_threshold = VLLM_RPC_TIMEOUT * 3.0 / 1000.0 - @property def dead_error(self) -> BaseException: if self._errored_with is not None: @@ -157,8 +141,6 @@ class MQLLMEngine: try: logger.debug("Starting Startup Loop.") self.run_startup_loop() - logger.debug("Starting heartbeat thread") - self.heartbeat_thread.start() logger.debug("Starting Engine Loop.") self.run_engine_loop() except Exception as e: @@ -172,7 +154,6 @@ class MQLLMEngine: def cleanup(self): """Cleanup zeromq state on shutdown.""" # Closes all sockets and destroys context. - self._heartbeat_stop_event.set() self.ctx.destroy(linger=0) del self.engine @@ -211,11 +192,12 @@ class MQLLMEngine: """Core busy loop of the LLMEngine.""" while True: - self._alive() if not self.engine.has_unfinished_requests(): # Poll until there is work to do. while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: - self._alive() + # When there's no work, check on engine health and send + # health status back to client + self._health_check() self.engine.do_log_stats() logger.debug("Waiting for new requests in engine loop.") @@ -314,32 +296,16 @@ class MQLLMEngine: if self.log_requests: logger.info("Aborted request %s.", request.request_id) - def _heartbeat_loop(self): - while not self._heartbeat_stop_event.wait( - timeout=self.heartbeat_interval_seconds): - # Loops until the stop event is set - self._heartbeat() - - logger.debug("Exiting MQLLMEngine heartbeat thread") - - def _heartbeat(self): + def _health_check(self): # Send unhealthy if engine has already errored if self._errored_with is not None: self._send_unhealthy(self._errored_with) - - # Check for life of the main loop - elif time.time() - self._last_alive_time > self.last_alive_threshold: - self._send_unhealthy(RuntimeError("Engine loop has died")) - - else: - # Otherwise- check health of the engine - # self.engine.check_health() raises on unhealthy - try: - self.engine.check_health() - self._send_healthy() - except Exception as e: - self._set_errored(e) - self._send_unhealthy(e) + try: + self.engine.check_health() + self._send_healthy() + except Exception as e: + self._set_errored(e) + self._send_unhealthy(e) def _send_outputs(self, outputs: REQUEST_OUTPUTS_T): """Send List of RequestOutput to RPCClient.""" @@ -369,9 +335,6 @@ class MQLLMEngine: if self._errored_with is None: self._errored_with = e - def _alive(self): - self._last_alive_time = time.time() - def start_profile(self) -> None: if type(self.engine.model_executor) is GPUExecutor: self.engine.model_executor.start_profile() diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index afa370a1cb40..0e0ec311023e 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -176,13 +176,16 @@ async def build_async_engine_client_from_engine_args( UsageContext.OPENAI_API_SERVER, ipc_path)) engine_process.start() - logger.info("Started engine process with PID %d", engine_process.pid) + engine_pid = engine_process.pid + assert engine_pid is not None, "Engine process failed to start" + logger.info("Started engine process with PID %d", engine_pid) # Build RPCClient, which conforms to EngineClient Protocol. # NOTE: Actually, this is not true yet. We still need to support # embedding models via RPC (see TODO above) engine_config = engine_args.create_engine_config() - mp_engine_client = MQLLMEngineClient(ipc_path, engine_config) + mp_engine_client = MQLLMEngineClient(ipc_path, engine_config, + engine_pid) try: while True: