mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:15:42 +08:00
Add health check, make async Engine more robust (#3015)
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
parent
22de45235c
commit
ff578cae54
@ -25,12 +25,8 @@ class MockEngine:
|
|||||||
return [RequestOutput(
|
return [RequestOutput(
|
||||||
request_id=self.request_id)] if self.request_id else []
|
request_id=self.request_id)] if self.request_id else []
|
||||||
|
|
||||||
async def encode_request_async(
|
async def encode_request_async(self, *args, **kwargs):
|
||||||
self,
|
pass
|
||||||
*args,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
return [1]
|
|
||||||
|
|
||||||
def generate(self, request_id):
|
def generate(self, request_id):
|
||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
@ -43,13 +39,16 @@ class MockEngine:
|
|||||||
self.add_request_calls += 1
|
self.add_request_calls += 1
|
||||||
|
|
||||||
async def add_request_async(self, **kwargs):
|
async def add_request_async(self, **kwargs):
|
||||||
del kwargs # Unused
|
|
||||||
self.add_request_calls += 1
|
self.add_request_calls += 1
|
||||||
|
return
|
||||||
|
|
||||||
def abort_request(self, request_id):
|
def abort_request(self, request_id):
|
||||||
del request_id # Unused
|
del request_id # Unused
|
||||||
self.abort_request_calls += 1
|
self.abort_request_calls += 1
|
||||||
|
|
||||||
|
def has_unfinished_requests(self):
|
||||||
|
return self.request_id is not None
|
||||||
|
|
||||||
|
|
||||||
class MockAsyncLLMEngine(AsyncLLMEngine):
|
class MockAsyncLLMEngine(AsyncLLMEngine):
|
||||||
|
|
||||||
@ -72,20 +71,21 @@ async def test_new_requests_event():
|
|||||||
await engine.add_request("2", "", None)
|
await engine.add_request("2", "", None)
|
||||||
engine.engine.generate("2")
|
engine.engine.generate("2")
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
|
await asyncio.sleep(0)
|
||||||
assert engine.engine.add_request_calls == 2
|
assert engine.engine.add_request_calls == 2
|
||||||
assert engine.engine.step_calls == 2
|
assert engine.engine.step_calls >= 2
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0.001)
|
||||||
assert engine.engine.step_calls == 3
|
assert engine.engine.step_calls >= 3
|
||||||
engine.engine.stop_generating()
|
engine.engine.stop_generating()
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0.001)
|
||||||
assert engine.engine.step_calls == 4
|
old_step_calls = engine.engine.step_calls
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0.001)
|
||||||
assert engine.engine.step_calls == 4
|
assert engine.engine.step_calls == old_step_calls
|
||||||
|
|
||||||
await engine.add_request("3", "", None)
|
await engine.add_request("3", "", None)
|
||||||
await asyncio.sleep(0.01)
|
await asyncio.sleep(0.01)
|
||||||
assert engine.engine.add_request_calls == 3
|
assert engine.engine.add_request_calls == 3
|
||||||
assert engine.engine.step_calls == 5
|
assert engine.engine.step_calls == old_step_calls + 1
|
||||||
await asyncio.sleep(0.01)
|
await asyncio.sleep(0.01)
|
||||||
assert engine.engine.add_request_calls == 3
|
assert engine.engine.add_request_calls == 3
|
||||||
assert engine.engine.step_calls == 5
|
assert engine.engine.step_calls == old_step_calls + 1
|
||||||
|
|||||||
@ -4,25 +4,14 @@ from vllm.engine.async_llm_engine import RequestTracker
|
|||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
|
|
||||||
|
|
||||||
class DummyEvent:
|
@pytest.mark.asyncio
|
||||||
|
async def test_request_tracker():
|
||||||
def __init__(self):
|
|
||||||
self.flag = False
|
|
||||||
|
|
||||||
def set(self):
|
|
||||||
self.flag = True
|
|
||||||
|
|
||||||
def clear(self):
|
|
||||||
self.flag = False
|
|
||||||
|
|
||||||
|
|
||||||
def test_request_tracker():
|
|
||||||
tracker = RequestTracker()
|
tracker = RequestTracker()
|
||||||
tracker.new_requests_event = DummyEvent()
|
|
||||||
stream_1 = tracker.add_request("1")
|
stream_1 = tracker.add_request("1")
|
||||||
assert tracker.new_requests_event.flag
|
assert tracker.new_requests_event.is_set()
|
||||||
|
await tracker.wait_for_new_requests()
|
||||||
new, finished = tracker.get_new_and_finished_requests()
|
new, finished = tracker.get_new_and_finished_requests()
|
||||||
assert not tracker.new_requests_event.flag
|
assert not tracker.new_requests_event.is_set()
|
||||||
assert len(new) == 1
|
assert len(new) == 1
|
||||||
assert new[0]["request_id"] == "1"
|
assert new[0]["request_id"] == "1"
|
||||||
assert not finished
|
assert not finished
|
||||||
@ -30,9 +19,10 @@ def test_request_tracker():
|
|||||||
|
|
||||||
stream_2 = tracker.add_request("2")
|
stream_2 = tracker.add_request("2")
|
||||||
stream_3 = tracker.add_request("3")
|
stream_3 = tracker.add_request("3")
|
||||||
assert tracker.new_requests_event.flag
|
assert tracker.new_requests_event.is_set()
|
||||||
|
await tracker.wait_for_new_requests()
|
||||||
new, finished = tracker.get_new_and_finished_requests()
|
new, finished = tracker.get_new_and_finished_requests()
|
||||||
assert not tracker.new_requests_event.flag
|
assert not tracker.new_requests_event.is_set()
|
||||||
assert len(new) == 2
|
assert len(new) == 2
|
||||||
assert new[0]["request_id"] == "2"
|
assert new[0]["request_id"] == "2"
|
||||||
assert new[1]["request_id"] == "3"
|
assert new[1]["request_id"] == "3"
|
||||||
@ -43,7 +33,7 @@ def test_request_tracker():
|
|||||||
# request_ids must be unique
|
# request_ids must be unique
|
||||||
with pytest.raises(KeyError):
|
with pytest.raises(KeyError):
|
||||||
tracker.add_request("1")
|
tracker.add_request("1")
|
||||||
assert not tracker.new_requests_event.flag
|
assert not tracker.new_requests_event.is_set()
|
||||||
|
|
||||||
tracker.abort_request("1")
|
tracker.abort_request("1")
|
||||||
new, finished = tracker.get_new_and_finished_requests()
|
new, finished = tracker.get_new_and_finished_requests()
|
||||||
@ -54,7 +44,8 @@ def test_request_tracker():
|
|||||||
|
|
||||||
stream_4 = tracker.add_request("4")
|
stream_4 = tracker.add_request("4")
|
||||||
tracker.abort_request("4")
|
tracker.abort_request("4")
|
||||||
assert tracker.new_requests_event.flag
|
assert tracker.new_requests_event.is_set()
|
||||||
|
await tracker.wait_for_new_requests()
|
||||||
new, finished = tracker.get_new_and_finished_requests()
|
new, finished = tracker.get_new_and_finished_requests()
|
||||||
assert len(finished) == 1
|
assert len(finished) == 1
|
||||||
assert "4" in finished
|
assert "4" in finished
|
||||||
@ -62,11 +53,12 @@ def test_request_tracker():
|
|||||||
assert stream_4.finished
|
assert stream_4.finished
|
||||||
|
|
||||||
stream_5 = tracker.add_request("5")
|
stream_5 = tracker.add_request("5")
|
||||||
assert tracker.new_requests_event.flag
|
assert tracker.new_requests_event.is_set()
|
||||||
tracker.process_request_output(
|
tracker.process_request_output(
|
||||||
RequestOutput("2", "output", [], [], [], bool(finished)))
|
RequestOutput("2", "output", [], [], [], finished=True))
|
||||||
|
await tracker.wait_for_new_requests()
|
||||||
new, finished = tracker.get_new_and_finished_requests()
|
new, finished = tracker.get_new_and_finished_requests()
|
||||||
assert not tracker.new_requests_event.flag
|
assert not tracker.new_requests_event.is_set()
|
||||||
assert len(finished) == 1
|
assert len(finished) == 1
|
||||||
assert "2" in finished
|
assert "2" in finished
|
||||||
assert len(new) == 1
|
assert len(new) == 1
|
||||||
|
|||||||
@ -1,8 +1,9 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type,
|
from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type,
|
||||||
Union, AsyncIterator)
|
Union, AsyncIterator, Callable)
|
||||||
|
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
@ -14,28 +15,31 @@ from vllm.outputs import RequestOutput
|
|||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
ENGINE_ITERATION_TIMEOUT_S = int(
|
||||||
|
os.environ.get("VLLM_ENGINE_ITERATION_TIMEOUT_S", "60"))
|
||||||
|
|
||||||
|
|
||||||
class AsyncEngineDeadError(RuntimeError):
|
class AsyncEngineDeadError(RuntimeError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _raise_exception_on_finish(task: asyncio.Task,
|
def _raise_exception_on_finish(
|
||||||
request_tracker: "RequestTracker") -> None:
|
task: asyncio.Task, error_callback: Callable[[Exception],
|
||||||
|
None]) -> None:
|
||||||
msg = ("Task finished unexpectedly. This should never happen! "
|
msg = ("Task finished unexpectedly. This should never happen! "
|
||||||
"Please open an issue on Github.")
|
"Please open an issue on Github.")
|
||||||
try:
|
|
||||||
|
exception = None
|
||||||
try:
|
try:
|
||||||
task.result()
|
task.result()
|
||||||
except asyncio.CancelledError:
|
# NOTE: This will be thrown if task exits normally (which it should not)
|
||||||
return
|
|
||||||
except Exception as exc:
|
|
||||||
raise AsyncEngineDeadError(
|
|
||||||
msg + " See stack trace above for the actual cause.") from exc
|
|
||||||
raise AsyncEngineDeadError(msg)
|
raise AsyncEngineDeadError(msg)
|
||||||
except Exception as exc:
|
except Exception as e:
|
||||||
request_tracker.propagate_exception(exc)
|
exception = e
|
||||||
raise exc
|
logger.error("Engine background task failed", exc_info=e)
|
||||||
|
error_callback(exception)
|
||||||
|
raise AsyncEngineDeadError(
|
||||||
|
msg + " See stack trace above for the actual cause.") from e
|
||||||
|
|
||||||
|
|
||||||
class AsyncStream:
|
class AsyncStream:
|
||||||
@ -78,13 +82,13 @@ class RequestTracker:
|
|||||||
self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
|
self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
|
||||||
self._new_requests: asyncio.Queue[Tuple[AsyncStream,
|
self._new_requests: asyncio.Queue[Tuple[AsyncStream,
|
||||||
dict]] = asyncio.Queue()
|
dict]] = asyncio.Queue()
|
||||||
self.new_requests_event = None
|
self.new_requests_event = asyncio.Event()
|
||||||
|
|
||||||
def __contains__(self, item):
|
def __contains__(self, item):
|
||||||
return item in self._request_streams
|
return item in self._request_streams
|
||||||
|
|
||||||
def init_event(self):
|
def __len__(self) -> int:
|
||||||
self.new_requests_event = asyncio.Event()
|
return len(self._request_streams)
|
||||||
|
|
||||||
def propagate_exception(self,
|
def propagate_exception(self,
|
||||||
exc: Exception,
|
exc: Exception,
|
||||||
@ -93,9 +97,11 @@ class RequestTracker:
|
|||||||
(all if request_id is None)."""
|
(all if request_id is None)."""
|
||||||
if request_id is not None:
|
if request_id is not None:
|
||||||
self._request_streams[request_id].put(exc)
|
self._request_streams[request_id].put(exc)
|
||||||
|
self.abort_request(request_id)
|
||||||
else:
|
else:
|
||||||
for stream in self._request_streams.values():
|
for rid, stream in self._request_streams.items():
|
||||||
stream.put(exc)
|
stream.put(exc)
|
||||||
|
self.abort_request(rid)
|
||||||
|
|
||||||
def process_request_output(self,
|
def process_request_output(self,
|
||||||
request_output: RequestOutput,
|
request_output: RequestOutput,
|
||||||
@ -172,12 +178,15 @@ class RequestTracker:
|
|||||||
self._request_streams[stream.request_id] = stream
|
self._request_streams[stream.request_id] = stream
|
||||||
new_requests.append(new_request)
|
new_requests.append(new_request)
|
||||||
|
|
||||||
self.new_requests_event.clear()
|
|
||||||
|
|
||||||
return new_requests, finished_requests
|
return new_requests, finished_requests
|
||||||
|
|
||||||
async def wait_for_new_requests(self):
|
async def wait_for_new_requests(self):
|
||||||
|
if not self.has_new_requests():
|
||||||
await self.new_requests_event.wait()
|
await self.new_requests_event.wait()
|
||||||
|
self.new_requests_event.clear()
|
||||||
|
|
||||||
|
def has_new_requests(self):
|
||||||
|
return not self._new_requests.empty()
|
||||||
|
|
||||||
|
|
||||||
class _AsyncLLMEngine(LLMEngine):
|
class _AsyncLLMEngine(LLMEngine):
|
||||||
@ -285,6 +294,10 @@ class _AsyncLLMEngine(LLMEngine):
|
|||||||
all_outputs = await asyncio.gather(*coros)
|
all_outputs = await asyncio.gather(*coros)
|
||||||
return all_outputs
|
return all_outputs
|
||||||
|
|
||||||
|
async def check_health_async(self):
|
||||||
|
"""Raises an error if engine is unhealthy."""
|
||||||
|
self._check_if_any_actor_is_dead()
|
||||||
|
|
||||||
|
|
||||||
class AsyncLLMEngine:
|
class AsyncLLMEngine:
|
||||||
"""An asynchronous wrapper for LLMEngine.
|
"""An asynchronous wrapper for LLMEngine.
|
||||||
@ -335,27 +348,48 @@ class AsyncLLMEngine:
|
|||||||
# collected
|
# collected
|
||||||
self._background_loop_unshielded = None
|
self._background_loop_unshielded = None
|
||||||
self.start_engine_loop = start_engine_loop
|
self.start_engine_loop = start_engine_loop
|
||||||
self._request_tracker = RequestTracker()
|
self._request_tracker: Optional[RequestTracker] = None
|
||||||
|
self._errored_with: Optional[BaseException] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_running(self) -> bool:
|
def is_running(self) -> bool:
|
||||||
return (self.background_loop is not None
|
return (self.background_loop is not None
|
||||||
and not self.background_loop.done())
|
and not self._background_loop_unshielded.done())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_stopped(self) -> bool:
|
||||||
|
return self.errored or (self.background_loop is not None
|
||||||
|
and self._background_loop_unshielded.done())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def errored(self) -> bool:
|
||||||
|
return self._errored_with is not None
|
||||||
|
|
||||||
|
def set_errored(self, exc: Exception) -> None:
|
||||||
|
self._errored_with = exc
|
||||||
|
|
||||||
|
def _error_callback(self, exc: Exception) -> None:
|
||||||
|
self.set_errored(exc)
|
||||||
|
self._request_tracker.propagate_exception(exc)
|
||||||
|
|
||||||
def get_tokenizer(self):
|
def get_tokenizer(self):
|
||||||
return self.engine.tokenizer.tokenizer
|
return self.engine.tokenizer.tokenizer
|
||||||
|
|
||||||
def start_background_loop(self) -> None:
|
def start_background_loop(self) -> None:
|
||||||
"""Start the background loop."""
|
"""Start the background loop."""
|
||||||
|
if self.errored:
|
||||||
|
raise AsyncEngineDeadError(
|
||||||
|
"Background loop has errored already.") from self._errored_with
|
||||||
if self.is_running:
|
if self.is_running:
|
||||||
raise RuntimeError("Background loop is already running.")
|
raise RuntimeError("Background loop is already running.")
|
||||||
self._request_tracker.init_event()
|
# Initialize the RequestTracker here so it uses the right event loop.
|
||||||
|
self._request_tracker = RequestTracker()
|
||||||
|
|
||||||
self._background_loop_unshielded = asyncio.get_event_loop(
|
self._background_loop_unshielded = asyncio.get_event_loop(
|
||||||
).create_task(self.run_engine_loop())
|
).create_task(self.run_engine_loop())
|
||||||
self._background_loop_unshielded.add_done_callback(
|
self._background_loop_unshielded.add_done_callback(
|
||||||
partial(_raise_exception_on_finish,
|
partial(_raise_exception_on_finish,
|
||||||
request_tracker=self._request_tracker))
|
error_callback=self._error_callback))
|
||||||
self.background_loop = asyncio.shield(self._background_loop_unshielded)
|
self.background_loop = asyncio.shield(self._background_loop_unshielded)
|
||||||
|
|
||||||
def _init_engine(self, *args,
|
def _init_engine(self, *args,
|
||||||
@ -423,12 +457,23 @@ class AsyncLLMEngine:
|
|||||||
self.engine.abort_request(request_ids)
|
self.engine.abort_request(request_ids)
|
||||||
|
|
||||||
async def run_engine_loop(self):
|
async def run_engine_loop(self):
|
||||||
# Initialize the RequestTracker here so it uses the right event loop.
|
|
||||||
has_requests_in_progress = False
|
has_requests_in_progress = False
|
||||||
while True:
|
while True:
|
||||||
if not has_requests_in_progress:
|
if not has_requests_in_progress:
|
||||||
|
logger.debug("Waiting for new requests...")
|
||||||
await self._request_tracker.wait_for_new_requests()
|
await self._request_tracker.wait_for_new_requests()
|
||||||
has_requests_in_progress = await self.engine_step()
|
logger.debug("Got new requests!")
|
||||||
|
|
||||||
|
# Abort if iteration takes too long due to unrecoverable errors
|
||||||
|
# (eg. NCCL timeouts).
|
||||||
|
try:
|
||||||
|
has_requests_in_progress = await asyncio.wait_for(
|
||||||
|
self.engine_step(), ENGINE_ITERATION_TIMEOUT_S)
|
||||||
|
except asyncio.TimeoutError as exc:
|
||||||
|
logger.error(
|
||||||
|
"Engine iteration timed out. This should never happen!")
|
||||||
|
self.set_errored(exc)
|
||||||
|
raise
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
async def add_request(
|
async def add_request(
|
||||||
@ -647,3 +692,19 @@ class AsyncLLMEngine:
|
|||||||
await self.engine.do_log_stats.remote()
|
await self.engine.do_log_stats.remote()
|
||||||
else:
|
else:
|
||||||
self.engine.do_log_stats()
|
self.engine.do_log_stats()
|
||||||
|
|
||||||
|
async def check_health(self):
|
||||||
|
"""Raises an error if engine is unhealthy."""
|
||||||
|
t = time.perf_counter()
|
||||||
|
logger.debug("Starting health check...")
|
||||||
|
if self.is_stopped:
|
||||||
|
raise AsyncEngineDeadError("Background loop is stopped.")
|
||||||
|
|
||||||
|
if self.engine_use_ray:
|
||||||
|
try:
|
||||||
|
await self.engine.check_health.remote()
|
||||||
|
except ray.exceptions.RayActorError as e:
|
||||||
|
raise RuntimeError("Engine is dead.") from e
|
||||||
|
else:
|
||||||
|
await self.engine.check_health_async()
|
||||||
|
logger.debug(f"Health check took {time.perf_counter()-t}s")
|
||||||
|
|||||||
@ -1119,3 +1119,23 @@ class LLMEngine:
|
|||||||
for worker in self.workers
|
for worker in self.workers
|
||||||
])
|
])
|
||||||
return forward_dag.experimental_compile()
|
return forward_dag.experimental_compile()
|
||||||
|
|
||||||
|
def check_health(self) -> None:
|
||||||
|
"""Raises an error if engine is unhealthy."""
|
||||||
|
self._check_if_any_actor_is_dead()
|
||||||
|
|
||||||
|
def _check_if_any_actor_is_dead(self):
|
||||||
|
if not self.parallel_config.worker_use_ray:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not self.workers:
|
||||||
|
return
|
||||||
|
|
||||||
|
dead_actors = []
|
||||||
|
for actor in self.workers:
|
||||||
|
actor_state = ray.state.actors(actor._ray_actor_id.hex()) # pylint: disable=protected-access
|
||||||
|
if actor_state["State"] == "DEAD":
|
||||||
|
dead_actors.append(actor)
|
||||||
|
if dead_actors:
|
||||||
|
raise RuntimeError("At least one Worker is dead. "
|
||||||
|
f"Dead Workers: {dead_actors}. ")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user