mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:25:01 +08:00
[Bugfix][core] replace heartbeat with pid check (#9818)
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
This commit is contained in:
parent
9ff4511e43
commit
3b3f1e7436
@ -21,7 +21,7 @@ from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
MODEL = "google/gemma-1.1-2b-it"
|
||||
ENGINE_ARGS = AsyncEngineArgs(model=MODEL)
|
||||
ENGINE_ARGS = AsyncEngineArgs(model=MODEL, enforce_eager=True)
|
||||
RAISED_ERROR = KeyError
|
||||
RAISED_VALUE = "foo"
|
||||
|
||||
@ -266,3 +266,28 @@ async def test_mp_cuda_init():
|
||||
|
||||
async with build_async_engine_client(args):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_engine_process_death(tmp_socket):
|
||||
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
|
||||
ipc_path=tmp_socket) as engine:
|
||||
|
||||
client = await engine.make_client()
|
||||
assert client.is_running
|
||||
|
||||
# kill the engine process
|
||||
engine.proc.kill()
|
||||
|
||||
# Generate call should fail
|
||||
with pytest.raises(MQEngineDeadError):
|
||||
async for _ in client.generate(prompt="Hello my name is",
|
||||
sampling_params=SamplingParams(),
|
||||
request_id=uuid.uuid4()):
|
||||
pass
|
||||
|
||||
# And the health check should show the engine is dead
|
||||
with pytest.raises(RuntimeError, match="Engine process .* died"):
|
||||
await client.check_health()
|
||||
|
||||
client.close()
|
||||
|
||||
@ -68,7 +68,7 @@ class RemoteMQLLMEngine:
|
||||
|
||||
async def make_client(self) -> MQLLMEngineClient:
|
||||
engine_config = self.engine_args.create_engine_config()
|
||||
client = MQLLMEngineClient(self.ipc_path, engine_config)
|
||||
client = MQLLMEngineClient(self.ipc_path, engine_config, self.proc.pid)
|
||||
while True:
|
||||
try:
|
||||
await client.setup()
|
||||
|
||||
@ -6,6 +6,7 @@ from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping,
|
||||
Optional, Union, cast, overload)
|
||||
|
||||
import cloudpickle
|
||||
import psutil
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
from zmq import Frame # type: ignore[attr-defined]
|
||||
@ -77,7 +78,8 @@ class MQLLMEngineClient(EngineClient):
|
||||
every N seconds, confirming the engine is healthy
|
||||
"""
|
||||
|
||||
def __init__(self, ipc_path: str, engine_config: EngineConfig):
|
||||
def __init__(self, ipc_path: str, engine_config: EngineConfig,
|
||||
engine_pid: int):
|
||||
self.context = zmq.asyncio.Context()
|
||||
self._errored_with: Optional[BaseException] = None
|
||||
|
||||
@ -115,6 +117,7 @@ class MQLLMEngineClient(EngineClient):
|
||||
# Loop to check health of the LLMEngine periodically.
|
||||
# Started after the MQLLMEngine is ready.
|
||||
self.health_loop: Optional[asyncio.Task] = None
|
||||
self._engine_process = psutil.Process(engine_pid)
|
||||
|
||||
@staticmethod
|
||||
def is_unsupported_config(engine_args: AsyncEngineArgs):
|
||||
@ -131,21 +134,22 @@ class MQLLMEngineClient(EngineClient):
|
||||
socket.close(linger=0)
|
||||
|
||||
async def run_heartbeat_loop(self, timeout: int):
|
||||
"""Background loop that continually listens to the RPCServer for
|
||||
heartbeats.
|
||||
"""Background loop that continually checks to ensure the engine process
|
||||
is still alive.
|
||||
"""
|
||||
try:
|
||||
while True:
|
||||
if await self.heartbeat_socket.poll(timeout=timeout) == 0:
|
||||
# No heartbeat was received. Set error and exit the loop
|
||||
# Check if the engine process is running:
|
||||
if not self._engine_process.is_running() or (
|
||||
self._engine_process.status() == psutil.STATUS_ZOMBIE):
|
||||
# NB: is_running() returns True for zombies
|
||||
self._set_errored(
|
||||
TimeoutError("No heartbeat received "
|
||||
"from MQLLMEngine"))
|
||||
logger.debug("Shutting down MQLLMEngineClient check "
|
||||
"health loop due to timeout")
|
||||
RuntimeError(
|
||||
f"Engine process (pid {self._engine_process.pid}) "
|
||||
"died."))
|
||||
break
|
||||
|
||||
else:
|
||||
if await self.heartbeat_socket.poll(timeout=timeout):
|
||||
# Heartbeat received- check the message
|
||||
await self._check_success(
|
||||
error_message="Heartbeat failed.",
|
||||
@ -156,6 +160,11 @@ class MQLLMEngineClient(EngineClient):
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("Shutting down MQLLMEngineClient check health loop.")
|
||||
|
||||
except psutil.NoSuchProcess:
|
||||
self._set_errored(
|
||||
RuntimeError(
|
||||
f"Engine process (pid {self._engine_process.pid}) died."))
|
||||
|
||||
except Exception as e:
|
||||
self._set_errored(e)
|
||||
|
||||
|
||||
@ -1,7 +1,5 @@
|
||||
import pickle
|
||||
import signal
|
||||
import threading
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from typing import Iterator, List, Optional, Union
|
||||
|
||||
@ -21,7 +19,7 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
|
||||
RPCStartupRequest, RPCStartupResponse,
|
||||
RPCUProfileRequest)
|
||||
# yapf: enable
|
||||
from vllm.envs import VLLM_RPC_TIMEOUT, VLLM_USE_V1
|
||||
from vllm.envs import VLLM_USE_V1
|
||||
from vllm.executor.gpu_executor import GPUExecutor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
@ -108,20 +106,6 @@ class MQLLMEngine:
|
||||
# Error state.
|
||||
self._errored_with: Optional[BaseException] = None
|
||||
|
||||
# Heartbeat thread
|
||||
self.heartbeat_thread = threading.Thread(target=self._heartbeat_loop,
|
||||
daemon=True)
|
||||
self._heartbeat_stop_event = threading.Event()
|
||||
# The heartbeat needs to be faster than what the client will wait for
|
||||
# The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds
|
||||
self.heartbeat_interval_seconds = VLLM_RPC_TIMEOUT / 5000.0
|
||||
|
||||
self._last_alive_time = time.time()
|
||||
# The heartbeats can tolerate a long period of the engine chugging
|
||||
# away at a generation request.
|
||||
# The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds
|
||||
self.last_alive_threshold = VLLM_RPC_TIMEOUT * 3.0 / 1000.0
|
||||
|
||||
@property
|
||||
def dead_error(self) -> BaseException:
|
||||
if self._errored_with is not None:
|
||||
@ -157,8 +141,6 @@ class MQLLMEngine:
|
||||
try:
|
||||
logger.debug("Starting Startup Loop.")
|
||||
self.run_startup_loop()
|
||||
logger.debug("Starting heartbeat thread")
|
||||
self.heartbeat_thread.start()
|
||||
logger.debug("Starting Engine Loop.")
|
||||
self.run_engine_loop()
|
||||
except Exception as e:
|
||||
@ -172,7 +154,6 @@ class MQLLMEngine:
|
||||
def cleanup(self):
|
||||
"""Cleanup zeromq state on shutdown."""
|
||||
# Closes all sockets and destroys context.
|
||||
self._heartbeat_stop_event.set()
|
||||
self.ctx.destroy(linger=0)
|
||||
del self.engine
|
||||
|
||||
@ -211,11 +192,12 @@ class MQLLMEngine:
|
||||
"""Core busy loop of the LLMEngine."""
|
||||
|
||||
while True:
|
||||
self._alive()
|
||||
if not self.engine.has_unfinished_requests():
|
||||
# Poll until there is work to do.
|
||||
while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
|
||||
self._alive()
|
||||
# When there's no work, check on engine health and send
|
||||
# health status back to client
|
||||
self._health_check()
|
||||
self.engine.do_log_stats()
|
||||
logger.debug("Waiting for new requests in engine loop.")
|
||||
|
||||
@ -314,32 +296,16 @@ class MQLLMEngine:
|
||||
if self.log_requests:
|
||||
logger.info("Aborted request %s.", request.request_id)
|
||||
|
||||
def _heartbeat_loop(self):
|
||||
while not self._heartbeat_stop_event.wait(
|
||||
timeout=self.heartbeat_interval_seconds):
|
||||
# Loops until the stop event is set
|
||||
self._heartbeat()
|
||||
|
||||
logger.debug("Exiting MQLLMEngine heartbeat thread")
|
||||
|
||||
def _heartbeat(self):
|
||||
def _health_check(self):
|
||||
# Send unhealthy if engine has already errored
|
||||
if self._errored_with is not None:
|
||||
self._send_unhealthy(self._errored_with)
|
||||
|
||||
# Check for life of the main loop
|
||||
elif time.time() - self._last_alive_time > self.last_alive_threshold:
|
||||
self._send_unhealthy(RuntimeError("Engine loop has died"))
|
||||
|
||||
else:
|
||||
# Otherwise- check health of the engine
|
||||
# self.engine.check_health() raises on unhealthy
|
||||
try:
|
||||
self.engine.check_health()
|
||||
self._send_healthy()
|
||||
except Exception as e:
|
||||
self._set_errored(e)
|
||||
self._send_unhealthy(e)
|
||||
try:
|
||||
self.engine.check_health()
|
||||
self._send_healthy()
|
||||
except Exception as e:
|
||||
self._set_errored(e)
|
||||
self._send_unhealthy(e)
|
||||
|
||||
def _send_outputs(self, outputs: REQUEST_OUTPUTS_T):
|
||||
"""Send List of RequestOutput to RPCClient."""
|
||||
@ -369,9 +335,6 @@ class MQLLMEngine:
|
||||
if self._errored_with is None:
|
||||
self._errored_with = e
|
||||
|
||||
def _alive(self):
|
||||
self._last_alive_time = time.time()
|
||||
|
||||
def start_profile(self) -> None:
|
||||
if type(self.engine.model_executor) is GPUExecutor:
|
||||
self.engine.model_executor.start_profile()
|
||||
|
||||
@ -176,13 +176,16 @@ async def build_async_engine_client_from_engine_args(
|
||||
UsageContext.OPENAI_API_SERVER,
|
||||
ipc_path))
|
||||
engine_process.start()
|
||||
logger.info("Started engine process with PID %d", engine_process.pid)
|
||||
engine_pid = engine_process.pid
|
||||
assert engine_pid is not None, "Engine process failed to start"
|
||||
logger.info("Started engine process with PID %d", engine_pid)
|
||||
|
||||
# Build RPCClient, which conforms to EngineClient Protocol.
|
||||
# NOTE: Actually, this is not true yet. We still need to support
|
||||
# embedding models via RPC (see TODO above)
|
||||
engine_config = engine_args.create_engine_config()
|
||||
mp_engine_client = MQLLMEngineClient(ipc_path, engine_config)
|
||||
mp_engine_client = MQLLMEngineClient(ipc_path, engine_config,
|
||||
engine_pid)
|
||||
|
||||
try:
|
||||
while True:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user