From ff36139ffc66294c19b503c1e52dc42c2cd265f6 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Sun, 17 Sep 2023 00:29:08 -0700 Subject: [PATCH] Remove AsyncLLMEngine busy loop, shield background task (#1059) --- requirements-dev.txt | 1 + tests/async_engine/test_async_llm_engine.py | 80 +++++++++++++++++++++ tests/async_engine/test_request_tracker.py | 21 ++++++ vllm/engine/async_llm_engine.py | 70 +++++++++++++----- 4 files changed, 154 insertions(+), 18 deletions(-) create mode 100644 tests/async_engine/test_async_llm_engine.py diff --git a/requirements-dev.txt b/requirements-dev.txt index f770ccec90ee..bfa1d06de101 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -11,3 +11,4 @@ types-setuptools # testing pytest pytest-forked +pytest-asyncio diff --git a/tests/async_engine/test_async_llm_engine.py b/tests/async_engine/test_async_llm_engine.py new file mode 100644 index 000000000000..44ad201e914b --- /dev/null +++ b/tests/async_engine/test_async_llm_engine.py @@ -0,0 +1,80 @@ +import asyncio +from dataclasses import dataclass + +import pytest + +from vllm.engine.async_llm_engine import AsyncLLMEngine + + +@dataclass +class RequestOutput: + request_id: int + finished: bool = False + + +class MockEngine: + + def __init__(self): + self.step_calls = 0 + self.add_request_calls = 0 + self.abort_request_calls = 0 + self.request_id = None + + async def step_async(self): + self.step_calls += 1 + return [RequestOutput( + request_id=self.request_id)] if self.request_id else [] + + def generate(self, request_id): + self.request_id = request_id + + def stop_generating(self): + self.request_id = None + + def add_request(self, **kwargs): + self.add_request_calls += 1 + return + + def abort_request(self, request_id): + self.abort_request_calls += 1 + return + + +class MockAsyncLLMEngine(AsyncLLMEngine): + + def _init_engine(self, *args, **kwargs): + return MockEngine() + + +@pytest.mark.asyncio +async def test_new_requests_event(): + engine = MockAsyncLLMEngine(worker_use_ray=False, engine_use_ray=False) + engine.start_background_loop() + await asyncio.sleep(0.01) + assert engine.engine.step_calls == 0 + + await engine.add_request("1", "", None) + await asyncio.sleep(0.01) + assert engine.engine.add_request_calls == 1 + assert engine.engine.step_calls == 1 + + await engine.add_request("2", "", None) + engine.engine.generate("2") + await asyncio.sleep(0) + assert engine.engine.add_request_calls == 2 + assert engine.engine.step_calls == 2 + await asyncio.sleep(0) + assert engine.engine.step_calls == 3 + engine.engine.stop_generating() + await asyncio.sleep(0) + assert engine.engine.step_calls == 4 + await asyncio.sleep(0) + assert engine.engine.step_calls == 4 + + await engine.add_request("3", "", None) + await asyncio.sleep(0.01) + assert engine.engine.add_request_calls == 3 + assert engine.engine.step_calls == 5 + await asyncio.sleep(0.01) + assert engine.engine.add_request_calls == 3 + assert engine.engine.step_calls == 5 diff --git a/tests/async_engine/test_request_tracker.py b/tests/async_engine/test_request_tracker.py index 3666e6c7e626..7787381f97d1 100644 --- a/tests/async_engine/test_request_tracker.py +++ b/tests/async_engine/test_request_tracker.py @@ -4,10 +4,25 @@ from vllm.engine.async_llm_engine import RequestTracker from vllm.outputs import RequestOutput +class DummyEvent: + + def __init__(self): + self._flag = False + + def set(self): + self._flag = True + + def clear(self): + self._flag = False + + def test_request_tracker(): tracker = RequestTracker() + tracker.new_requests_event = DummyEvent() stream_1 = tracker.add_request("1") + assert tracker.new_requests_event._flag new, finished = tracker.get_new_and_finished_requests() + assert not tracker.new_requests_event._flag assert len(new) == 1 assert new[0]["request_id"] == "1" assert not finished @@ -15,7 +30,9 @@ def test_request_tracker(): stream_2 = tracker.add_request("2") stream_3 = tracker.add_request("3") + assert tracker.new_requests_event._flag new, finished = tracker.get_new_and_finished_requests() + assert not tracker.new_requests_event._flag assert len(new) == 2 assert new[0]["request_id"] == "2" assert new[1]["request_id"] == "3" @@ -26,6 +43,7 @@ def test_request_tracker(): # request_ids must be unique with pytest.raises(KeyError): tracker.add_request("1") + assert not tracker.new_requests_event._flag tracker.abort_request("1") new, finished = tracker.get_new_and_finished_requests() @@ -36,6 +54,7 @@ def test_request_tracker(): stream_4 = tracker.add_request("4") tracker.abort_request("4") + assert tracker.new_requests_event._flag new, finished = tracker.get_new_and_finished_requests() assert len(finished) == 1 assert "4" in finished @@ -43,9 +62,11 @@ def test_request_tracker(): assert stream_4.finished stream_5 = tracker.add_request("5") + assert tracker.new_requests_event._flag tracker.process_request_output( RequestOutput("2", "output", [], [], finished=True)) new, finished = tracker.get_new_and_finished_requests() + assert not tracker.new_requests_event._flag assert len(finished) == 1 assert "2" in finished assert len(new) == 1 diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index a83a602d0222..40ac47522847 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1,7 +1,8 @@ import asyncio import time from functools import partial -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union +from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type, + Union) from vllm.config import ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs @@ -78,14 +79,24 @@ class RequestTracker: self._finished_requests: asyncio.Queue[str] = asyncio.Queue() self._new_requests: asyncio.Queue[Tuple[AsyncStream, dict]] = asyncio.Queue() + self.new_requests_event = None def __contains__(self, item): return item in self._request_streams - def propagate_exception(self, exc: Exception) -> None: - """Propagate an exception to all request streams.""" - for stream in self._request_streams.values(): - stream.put(exc) + def init_event(self): + self.new_requests_event = asyncio.Event() + + def propagate_exception(self, + exc: Exception, + request_id: Optional[str] = None) -> None: + """Propagate an exception to request streams + (all if request_id is None).""" + if request_id is not None: + self._request_streams[request_id].put(exc) + else: + for stream in self._request_streams.values(): + stream.put(exc) def process_request_output(self, request_output: RequestOutput, @@ -112,6 +123,9 @@ class RequestTracker: "request_id": request_id, **engine_add_request_kwargs })) + + self.new_requests_event.set() + return stream def abort_request(self, request_id: str, *, verbose: bool = False) -> None: @@ -148,8 +162,13 @@ class RequestTracker: self._request_streams[stream.request_id] = stream new_requests.append(new_request) + self.new_requests_event.clear() + return new_requests, finished_requests + async def wait_for_new_requests(self): + await self.new_requests_event.wait() + class _AsyncLLMEngine(LLMEngine): """Extension of LLMEngine to add async methods.""" @@ -251,9 +270,13 @@ class AsyncLLMEngine: self.max_log_len = max_log_len self.engine = self._init_engine(*args, **kwargs) - self.request_tracker: RequestTracker = RequestTracker() self.background_loop = None + # We need to keep a reference to unshielded + # task as well to prevent it from being garbage + # collected + self._background_loop_unshielded = None self.start_engine_loop = start_engine_loop + self._request_tracker = RequestTracker() @property def is_running(self) -> bool: @@ -264,11 +287,14 @@ class AsyncLLMEngine: """Start the background loop.""" if self.is_running: raise RuntimeError("Background loop is already running.") - self.background_loop = asyncio.get_event_loop().create_task( - self.run_engine_loop()) - self.background_loop.add_done_callback( + self._request_tracker.init_event() + + self._background_loop_unshielded = asyncio.get_event_loop( + ).create_task(self.run_engine_loop()) + self._background_loop_unshielded.add_done_callback( partial(_raise_exception_on_finish, - request_tracker=self.request_tracker)) + request_tracker=self._request_tracker)) + self.background_loop = asyncio.shield(self._background_loop_unshielded) def _init_engine(self, *args, **kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]: @@ -280,11 +306,13 @@ class AsyncLLMEngine: engine_class = ray.remote(num_gpus=1)(self._engine_class).remote return engine_class(*args, **kwargs) - async def engine_step(self): - """Kick the engine to process the waiting requests.""" + async def engine_step(self) -> bool: + """Kick the engine to process the waiting requests. + + Returns True if there are in-progress requests.""" new_requests, finished_requests = ( - self.request_tracker.get_new_and_finished_requests()) + self._request_tracker.get_new_and_finished_requests()) for new_request in new_requests: # Add the request into the vLLM engine's waiting queue. @@ -304,9 +332,11 @@ class AsyncLLMEngine: # Put the outputs into the corresponding streams. for request_output in request_outputs: - self.request_tracker.process_request_output( + self._request_tracker.process_request_output( request_output, verbose=self.log_requests) + return len(request_outputs) > 0 + async def _engine_abort(self, request_ids: Iterable[str]): if self.engine_use_ray: await self.engine.abort_request.remote(request_ids) @@ -314,8 +344,12 @@ class AsyncLLMEngine: self.engine.abort_request(request_ids) async def run_engine_loop(self): + # Initialize the RequestTracker here so it uses the right event loop. + has_requests_in_progress = False while True: - await self.engine_step() + if not has_requests_in_progress: + await self._request_tracker.wait_for_new_requests() + has_requests_in_progress = await self.engine_step() await asyncio.sleep(0) async def add_request( @@ -350,7 +384,7 @@ class AsyncLLMEngine: "error that caused the background loop to stop " "(AsyncEngineDeadError).") - stream = self.request_tracker.add_request( + stream = self._request_tracker.add_request( request_id, prompt=prompt, sampling_params=sampling_params, @@ -428,8 +462,8 @@ class AsyncLLMEngine: Args: request_id: The unique id of the request. """ - self.request_tracker.abort_request(request_id, - verbose=self.log_requests) + self._request_tracker.abort_request(request_id, + verbose=self.log_requests) async def get_model_config(self) -> ModelConfig: """Get the model configuration of the vLLM engine."""