From dfea17314827845d55dabb03ebe905f58e6682e4 Mon Sep 17 00:00:00 2001 From: Ruoyu Qin Date: Sun, 28 Apr 2024 00:48:37 +0800 Subject: [PATCH] [Bugfix] Abort requests when the connection to /v1/completions is interrupted (#4363) --- .../test_merge_async_iterators.py | 41 +++++++++++++++++++ vllm/utils.py | 17 +++++--- 2 files changed, 53 insertions(+), 5 deletions(-) create mode 100644 tests/async_engine/test_merge_async_iterators.py diff --git a/tests/async_engine/test_merge_async_iterators.py b/tests/async_engine/test_merge_async_iterators.py new file mode 100644 index 000000000000..ea453526c77f --- /dev/null +++ b/tests/async_engine/test_merge_async_iterators.py @@ -0,0 +1,41 @@ +import asyncio +from typing import AsyncIterator, Tuple + +import pytest + +from vllm.utils import merge_async_iterators + + +@pytest.mark.asyncio +async def test_merge_async_iterators(): + + async def mock_async_iterator(idx: int) -> AsyncIterator[str]: + try: + while True: + yield f"item from iterator {idx}" + await asyncio.sleep(0.1) + except asyncio.CancelledError: + pass + + iterators = [mock_async_iterator(i) for i in range(3)] + merged_iterator: AsyncIterator[Tuple[int, str]] = merge_async_iterators( + *iterators) + + async def stream_output(generator: AsyncIterator[Tuple[int, str]]): + async for idx, output in generator: + print(f"idx: {idx}, output: {output}") + + task = asyncio.create_task(stream_output(merged_iterator)) + await asyncio.sleep(0.5) + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + + for iterator in iterators: + try: + await asyncio.wait_for(anext(iterator), 1) + except StopAsyncIteration: + # All iterators should be cancelled and print this message. + print("Iterator was cancelled normally") + except (Exception, asyncio.CancelledError) as e: + raise AssertionError() from e diff --git a/vllm/utils.py b/vllm/utils.py index 76c2fc66e47c..88447878f170 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -225,11 +225,18 @@ def merge_async_iterators( ] async def consumer(): - while not all(finished) or not queue.empty(): - item = await queue.get() - if isinstance(item, Exception): - raise item - yield item + try: + while not all(finished) or not queue.empty(): + item = await queue.get() + if isinstance(item, Exception): + raise item + yield item + except (Exception, asyncio.CancelledError) as e: + for task in _tasks: + # NOTE: Pass the error msg in cancel() + # when only Python 3.9+ is supported. + task.cancel() + raise e await asyncio.gather(*_tasks) return consumer()