mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:15:01 +08:00
[Bugfix][core] replace heartbeat with pid check (#9818)
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
This commit is contained in:
parent
9ff4511e43
commit
3b3f1e7436
@ -21,7 +21,7 @@ from vllm.usage.usage_lib import UsageContext
|
|||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
MODEL = "google/gemma-1.1-2b-it"
|
MODEL = "google/gemma-1.1-2b-it"
|
||||||
ENGINE_ARGS = AsyncEngineArgs(model=MODEL)
|
ENGINE_ARGS = AsyncEngineArgs(model=MODEL, enforce_eager=True)
|
||||||
RAISED_ERROR = KeyError
|
RAISED_ERROR = KeyError
|
||||||
RAISED_VALUE = "foo"
|
RAISED_VALUE = "foo"
|
||||||
|
|
||||||
@ -266,3 +266,28 @@ async def test_mp_cuda_init():
|
|||||||
|
|
||||||
async with build_async_engine_client(args):
|
async with build_async_engine_client(args):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_engine_process_death(tmp_socket):
|
||||||
|
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
|
||||||
|
ipc_path=tmp_socket) as engine:
|
||||||
|
|
||||||
|
client = await engine.make_client()
|
||||||
|
assert client.is_running
|
||||||
|
|
||||||
|
# kill the engine process
|
||||||
|
engine.proc.kill()
|
||||||
|
|
||||||
|
# Generate call should fail
|
||||||
|
with pytest.raises(MQEngineDeadError):
|
||||||
|
async for _ in client.generate(prompt="Hello my name is",
|
||||||
|
sampling_params=SamplingParams(),
|
||||||
|
request_id=uuid.uuid4()):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# And the health check should show the engine is dead
|
||||||
|
with pytest.raises(RuntimeError, match="Engine process .* died"):
|
||||||
|
await client.check_health()
|
||||||
|
|
||||||
|
client.close()
|
||||||
|
|||||||
@ -68,7 +68,7 @@ class RemoteMQLLMEngine:
|
|||||||
|
|
||||||
async def make_client(self) -> MQLLMEngineClient:
|
async def make_client(self) -> MQLLMEngineClient:
|
||||||
engine_config = self.engine_args.create_engine_config()
|
engine_config = self.engine_args.create_engine_config()
|
||||||
client = MQLLMEngineClient(self.ipc_path, engine_config)
|
client = MQLLMEngineClient(self.ipc_path, engine_config, self.proc.pid)
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
await client.setup()
|
await client.setup()
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping,
|
|||||||
Optional, Union, cast, overload)
|
Optional, Union, cast, overload)
|
||||||
|
|
||||||
import cloudpickle
|
import cloudpickle
|
||||||
|
import psutil
|
||||||
import zmq
|
import zmq
|
||||||
import zmq.asyncio
|
import zmq.asyncio
|
||||||
from zmq import Frame # type: ignore[attr-defined]
|
from zmq import Frame # type: ignore[attr-defined]
|
||||||
@ -77,7 +78,8 @@ class MQLLMEngineClient(EngineClient):
|
|||||||
every N seconds, confirming the engine is healthy
|
every N seconds, confirming the engine is healthy
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, ipc_path: str, engine_config: EngineConfig):
|
def __init__(self, ipc_path: str, engine_config: EngineConfig,
|
||||||
|
engine_pid: int):
|
||||||
self.context = zmq.asyncio.Context()
|
self.context = zmq.asyncio.Context()
|
||||||
self._errored_with: Optional[BaseException] = None
|
self._errored_with: Optional[BaseException] = None
|
||||||
|
|
||||||
@ -115,6 +117,7 @@ class MQLLMEngineClient(EngineClient):
|
|||||||
# Loop to check health of the LLMEngine periodically.
|
# Loop to check health of the LLMEngine periodically.
|
||||||
# Started after the MQLLMEngine is ready.
|
# Started after the MQLLMEngine is ready.
|
||||||
self.health_loop: Optional[asyncio.Task] = None
|
self.health_loop: Optional[asyncio.Task] = None
|
||||||
|
self._engine_process = psutil.Process(engine_pid)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_unsupported_config(engine_args: AsyncEngineArgs):
|
def is_unsupported_config(engine_args: AsyncEngineArgs):
|
||||||
@ -131,21 +134,22 @@ class MQLLMEngineClient(EngineClient):
|
|||||||
socket.close(linger=0)
|
socket.close(linger=0)
|
||||||
|
|
||||||
async def run_heartbeat_loop(self, timeout: int):
|
async def run_heartbeat_loop(self, timeout: int):
|
||||||
"""Background loop that continually listens to the RPCServer for
|
"""Background loop that continually checks to ensure the engine process
|
||||||
heartbeats.
|
is still alive.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
if await self.heartbeat_socket.poll(timeout=timeout) == 0:
|
# Check if the engine process is running:
|
||||||
# No heartbeat was received. Set error and exit the loop
|
if not self._engine_process.is_running() or (
|
||||||
|
self._engine_process.status() == psutil.STATUS_ZOMBIE):
|
||||||
|
# NB: is_running() returns True for zombies
|
||||||
self._set_errored(
|
self._set_errored(
|
||||||
TimeoutError("No heartbeat received "
|
RuntimeError(
|
||||||
"from MQLLMEngine"))
|
f"Engine process (pid {self._engine_process.pid}) "
|
||||||
logger.debug("Shutting down MQLLMEngineClient check "
|
"died."))
|
||||||
"health loop due to timeout")
|
|
||||||
break
|
break
|
||||||
|
|
||||||
else:
|
if await self.heartbeat_socket.poll(timeout=timeout):
|
||||||
# Heartbeat received- check the message
|
# Heartbeat received- check the message
|
||||||
await self._check_success(
|
await self._check_success(
|
||||||
error_message="Heartbeat failed.",
|
error_message="Heartbeat failed.",
|
||||||
@ -156,6 +160,11 @@ class MQLLMEngineClient(EngineClient):
|
|||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.debug("Shutting down MQLLMEngineClient check health loop.")
|
logger.debug("Shutting down MQLLMEngineClient check health loop.")
|
||||||
|
|
||||||
|
except psutil.NoSuchProcess:
|
||||||
|
self._set_errored(
|
||||||
|
RuntimeError(
|
||||||
|
f"Engine process (pid {self._engine_process.pid}) died."))
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._set_errored(e)
|
self._set_errored(e)
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,5 @@
|
|||||||
import pickle
|
import pickle
|
||||||
import signal
|
import signal
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Iterator, List, Optional, Union
|
from typing import Iterator, List, Optional, Union
|
||||||
|
|
||||||
@ -21,7 +19,7 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
|
|||||||
RPCStartupRequest, RPCStartupResponse,
|
RPCStartupRequest, RPCStartupResponse,
|
||||||
RPCUProfileRequest)
|
RPCUProfileRequest)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.envs import VLLM_RPC_TIMEOUT, VLLM_USE_V1
|
from vllm.envs import VLLM_USE_V1
|
||||||
from vllm.executor.gpu_executor import GPUExecutor
|
from vllm.executor.gpu_executor import GPUExecutor
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
@ -108,20 +106,6 @@ class MQLLMEngine:
|
|||||||
# Error state.
|
# Error state.
|
||||||
self._errored_with: Optional[BaseException] = None
|
self._errored_with: Optional[BaseException] = None
|
||||||
|
|
||||||
# Heartbeat thread
|
|
||||||
self.heartbeat_thread = threading.Thread(target=self._heartbeat_loop,
|
|
||||||
daemon=True)
|
|
||||||
self._heartbeat_stop_event = threading.Event()
|
|
||||||
# The heartbeat needs to be faster than what the client will wait for
|
|
||||||
# The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds
|
|
||||||
self.heartbeat_interval_seconds = VLLM_RPC_TIMEOUT / 5000.0
|
|
||||||
|
|
||||||
self._last_alive_time = time.time()
|
|
||||||
# The heartbeats can tolerate a long period of the engine chugging
|
|
||||||
# away at a generation request.
|
|
||||||
# The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds
|
|
||||||
self.last_alive_threshold = VLLM_RPC_TIMEOUT * 3.0 / 1000.0
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dead_error(self) -> BaseException:
|
def dead_error(self) -> BaseException:
|
||||||
if self._errored_with is not None:
|
if self._errored_with is not None:
|
||||||
@ -157,8 +141,6 @@ class MQLLMEngine:
|
|||||||
try:
|
try:
|
||||||
logger.debug("Starting Startup Loop.")
|
logger.debug("Starting Startup Loop.")
|
||||||
self.run_startup_loop()
|
self.run_startup_loop()
|
||||||
logger.debug("Starting heartbeat thread")
|
|
||||||
self.heartbeat_thread.start()
|
|
||||||
logger.debug("Starting Engine Loop.")
|
logger.debug("Starting Engine Loop.")
|
||||||
self.run_engine_loop()
|
self.run_engine_loop()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -172,7 +154,6 @@ class MQLLMEngine:
|
|||||||
def cleanup(self):
|
def cleanup(self):
|
||||||
"""Cleanup zeromq state on shutdown."""
|
"""Cleanup zeromq state on shutdown."""
|
||||||
# Closes all sockets and destroys context.
|
# Closes all sockets and destroys context.
|
||||||
self._heartbeat_stop_event.set()
|
|
||||||
self.ctx.destroy(linger=0)
|
self.ctx.destroy(linger=0)
|
||||||
del self.engine
|
del self.engine
|
||||||
|
|
||||||
@ -211,11 +192,12 @@ class MQLLMEngine:
|
|||||||
"""Core busy loop of the LLMEngine."""
|
"""Core busy loop of the LLMEngine."""
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
self._alive()
|
|
||||||
if not self.engine.has_unfinished_requests():
|
if not self.engine.has_unfinished_requests():
|
||||||
# Poll until there is work to do.
|
# Poll until there is work to do.
|
||||||
while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
|
while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
|
||||||
self._alive()
|
# When there's no work, check on engine health and send
|
||||||
|
# health status back to client
|
||||||
|
self._health_check()
|
||||||
self.engine.do_log_stats()
|
self.engine.do_log_stats()
|
||||||
logger.debug("Waiting for new requests in engine loop.")
|
logger.debug("Waiting for new requests in engine loop.")
|
||||||
|
|
||||||
@ -314,32 +296,16 @@ class MQLLMEngine:
|
|||||||
if self.log_requests:
|
if self.log_requests:
|
||||||
logger.info("Aborted request %s.", request.request_id)
|
logger.info("Aborted request %s.", request.request_id)
|
||||||
|
|
||||||
def _heartbeat_loop(self):
|
def _health_check(self):
|
||||||
while not self._heartbeat_stop_event.wait(
|
|
||||||
timeout=self.heartbeat_interval_seconds):
|
|
||||||
# Loops until the stop event is set
|
|
||||||
self._heartbeat()
|
|
||||||
|
|
||||||
logger.debug("Exiting MQLLMEngine heartbeat thread")
|
|
||||||
|
|
||||||
def _heartbeat(self):
|
|
||||||
# Send unhealthy if engine has already errored
|
# Send unhealthy if engine has already errored
|
||||||
if self._errored_with is not None:
|
if self._errored_with is not None:
|
||||||
self._send_unhealthy(self._errored_with)
|
self._send_unhealthy(self._errored_with)
|
||||||
|
try:
|
||||||
# Check for life of the main loop
|
self.engine.check_health()
|
||||||
elif time.time() - self._last_alive_time > self.last_alive_threshold:
|
self._send_healthy()
|
||||||
self._send_unhealthy(RuntimeError("Engine loop has died"))
|
except Exception as e:
|
||||||
|
self._set_errored(e)
|
||||||
else:
|
self._send_unhealthy(e)
|
||||||
# Otherwise- check health of the engine
|
|
||||||
# self.engine.check_health() raises on unhealthy
|
|
||||||
try:
|
|
||||||
self.engine.check_health()
|
|
||||||
self._send_healthy()
|
|
||||||
except Exception as e:
|
|
||||||
self._set_errored(e)
|
|
||||||
self._send_unhealthy(e)
|
|
||||||
|
|
||||||
def _send_outputs(self, outputs: REQUEST_OUTPUTS_T):
|
def _send_outputs(self, outputs: REQUEST_OUTPUTS_T):
|
||||||
"""Send List of RequestOutput to RPCClient."""
|
"""Send List of RequestOutput to RPCClient."""
|
||||||
@ -369,9 +335,6 @@ class MQLLMEngine:
|
|||||||
if self._errored_with is None:
|
if self._errored_with is None:
|
||||||
self._errored_with = e
|
self._errored_with = e
|
||||||
|
|
||||||
def _alive(self):
|
|
||||||
self._last_alive_time = time.time()
|
|
||||||
|
|
||||||
def start_profile(self) -> None:
|
def start_profile(self) -> None:
|
||||||
if type(self.engine.model_executor) is GPUExecutor:
|
if type(self.engine.model_executor) is GPUExecutor:
|
||||||
self.engine.model_executor.start_profile()
|
self.engine.model_executor.start_profile()
|
||||||
|
|||||||
@ -176,13 +176,16 @@ async def build_async_engine_client_from_engine_args(
|
|||||||
UsageContext.OPENAI_API_SERVER,
|
UsageContext.OPENAI_API_SERVER,
|
||||||
ipc_path))
|
ipc_path))
|
||||||
engine_process.start()
|
engine_process.start()
|
||||||
logger.info("Started engine process with PID %d", engine_process.pid)
|
engine_pid = engine_process.pid
|
||||||
|
assert engine_pid is not None, "Engine process failed to start"
|
||||||
|
logger.info("Started engine process with PID %d", engine_pid)
|
||||||
|
|
||||||
# Build RPCClient, which conforms to EngineClient Protocol.
|
# Build RPCClient, which conforms to EngineClient Protocol.
|
||||||
# NOTE: Actually, this is not true yet. We still need to support
|
# NOTE: Actually, this is not true yet. We still need to support
|
||||||
# embedding models via RPC (see TODO above)
|
# embedding models via RPC (see TODO above)
|
||||||
engine_config = engine_args.create_engine_config()
|
engine_config = engine_args.create_engine_config()
|
||||||
mp_engine_client = MQLLMEngineClient(ipc_path, engine_config)
|
mp_engine_client = MQLLMEngineClient(ipc_path, engine_config,
|
||||||
|
engine_pid)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user