mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:05:01 +08:00
Add health check, make async Engine more robust (#3015)
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
parent
22de45235c
commit
ff578cae54
@ -25,12 +25,8 @@ class MockEngine:
|
||||
return [RequestOutput(
|
||||
request_id=self.request_id)] if self.request_id else []
|
||||
|
||||
async def encode_request_async(
|
||||
self,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
return [1]
|
||||
async def encode_request_async(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def generate(self, request_id):
|
||||
self.request_id = request_id
|
||||
@ -43,13 +39,16 @@ class MockEngine:
|
||||
self.add_request_calls += 1
|
||||
|
||||
async def add_request_async(self, **kwargs):
|
||||
del kwargs # Unused
|
||||
self.add_request_calls += 1
|
||||
return
|
||||
|
||||
def abort_request(self, request_id):
|
||||
del request_id # Unused
|
||||
self.abort_request_calls += 1
|
||||
|
||||
def has_unfinished_requests(self):
|
||||
return self.request_id is not None
|
||||
|
||||
|
||||
class MockAsyncLLMEngine(AsyncLLMEngine):
|
||||
|
||||
@ -72,20 +71,21 @@ async def test_new_requests_event():
|
||||
await engine.add_request("2", "", None)
|
||||
engine.engine.generate("2")
|
||||
await asyncio.sleep(0)
|
||||
await asyncio.sleep(0)
|
||||
assert engine.engine.add_request_calls == 2
|
||||
assert engine.engine.step_calls == 2
|
||||
await asyncio.sleep(0)
|
||||
assert engine.engine.step_calls == 3
|
||||
assert engine.engine.step_calls >= 2
|
||||
await asyncio.sleep(0.001)
|
||||
assert engine.engine.step_calls >= 3
|
||||
engine.engine.stop_generating()
|
||||
await asyncio.sleep(0)
|
||||
assert engine.engine.step_calls == 4
|
||||
await asyncio.sleep(0)
|
||||
assert engine.engine.step_calls == 4
|
||||
await asyncio.sleep(0.001)
|
||||
old_step_calls = engine.engine.step_calls
|
||||
await asyncio.sleep(0.001)
|
||||
assert engine.engine.step_calls == old_step_calls
|
||||
|
||||
await engine.add_request("3", "", None)
|
||||
await asyncio.sleep(0.01)
|
||||
assert engine.engine.add_request_calls == 3
|
||||
assert engine.engine.step_calls == 5
|
||||
assert engine.engine.step_calls == old_step_calls + 1
|
||||
await asyncio.sleep(0.01)
|
||||
assert engine.engine.add_request_calls == 3
|
||||
assert engine.engine.step_calls == 5
|
||||
assert engine.engine.step_calls == old_step_calls + 1
|
||||
|
||||
@ -4,25 +4,14 @@ from vllm.engine.async_llm_engine import RequestTracker
|
||||
from vllm.outputs import RequestOutput
|
||||
|
||||
|
||||
class DummyEvent:
|
||||
|
||||
def __init__(self):
|
||||
self.flag = False
|
||||
|
||||
def set(self):
|
||||
self.flag = True
|
||||
|
||||
def clear(self):
|
||||
self.flag = False
|
||||
|
||||
|
||||
def test_request_tracker():
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_tracker():
|
||||
tracker = RequestTracker()
|
||||
tracker.new_requests_event = DummyEvent()
|
||||
stream_1 = tracker.add_request("1")
|
||||
assert tracker.new_requests_event.flag
|
||||
assert tracker.new_requests_event.is_set()
|
||||
await tracker.wait_for_new_requests()
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
assert not tracker.new_requests_event.flag
|
||||
assert not tracker.new_requests_event.is_set()
|
||||
assert len(new) == 1
|
||||
assert new[0]["request_id"] == "1"
|
||||
assert not finished
|
||||
@ -30,9 +19,10 @@ def test_request_tracker():
|
||||
|
||||
stream_2 = tracker.add_request("2")
|
||||
stream_3 = tracker.add_request("3")
|
||||
assert tracker.new_requests_event.flag
|
||||
assert tracker.new_requests_event.is_set()
|
||||
await tracker.wait_for_new_requests()
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
assert not tracker.new_requests_event.flag
|
||||
assert not tracker.new_requests_event.is_set()
|
||||
assert len(new) == 2
|
||||
assert new[0]["request_id"] == "2"
|
||||
assert new[1]["request_id"] == "3"
|
||||
@ -43,7 +33,7 @@ def test_request_tracker():
|
||||
# request_ids must be unique
|
||||
with pytest.raises(KeyError):
|
||||
tracker.add_request("1")
|
||||
assert not tracker.new_requests_event.flag
|
||||
assert not tracker.new_requests_event.is_set()
|
||||
|
||||
tracker.abort_request("1")
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
@ -54,7 +44,8 @@ def test_request_tracker():
|
||||
|
||||
stream_4 = tracker.add_request("4")
|
||||
tracker.abort_request("4")
|
||||
assert tracker.new_requests_event.flag
|
||||
assert tracker.new_requests_event.is_set()
|
||||
await tracker.wait_for_new_requests()
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
assert len(finished) == 1
|
||||
assert "4" in finished
|
||||
@ -62,11 +53,12 @@ def test_request_tracker():
|
||||
assert stream_4.finished
|
||||
|
||||
stream_5 = tracker.add_request("5")
|
||||
assert tracker.new_requests_event.flag
|
||||
assert tracker.new_requests_event.is_set()
|
||||
tracker.process_request_output(
|
||||
RequestOutput("2", "output", [], [], [], bool(finished)))
|
||||
RequestOutput("2", "output", [], [], [], finished=True))
|
||||
await tracker.wait_for_new_requests()
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
assert not tracker.new_requests_event.flag
|
||||
assert not tracker.new_requests_event.is_set()
|
||||
assert len(finished) == 1
|
||||
assert "2" in finished
|
||||
assert len(new) == 1
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from functools import partial
|
||||
from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type,
|
||||
Union, AsyncIterator)
|
||||
Union, AsyncIterator, Callable)
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.config import ModelConfig
|
||||
@ -14,28 +15,31 @@ from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
logger = init_logger(__name__)
|
||||
ENGINE_ITERATION_TIMEOUT_S = int(
|
||||
os.environ.get("VLLM_ENGINE_ITERATION_TIMEOUT_S", "60"))
|
||||
|
||||
|
||||
class AsyncEngineDeadError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
def _raise_exception_on_finish(task: asyncio.Task,
|
||||
request_tracker: "RequestTracker") -> None:
|
||||
def _raise_exception_on_finish(
|
||||
task: asyncio.Task, error_callback: Callable[[Exception],
|
||||
None]) -> None:
|
||||
msg = ("Task finished unexpectedly. This should never happen! "
|
||||
"Please open an issue on Github.")
|
||||
try:
|
||||
|
||||
exception = None
|
||||
try:
|
||||
task.result()
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as exc:
|
||||
raise AsyncEngineDeadError(
|
||||
msg + " See stack trace above for the actual cause.") from exc
|
||||
# NOTE: This will be thrown if task exits normally (which it should not)
|
||||
raise AsyncEngineDeadError(msg)
|
||||
except Exception as exc:
|
||||
request_tracker.propagate_exception(exc)
|
||||
raise exc
|
||||
except Exception as e:
|
||||
exception = e
|
||||
logger.error("Engine background task failed", exc_info=e)
|
||||
error_callback(exception)
|
||||
raise AsyncEngineDeadError(
|
||||
msg + " See stack trace above for the actual cause.") from e
|
||||
|
||||
|
||||
class AsyncStream:
|
||||
@ -78,13 +82,13 @@ class RequestTracker:
|
||||
self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
|
||||
self._new_requests: asyncio.Queue[Tuple[AsyncStream,
|
||||
dict]] = asyncio.Queue()
|
||||
self.new_requests_event = None
|
||||
self.new_requests_event = asyncio.Event()
|
||||
|
||||
def __contains__(self, item):
|
||||
return item in self._request_streams
|
||||
|
||||
def init_event(self):
|
||||
self.new_requests_event = asyncio.Event()
|
||||
def __len__(self) -> int:
|
||||
return len(self._request_streams)
|
||||
|
||||
def propagate_exception(self,
|
||||
exc: Exception,
|
||||
@ -93,9 +97,11 @@ class RequestTracker:
|
||||
(all if request_id is None)."""
|
||||
if request_id is not None:
|
||||
self._request_streams[request_id].put(exc)
|
||||
self.abort_request(request_id)
|
||||
else:
|
||||
for stream in self._request_streams.values():
|
||||
for rid, stream in self._request_streams.items():
|
||||
stream.put(exc)
|
||||
self.abort_request(rid)
|
||||
|
||||
def process_request_output(self,
|
||||
request_output: RequestOutput,
|
||||
@ -172,12 +178,15 @@ class RequestTracker:
|
||||
self._request_streams[stream.request_id] = stream
|
||||
new_requests.append(new_request)
|
||||
|
||||
self.new_requests_event.clear()
|
||||
|
||||
return new_requests, finished_requests
|
||||
|
||||
async def wait_for_new_requests(self):
|
||||
if not self.has_new_requests():
|
||||
await self.new_requests_event.wait()
|
||||
self.new_requests_event.clear()
|
||||
|
||||
def has_new_requests(self):
|
||||
return not self._new_requests.empty()
|
||||
|
||||
|
||||
class _AsyncLLMEngine(LLMEngine):
|
||||
@ -285,6 +294,10 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
all_outputs = await asyncio.gather(*coros)
|
||||
return all_outputs
|
||||
|
||||
async def check_health_async(self):
|
||||
"""Raises an error if engine is unhealthy."""
|
||||
self._check_if_any_actor_is_dead()
|
||||
|
||||
|
||||
class AsyncLLMEngine:
|
||||
"""An asynchronous wrapper for LLMEngine.
|
||||
@ -335,27 +348,48 @@ class AsyncLLMEngine:
|
||||
# collected
|
||||
self._background_loop_unshielded = None
|
||||
self.start_engine_loop = start_engine_loop
|
||||
self._request_tracker = RequestTracker()
|
||||
self._request_tracker: Optional[RequestTracker] = None
|
||||
self._errored_with: Optional[BaseException] = None
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
return (self.background_loop is not None
|
||||
and not self.background_loop.done())
|
||||
and not self._background_loop_unshielded.done())
|
||||
|
||||
@property
|
||||
def is_stopped(self) -> bool:
|
||||
return self.errored or (self.background_loop is not None
|
||||
and self._background_loop_unshielded.done())
|
||||
|
||||
@property
|
||||
def errored(self) -> bool:
|
||||
return self._errored_with is not None
|
||||
|
||||
def set_errored(self, exc: Exception) -> None:
|
||||
self._errored_with = exc
|
||||
|
||||
def _error_callback(self, exc: Exception) -> None:
|
||||
self.set_errored(exc)
|
||||
self._request_tracker.propagate_exception(exc)
|
||||
|
||||
def get_tokenizer(self):
|
||||
return self.engine.tokenizer.tokenizer
|
||||
|
||||
def start_background_loop(self) -> None:
|
||||
"""Start the background loop."""
|
||||
if self.errored:
|
||||
raise AsyncEngineDeadError(
|
||||
"Background loop has errored already.") from self._errored_with
|
||||
if self.is_running:
|
||||
raise RuntimeError("Background loop is already running.")
|
||||
self._request_tracker.init_event()
|
||||
# Initialize the RequestTracker here so it uses the right event loop.
|
||||
self._request_tracker = RequestTracker()
|
||||
|
||||
self._background_loop_unshielded = asyncio.get_event_loop(
|
||||
).create_task(self.run_engine_loop())
|
||||
self._background_loop_unshielded.add_done_callback(
|
||||
partial(_raise_exception_on_finish,
|
||||
request_tracker=self._request_tracker))
|
||||
error_callback=self._error_callback))
|
||||
self.background_loop = asyncio.shield(self._background_loop_unshielded)
|
||||
|
||||
def _init_engine(self, *args,
|
||||
@ -423,12 +457,23 @@ class AsyncLLMEngine:
|
||||
self.engine.abort_request(request_ids)
|
||||
|
||||
async def run_engine_loop(self):
|
||||
# Initialize the RequestTracker here so it uses the right event loop.
|
||||
has_requests_in_progress = False
|
||||
while True:
|
||||
if not has_requests_in_progress:
|
||||
logger.debug("Waiting for new requests...")
|
||||
await self._request_tracker.wait_for_new_requests()
|
||||
has_requests_in_progress = await self.engine_step()
|
||||
logger.debug("Got new requests!")
|
||||
|
||||
# Abort if iteration takes too long due to unrecoverable errors
|
||||
# (eg. NCCL timeouts).
|
||||
try:
|
||||
has_requests_in_progress = await asyncio.wait_for(
|
||||
self.engine_step(), ENGINE_ITERATION_TIMEOUT_S)
|
||||
except asyncio.TimeoutError as exc:
|
||||
logger.error(
|
||||
"Engine iteration timed out. This should never happen!")
|
||||
self.set_errored(exc)
|
||||
raise
|
||||
await asyncio.sleep(0)
|
||||
|
||||
async def add_request(
|
||||
@ -647,3 +692,19 @@ class AsyncLLMEngine:
|
||||
await self.engine.do_log_stats.remote()
|
||||
else:
|
||||
self.engine.do_log_stats()
|
||||
|
||||
async def check_health(self):
|
||||
"""Raises an error if engine is unhealthy."""
|
||||
t = time.perf_counter()
|
||||
logger.debug("Starting health check...")
|
||||
if self.is_stopped:
|
||||
raise AsyncEngineDeadError("Background loop is stopped.")
|
||||
|
||||
if self.engine_use_ray:
|
||||
try:
|
||||
await self.engine.check_health.remote()
|
||||
except ray.exceptions.RayActorError as e:
|
||||
raise RuntimeError("Engine is dead.") from e
|
||||
else:
|
||||
await self.engine.check_health_async()
|
||||
logger.debug(f"Health check took {time.perf_counter()-t}s")
|
||||
|
||||
@ -1119,3 +1119,23 @@ class LLMEngine:
|
||||
for worker in self.workers
|
||||
])
|
||||
return forward_dag.experimental_compile()
|
||||
|
||||
def check_health(self) -> None:
|
||||
"""Raises an error if engine is unhealthy."""
|
||||
self._check_if_any_actor_is_dead()
|
||||
|
||||
def _check_if_any_actor_is_dead(self):
|
||||
if not self.parallel_config.worker_use_ray:
|
||||
return
|
||||
|
||||
if not self.workers:
|
||||
return
|
||||
|
||||
dead_actors = []
|
||||
for actor in self.workers:
|
||||
actor_state = ray.state.actors(actor._ray_actor_id.hex()) # pylint: disable=protected-access
|
||||
if actor_state["State"] == "DEAD":
|
||||
dead_actors.append(actor)
|
||||
if dead_actors:
|
||||
raise RuntimeError("At least one Worker is dead. "
|
||||
f"Dead Workers: {dead_actors}. ")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user