[BugFix] Avoid premature async generator exit and raise all exception variations (#7698)

This commit is contained in:
Nick Hill 2024-08-21 11:45:55 -04:00 committed by GitHub
parent dd3fa0e430
commit c75363fbc0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 101 additions and 21 deletions

View File

@ -1,14 +1,19 @@
import asyncio import asyncio
import os import os
from asyncio import CancelledError
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional
import pytest import pytest
import pytest_asyncio
import torch import torch
from vllm import SamplingParams from vllm import SamplingParams
from vllm.config import ParallelConfig from vllm.config import ParallelConfig
from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine
from vllm.outputs import RequestOutput as RealRequestOutput
from ..conftest import cleanup
from ..utils import wait_for_gpu_memory_to_clear from ..utils import wait_for_gpu_memory_to_clear
@ -118,15 +123,38 @@ async def test_new_requests_event():
os.environ.pop("VLLM_ALLOW_ENGINE_USE_RAY") os.environ.pop("VLLM_ALLOW_ENGINE_USE_RAY")
def test_asyncio_run(): def start_engine():
wait_for_gpu_memory_to_clear( wait_for_gpu_memory_to_clear(
devices=list(range(torch.cuda.device_count())), devices=list(range(torch.cuda.device_count())),
threshold_bytes=2 * 2**30, threshold_bytes=2 * 2**30,
timeout_s=60, timeout_s=60,
) )
engine = AsyncLLMEngine.from_engine_args( return AsyncLLMEngine.from_engine_args(
AsyncEngineArgs(model="facebook/opt-125m")) AsyncEngineArgs(model="facebook/opt-125m", enforce_eager=True))
@pytest_asyncio.fixture(scope="module")
async def async_engine():
engine = await asyncio.get_event_loop().run_in_executor(executor=None,
func=start_engine)
try:
yield engine
finally:
engine.shutdown_background_loop()
del engine
await asyncio.sleep(0.1)
cleanup()
@pytest.fixture()
def should_do_global_cleanup_after_test(request) -> bool:
# So we can share the async engine fixture between these tests
return False
@pytest.mark.asyncio(scope="module")
async def test_asyncio_run(async_engine):
async def run(prompt: str): async def run(prompt: str):
sampling_params = SamplingParams( sampling_params = SamplingParams(
@ -134,17 +162,64 @@ def test_asyncio_run():
max_tokens=32, max_tokens=32,
) )
async for output in engine.generate(prompt, async for output in async_engine.generate(prompt,
sampling_params, sampling_params,
request_id=prompt): request_id=prompt):
final_output = output final_output = output
return final_output return final_output
async def generate(): results = await asyncio.gather(
return await asyncio.gather( run("test0"),
run("test0"), run("test1"),
run("test1"), )
)
results = asyncio.run(generate())
assert len(results) == 2 assert len(results) == 2
@pytest.mark.asyncio(scope="module")
async def test_cancellation(async_engine):
sampling_params = SamplingParams(
temperature=0,
min_tokens=10,
max_tokens=10,
)
i = 0
with pytest.raises(CancelledError):
async for output in async_engine.generate("test2",
sampling_params,
request_id="test2"):
assert not output.finished
i += 1
if i == 5:
await async_engine.abort("test2")
assert i == 5
@pytest.mark.asyncio(scope="module")
async def test_delayed_generator(async_engine):
sampling_params = SamplingParams(
temperature=0,
min_tokens=10,
max_tokens=10,
)
stream = async_engine.generate("test3",
sampling_params,
request_id="test3")
i = 0
final_output: Optional[RealRequestOutput] = None
async for output in stream:
final_output = output
if i == 0:
# wait for generation to complete before consuming
# the remaining messages
await asyncio.sleep(1)
if i < 9:
assert not output.finished
i += 1
assert i == 10
assert final_output is not None
assert len(final_output.outputs[0].token_ids) == 10
assert final_output.finished

View File

@ -2,8 +2,8 @@ import asyncio
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial from functools import partial
from typing import (AsyncGenerator, Callable, Dict, Iterable, List, Mapping, from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
Optional, Set, Tuple, Type, Union) Mapping, Optional, Set, Tuple, Type, Union)
import torch import torch
from typing_extensions import assert_never from typing_extensions import assert_never
@ -85,9 +85,8 @@ class AsyncStream:
def put(self, item: Union[RequestOutput, EmbeddingRequestOutput, def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
Exception]) -> None: Exception]) -> None:
if self._finished: if not self._finished:
return self._queue.put_nowait(item)
self._queue.put_nowait(item)
def finish( def finish(
self, self,
@ -96,7 +95,7 @@ class AsyncStream:
if not self._finished: if not self._finished:
self._finished = True self._finished = True
self._queue.put_nowait( self._queue.put_nowait(
exception if exception is not None else STOP_ITERATION) exception if self._is_raisable(exception) else STOP_ITERATION)
@property @property
def finished(self) -> bool: def finished(self) -> bool:
@ -106,9 +105,9 @@ class AsyncStream:
self self
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]: ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
try: try:
while not self._finished: while True:
result = await self._queue.get() result = await self._queue.get()
if isinstance(result, Exception): if self._is_raisable(result):
if result == STOP_ITERATION: if result == STOP_ITERATION:
return return
raise result raise result
@ -117,6 +116,12 @@ class AsyncStream:
self._cancel(self.request_id) self._cancel(self.request_id)
raise asyncio.CancelledError from None raise asyncio.CancelledError from None
@staticmethod
def _is_raisable(value: Any):
return isinstance(value, BaseException) or \
(isinstance(value, type) and \
issubclass(value, BaseException))
class RequestTracker: class RequestTracker:
"""Synchronous abstraction for tracking requests.""" """Synchronous abstraction for tracking requests."""