mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:54:56 +08:00
[Bugfix] Fix request cancellation without polling (#11190)
This commit is contained in:
parent
f9ecbb18bf
commit
2d1b9baa8f
@ -1,6 +1,8 @@
|
||||
import asyncio
|
||||
from http import HTTPStatus
|
||||
from typing import List
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import requests
|
||||
@ -103,3 +105,52 @@ async def test_check_health(server: RemoteOpenAIServer):
|
||||
response = requests.get(server.url_for("health"))
|
||||
|
||||
assert response.status_code == HTTPStatus.OK
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"server_args",
|
||||
[
|
||||
pytest.param(["--max-model-len", "10100"],
|
||||
id="default-frontend-multiprocessing"),
|
||||
pytest.param(
|
||||
["--disable-frontend-multiprocessing", "--max-model-len", "10100"],
|
||||
id="disable-frontend-multiprocessing")
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_cancellation(server: RemoteOpenAIServer):
|
||||
# clunky test: send an ungodly amount of load in with short timeouts
|
||||
# then ensure that it still responds quickly afterwards
|
||||
|
||||
chat_input = [{"role": "user", "content": "Write a long story"}]
|
||||
client = server.get_async_client(timeout=0.5)
|
||||
tasks = []
|
||||
# Request about 2 million tokens
|
||||
for _ in range(200):
|
||||
task = asyncio.create_task(
|
||||
client.chat.completions.create(messages=chat_input,
|
||||
model=MODEL_NAME,
|
||||
max_tokens=10000,
|
||||
extra_body={"min_tokens": 10000}))
|
||||
tasks.append(task)
|
||||
|
||||
done, pending = await asyncio.wait(tasks,
|
||||
return_when=asyncio.ALL_COMPLETED)
|
||||
|
||||
# Make sure all requests were sent to the server and timed out
|
||||
# (We don't want to hide other errors like 400s that would invalidate this
|
||||
# test)
|
||||
assert len(pending) == 0
|
||||
for d in done:
|
||||
with pytest.raises(openai.APITimeoutError):
|
||||
d.result()
|
||||
|
||||
# If the server had not cancelled all the other requests, then it would not
|
||||
# be able to respond to this one within the timeout
|
||||
client = server.get_async_client(timeout=5)
|
||||
response = await client.chat.completions.create(messages=chat_input,
|
||||
model=MODEL_NAME,
|
||||
max_tokens=10)
|
||||
|
||||
assert len(response.choices) == 1
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import asyncio
|
||||
import os
|
||||
import socket
|
||||
from functools import partial
|
||||
from typing import AsyncIterator, Tuple
|
||||
|
||||
import pytest
|
||||
@ -26,10 +25,7 @@ async def test_merge_async_iterators():
|
||||
print(f"iterator {idx} cancelled")
|
||||
|
||||
iterators = [mock_async_iterator(i) for i in range(3)]
|
||||
merged_iterator = merge_async_iterators(*iterators,
|
||||
is_cancelled=partial(asyncio.sleep,
|
||||
0,
|
||||
result=False))
|
||||
merged_iterator = merge_async_iterators(*iterators)
|
||||
|
||||
async def stream_output(generator: AsyncIterator[Tuple[int, str]]):
|
||||
async for idx, output in generator:
|
||||
|
||||
@ -163,12 +163,11 @@ class RemoteOpenAIServer:
|
||||
api_key=self.DUMMY_API_KEY,
|
||||
)
|
||||
|
||||
def get_async_client(self):
|
||||
return openai.AsyncOpenAI(
|
||||
base_url=self.url_for("v1"),
|
||||
api_key=self.DUMMY_API_KEY,
|
||||
max_retries=0,
|
||||
)
|
||||
def get_async_client(self, **kwargs):
|
||||
return openai.AsyncOpenAI(base_url=self.url_for("v1"),
|
||||
api_key=self.DUMMY_API_KEY,
|
||||
max_retries=0,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def _test_completion(
|
||||
|
||||
@ -1065,16 +1065,20 @@ class AsyncLLMEngine(EngineClient):
|
||||
>>> # Process and return the final output
|
||||
>>> ...
|
||||
"""
|
||||
async for output in await self.add_request(
|
||||
request_id,
|
||||
prompt,
|
||||
sampling_params,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=priority,
|
||||
):
|
||||
yield LLMEngine.validate_output(output, RequestOutput)
|
||||
try:
|
||||
async for output in await self.add_request(
|
||||
request_id,
|
||||
prompt,
|
||||
sampling_params,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=priority,
|
||||
):
|
||||
yield LLMEngine.validate_output(output, RequestOutput)
|
||||
except asyncio.CancelledError:
|
||||
await self.abort(request_id)
|
||||
raise
|
||||
|
||||
async def encode(
|
||||
self,
|
||||
@ -1147,15 +1151,19 @@ class AsyncLLMEngine(EngineClient):
|
||||
>>> # Process and return the final output
|
||||
>>> ...
|
||||
"""
|
||||
async for output in await self.add_request(
|
||||
request_id,
|
||||
prompt,
|
||||
pooling_params,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=priority,
|
||||
):
|
||||
yield LLMEngine.validate_output(output, PoolingRequestOutput)
|
||||
try:
|
||||
async for output in await self.add_request(
|
||||
request_id,
|
||||
prompt,
|
||||
pooling_params,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=priority,
|
||||
):
|
||||
yield LLMEngine.validate_output(output, PoolingRequestOutput)
|
||||
except asyncio.CancelledError:
|
||||
await self.abort(request_id)
|
||||
raise
|
||||
|
||||
async def abort(self, request_id: str) -> None:
|
||||
"""Abort a request.
|
||||
|
||||
@ -17,11 +17,11 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.entrypoints.launcher import serve_http
|
||||
from vllm.entrypoints.utils import with_cancellation
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import (FlexibleArgumentParser, iterate_with_cancellation,
|
||||
random_uuid)
|
||||
from vllm.utils import FlexibleArgumentParser, random_uuid
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
logger = init_logger("vllm.entrypoints.api_server")
|
||||
@ -47,6 +47,11 @@ async def generate(request: Request) -> Response:
|
||||
- other fields: the sampling parameters (See `SamplingParams` for details).
|
||||
"""
|
||||
request_dict = await request.json()
|
||||
return await _generate(request_dict, raw_request=request)
|
||||
|
||||
|
||||
@with_cancellation
|
||||
async def _generate(request_dict: dict, raw_request: Request) -> Response:
|
||||
prompt = request_dict.pop("prompt")
|
||||
stream = request_dict.pop("stream", False)
|
||||
sampling_params = SamplingParams(**request_dict)
|
||||
@ -54,8 +59,6 @@ async def generate(request: Request) -> Response:
|
||||
|
||||
assert engine is not None
|
||||
results_generator = engine.generate(prompt, sampling_params, request_id)
|
||||
results_generator = iterate_with_cancellation(
|
||||
results_generator, is_cancelled=request.is_disconnected)
|
||||
|
||||
# Streaming case
|
||||
async def stream_results() -> AsyncGenerator[bytes, None]:
|
||||
|
||||
@ -59,6 +59,7 @@ from vllm.entrypoints.openai.serving_score import OpenAIServingScores
|
||||
from vllm.entrypoints.openai.serving_tokenization import (
|
||||
OpenAIServingTokenization)
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
||||
from vllm.entrypoints.utils import with_cancellation
|
||||
from vllm.logger import init_logger
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import (FlexibleArgumentParser, get_open_zmq_ipc_path,
|
||||
@ -311,6 +312,7 @@ async def health(raw_request: Request) -> Response:
|
||||
|
||||
|
||||
@router.post("/tokenize")
|
||||
@with_cancellation
|
||||
async def tokenize(request: TokenizeRequest, raw_request: Request):
|
||||
handler = tokenization(raw_request)
|
||||
|
||||
@ -325,6 +327,7 @@ async def tokenize(request: TokenizeRequest, raw_request: Request):
|
||||
|
||||
|
||||
@router.post("/detokenize")
|
||||
@with_cancellation
|
||||
async def detokenize(request: DetokenizeRequest, raw_request: Request):
|
||||
handler = tokenization(raw_request)
|
||||
|
||||
@ -353,6 +356,7 @@ async def show_version():
|
||||
|
||||
|
||||
@router.post("/v1/chat/completions")
|
||||
@with_cancellation
|
||||
async def create_chat_completion(request: ChatCompletionRequest,
|
||||
raw_request: Request):
|
||||
handler = chat(raw_request)
|
||||
@ -373,6 +377,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
||||
|
||||
|
||||
@router.post("/v1/completions")
|
||||
@with_cancellation
|
||||
async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
handler = completion(raw_request)
|
||||
if handler is None:
|
||||
@ -390,6 +395,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
|
||||
|
||||
@router.post("/v1/embeddings")
|
||||
@with_cancellation
|
||||
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
|
||||
handler = embedding(raw_request)
|
||||
if handler is None:
|
||||
@ -407,6 +413,7 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
|
||||
|
||||
|
||||
@router.post("/score")
|
||||
@with_cancellation
|
||||
async def create_score(request: ScoreRequest, raw_request: Request):
|
||||
handler = score(raw_request)
|
||||
if handler is None:
|
||||
@ -424,6 +431,7 @@ async def create_score(request: ScoreRequest, raw_request: Request):
|
||||
|
||||
|
||||
@router.post("/v1/score")
|
||||
@with_cancellation
|
||||
async def create_score_v1(request: ScoreRequest, raw_request: Request):
|
||||
logger.warning(
|
||||
"To indicate that Score API is not part of standard OpenAI API, we "
|
||||
|
||||
@ -32,7 +32,6 @@ from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.transformers_utils.tokenizers import maybe_serialize_tool_calls
|
||||
from vllm.utils import iterate_with_cancellation
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -234,10 +233,6 @@ class OpenAIServingChat(OpenAIServing):
|
||||
assert len(generators) == 1
|
||||
result_generator, = generators
|
||||
|
||||
if raw_request:
|
||||
result_generator = iterate_with_cancellation(
|
||||
result_generator, raw_request.is_disconnected)
|
||||
|
||||
# Streaming response
|
||||
if request.stream:
|
||||
return self.chat_completion_stream_generator(
|
||||
|
||||
@ -159,8 +159,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
result_generator = merge_async_iterators(
|
||||
*generators, is_cancelled=raw_request.is_disconnected)
|
||||
result_generator = merge_async_iterators(*generators)
|
||||
|
||||
model_name = self._get_model_name(lora_request)
|
||||
num_prompts = len(engine_prompts)
|
||||
|
||||
@ -202,10 +202,7 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
result_generator = merge_async_iterators(
|
||||
*generators,
|
||||
is_cancelled=raw_request.is_disconnected if raw_request else None,
|
||||
)
|
||||
result_generator = merge_async_iterators(*generators)
|
||||
|
||||
num_prompts = len(engine_prompts)
|
||||
|
||||
|
||||
@ -186,10 +186,7 @@ class OpenAIServingScores(OpenAIServing):
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
result_generator = merge_async_iterators(
|
||||
*generators,
|
||||
is_cancelled=raw_request.is_disconnected if raw_request else None,
|
||||
)
|
||||
result_generator = merge_async_iterators(*generators)
|
||||
|
||||
num_prompts = len(engine_prompts)
|
||||
|
||||
|
||||
57
vllm/entrypoints/utils.py
Normal file
57
vllm/entrypoints/utils.py
Normal file
@ -0,0 +1,57 @@
|
||||
import asyncio
|
||||
import functools
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
|
||||
async def listen_for_disconnect(request: Request) -> None:
|
||||
"""Returns if a disconnect message is received"""
|
||||
while True:
|
||||
message = await request.receive()
|
||||
if message["type"] == "http.disconnect":
|
||||
break
|
||||
|
||||
|
||||
def with_cancellation(handler_func):
|
||||
"""Decorator that allows a route handler to be cancelled by client
|
||||
disconnections.
|
||||
|
||||
This does _not_ use request.is_disconnected, which does not work with
|
||||
middleware. Instead this follows the pattern from
|
||||
starlette.StreamingResponse, which simultaneously awaits on two tasks- one
|
||||
to wait for an http disconnect message, and the other to do the work that we
|
||||
want done. When the first task finishes, the other is cancelled.
|
||||
|
||||
A core assumption of this method is that the body of the request has already
|
||||
been read. This is a safe assumption to make for fastapi handlers that have
|
||||
already parsed the body of the request into a pydantic model for us.
|
||||
This decorator is unsafe to use elsewhere, as it will consume and throw away
|
||||
all incoming messages for the request while it looks for a disconnect
|
||||
message.
|
||||
|
||||
In the case where a `StreamingResponse` is returned by the handler, this
|
||||
wrapper will stop listening for disconnects and instead the response object
|
||||
will start listening for disconnects.
|
||||
"""
|
||||
|
||||
# Functools.wraps is required for this wrapper to appear to fastapi as a
|
||||
# normal route handler, with the correct request type hinting.
|
||||
@functools.wraps(handler_func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
|
||||
# The request is either the second positional arg or `raw_request`
|
||||
request = args[1] if len(args) > 1 else kwargs["raw_request"]
|
||||
|
||||
handler_task = asyncio.create_task(handler_func(*args, **kwargs))
|
||||
cancellation_task = asyncio.create_task(listen_for_disconnect(request))
|
||||
|
||||
done, pending = await asyncio.wait([handler_task, cancellation_task],
|
||||
return_when=asyncio.FIRST_COMPLETED)
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
|
||||
if handler_task in done:
|
||||
return handler_task.result()
|
||||
return None
|
||||
|
||||
return wrapper
|
||||
@ -20,7 +20,7 @@ import time
|
||||
import uuid
|
||||
import warnings
|
||||
import weakref
|
||||
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Future, Task
|
||||
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task
|
||||
from collections import UserDict, defaultdict
|
||||
from collections.abc import Iterable, Mapping
|
||||
from dataclasses import dataclass, field
|
||||
@ -370,72 +370,23 @@ def _next_task(iterator: AsyncGenerator[T, None],
|
||||
return loop.create_task(iterator.__anext__()) # type: ignore[arg-type]
|
||||
|
||||
|
||||
async def iterate_with_cancellation(
|
||||
iterator: AsyncGenerator[T, None],
|
||||
is_cancelled: Callable[[], Awaitable[bool]],
|
||||
) -> AsyncGenerator[T, None]:
|
||||
"""Convert async iterator into one that polls the provided function
|
||||
at least once per second to check for client cancellation.
|
||||
"""
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
awaits: List[Future[T]] = [_next_task(iterator, loop)]
|
||||
next_cancel_check: float = 0
|
||||
while True:
|
||||
done, pending = await asyncio.wait(awaits, timeout=1.5)
|
||||
|
||||
# Check for cancellation at most once per second
|
||||
time_now = time.time()
|
||||
if time_now >= next_cancel_check:
|
||||
if await is_cancelled():
|
||||
with contextlib.suppress(BaseException):
|
||||
awaits[0].cancel()
|
||||
await iterator.aclose()
|
||||
raise asyncio.CancelledError("client cancelled")
|
||||
next_cancel_check = time_now + 1
|
||||
|
||||
if done:
|
||||
try:
|
||||
item = await awaits[0]
|
||||
awaits[0] = _next_task(iterator, loop)
|
||||
yield item
|
||||
except StopAsyncIteration:
|
||||
# we are done
|
||||
return
|
||||
|
||||
|
||||
async def merge_async_iterators(
|
||||
*iterators: AsyncGenerator[T, None],
|
||||
is_cancelled: Optional[Callable[[], Awaitable[bool]]] = None,
|
||||
) -> AsyncGenerator[Tuple[int, T], None]:
|
||||
*iterators: AsyncGenerator[T,
|
||||
None], ) -> AsyncGenerator[Tuple[int, T], None]:
|
||||
"""Merge multiple asynchronous iterators into a single iterator.
|
||||
|
||||
This method handle the case where some iterators finish before others.
|
||||
When it yields, it yields a tuple (i, item) where i is the index of the
|
||||
iterator that yields the item.
|
||||
|
||||
It also optionally polls a provided function at least once per second
|
||||
to check for client cancellation.
|
||||
"""
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
awaits = {_next_task(pair[1], loop): pair for pair in enumerate(iterators)}
|
||||
timeout = None if is_cancelled is None else 1.5
|
||||
next_cancel_check: float = 0
|
||||
try:
|
||||
while awaits:
|
||||
done, pending = await asyncio.wait(awaits.keys(),
|
||||
return_when=FIRST_COMPLETED,
|
||||
timeout=timeout)
|
||||
if is_cancelled is not None:
|
||||
# Check for cancellation at most once per second
|
||||
time_now = time.time()
|
||||
if time_now >= next_cancel_check:
|
||||
if await is_cancelled():
|
||||
raise asyncio.CancelledError("client cancelled")
|
||||
next_cancel_check = time_now + 1
|
||||
done, _ = await asyncio.wait(awaits.keys(),
|
||||
return_when=FIRST_COMPLETED)
|
||||
for d in done:
|
||||
pair = awaits.pop(d)
|
||||
try:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user