mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 22:45:50 +08:00
[BugFix] Avoid premature async generator exit and raise all exception variations (#7698)
This commit is contained in:
parent
dd3fa0e430
commit
c75363fbc0
@ -1,14 +1,19 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
|
from asyncio import CancelledError
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
from vllm.config import ParallelConfig
|
from vllm.config import ParallelConfig
|
||||||
from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine
|
from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine
|
||||||
|
from vllm.outputs import RequestOutput as RealRequestOutput
|
||||||
|
|
||||||
|
from ..conftest import cleanup
|
||||||
from ..utils import wait_for_gpu_memory_to_clear
|
from ..utils import wait_for_gpu_memory_to_clear
|
||||||
|
|
||||||
|
|
||||||
@ -118,15 +123,38 @@ async def test_new_requests_event():
|
|||||||
os.environ.pop("VLLM_ALLOW_ENGINE_USE_RAY")
|
os.environ.pop("VLLM_ALLOW_ENGINE_USE_RAY")
|
||||||
|
|
||||||
|
|
||||||
def test_asyncio_run():
|
def start_engine():
|
||||||
wait_for_gpu_memory_to_clear(
|
wait_for_gpu_memory_to_clear(
|
||||||
devices=list(range(torch.cuda.device_count())),
|
devices=list(range(torch.cuda.device_count())),
|
||||||
threshold_bytes=2 * 2**30,
|
threshold_bytes=2 * 2**30,
|
||||||
timeout_s=60,
|
timeout_s=60,
|
||||||
)
|
)
|
||||||
|
|
||||||
engine = AsyncLLMEngine.from_engine_args(
|
return AsyncLLMEngine.from_engine_args(
|
||||||
AsyncEngineArgs(model="facebook/opt-125m"))
|
AsyncEngineArgs(model="facebook/opt-125m", enforce_eager=True))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(scope="module")
|
||||||
|
async def async_engine():
|
||||||
|
engine = await asyncio.get_event_loop().run_in_executor(executor=None,
|
||||||
|
func=start_engine)
|
||||||
|
try:
|
||||||
|
yield engine
|
||||||
|
finally:
|
||||||
|
engine.shutdown_background_loop()
|
||||||
|
del engine
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def should_do_global_cleanup_after_test(request) -> bool:
|
||||||
|
# So we can share the async engine fixture between these tests
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(scope="module")
|
||||||
|
async def test_asyncio_run(async_engine):
|
||||||
|
|
||||||
async def run(prompt: str):
|
async def run(prompt: str):
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
@ -134,17 +162,64 @@ def test_asyncio_run():
|
|||||||
max_tokens=32,
|
max_tokens=32,
|
||||||
)
|
)
|
||||||
|
|
||||||
async for output in engine.generate(prompt,
|
async for output in async_engine.generate(prompt,
|
||||||
sampling_params,
|
sampling_params,
|
||||||
request_id=prompt):
|
request_id=prompt):
|
||||||
final_output = output
|
final_output = output
|
||||||
return final_output
|
return final_output
|
||||||
|
|
||||||
async def generate():
|
results = await asyncio.gather(
|
||||||
return await asyncio.gather(
|
run("test0"),
|
||||||
run("test0"),
|
run("test1"),
|
||||||
run("test1"),
|
)
|
||||||
)
|
|
||||||
|
|
||||||
results = asyncio.run(generate())
|
|
||||||
assert len(results) == 2
|
assert len(results) == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(scope="module")
|
||||||
|
async def test_cancellation(async_engine):
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=0,
|
||||||
|
min_tokens=10,
|
||||||
|
max_tokens=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
with pytest.raises(CancelledError):
|
||||||
|
async for output in async_engine.generate("test2",
|
||||||
|
sampling_params,
|
||||||
|
request_id="test2"):
|
||||||
|
assert not output.finished
|
||||||
|
i += 1
|
||||||
|
if i == 5:
|
||||||
|
await async_engine.abort("test2")
|
||||||
|
|
||||||
|
assert i == 5
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(scope="module")
|
||||||
|
async def test_delayed_generator(async_engine):
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=0,
|
||||||
|
min_tokens=10,
|
||||||
|
max_tokens=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
stream = async_engine.generate("test3",
|
||||||
|
sampling_params,
|
||||||
|
request_id="test3")
|
||||||
|
i = 0
|
||||||
|
final_output: Optional[RealRequestOutput] = None
|
||||||
|
async for output in stream:
|
||||||
|
final_output = output
|
||||||
|
if i == 0:
|
||||||
|
# wait for generation to complete before consuming
|
||||||
|
# the remaining messages
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
if i < 9:
|
||||||
|
assert not output.finished
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
assert i == 10
|
||||||
|
assert final_output is not None
|
||||||
|
assert len(final_output.outputs[0].token_ids) == 10
|
||||||
|
assert final_output.finished
|
||||||
|
|||||||
@ -2,8 +2,8 @@ import asyncio
|
|||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import (AsyncGenerator, Callable, Dict, Iterable, List, Mapping,
|
from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
|
||||||
Optional, Set, Tuple, Type, Union)
|
Mapping, Optional, Set, Tuple, Type, Union)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from typing_extensions import assert_never
|
from typing_extensions import assert_never
|
||||||
@ -85,9 +85,8 @@ class AsyncStream:
|
|||||||
|
|
||||||
def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
|
def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
|
||||||
Exception]) -> None:
|
Exception]) -> None:
|
||||||
if self._finished:
|
if not self._finished:
|
||||||
return
|
self._queue.put_nowait(item)
|
||||||
self._queue.put_nowait(item)
|
|
||||||
|
|
||||||
def finish(
|
def finish(
|
||||||
self,
|
self,
|
||||||
@ -96,7 +95,7 @@ class AsyncStream:
|
|||||||
if not self._finished:
|
if not self._finished:
|
||||||
self._finished = True
|
self._finished = True
|
||||||
self._queue.put_nowait(
|
self._queue.put_nowait(
|
||||||
exception if exception is not None else STOP_ITERATION)
|
exception if self._is_raisable(exception) else STOP_ITERATION)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def finished(self) -> bool:
|
def finished(self) -> bool:
|
||||||
@ -106,9 +105,9 @@ class AsyncStream:
|
|||||||
self
|
self
|
||||||
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
|
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
|
||||||
try:
|
try:
|
||||||
while not self._finished:
|
while True:
|
||||||
result = await self._queue.get()
|
result = await self._queue.get()
|
||||||
if isinstance(result, Exception):
|
if self._is_raisable(result):
|
||||||
if result == STOP_ITERATION:
|
if result == STOP_ITERATION:
|
||||||
return
|
return
|
||||||
raise result
|
raise result
|
||||||
@ -117,6 +116,12 @@ class AsyncStream:
|
|||||||
self._cancel(self.request_id)
|
self._cancel(self.request_id)
|
||||||
raise asyncio.CancelledError from None
|
raise asyncio.CancelledError from None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_raisable(value: Any):
|
||||||
|
return isinstance(value, BaseException) or \
|
||||||
|
(isinstance(value, type) and \
|
||||||
|
issubclass(value, BaseException))
|
||||||
|
|
||||||
|
|
||||||
class RequestTracker:
|
class RequestTracker:
|
||||||
"""Synchronous abstraction for tracking requests."""
|
"""Synchronous abstraction for tracking requests."""
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user