[Bugfix] Fix request cancellation without polling (#11190)

This commit is contained in:
Joe Runde 2024-12-17 13:26:32 -07:00 committed by GitHub
parent f9ecbb18bf
commit 2d1b9baa8f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 164 additions and 103 deletions

View File

@ -1,6 +1,8 @@
import asyncio
from http import HTTPStatus from http import HTTPStatus
from typing import List from typing import List
import openai
import pytest import pytest
import pytest_asyncio import pytest_asyncio
import requests import requests
@ -103,3 +105,52 @@ async def test_check_health(server: RemoteOpenAIServer):
response = requests.get(server.url_for("health")) response = requests.get(server.url_for("health"))
assert response.status_code == HTTPStatus.OK assert response.status_code == HTTPStatus.OK
@pytest.mark.parametrize(
"server_args",
[
pytest.param(["--max-model-len", "10100"],
id="default-frontend-multiprocessing"),
pytest.param(
["--disable-frontend-multiprocessing", "--max-model-len", "10100"],
id="disable-frontend-multiprocessing")
],
indirect=True,
)
@pytest.mark.asyncio
async def test_request_cancellation(server: RemoteOpenAIServer):
# clunky test: send an ungodly amount of load in with short timeouts
# then ensure that it still responds quickly afterwards
chat_input = [{"role": "user", "content": "Write a long story"}]
client = server.get_async_client(timeout=0.5)
tasks = []
# Request about 2 million tokens
for _ in range(200):
task = asyncio.create_task(
client.chat.completions.create(messages=chat_input,
model=MODEL_NAME,
max_tokens=10000,
extra_body={"min_tokens": 10000}))
tasks.append(task)
done, pending = await asyncio.wait(tasks,
return_when=asyncio.ALL_COMPLETED)
# Make sure all requests were sent to the server and timed out
# (We don't want to hide other errors like 400s that would invalidate this
# test)
assert len(pending) == 0
for d in done:
with pytest.raises(openai.APITimeoutError):
d.result()
# If the server had not cancelled all the other requests, then it would not
# be able to respond to this one within the timeout
client = server.get_async_client(timeout=5)
response = await client.chat.completions.create(messages=chat_input,
model=MODEL_NAME,
max_tokens=10)
assert len(response.choices) == 1

View File

@ -1,7 +1,6 @@
import asyncio import asyncio
import os import os
import socket import socket
from functools import partial
from typing import AsyncIterator, Tuple from typing import AsyncIterator, Tuple
import pytest import pytest
@ -26,10 +25,7 @@ async def test_merge_async_iterators():
print(f"iterator {idx} cancelled") print(f"iterator {idx} cancelled")
iterators = [mock_async_iterator(i) for i in range(3)] iterators = [mock_async_iterator(i) for i in range(3)]
merged_iterator = merge_async_iterators(*iterators, merged_iterator = merge_async_iterators(*iterators)
is_cancelled=partial(asyncio.sleep,
0,
result=False))
async def stream_output(generator: AsyncIterator[Tuple[int, str]]): async def stream_output(generator: AsyncIterator[Tuple[int, str]]):
async for idx, output in generator: async for idx, output in generator:

View File

@ -163,12 +163,11 @@ class RemoteOpenAIServer:
api_key=self.DUMMY_API_KEY, api_key=self.DUMMY_API_KEY,
) )
def get_async_client(self): def get_async_client(self, **kwargs):
return openai.AsyncOpenAI( return openai.AsyncOpenAI(base_url=self.url_for("v1"),
base_url=self.url_for("v1"), api_key=self.DUMMY_API_KEY,
api_key=self.DUMMY_API_KEY, max_retries=0,
max_retries=0, **kwargs)
)
def _test_completion( def _test_completion(

View File

@ -1065,16 +1065,20 @@ class AsyncLLMEngine(EngineClient):
>>> # Process and return the final output >>> # Process and return the final output
>>> ... >>> ...
""" """
async for output in await self.add_request( try:
request_id, async for output in await self.add_request(
prompt, request_id,
sampling_params, prompt,
lora_request=lora_request, sampling_params,
trace_headers=trace_headers, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, trace_headers=trace_headers,
priority=priority, prompt_adapter_request=prompt_adapter_request,
): priority=priority,
yield LLMEngine.validate_output(output, RequestOutput) ):
yield LLMEngine.validate_output(output, RequestOutput)
except asyncio.CancelledError:
await self.abort(request_id)
raise
async def encode( async def encode(
self, self,
@ -1147,15 +1151,19 @@ class AsyncLLMEngine(EngineClient):
>>> # Process and return the final output >>> # Process and return the final output
>>> ... >>> ...
""" """
async for output in await self.add_request( try:
request_id, async for output in await self.add_request(
prompt, request_id,
pooling_params, prompt,
lora_request=lora_request, pooling_params,
trace_headers=trace_headers, lora_request=lora_request,
priority=priority, trace_headers=trace_headers,
): priority=priority,
yield LLMEngine.validate_output(output, PoolingRequestOutput) ):
yield LLMEngine.validate_output(output, PoolingRequestOutput)
except asyncio.CancelledError:
await self.abort(request_id)
raise
async def abort(self, request_id: str) -> None: async def abort(self, request_id: str) -> None:
"""Abort a request. """Abort a request.

View File

@ -17,11 +17,11 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.utils import with_cancellation
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import (FlexibleArgumentParser, iterate_with_cancellation, from vllm.utils import FlexibleArgumentParser, random_uuid
random_uuid)
from vllm.version import __version__ as VLLM_VERSION from vllm.version import __version__ as VLLM_VERSION
logger = init_logger("vllm.entrypoints.api_server") logger = init_logger("vllm.entrypoints.api_server")
@ -47,6 +47,11 @@ async def generate(request: Request) -> Response:
- other fields: the sampling parameters (See `SamplingParams` for details). - other fields: the sampling parameters (See `SamplingParams` for details).
""" """
request_dict = await request.json() request_dict = await request.json()
return await _generate(request_dict, raw_request=request)
@with_cancellation
async def _generate(request_dict: dict, raw_request: Request) -> Response:
prompt = request_dict.pop("prompt") prompt = request_dict.pop("prompt")
stream = request_dict.pop("stream", False) stream = request_dict.pop("stream", False)
sampling_params = SamplingParams(**request_dict) sampling_params = SamplingParams(**request_dict)
@ -54,8 +59,6 @@ async def generate(request: Request) -> Response:
assert engine is not None assert engine is not None
results_generator = engine.generate(prompt, sampling_params, request_id) results_generator = engine.generate(prompt, sampling_params, request_id)
results_generator = iterate_with_cancellation(
results_generator, is_cancelled=request.is_disconnected)
# Streaming case # Streaming case
async def stream_results() -> AsyncGenerator[bytes, None]: async def stream_results() -> AsyncGenerator[bytes, None]:

View File

@ -59,6 +59,7 @@ from vllm.entrypoints.openai.serving_score import OpenAIServingScores
from vllm.entrypoints.openai.serving_tokenization import ( from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization) OpenAIServingTokenization)
from vllm.entrypoints.openai.tool_parsers import ToolParserManager from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.entrypoints.utils import with_cancellation
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import (FlexibleArgumentParser, get_open_zmq_ipc_path, from vllm.utils import (FlexibleArgumentParser, get_open_zmq_ipc_path,
@ -311,6 +312,7 @@ async def health(raw_request: Request) -> Response:
@router.post("/tokenize") @router.post("/tokenize")
@with_cancellation
async def tokenize(request: TokenizeRequest, raw_request: Request): async def tokenize(request: TokenizeRequest, raw_request: Request):
handler = tokenization(raw_request) handler = tokenization(raw_request)
@ -325,6 +327,7 @@ async def tokenize(request: TokenizeRequest, raw_request: Request):
@router.post("/detokenize") @router.post("/detokenize")
@with_cancellation
async def detokenize(request: DetokenizeRequest, raw_request: Request): async def detokenize(request: DetokenizeRequest, raw_request: Request):
handler = tokenization(raw_request) handler = tokenization(raw_request)
@ -353,6 +356,7 @@ async def show_version():
@router.post("/v1/chat/completions") @router.post("/v1/chat/completions")
@with_cancellation
async def create_chat_completion(request: ChatCompletionRequest, async def create_chat_completion(request: ChatCompletionRequest,
raw_request: Request): raw_request: Request):
handler = chat(raw_request) handler = chat(raw_request)
@ -373,6 +377,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
@router.post("/v1/completions") @router.post("/v1/completions")
@with_cancellation
async def create_completion(request: CompletionRequest, raw_request: Request): async def create_completion(request: CompletionRequest, raw_request: Request):
handler = completion(raw_request) handler = completion(raw_request)
if handler is None: if handler is None:
@ -390,6 +395,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
@router.post("/v1/embeddings") @router.post("/v1/embeddings")
@with_cancellation
async def create_embedding(request: EmbeddingRequest, raw_request: Request): async def create_embedding(request: EmbeddingRequest, raw_request: Request):
handler = embedding(raw_request) handler = embedding(raw_request)
if handler is None: if handler is None:
@ -407,6 +413,7 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
@router.post("/score") @router.post("/score")
@with_cancellation
async def create_score(request: ScoreRequest, raw_request: Request): async def create_score(request: ScoreRequest, raw_request: Request):
handler = score(raw_request) handler = score(raw_request)
if handler is None: if handler is None:
@ -424,6 +431,7 @@ async def create_score(request: ScoreRequest, raw_request: Request):
@router.post("/v1/score") @router.post("/v1/score")
@with_cancellation
async def create_score_v1(request: ScoreRequest, raw_request: Request): async def create_score_v1(request: ScoreRequest, raw_request: Request):
logger.warning( logger.warning(
"To indicate that Score API is not part of standard OpenAI API, we " "To indicate that Score API is not part of standard OpenAI API, we "

View File

@ -32,7 +32,6 @@ from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.transformers_utils.tokenizers import maybe_serialize_tool_calls from vllm.transformers_utils.tokenizers import maybe_serialize_tool_calls
from vllm.utils import iterate_with_cancellation
logger = init_logger(__name__) logger = init_logger(__name__)
@ -234,10 +233,6 @@ class OpenAIServingChat(OpenAIServing):
assert len(generators) == 1 assert len(generators) == 1
result_generator, = generators result_generator, = generators
if raw_request:
result_generator = iterate_with_cancellation(
result_generator, raw_request.is_disconnected)
# Streaming response # Streaming response
if request.stream: if request.stream:
return self.chat_completion_stream_generator( return self.chat_completion_stream_generator(

View File

@ -159,8 +159,7 @@ class OpenAIServingCompletion(OpenAIServing):
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e)) return self.create_error_response(str(e))
result_generator = merge_async_iterators( result_generator = merge_async_iterators(*generators)
*generators, is_cancelled=raw_request.is_disconnected)
model_name = self._get_model_name(lora_request) model_name = self._get_model_name(lora_request)
num_prompts = len(engine_prompts) num_prompts = len(engine_prompts)

View File

@ -202,10 +202,7 @@ class OpenAIServingEmbedding(OpenAIServing):
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e)) return self.create_error_response(str(e))
result_generator = merge_async_iterators( result_generator = merge_async_iterators(*generators)
*generators,
is_cancelled=raw_request.is_disconnected if raw_request else None,
)
num_prompts = len(engine_prompts) num_prompts = len(engine_prompts)

View File

@ -186,10 +186,7 @@ class OpenAIServingScores(OpenAIServing):
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e)) return self.create_error_response(str(e))
result_generator = merge_async_iterators( result_generator = merge_async_iterators(*generators)
*generators,
is_cancelled=raw_request.is_disconnected if raw_request else None,
)
num_prompts = len(engine_prompts) num_prompts = len(engine_prompts)

57
vllm/entrypoints/utils.py Normal file
View File

@ -0,0 +1,57 @@
import asyncio
import functools
from fastapi import Request
async def listen_for_disconnect(request: Request) -> None:
"""Returns if a disconnect message is received"""
while True:
message = await request.receive()
if message["type"] == "http.disconnect":
break
def with_cancellation(handler_func):
"""Decorator that allows a route handler to be cancelled by client
disconnections.
This does _not_ use request.is_disconnected, which does not work with
middleware. Instead this follows the pattern from
starlette.StreamingResponse, which simultaneously awaits on two tasks- one
to wait for an http disconnect message, and the other to do the work that we
want done. When the first task finishes, the other is cancelled.
A core assumption of this method is that the body of the request has already
been read. This is a safe assumption to make for fastapi handlers that have
already parsed the body of the request into a pydantic model for us.
This decorator is unsafe to use elsewhere, as it will consume and throw away
all incoming messages for the request while it looks for a disconnect
message.
In the case where a `StreamingResponse` is returned by the handler, this
wrapper will stop listening for disconnects and instead the response object
will start listening for disconnects.
"""
# Functools.wraps is required for this wrapper to appear to fastapi as a
# normal route handler, with the correct request type hinting.
@functools.wraps(handler_func)
async def wrapper(*args, **kwargs):
# The request is either the second positional arg or `raw_request`
request = args[1] if len(args) > 1 else kwargs["raw_request"]
handler_task = asyncio.create_task(handler_func(*args, **kwargs))
cancellation_task = asyncio.create_task(listen_for_disconnect(request))
done, pending = await asyncio.wait([handler_task, cancellation_task],
return_when=asyncio.FIRST_COMPLETED)
for task in pending:
task.cancel()
if handler_task in done:
return handler_task.result()
return None
return wrapper

View File

@ -20,7 +20,7 @@ import time
import uuid import uuid
import warnings import warnings
import weakref import weakref
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Future, Task from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task
from collections import UserDict, defaultdict from collections import UserDict, defaultdict
from collections.abc import Iterable, Mapping from collections.abc import Iterable, Mapping
from dataclasses import dataclass, field from dataclasses import dataclass, field
@ -370,72 +370,23 @@ def _next_task(iterator: AsyncGenerator[T, None],
return loop.create_task(iterator.__anext__()) # type: ignore[arg-type] return loop.create_task(iterator.__anext__()) # type: ignore[arg-type]
async def iterate_with_cancellation(
iterator: AsyncGenerator[T, None],
is_cancelled: Callable[[], Awaitable[bool]],
) -> AsyncGenerator[T, None]:
"""Convert async iterator into one that polls the provided function
at least once per second to check for client cancellation.
"""
loop = asyncio.get_running_loop()
awaits: List[Future[T]] = [_next_task(iterator, loop)]
next_cancel_check: float = 0
while True:
done, pending = await asyncio.wait(awaits, timeout=1.5)
# Check for cancellation at most once per second
time_now = time.time()
if time_now >= next_cancel_check:
if await is_cancelled():
with contextlib.suppress(BaseException):
awaits[0].cancel()
await iterator.aclose()
raise asyncio.CancelledError("client cancelled")
next_cancel_check = time_now + 1
if done:
try:
item = await awaits[0]
awaits[0] = _next_task(iterator, loop)
yield item
except StopAsyncIteration:
# we are done
return
async def merge_async_iterators( async def merge_async_iterators(
*iterators: AsyncGenerator[T, None], *iterators: AsyncGenerator[T,
is_cancelled: Optional[Callable[[], Awaitable[bool]]] = None, None], ) -> AsyncGenerator[Tuple[int, T], None]:
) -> AsyncGenerator[Tuple[int, T], None]:
"""Merge multiple asynchronous iterators into a single iterator. """Merge multiple asynchronous iterators into a single iterator.
This method handle the case where some iterators finish before others. This method handle the case where some iterators finish before others.
When it yields, it yields a tuple (i, item) where i is the index of the When it yields, it yields a tuple (i, item) where i is the index of the
iterator that yields the item. iterator that yields the item.
It also optionally polls a provided function at least once per second
to check for client cancellation.
""" """
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
awaits = {_next_task(pair[1], loop): pair for pair in enumerate(iterators)} awaits = {_next_task(pair[1], loop): pair for pair in enumerate(iterators)}
timeout = None if is_cancelled is None else 1.5
next_cancel_check: float = 0
try: try:
while awaits: while awaits:
done, pending = await asyncio.wait(awaits.keys(), done, _ = await asyncio.wait(awaits.keys(),
return_when=FIRST_COMPLETED, return_when=FIRST_COMPLETED)
timeout=timeout)
if is_cancelled is not None:
# Check for cancellation at most once per second
time_now = time.time()
if time_now >= next_cancel_check:
if await is_cancelled():
raise asyncio.CancelledError("client cancelled")
next_cancel_check = time_now + 1
for d in done: for d in done:
pair = awaits.pop(d) pair = awaits.pop(d)
try: try: