Add health check, make async Engine more robust (#3015)

Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
Antoni Baum 2024-03-04 14:01:40 -08:00 committed by GitHub
parent 22de45235c
commit ff578cae54
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 138 additions and 65 deletions

View File

@ -25,12 +25,8 @@ class MockEngine:
return [RequestOutput( return [RequestOutput(
request_id=self.request_id)] if self.request_id else [] request_id=self.request_id)] if self.request_id else []
async def encode_request_async( async def encode_request_async(self, *args, **kwargs):
self, pass
*args,
**kwargs,
):
return [1]
def generate(self, request_id): def generate(self, request_id):
self.request_id = request_id self.request_id = request_id
@ -43,13 +39,16 @@ class MockEngine:
self.add_request_calls += 1 self.add_request_calls += 1
async def add_request_async(self, **kwargs): async def add_request_async(self, **kwargs):
del kwargs # Unused
self.add_request_calls += 1 self.add_request_calls += 1
return
def abort_request(self, request_id): def abort_request(self, request_id):
del request_id # Unused del request_id # Unused
self.abort_request_calls += 1 self.abort_request_calls += 1
def has_unfinished_requests(self):
return self.request_id is not None
class MockAsyncLLMEngine(AsyncLLMEngine): class MockAsyncLLMEngine(AsyncLLMEngine):
@ -72,20 +71,21 @@ async def test_new_requests_event():
await engine.add_request("2", "", None) await engine.add_request("2", "", None)
engine.engine.generate("2") engine.engine.generate("2")
await asyncio.sleep(0) await asyncio.sleep(0)
await asyncio.sleep(0)
assert engine.engine.add_request_calls == 2 assert engine.engine.add_request_calls == 2
assert engine.engine.step_calls == 2 assert engine.engine.step_calls >= 2
await asyncio.sleep(0) await asyncio.sleep(0.001)
assert engine.engine.step_calls == 3 assert engine.engine.step_calls >= 3
engine.engine.stop_generating() engine.engine.stop_generating()
await asyncio.sleep(0) await asyncio.sleep(0.001)
assert engine.engine.step_calls == 4 old_step_calls = engine.engine.step_calls
await asyncio.sleep(0) await asyncio.sleep(0.001)
assert engine.engine.step_calls == 4 assert engine.engine.step_calls == old_step_calls
await engine.add_request("3", "", None) await engine.add_request("3", "", None)
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
assert engine.engine.add_request_calls == 3 assert engine.engine.add_request_calls == 3
assert engine.engine.step_calls == 5 assert engine.engine.step_calls == old_step_calls + 1
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
assert engine.engine.add_request_calls == 3 assert engine.engine.add_request_calls == 3
assert engine.engine.step_calls == 5 assert engine.engine.step_calls == old_step_calls + 1

View File

@ -4,25 +4,14 @@ from vllm.engine.async_llm_engine import RequestTracker
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
class DummyEvent: @pytest.mark.asyncio
async def test_request_tracker():
def __init__(self):
self.flag = False
def set(self):
self.flag = True
def clear(self):
self.flag = False
def test_request_tracker():
tracker = RequestTracker() tracker = RequestTracker()
tracker.new_requests_event = DummyEvent()
stream_1 = tracker.add_request("1") stream_1 = tracker.add_request("1")
assert tracker.new_requests_event.flag assert tracker.new_requests_event.is_set()
await tracker.wait_for_new_requests()
new, finished = tracker.get_new_and_finished_requests() new, finished = tracker.get_new_and_finished_requests()
assert not tracker.new_requests_event.flag assert not tracker.new_requests_event.is_set()
assert len(new) == 1 assert len(new) == 1
assert new[0]["request_id"] == "1" assert new[0]["request_id"] == "1"
assert not finished assert not finished
@ -30,9 +19,10 @@ def test_request_tracker():
stream_2 = tracker.add_request("2") stream_2 = tracker.add_request("2")
stream_3 = tracker.add_request("3") stream_3 = tracker.add_request("3")
assert tracker.new_requests_event.flag assert tracker.new_requests_event.is_set()
await tracker.wait_for_new_requests()
new, finished = tracker.get_new_and_finished_requests() new, finished = tracker.get_new_and_finished_requests()
assert not tracker.new_requests_event.flag assert not tracker.new_requests_event.is_set()
assert len(new) == 2 assert len(new) == 2
assert new[0]["request_id"] == "2" assert new[0]["request_id"] == "2"
assert new[1]["request_id"] == "3" assert new[1]["request_id"] == "3"
@ -43,7 +33,7 @@ def test_request_tracker():
# request_ids must be unique # request_ids must be unique
with pytest.raises(KeyError): with pytest.raises(KeyError):
tracker.add_request("1") tracker.add_request("1")
assert not tracker.new_requests_event.flag assert not tracker.new_requests_event.is_set()
tracker.abort_request("1") tracker.abort_request("1")
new, finished = tracker.get_new_and_finished_requests() new, finished = tracker.get_new_and_finished_requests()
@ -54,7 +44,8 @@ def test_request_tracker():
stream_4 = tracker.add_request("4") stream_4 = tracker.add_request("4")
tracker.abort_request("4") tracker.abort_request("4")
assert tracker.new_requests_event.flag assert tracker.new_requests_event.is_set()
await tracker.wait_for_new_requests()
new, finished = tracker.get_new_and_finished_requests() new, finished = tracker.get_new_and_finished_requests()
assert len(finished) == 1 assert len(finished) == 1
assert "4" in finished assert "4" in finished
@ -62,11 +53,12 @@ def test_request_tracker():
assert stream_4.finished assert stream_4.finished
stream_5 = tracker.add_request("5") stream_5 = tracker.add_request("5")
assert tracker.new_requests_event.flag assert tracker.new_requests_event.is_set()
tracker.process_request_output( tracker.process_request_output(
RequestOutput("2", "output", [], [], [], bool(finished))) RequestOutput("2", "output", [], [], [], finished=True))
await tracker.wait_for_new_requests()
new, finished = tracker.get_new_and_finished_requests() new, finished = tracker.get_new_and_finished_requests()
assert not tracker.new_requests_event.flag assert not tracker.new_requests_event.is_set()
assert len(finished) == 1 assert len(finished) == 1
assert "2" in finished assert "2" in finished
assert len(new) == 1 assert len(new) == 1

View File

@ -1,8 +1,9 @@
import asyncio import asyncio
import os
import time import time
from functools import partial from functools import partial
from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type, from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type,
Union, AsyncIterator) Union, AsyncIterator, Callable)
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.config import ModelConfig from vllm.config import ModelConfig
@ -14,28 +15,31 @@ from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
logger = init_logger(__name__) logger = init_logger(__name__)
ENGINE_ITERATION_TIMEOUT_S = int(
os.environ.get("VLLM_ENGINE_ITERATION_TIMEOUT_S", "60"))
class AsyncEngineDeadError(RuntimeError): class AsyncEngineDeadError(RuntimeError):
pass pass
def _raise_exception_on_finish(task: asyncio.Task, def _raise_exception_on_finish(
request_tracker: "RequestTracker") -> None: task: asyncio.Task, error_callback: Callable[[Exception],
None]) -> None:
msg = ("Task finished unexpectedly. This should never happen! " msg = ("Task finished unexpectedly. This should never happen! "
"Please open an issue on Github.") "Please open an issue on Github.")
exception = None
try: try:
try: task.result()
task.result() # NOTE: This will be thrown if task exits normally (which it should not)
except asyncio.CancelledError:
return
except Exception as exc:
raise AsyncEngineDeadError(
msg + " See stack trace above for the actual cause.") from exc
raise AsyncEngineDeadError(msg) raise AsyncEngineDeadError(msg)
except Exception as exc: except Exception as e:
request_tracker.propagate_exception(exc) exception = e
raise exc logger.error("Engine background task failed", exc_info=e)
error_callback(exception)
raise AsyncEngineDeadError(
msg + " See stack trace above for the actual cause.") from e
class AsyncStream: class AsyncStream:
@ -78,13 +82,13 @@ class RequestTracker:
self._finished_requests: asyncio.Queue[str] = asyncio.Queue() self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
self._new_requests: asyncio.Queue[Tuple[AsyncStream, self._new_requests: asyncio.Queue[Tuple[AsyncStream,
dict]] = asyncio.Queue() dict]] = asyncio.Queue()
self.new_requests_event = None self.new_requests_event = asyncio.Event()
def __contains__(self, item): def __contains__(self, item):
return item in self._request_streams return item in self._request_streams
def init_event(self): def __len__(self) -> int:
self.new_requests_event = asyncio.Event() return len(self._request_streams)
def propagate_exception(self, def propagate_exception(self,
exc: Exception, exc: Exception,
@ -93,9 +97,11 @@ class RequestTracker:
(all if request_id is None).""" (all if request_id is None)."""
if request_id is not None: if request_id is not None:
self._request_streams[request_id].put(exc) self._request_streams[request_id].put(exc)
self.abort_request(request_id)
else: else:
for stream in self._request_streams.values(): for rid, stream in self._request_streams.items():
stream.put(exc) stream.put(exc)
self.abort_request(rid)
def process_request_output(self, def process_request_output(self,
request_output: RequestOutput, request_output: RequestOutput,
@ -172,12 +178,15 @@ class RequestTracker:
self._request_streams[stream.request_id] = stream self._request_streams[stream.request_id] = stream
new_requests.append(new_request) new_requests.append(new_request)
self.new_requests_event.clear()
return new_requests, finished_requests return new_requests, finished_requests
async def wait_for_new_requests(self): async def wait_for_new_requests(self):
await self.new_requests_event.wait() if not self.has_new_requests():
await self.new_requests_event.wait()
self.new_requests_event.clear()
def has_new_requests(self):
return not self._new_requests.empty()
class _AsyncLLMEngine(LLMEngine): class _AsyncLLMEngine(LLMEngine):
@ -285,6 +294,10 @@ class _AsyncLLMEngine(LLMEngine):
all_outputs = await asyncio.gather(*coros) all_outputs = await asyncio.gather(*coros)
return all_outputs return all_outputs
async def check_health_async(self):
"""Raises an error if engine is unhealthy."""
self._check_if_any_actor_is_dead()
class AsyncLLMEngine: class AsyncLLMEngine:
"""An asynchronous wrapper for LLMEngine. """An asynchronous wrapper for LLMEngine.
@ -335,27 +348,48 @@ class AsyncLLMEngine:
# collected # collected
self._background_loop_unshielded = None self._background_loop_unshielded = None
self.start_engine_loop = start_engine_loop self.start_engine_loop = start_engine_loop
self._request_tracker = RequestTracker() self._request_tracker: Optional[RequestTracker] = None
self._errored_with: Optional[BaseException] = None
@property @property
def is_running(self) -> bool: def is_running(self) -> bool:
return (self.background_loop is not None return (self.background_loop is not None
and not self.background_loop.done()) and not self._background_loop_unshielded.done())
@property
def is_stopped(self) -> bool:
return self.errored or (self.background_loop is not None
and self._background_loop_unshielded.done())
@property
def errored(self) -> bool:
return self._errored_with is not None
def set_errored(self, exc: Exception) -> None:
self._errored_with = exc
def _error_callback(self, exc: Exception) -> None:
self.set_errored(exc)
self._request_tracker.propagate_exception(exc)
def get_tokenizer(self): def get_tokenizer(self):
return self.engine.tokenizer.tokenizer return self.engine.tokenizer.tokenizer
def start_background_loop(self) -> None: def start_background_loop(self) -> None:
"""Start the background loop.""" """Start the background loop."""
if self.errored:
raise AsyncEngineDeadError(
"Background loop has errored already.") from self._errored_with
if self.is_running: if self.is_running:
raise RuntimeError("Background loop is already running.") raise RuntimeError("Background loop is already running.")
self._request_tracker.init_event() # Initialize the RequestTracker here so it uses the right event loop.
self._request_tracker = RequestTracker()
self._background_loop_unshielded = asyncio.get_event_loop( self._background_loop_unshielded = asyncio.get_event_loop(
).create_task(self.run_engine_loop()) ).create_task(self.run_engine_loop())
self._background_loop_unshielded.add_done_callback( self._background_loop_unshielded.add_done_callback(
partial(_raise_exception_on_finish, partial(_raise_exception_on_finish,
request_tracker=self._request_tracker)) error_callback=self._error_callback))
self.background_loop = asyncio.shield(self._background_loop_unshielded) self.background_loop = asyncio.shield(self._background_loop_unshielded)
def _init_engine(self, *args, def _init_engine(self, *args,
@ -423,12 +457,23 @@ class AsyncLLMEngine:
self.engine.abort_request(request_ids) self.engine.abort_request(request_ids)
async def run_engine_loop(self): async def run_engine_loop(self):
# Initialize the RequestTracker here so it uses the right event loop.
has_requests_in_progress = False has_requests_in_progress = False
while True: while True:
if not has_requests_in_progress: if not has_requests_in_progress:
logger.debug("Waiting for new requests...")
await self._request_tracker.wait_for_new_requests() await self._request_tracker.wait_for_new_requests()
has_requests_in_progress = await self.engine_step() logger.debug("Got new requests!")
# Abort if iteration takes too long due to unrecoverable errors
# (eg. NCCL timeouts).
try:
has_requests_in_progress = await asyncio.wait_for(
self.engine_step(), ENGINE_ITERATION_TIMEOUT_S)
except asyncio.TimeoutError as exc:
logger.error(
"Engine iteration timed out. This should never happen!")
self.set_errored(exc)
raise
await asyncio.sleep(0) await asyncio.sleep(0)
async def add_request( async def add_request(
@ -647,3 +692,19 @@ class AsyncLLMEngine:
await self.engine.do_log_stats.remote() await self.engine.do_log_stats.remote()
else: else:
self.engine.do_log_stats() self.engine.do_log_stats()
async def check_health(self):
"""Raises an error if engine is unhealthy."""
t = time.perf_counter()
logger.debug("Starting health check...")
if self.is_stopped:
raise AsyncEngineDeadError("Background loop is stopped.")
if self.engine_use_ray:
try:
await self.engine.check_health.remote()
except ray.exceptions.RayActorError as e:
raise RuntimeError("Engine is dead.") from e
else:
await self.engine.check_health_async()
logger.debug(f"Health check took {time.perf_counter()-t}s")

View File

@ -1119,3 +1119,23 @@ class LLMEngine:
for worker in self.workers for worker in self.workers
]) ])
return forward_dag.experimental_compile() return forward_dag.experimental_compile()
def check_health(self) -> None:
"""Raises an error if engine is unhealthy."""
self._check_if_any_actor_is_dead()
def _check_if_any_actor_is_dead(self):
if not self.parallel_config.worker_use_ray:
return
if not self.workers:
return
dead_actors = []
for actor in self.workers:
actor_state = ray.state.actors(actor._ray_actor_id.hex()) # pylint: disable=protected-access
if actor_state["State"] == "DEAD":
dead_actors.append(actor)
if dead_actors:
raise RuntimeError("At least one Worker is dead. "
f"Dead Workers: {dead_actors}. ")