Add health check, make async Engine more robust (#3015)

Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
Antoni Baum 2024-03-04 14:01:40 -08:00 committed by GitHub
parent 22de45235c
commit ff578cae54
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 138 additions and 65 deletions

View File

@ -25,12 +25,8 @@ class MockEngine:
return [RequestOutput(
request_id=self.request_id)] if self.request_id else []
async def encode_request_async(
self,
*args,
**kwargs,
):
return [1]
async def encode_request_async(self, *args, **kwargs):
pass
def generate(self, request_id):
self.request_id = request_id
@ -43,13 +39,16 @@ class MockEngine:
self.add_request_calls += 1
async def add_request_async(self, **kwargs):
del kwargs # Unused
self.add_request_calls += 1
return
def abort_request(self, request_id):
del request_id # Unused
self.abort_request_calls += 1
def has_unfinished_requests(self):
return self.request_id is not None
class MockAsyncLLMEngine(AsyncLLMEngine):
@ -72,20 +71,21 @@ async def test_new_requests_event():
await engine.add_request("2", "", None)
engine.engine.generate("2")
await asyncio.sleep(0)
await asyncio.sleep(0)
assert engine.engine.add_request_calls == 2
assert engine.engine.step_calls == 2
await asyncio.sleep(0)
assert engine.engine.step_calls == 3
assert engine.engine.step_calls >= 2
await asyncio.sleep(0.001)
assert engine.engine.step_calls >= 3
engine.engine.stop_generating()
await asyncio.sleep(0)
assert engine.engine.step_calls == 4
await asyncio.sleep(0)
assert engine.engine.step_calls == 4
await asyncio.sleep(0.001)
old_step_calls = engine.engine.step_calls
await asyncio.sleep(0.001)
assert engine.engine.step_calls == old_step_calls
await engine.add_request("3", "", None)
await asyncio.sleep(0.01)
assert engine.engine.add_request_calls == 3
assert engine.engine.step_calls == 5
assert engine.engine.step_calls == old_step_calls + 1
await asyncio.sleep(0.01)
assert engine.engine.add_request_calls == 3
assert engine.engine.step_calls == 5
assert engine.engine.step_calls == old_step_calls + 1

View File

@ -4,25 +4,14 @@ from vllm.engine.async_llm_engine import RequestTracker
from vllm.outputs import RequestOutput
class DummyEvent:
def __init__(self):
self.flag = False
def set(self):
self.flag = True
def clear(self):
self.flag = False
def test_request_tracker():
@pytest.mark.asyncio
async def test_request_tracker():
tracker = RequestTracker()
tracker.new_requests_event = DummyEvent()
stream_1 = tracker.add_request("1")
assert tracker.new_requests_event.flag
assert tracker.new_requests_event.is_set()
await tracker.wait_for_new_requests()
new, finished = tracker.get_new_and_finished_requests()
assert not tracker.new_requests_event.flag
assert not tracker.new_requests_event.is_set()
assert len(new) == 1
assert new[0]["request_id"] == "1"
assert not finished
@ -30,9 +19,10 @@ def test_request_tracker():
stream_2 = tracker.add_request("2")
stream_3 = tracker.add_request("3")
assert tracker.new_requests_event.flag
assert tracker.new_requests_event.is_set()
await tracker.wait_for_new_requests()
new, finished = tracker.get_new_and_finished_requests()
assert not tracker.new_requests_event.flag
assert not tracker.new_requests_event.is_set()
assert len(new) == 2
assert new[0]["request_id"] == "2"
assert new[1]["request_id"] == "3"
@ -43,7 +33,7 @@ def test_request_tracker():
# request_ids must be unique
with pytest.raises(KeyError):
tracker.add_request("1")
assert not tracker.new_requests_event.flag
assert not tracker.new_requests_event.is_set()
tracker.abort_request("1")
new, finished = tracker.get_new_and_finished_requests()
@ -54,7 +44,8 @@ def test_request_tracker():
stream_4 = tracker.add_request("4")
tracker.abort_request("4")
assert tracker.new_requests_event.flag
assert tracker.new_requests_event.is_set()
await tracker.wait_for_new_requests()
new, finished = tracker.get_new_and_finished_requests()
assert len(finished) == 1
assert "4" in finished
@ -62,11 +53,12 @@ def test_request_tracker():
assert stream_4.finished
stream_5 = tracker.add_request("5")
assert tracker.new_requests_event.flag
assert tracker.new_requests_event.is_set()
tracker.process_request_output(
RequestOutput("2", "output", [], [], [], bool(finished)))
RequestOutput("2", "output", [], [], [], finished=True))
await tracker.wait_for_new_requests()
new, finished = tracker.get_new_and_finished_requests()
assert not tracker.new_requests_event.flag
assert not tracker.new_requests_event.is_set()
assert len(finished) == 1
assert "2" in finished
assert len(new) == 1

