mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:05:02 +08:00
[Frontend] don't block event loop in tokenization (preprocess) in OpenAI compatible server (#10635)
Signed-off-by: Tomer Asida <tomera@ai21.com>
This commit is contained in:
parent
9b4b150395
commit
395b1c7454
137
tests/entrypoints/openai/test_async_tokenization.py
Normal file
137
tests/entrypoints/openai/test_async_tokenization.py
Normal file
@ -0,0 +1,137 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import random
|
||||
import time
|
||||
from typing import Callable
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import requests
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server(): # noqa: F811
|
||||
args = [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--max-model-len",
|
||||
"8192",
|
||||
"--enforce-eager",
|
||||
"--max-num-seqs",
|
||||
"128",
|
||||
"--load-format",
|
||||
"dummy",
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server):
|
||||
async with server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
ids=["completion", "chat"],
|
||||
argnames=["create_func_gen", "content_body"],
|
||||
argvalues=[
|
||||
(lambda x: x.completions.create, {
|
||||
"prompt": " ".join(['A'] * 10_000)
|
||||
}),
|
||||
(lambda x: x.chat.completions.create, {
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": " ".join(['A'] * 10_000)
|
||||
}]
|
||||
}),
|
||||
],
|
||||
)
|
||||
async def test_with_and_without_truncate(
|
||||
server: RemoteOpenAIServer,
|
||||
client: openai.AsyncOpenAI,
|
||||
create_func_gen: Callable,
|
||||
content_body: dict,
|
||||
):
|
||||
create_func = create_func_gen(client)
|
||||
body = {"model": MODEL_NAME, **content_body, "max_tokens": 10}
|
||||
|
||||
num_requests = 10
|
||||
truncate_prompt_tokens = ([1000] * (num_requests // 2) + [None] *
|
||||
(num_requests - num_requests // 2))
|
||||
random.shuffle(truncate_prompt_tokens)
|
||||
|
||||
bodies = [{
|
||||
**body, "extra_body": {
|
||||
'truncate_prompt_tokens': t
|
||||
}
|
||||
} for t in truncate_prompt_tokens]
|
||||
|
||||
async def get_status_code(**kwargs):
|
||||
try:
|
||||
await create_func(**kwargs)
|
||||
return 200
|
||||
except openai.APIStatusError as e:
|
||||
return e.status_code
|
||||
|
||||
responses = await asyncio.gather(*[get_status_code(**b) for b in bodies])
|
||||
assert 500 not in responses
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
ids=["single completion", "multiple completions", "chat"],
|
||||
argnames=["create_func_gen", "content_body"],
|
||||
argvalues=[
|
||||
(lambda x: x.completions.create, {
|
||||
"prompt": " ".join(['A'] * 300_000)
|
||||
}),
|
||||
(lambda x: x.completions.create, {
|
||||
"prompt": [" ".join(['A'] * 300_000)] * 2
|
||||
}),
|
||||
(lambda x: x.chat.completions.create, {
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": " ".join(['A'] * 300_000)
|
||||
}]
|
||||
}),
|
||||
],
|
||||
)
|
||||
async def test_healthcheck_response_time(
|
||||
server: RemoteOpenAIServer,
|
||||
client: openai.AsyncOpenAI,
|
||||
create_func_gen: Callable,
|
||||
content_body: dict,
|
||||
):
|
||||
num_requests = 50
|
||||
|
||||
create_func = create_func_gen(client)
|
||||
body = {"model": MODEL_NAME, **content_body, "max_tokens": 10}
|
||||
|
||||
def get_response_time(url):
|
||||
start_time = time.monotonic()
|
||||
res = requests.get(url)
|
||||
end_time = time.monotonic()
|
||||
assert res.status_code == 200
|
||||
return end_time - start_time
|
||||
|
||||
no_load_response_time = get_response_time(server.url_for("health"))
|
||||
tasks = [
|
||||
asyncio.create_task(create_func(**body)) for _ in range(num_requests)
|
||||
]
|
||||
await asyncio.sleep(1) # give the tasks a chance to start running
|
||||
load_response_time = get_response_time(server.url_for("health"))
|
||||
|
||||
with contextlib.suppress(openai.APIStatusError):
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
assert load_response_time < 100 * no_load_response_time
|
||||
assert load_response_time < 0.1
|
||||
@ -101,7 +101,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||
|
||||
request_prompts, engine_prompts = self._preprocess_completion(
|
||||
request_prompts, engine_prompts = await self._preprocess_completion(
|
||||
request,
|
||||
tokenizer,
|
||||
request.prompt,
|
||||
|
||||
@ -156,13 +156,14 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
else:
|
||||
request_prompts, engine_prompts = self._preprocess_completion(
|
||||
request,
|
||||
tokenizer,
|
||||
request.input,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
(request_prompts,
|
||||
engine_prompts) = await self._preprocess_completion(
|
||||
request,
|
||||
tokenizer,
|
||||
request.input,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import json
|
||||
import pathlib
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from http import HTTPStatus
|
||||
from typing import (Any, Callable, Dict, Iterable, Iterator, List, Mapping,
|
||||
@ -46,7 +47,7 @@ from vllm.sequence import Logprob
|
||||
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
||||
log_tracing_disabled_warning)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.utils import AtomicCounter, is_list_of
|
||||
from vllm.utils import AtomicCounter, is_list_of, make_async
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -140,6 +141,14 @@ class OpenAIServing:
|
||||
self.request_logger = request_logger
|
||||
self.return_tokens_as_token_ids = return_tokens_as_token_ids
|
||||
|
||||
self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
self._tokenize_prompt_input_async = make_async(
|
||||
self._tokenize_prompt_input, executor=self._tokenizer_executor)
|
||||
self._tokenize_prompt_input_or_inputs_async = make_async(
|
||||
self._tokenize_prompt_input_or_inputs,
|
||||
executor=self._tokenizer_executor)
|
||||
|
||||
async def show_available_models(self) -> ModelList:
|
||||
"""Show available models. Right now we only have one model."""
|
||||
model_cards = [
|
||||
@ -368,7 +377,7 @@ class OpenAIServing:
|
||||
input_or_inputs: Union[str, List[str], List[int], List[List[int]]],
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
|
||||
add_special_tokens: bool = True,
|
||||
) -> Iterator[TextTokensPrompt]:
|
||||
) -> List[TextTokensPrompt]:
|
||||
"""
|
||||
Tokenize/detokenize depending on the input format.
|
||||
|
||||
@ -376,45 +385,41 @@ class OpenAIServing:
|
||||
, each input can be a string or array of tokens. Note that each request
|
||||
can pass one or more inputs.
|
||||
"""
|
||||
for prompt_input in parse_and_batch_prompt(input_or_inputs):
|
||||
# Although our type checking is based on mypy,
|
||||
# VSCode Pyright extension should still work properly
|
||||
# "is True" is required for Pyright to perform type narrowing
|
||||
# See: https://github.com/microsoft/pyright/issues/7672
|
||||
if prompt_input["is_tokens"] is False:
|
||||
yield self._normalize_prompt_text_to_input(
|
||||
request,
|
||||
tokenizer,
|
||||
prompt=prompt_input["content"],
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
add_special_tokens=add_special_tokens,
|
||||
)
|
||||
else:
|
||||
yield self._normalize_prompt_tokens_to_input(
|
||||
request,
|
||||
tokenizer,
|
||||
prompt_ids=prompt_input["content"],
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
)
|
||||
# Although our type checking is based on mypy,
|
||||
# VSCode Pyright extension should still work properly
|
||||
# "is True" is required for Pyright to perform type narrowing
|
||||
# See: https://github.com/microsoft/pyright/issues/7672
|
||||
return [
|
||||
self._normalize_prompt_text_to_input(
|
||||
request,
|
||||
tokenizer,
|
||||
prompt=prompt_input["content"],
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
add_special_tokens=add_special_tokens)
|
||||
if prompt_input["is_tokens"] is False else
|
||||
self._normalize_prompt_tokens_to_input(
|
||||
request,
|
||||
tokenizer,
|
||||
prompt_ids=prompt_input["content"],
|
||||
truncate_prompt_tokens=truncate_prompt_tokens)
|
||||
for prompt_input in parse_and_batch_prompt(input_or_inputs)
|
||||
]
|
||||
|
||||
def _preprocess_completion(
|
||||
async def _preprocess_completion(
|
||||
self,
|
||||
request: CompletionLikeRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
input_or_inputs: Union[str, List[str], List[int], List[List[int]]],
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
|
||||
add_special_tokens: bool = True,
|
||||
) -> Tuple[Sequence[TextTokensPrompt], List[TokensPrompt]]:
|
||||
request_prompts = [
|
||||
request_prompt
|
||||
for request_prompt in self._tokenize_prompt_input_or_inputs(
|
||||
request,
|
||||
tokenizer,
|
||||
input_or_inputs,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
add_special_tokens=add_special_tokens,
|
||||
)
|
||||
]
|
||||
) -> Tuple[List[TextTokensPrompt], List[TokensPrompt]]:
|
||||
request_prompts = await self._tokenize_prompt_input_or_inputs_async(
|
||||
request,
|
||||
tokenizer,
|
||||
input_or_inputs,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
add_special_tokens=add_special_tokens,
|
||||
)
|
||||
|
||||
engine_prompts = [
|
||||
TokensPrompt(prompt_token_ids=request_prompt["prompt_token_ids"])
|
||||
@ -493,7 +498,7 @@ class OpenAIServing:
|
||||
request=request)
|
||||
|
||||
if isinstance(request_prompt, str):
|
||||
prompt_inputs = self._tokenize_prompt_input(
|
||||
prompt_inputs = await self._tokenize_prompt_input_async(
|
||||
request,
|
||||
tokenizer,
|
||||
request_prompt,
|
||||
|
||||
@ -15,7 +15,7 @@ from vllm.inputs.data import TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import EmbeddingRequestOutput
|
||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.utils import merge_async_iterators, random_uuid
|
||||
from vllm.utils import make_async, merge_async_iterators, random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -145,9 +145,11 @@ class OpenAIServingScores(OpenAIServing):
|
||||
tokenization_kwargs["truncation"] = True
|
||||
tokenization_kwargs["max_length"] = truncate_prompt_tokens
|
||||
|
||||
prompt_inputs = tokenizer(text=q,
|
||||
text_pair=t,
|
||||
**tokenization_kwargs)
|
||||
tokenize_async = make_async(tokenizer.__call__,
|
||||
executor=self._tokenizer_executor)
|
||||
prompt_inputs = await tokenize_async(text=q,
|
||||
text_pair=t,
|
||||
**tokenization_kwargs)
|
||||
engine_prompt = TokensPrompt(
|
||||
prompt_token_ids=prompt_inputs["input_ids"],
|
||||
token_type_ids=prompt_inputs.get("token_type_ids"))
|
||||
|
||||
@ -81,12 +81,13 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
else:
|
||||
request_prompts, engine_prompts = self._preprocess_completion(
|
||||
request,
|
||||
tokenizer,
|
||||
request.prompt,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
(request_prompts,
|
||||
engine_prompts) = await self._preprocess_completion(
|
||||
request,
|
||||
tokenizer,
|
||||
request.prompt,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
@ -134,7 +135,7 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
# Silently ignore prompt adapter since it does not affect tokenization
|
||||
# (Unlike in Embeddings API where an error is raised)
|
||||
|
||||
prompt_input = self._tokenize_prompt_input(
|
||||
prompt_input = await self._tokenize_prompt_input_async(
|
||||
request,
|
||||
tokenizer,
|
||||
request.tokens,
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import concurrent
|
||||
import contextlib
|
||||
import datetime
|
||||
import enum
|
||||
@ -351,7 +352,10 @@ def in_wsl() -> bool:
|
||||
return "microsoft" in " ".join(uname()).lower()
|
||||
|
||||
|
||||
def make_async(func: Callable[P, T]) -> Callable[P, Awaitable[T]]:
|
||||
def make_async(
|
||||
func: Callable[P, T],
|
||||
executor: Optional[concurrent.futures.Executor] = None
|
||||
) -> Callable[P, Awaitable[T]]:
|
||||
"""Take a blocking function, and run it on in an executor thread.
|
||||
|
||||
This function prevents the blocking function from blocking the
|
||||
@ -362,7 +366,7 @@ def make_async(func: Callable[P, T]) -> Callable[P, Awaitable[T]]:
|
||||
def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> asyncio.Future:
|
||||
loop = asyncio.get_event_loop()
|
||||
p_func = partial(func, *args, **kwargs)
|
||||
return loop.run_in_executor(executor=None, func=p_func)
|
||||
return loop.run_in_executor(executor=executor, func=p_func)
|
||||
|
||||
return _async_wrapper
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user