diff --git a/tests/async_engine/test_async_llm_engine.py b/tests/async_engine/test_async_llm_engine.py index d763f2c2e07b..03494581431d 100644 --- a/tests/async_engine/test_async_llm_engine.py +++ b/tests/async_engine/test_async_llm_engine.py @@ -1,14 +1,19 @@ import asyncio import os +from asyncio import CancelledError from dataclasses import dataclass +from typing import Optional import pytest +import pytest_asyncio import torch from vllm import SamplingParams from vllm.config import ParallelConfig from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine +from vllm.outputs import RequestOutput as RealRequestOutput +from ..conftest import cleanup from ..utils import wait_for_gpu_memory_to_clear @@ -118,15 +123,38 @@ async def test_new_requests_event(): os.environ.pop("VLLM_ALLOW_ENGINE_USE_RAY") -def test_asyncio_run(): +def start_engine(): wait_for_gpu_memory_to_clear( devices=list(range(torch.cuda.device_count())), threshold_bytes=2 * 2**30, timeout_s=60, ) - engine = AsyncLLMEngine.from_engine_args( - AsyncEngineArgs(model="facebook/opt-125m")) + return AsyncLLMEngine.from_engine_args( + AsyncEngineArgs(model="facebook/opt-125m", enforce_eager=True)) + + +@pytest_asyncio.fixture(scope="module") +async def async_engine(): + engine = await asyncio.get_event_loop().run_in_executor(executor=None, + func=start_engine) + try: + yield engine + finally: + engine.shutdown_background_loop() + del engine + await asyncio.sleep(0.1) + cleanup() + + +@pytest.fixture() +def should_do_global_cleanup_after_test(request) -> bool: + # So we can share the async engine fixture between these tests + return False + + +@pytest.mark.asyncio(scope="module") +async def test_asyncio_run(async_engine): async def run(prompt: str): sampling_params = SamplingParams( @@ -134,17 +162,64 @@ def test_asyncio_run(): max_tokens=32, ) - async for output in engine.generate(prompt, - sampling_params, - request_id=prompt): + async for output in async_engine.generate(prompt, + sampling_params, + request_id=prompt): final_output = output return final_output - async def generate(): - return await asyncio.gather( - run("test0"), - run("test1"), - ) - - results = asyncio.run(generate()) + results = await asyncio.gather( + run("test0"), + run("test1"), + ) assert len(results) == 2 + + +@pytest.mark.asyncio(scope="module") +async def test_cancellation(async_engine): + sampling_params = SamplingParams( + temperature=0, + min_tokens=10, + max_tokens=10, + ) + + i = 0 + with pytest.raises(CancelledError): + async for output in async_engine.generate("test2", + sampling_params, + request_id="test2"): + assert not output.finished + i += 1 + if i == 5: + await async_engine.abort("test2") + + assert i == 5 + + +@pytest.mark.asyncio(scope="module") +async def test_delayed_generator(async_engine): + sampling_params = SamplingParams( + temperature=0, + min_tokens=10, + max_tokens=10, + ) + + stream = async_engine.generate("test3", + sampling_params, + request_id="test3") + i = 0 + final_output: Optional[RealRequestOutput] = None + async for output in stream: + final_output = output + if i == 0: + # wait for generation to complete before consuming + # the remaining messages + await asyncio.sleep(1) + if i < 9: + assert not output.finished + i += 1 + + assert i == 10 + assert final_output is not None + assert len(final_output.outputs[0].token_ids) == 10 + assert final_output.finished diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index b33c19e97141..ceda0b83a239 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -2,8 +2,8 @@ import asyncio import time from dataclasses import dataclass from functools import partial -from typing import (AsyncGenerator, Callable, Dict, Iterable, List, Mapping, - Optional, Set, Tuple, Type, Union) +from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List, + Mapping, Optional, Set, Tuple, Type, Union) import torch from typing_extensions import assert_never @@ -85,9 +85,8 @@ class AsyncStream: def put(self, item: Union[RequestOutput, EmbeddingRequestOutput, Exception]) -> None: - if self._finished: - return - self._queue.put_nowait(item) + if not self._finished: + self._queue.put_nowait(item) def finish( self, @@ -96,7 +95,7 @@ class AsyncStream: if not self._finished: self._finished = True self._queue.put_nowait( - exception if exception is not None else STOP_ITERATION) + exception if self._is_raisable(exception) else STOP_ITERATION) @property def finished(self) -> bool: @@ -106,9 +105,9 @@ class AsyncStream: self ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]: try: - while not self._finished: + while True: result = await self._queue.get() - if isinstance(result, Exception): + if self._is_raisable(result): if result == STOP_ITERATION: return raise result @@ -117,6 +116,12 @@ class AsyncStream: self._cancel(self.request_id) raise asyncio.CancelledError from None + @staticmethod + def _is_raisable(value: Any): + return isinstance(value, BaseException) or \ + (isinstance(value, type) and \ + issubclass(value, BaseException)) + class RequestTracker: """Synchronous abstraction for tracking requests."""