View File

@ -1,8 +1,9 @@
import asyncio
import os
import time
from functools import partial
from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type,
Union, AsyncIterator)
Union, AsyncIterator, Callable)
from vllm.lora.request import LoRARequest
from vllm.config import ModelConfig
@ -14,28 +15,31 @@ from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
logger = init_logger(__name__)
ENGINE_ITERATION_TIMEOUT_S = int(
os.environ.get("VLLM_ENGINE_ITERATION_TIMEOUT_S", "60"))
class AsyncEngineDeadError(RuntimeError):
pass
def _raise_exception_on_finish(task: asyncio.Task,
request_tracker: "RequestTracker") -> None:
def _raise_exception_on_finish(
task: asyncio.Task, error_callback: Callable[[Exception],
None]) -> None:
msg = ("Task finished unexpectedly. This should never happen! "
"Please open an issue on Github.")
try:
exception = None
try:
task.result()
except asyncio.CancelledError:
return
except Exception as exc:
raise AsyncEngineDeadError(
msg + " See stack trace above for the actual cause.") from exc
# NOTE: This will be thrown if task exits normally (which it should not)
raise AsyncEngineDeadError(msg)
except Exception as exc:
request_tracker.propagate_exception(exc)
raise exc
except Exception as e:
exception = e
logger.error("Engine background task failed", exc_info=e)
error_callback(exception)
raise AsyncEngineDeadError(
msg + " See stack trace above for the actual cause.") from e
class AsyncStream:
@ -78,13 +82,13 @@ class RequestTracker:
self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
self._new_requests: asyncio.Queue[Tuple[AsyncStream,
dict]] = asyncio.Queue()
self.new_requests_event = None
self.new_requests_event = asyncio.Event()
def __contains__(self, item):
return item in self._request_streams
def init_event(self):
self.new_requests_event = asyncio.Event()
def __len__(self) -> int:
return len(self._request_streams)
def propagate_exception(self,
exc: Exception,
@ -93,9 +97,11 @@ class RequestTracker:
(all if request_id is None)."""
if request_id is not None:
self._request_streams[request_id].put(exc)
self.abort_request(request_id)
else:
for stream in self._request_streams.values():
for rid, stream in self._request_streams.items():
stream.put(exc)
self.abort_request(rid)
def process_request_output(self,
request_output: RequestOutput,
@ -172,12 +178,15 @@ class RequestTracker:
self._request_streams[stream.request_id] = stream
new_requests.append(new_request)
self.new_requests_event.clear()
return new_requests, finished_requests
async def wait_for_new_requests(self):
if not self.has_new_requests():
await self.new_requests_event.wait()
self.new_requests_event.clear()
def has_new_requests(self):
return not self._new_requests.empty()
class _AsyncLLMEngine(LLMEngine):
@ -285,6 +294,10 @@ class _AsyncLLMEngine(LLMEngine):
all_outputs = await asyncio.gather(*coros)
return all_outputs
async def check_health_async(self):
"""Raises an error if engine is unhealthy."""
self._check_if_any_actor_is_dead()
class AsyncLLMEngine:
"""An asynchronous wrapper for LLMEngine.
@ -335,27 +348,48 @@ class AsyncLLMEngine:
# collected
self._background_loop_unshielded = None
self.start_engine_loop = start_engine_loop
self._request_tracker = RequestTracker()
self._request_tracker: Optional[RequestTracker] = None
self._errored_with: Optional[BaseException] = None
@property
def is_running(self) -> bool:
return (self.background_loop is not None
and not self.background_loop.done())
and not self._background_loop_unshielded.done())
@property
def is_stopped(self) -> bool:
return self.errored or (self.background_loop is not None
and self._background_loop_unshielded.done())
@property
def errored(self) -> bool:
return self._errored_with is not None
def set_errored(self, exc: Exception) -> None:
self._errored_with = exc
def _error_callback(self, exc: Exception) -> None:
self.set_errored(exc)
self._request_tracker.propagate_exception(exc)
def get_tokenizer(self):
return self.engine.tokenizer.tokenizer
def start_background_loop(self) -> None:
"""Start the background loop."""
if self.errored:
raise AsyncEngineDeadError(
"Background loop has errored already.") from self._errored_with
if self.is_running:
raise RuntimeError("Background loop is already running.")
self._request_tracker.init_event()
# Initialize the RequestTracker here so it uses the right event loop.
self._request_tracker = RequestTracker()
self._background_loop_unshielded = asyncio.get_event_loop(
).create_task(self.run_engine_loop())
self._background_loop_unshielded.add_done_callback(
partial(_raise_exception_on_finish,
request_tracker=self._request_tracker))
error_callback=self._error_callback))
self.background_loop = asyncio.shield(self._background_loop_unshielded)
def _init_engine(self, *args,
@ -423,12 +457,23 @@ class AsyncLLMEngine:
self.engine.abort_request(request_ids)
async def run_engine_loop(self):
# Initialize the RequestTracker here so it uses the right event loop.
has_requests_in_progress = False
while True:
if not has_requests_in_progress:
logger.debug("Waiting for new requests...")
await self._request_tracker.wait_for_new_requests()
has_requests_in_progress = await self.engine_step()
logger.debug("Got new requests!")
# Abort if iteration takes too long due to unrecoverable errors
# (eg. NCCL timeouts).
try:
has_requests_in_progress = await asyncio.wait_for(
self.engine_step(), ENGINE_ITERATION_TIMEOUT_S)
except asyncio.TimeoutError as exc:
logger.error(
"Engine iteration timed out. This should never happen!")
self.set_errored(exc)
raise
await asyncio.sleep(0)
async def add_request(
@ -647,3 +692,19 @@ class AsyncLLMEngine:
await self.engine.do_log_stats.remote()
else:
self.engine.do_log_stats()
async def check_health(self):
"""Raises an error if engine is unhealthy."""
t = time.perf_counter()
logger.debug("Starting health check...")
if self.is_stopped:
raise AsyncEngineDeadError("Background loop is stopped.")
if self.engine_use_ray:
try:
await self.engine.check_health.remote()
except ray.exceptions.RayActorError as e:
raise RuntimeError("Engine is dead.") from e
else:
await self.engine.check_health_async()
logger.debug(f"Health check took {time.perf_counter()-t}s")

View File

@ -1119,3 +1119,23 @@ class LLMEngine:
for worker in self.workers
])
return forward_dag.experimental_compile()
def check_health(self) -> None:
"""Raises an error if engine is unhealthy."""
self._check_if_any_actor_is_dead()
def _check_if_any_actor_is_dead(self):
if not self.parallel_config.worker_use_ray:
return
if not self.workers:
return
dead_actors = []
for actor in self.workers:
actor_state = ray.state.actors(actor._ray_actor_id.hex()) # pylint: disable=protected-access
if actor_state["State"] == "DEAD":
dead_actors.append(actor)
if dead_actors:
raise RuntimeError("At least one Worker is dead. "
f"Dead Workers: {dead_actors}. ")