[Bugfix] Fix request cancellation without polling (#11190)

This commit is contained in:
Joe Runde 2024-12-17 13:26:32 -07:00 committed by GitHub
parent f9ecbb18bf
commit 2d1b9baa8f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 164 additions and 103 deletions

View File

@ -1,6 +1,8 @@
import asyncio
from http import HTTPStatus
from typing import List
import openai
import pytest
import pytest_asyncio
import requests
@ -103,3 +105,52 @@ async def test_check_health(server: RemoteOpenAIServer):
response = requests.get(server.url_for("health"))
assert response.status_code == HTTPStatus.OK
@pytest.mark.parametrize(
"server_args",
[
pytest.param(["--max-model-len", "10100"],
id="default-frontend-multiprocessing"),
pytest.param(
["--disable-frontend-multiprocessing", "--max-model-len", "10100"],
id="disable-frontend-multiprocessing")
],
indirect=True,
)
@pytest.mark.asyncio
async def test_request_cancellation(server: RemoteOpenAIServer):
# clunky test: send an ungodly amount of load in with short timeouts
# then ensure that it still responds quickly afterwards
chat_input = [{"role": "user", "content": "Write a long story"}]
client = server.get_async_client(timeout=0.5)
tasks = []
# Request about 2 million tokens
for _ in range(200):
task = asyncio.create_task(
client.chat.completions.create(messages=chat_input,
model=MODEL_NAME,
max_tokens=10000,
extra_body={"min_tokens": 10000}))
tasks.append(task)
done, pending = await asyncio.wait(tasks,
return_when=asyncio.ALL_COMPLETED)
# Make sure all requests were sent to the server and timed out
# (We don't want to hide other errors like 400s that would invalidate this
# test)
assert len(pending) == 0
for d in done:
with pytest.raises(openai.APITimeoutError):
d.result()
# If the server had not cancelled all the other requests, then it would not
# be able to respond to this one within the timeout
client = server.get_async_client(timeout=5)
response = await client.chat.completions.create(messages=chat_input,
model=MODEL_NAME,
max_tokens=10)
assert len(response.choices) == 1

View File

@ -1,7 +1,6 @@
import asyncio
import os
import socket
from functools import partial
from typing import AsyncIterator, Tuple
import pytest
@ -26,10 +25,7 @@ async def test_merge_async_iterators():
print(f"iterator {idx} cancelled")
iterators = [mock_async_iterator(i) for i in range(3)]
merged_iterator = merge_async_iterators(*iterators,
is_cancelled=partial(asyncio.sleep,
0,
result=False))
merged_iterator = merge_async_iterators(*iterators)
async def stream_output(generator: AsyncIterator[Tuple[int, str]]):
async for idx, output in generator:

View File

@ -163,12 +163,11 @@ class RemoteOpenAIServer:
api_key=self.DUMMY_API_KEY,
)
def get_async_client(self):
return openai.AsyncOpenAI(
base_url=self.url_for("v1"),
api_key=self.DUMMY_API_KEY,
max_retries=0,
)
def get_async_client(self, **kwargs):
return openai.AsyncOpenAI(base_url=self.url_for("v1"),
api_key=self.DUMMY_API_KEY,
max_retries=0,
**kwargs)
def _test_completion(

View File

@ -1065,16 +1065,20 @@ class AsyncLLMEngine(EngineClient):
>>> # Process and return the final output
>>> ...
"""
async for output in await self.add_request(
request_id,
prompt,
sampling_params,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
):
yield LLMEngine.validate_output(output, RequestOutput)
try:
async for output in await self.add_request(
request_id,
prompt,
sampling_params,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
):
yield LLMEngine.validate_output(output, RequestOutput)
except asyncio.CancelledError:
await self.abort(request_id)
raise
async def encode(
self,
@ -1147,15 +1151,19 @@ class AsyncLLMEngine(EngineClient):
>>> # Process and return the final output
>>> ...
"""
async for output in await self.add_request(
request_id,
prompt,
pooling_params,
lora_request=lora_request,
trace_headers=trace_headers,
priority=priority,
):
yield LLMEngine.validate_output(output, PoolingRequestOutput)
try:
async for output in await self.add_request(
request_id,
prompt,
pooling_params,
lora_request=lora_request,
trace_headers=trace_headers,
priority=priority,
):
yield LLMEngine.validate_output(output, PoolingRequestOutput)
except asyncio.CancelledError:
await self.abort(request_id)
raise
async def abort(self, request_id: str) -> None:
"""Abort a request.

View File

@ -17,11 +17,11 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.utils import with_cancellation
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.usage.usage_lib import UsageContext
from vllm.utils import (FlexibleArgumentParser, iterate_with_cancellation,
random_uuid)
from vllm.utils import FlexibleArgumentParser, random_uuid
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger("vllm.entrypoints.api_server")
@ -47,6 +47,11 @@ async def generate(request: Request) -> Response:
- other fields: the sampling parameters (See `SamplingParams` for details).
"""
request_dict = await request.json()
return await _generate(request_dict, raw_request=request)
@with_cancellation
async def _generate(request_dict: dict, raw_request: Request) -> Response:
prompt = request_dict.pop("prompt")
stream = request_dict.pop("stream", False)
sampling_params = SamplingParams(**request_dict)
@ -54,8 +59,6 @@ async def generate(request: Request) -> Response:
assert engine is not None
results_generator = engine.generate(prompt, sampling_params, request_id)
results_generator = iterate_with_cancellation(
results_generator, is_cancelled=request.is_disconnected)
# Streaming case
async def stream_results() -> AsyncGenerator[bytes, None]:

View File

@ -59,6 +59,7 @@ from vllm.entrypoints.openai.serving_score import OpenAIServingScores
from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization)
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.entrypoints.utils import with_cancellation
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
from vllm.utils import (FlexibleArgumentParser, get_open_zmq_ipc_path,
@ -311,6 +312,7 @@ async def health(raw_request: Request) -> Response:
@router.post("/tokenize")
@with_cancellation
async def tokenize(request: TokenizeRequest, raw_request: Request):
handler = tokenization(raw_request)
@ -325,6 +327,7 @@ async def tokenize(request: TokenizeRequest, raw_request: Request):
@router.post("/detokenize")
@with_cancellation
async def detokenize(request: DetokenizeRequest, raw_request: Request):
handler = tokenization(raw_request)
@ -353,6 +356,7 @@ async def show_version():
@router.post("/v1/chat/completions")
@with_cancellation
async def create_chat_completion(request: ChatCompletionRequest,
raw_request: Request):
handler = chat(raw_request)
@ -373,6 +377,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
@router.post("/v1/completions")
@with_cancellation
async def create_completion(request: CompletionRequest, raw_request: Request):
handler = completion(raw_request)
if handler is None:
@ -390,6 +395,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
@router.post("/v1/embeddings")
@with_cancellation
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
handler = embedding(raw_request)
if handler is None:
@ -407,6 +413,7 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
@router.post("/score")
@with_cancellation
async def create_score(request: ScoreRequest, raw_request: Request):
handler = score(raw_request)
if handler is None:
@ -424,6 +431,7 @@ async def create_score(request: ScoreRequest, raw_request: Request):
@router.post("/v1/score")
@with_cancellation
async def create_score_v1(request: ScoreRequest, raw_request: Request):
logger.warning(
"To indicate that Score API is not part of standard OpenAI API, we "

View File

@ -32,7 +32,6 @@ from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.transformers_utils.tokenizers import maybe_serialize_tool_calls
from vllm.utils import iterate_with_cancellation
logger = init_logger(__name__)
@ -234,10 +233,6 @@ class OpenAIServingChat(OpenAIServing):
assert len(generators) == 1
result_generator, = generators
if raw_request:
result_generator = iterate_with_cancellation(
result_generator, raw_request.is_disconnected)
# Streaming response
if request.stream:
return self.chat_completion_stream_generator(

View File

@ -159,8 +159,7 @@ class OpenAIServingCompletion(OpenAIServing):
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
result_generator = merge_async_iterators(
*generators, is_cancelled=raw_request.is_disconnected)
result_generator = merge_async_iterators(*generators)
model_name = self._get_model_name(lora_request)
num_prompts = len(engine_prompts)

View File

@ -202,10 +202,7 @@ class OpenAIServingEmbedding(OpenAIServing):
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
result_generator = merge_async_iterators(
*generators,
is_cancelled=raw_request.is_disconnected if raw_request else None,
)
result_generator = merge_async_iterators(*generators)
num_prompts = len(engine_prompts)

View File

@ -186,10 +186,7 @@ class OpenAIServingScores(OpenAIServing):
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
result_generator = merge_async_iterators(
*generators,
is_cancelled=raw_request.is_disconnected if raw_request else None,
)
result_generator = merge_async_iterators(*generators)
num_prompts = len(engine_prompts)

57
vllm/entrypoints/utils.py Normal file
View File

@ -0,0 +1,57 @@
import asyncio
import functools
from fastapi import Request
async def listen_for_disconnect(request: Request) -> None:
"""Returns if a disconnect message is received"""
while True:
message = await request.receive()
if message["type"] == "http.disconnect":
break
def with_cancellation(handler_func):
"""Decorator that allows a route handler to be cancelled by client
disconnections.
This does _not_ use request.is_disconnected, which does not work with
middleware. Instead this follows the pattern from
starlette.StreamingResponse, which simultaneously awaits on two tasks- one
to wait for an http disconnect message, and the other to do the work that we
want done. When the first task finishes, the other is cancelled.
A core assumption of this method is that the body of the request has already
been read. This is a safe assumption to make for fastapi handlers that have
already parsed the body of the request into a pydantic model for us.
This decorator is unsafe to use elsewhere, as it will consume and throw away
all incoming messages for the request while it looks for a disconnect
message.
In the case where a `StreamingResponse` is returned by the handler, this
wrapper will stop listening for disconnects and instead the response object
will start listening for disconnects.
"""
# Functools.wraps is required for this wrapper to appear to fastapi as a
# normal route handler, with the correct request type hinting.
@functools.wraps(handler_func)
async def wrapper(*args, **kwargs):
# The request is either the second positional arg or `raw_request`
request = args[1] if len(args) > 1 else kwargs["raw_request"]
handler_task = asyncio.create_task(handler_func(*args, **kwargs))
cancellation_task = asyncio.create_task(listen_for_disconnect(request))
done, pending = await asyncio.wait([handler_task, cancellation_task],
return_when=asyncio.FIRST_COMPLETED)
for task in pending:
task.cancel()
if handler_task in done:
return handler_task.result()
return None
return wrapper

View File

@ -20,7 +20,7 @@ import time
import uuid
import warnings
import weakref
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Future, Task
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task
from collections import UserDict, defaultdict
from collections.abc import Iterable, Mapping
from dataclasses import dataclass, field
@ -370,72 +370,23 @@ def _next_task(iterator: AsyncGenerator[T, None],
return loop.create_task(iterator.__anext__()) # type: ignore[arg-type]
async def iterate_with_cancellation(
iterator: AsyncGenerator[T, None],
is_cancelled: Callable[[], Awaitable[bool]],
) -> AsyncGenerator[T, None]:
"""Convert async iterator into one that polls the provided function
at least once per second to check for client cancellation.
"""
loop = asyncio.get_running_loop()
awaits: List[Future[T]] = [_next_task(iterator, loop)]
next_cancel_check: float = 0
while True:
done, pending = await asyncio.wait(awaits, timeout=1.5)
# Check for cancellation at most once per second
time_now = time.time()
if time_now >= next_cancel_check:
if await is_cancelled():
with contextlib.suppress(BaseException):
awaits[0].cancel()
await iterator.aclose()
raise asyncio.CancelledError("client cancelled")
next_cancel_check = time_now + 1
if done:
try:
item = await awaits[0]
awaits[0] = _next_task(iterator, loop)
yield item
except StopAsyncIteration:
# we are done
return
async def merge_async_iterators(
*iterators: AsyncGenerator[T, None],
is_cancelled: Optional[Callable[[], Awaitable[bool]]] = None,
) -> AsyncGenerator[Tuple[int, T], None]:
*iterators: AsyncGenerator[T,
None], ) -> AsyncGenerator[Tuple[int, T], None]:
"""Merge multiple asynchronous iterators into a single iterator.
This method handle the case where some iterators finish before others.
When it yields, it yields a tuple (i, item) where i is the index of the
iterator that yields the item.
It also optionally polls a provided function at least once per second
to check for client cancellation.
"""
loop = asyncio.get_running_loop()
awaits = {_next_task(pair[1], loop): pair for pair in enumerate(iterators)}
timeout = None if is_cancelled is None else 1.5
next_cancel_check: float = 0
try:
while awaits:
done, pending = await asyncio.wait(awaits.keys(),
return_when=FIRST_COMPLETED,
timeout=timeout)
if is_cancelled is not None:
# Check for cancellation at most once per second
time_now = time.time()
if time_now >= next_cancel_check:
if await is_cancelled():
raise asyncio.CancelledError("client cancelled")
next_cancel_check = time_now + 1
done, _ = await asyncio.wait(awaits.keys(),
return_when=FIRST_COMPLETED)
for d in done:
pair = awaits.pop(d)
try: