[Frontend] don't block event loop in tokenization (preprocess) in OpenAI compatible server (#10635)

Signed-off-by: Tomer Asida <tomera@ai21.com>
This commit is contained in:
tomeras91 2024-11-27 23:21:10 +02:00 committed by GitHub
parent 9b4b150395
commit 395b1c7454
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 206 additions and 56 deletions

View File

@ -0,0 +1,137 @@
import asyncio
import contextlib
import random
import time
from typing import Callable
import openai
import pytest
import pytest_asyncio
import requests
from tests.utils import RemoteOpenAIServer
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
@pytest.fixture(scope="module")
def server(): # noqa: F811
args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--max-model-len",
"8192",
"--enforce-eager",
"--max-num-seqs",
"128",
"--load-format",
"dummy",
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def client(server):
async with server.get_async_client() as async_client:
yield async_client
@pytest.mark.asyncio
@pytest.mark.parametrize(
ids=["completion", "chat"],
argnames=["create_func_gen", "content_body"],
argvalues=[
(lambda x: x.completions.create, {
"prompt": " ".join(['A'] * 10_000)
}),
(lambda x: x.chat.completions.create, {
"messages": [{
"role": "user",
"content": " ".join(['A'] * 10_000)
}]
}),
],
)
async def test_with_and_without_truncate(
server: RemoteOpenAIServer,
client: openai.AsyncOpenAI,
create_func_gen: Callable,
content_body: dict,
):
create_func = create_func_gen(client)
body = {"model": MODEL_NAME, **content_body, "max_tokens": 10}
num_requests = 10
truncate_prompt_tokens = ([1000] * (num_requests // 2) + [None] *
(num_requests - num_requests // 2))
random.shuffle(truncate_prompt_tokens)
bodies = [{
**body, "extra_body": {
'truncate_prompt_tokens': t
}
} for t in truncate_prompt_tokens]
async def get_status_code(**kwargs):
try:
await create_func(**kwargs)
return 200
except openai.APIStatusError as e:
return e.status_code
responses = await asyncio.gather(*[get_status_code(**b) for b in bodies])
assert 500 not in responses
@pytest.mark.asyncio
@pytest.mark.parametrize(
ids=["single completion", "multiple completions", "chat"],
argnames=["create_func_gen", "content_body"],
argvalues=[
(lambda x: x.completions.create, {
"prompt": " ".join(['A'] * 300_000)
}),
(lambda x: x.completions.create, {
"prompt": [" ".join(['A'] * 300_000)] * 2
}),
(lambda x: x.chat.completions.create, {
"messages": [{
"role": "user",
"content": " ".join(['A'] * 300_000)
}]
}),
],
)
async def test_healthcheck_response_time(
server: RemoteOpenAIServer,
client: openai.AsyncOpenAI,
create_func_gen: Callable,
content_body: dict,
):
num_requests = 50
create_func = create_func_gen(client)
body = {"model": MODEL_NAME, **content_body, "max_tokens": 10}
def get_response_time(url):
start_time = time.monotonic()
res = requests.get(url)
end_time = time.monotonic()
assert res.status_code == 200
return end_time - start_time
no_load_response_time = get_response_time(server.url_for("health"))
tasks = [
asyncio.create_task(create_func(**body)) for _ in range(num_requests)
]
await asyncio.sleep(1) # give the tasks a chance to start running
load_response_time = get_response_time(server.url_for("health"))
with contextlib.suppress(openai.APIStatusError):
await asyncio.gather(*tasks)
assert load_response_time < 100 * no_load_response_time
assert load_response_time < 0.1

View File

@ -101,7 +101,7 @@ class OpenAIServingCompletion(OpenAIServing):
tokenizer = await self.engine_client.get_tokenizer(lora_request) tokenizer = await self.engine_client.get_tokenizer(lora_request)
request_prompts, engine_prompts = self._preprocess_completion( request_prompts, engine_prompts = await self._preprocess_completion(
request, request,
tokenizer, tokenizer,
request.prompt, request.prompt,

View File

@ -156,7 +156,8 @@ class OpenAIServingEmbedding(OpenAIServing):
add_special_tokens=request.add_special_tokens, add_special_tokens=request.add_special_tokens,
) )
else: else:
request_prompts, engine_prompts = self._preprocess_completion( (request_prompts,
engine_prompts) = await self._preprocess_completion(
request, request,
tokenizer, tokenizer,
request.input, request.input,

View File

@ -1,5 +1,6 @@
import json import json
import pathlib import pathlib
from concurrent.futures.thread import ThreadPoolExecutor
from dataclasses import dataclass from dataclasses import dataclass
from http import HTTPStatus from http import HTTPStatus
from typing import (Any, Callable, Dict, Iterable, Iterator, List, Mapping, from typing import (Any, Callable, Dict, Iterable, Iterator, List, Mapping,
@ -46,7 +47,7 @@ from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers, from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning) log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import AtomicCounter, is_list_of from vllm.utils import AtomicCounter, is_list_of, make_async
logger = init_logger(__name__) logger = init_logger(__name__)
@ -140,6 +141,14 @@ class OpenAIServing:
self.request_logger = request_logger self.request_logger = request_logger
self.return_tokens_as_token_ids = return_tokens_as_token_ids self.return_tokens_as_token_ids = return_tokens_as_token_ids
self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
self._tokenize_prompt_input_async = make_async(
self._tokenize_prompt_input, executor=self._tokenizer_executor)
self._tokenize_prompt_input_or_inputs_async = make_async(
self._tokenize_prompt_input_or_inputs,
executor=self._tokenizer_executor)
async def show_available_models(self) -> ModelList: async def show_available_models(self) -> ModelList:
"""Show available models. Right now we only have one model.""" """Show available models. Right now we only have one model."""
model_cards = [ model_cards = [
@ -368,7 +377,7 @@ class OpenAIServing:
input_or_inputs: Union[str, List[str], List[int], List[List[int]]], input_or_inputs: Union[str, List[str], List[int], List[List[int]]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
add_special_tokens: bool = True, add_special_tokens: bool = True,
) -> Iterator[TextTokensPrompt]: ) -> List[TextTokensPrompt]:
""" """
Tokenize/detokenize depending on the input format. Tokenize/detokenize depending on the input format.
@ -376,45 +385,41 @@ class OpenAIServing:
, each input can be a string or array of tokens. Note that each request , each input can be a string or array of tokens. Note that each request
can pass one or more inputs. can pass one or more inputs.
""" """
for prompt_input in parse_and_batch_prompt(input_or_inputs):
# Although our type checking is based on mypy, # Although our type checking is based on mypy,
# VSCode Pyright extension should still work properly # VSCode Pyright extension should still work properly
# "is True" is required for Pyright to perform type narrowing # "is True" is required for Pyright to perform type narrowing
# See: https://github.com/microsoft/pyright/issues/7672 # See: https://github.com/microsoft/pyright/issues/7672
if prompt_input["is_tokens"] is False: return [
yield self._normalize_prompt_text_to_input( self._normalize_prompt_text_to_input(
request, request,
tokenizer, tokenizer,
prompt=prompt_input["content"], prompt=prompt_input["content"],
truncate_prompt_tokens=truncate_prompt_tokens, truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens, add_special_tokens=add_special_tokens)
) if prompt_input["is_tokens"] is False else
else: self._normalize_prompt_tokens_to_input(
yield self._normalize_prompt_tokens_to_input(
request, request,
tokenizer, tokenizer,
prompt_ids=prompt_input["content"], prompt_ids=prompt_input["content"],
truncate_prompt_tokens=truncate_prompt_tokens, truncate_prompt_tokens=truncate_prompt_tokens)
) for prompt_input in parse_and_batch_prompt(input_or_inputs)
]
def _preprocess_completion( async def _preprocess_completion(
self, self,
request: CompletionLikeRequest, request: CompletionLikeRequest,
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
input_or_inputs: Union[str, List[str], List[int], List[List[int]]], input_or_inputs: Union[str, List[str], List[int], List[List[int]]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
add_special_tokens: bool = True, add_special_tokens: bool = True,
) -> Tuple[Sequence[TextTokensPrompt], List[TokensPrompt]]: ) -> Tuple[List[TextTokensPrompt], List[TokensPrompt]]:
request_prompts = [ request_prompts = await self._tokenize_prompt_input_or_inputs_async(
request_prompt
for request_prompt in self._tokenize_prompt_input_or_inputs(
request, request,
tokenizer, tokenizer,
input_or_inputs, input_or_inputs,
truncate_prompt_tokens=truncate_prompt_tokens, truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens, add_special_tokens=add_special_tokens,
) )
]
engine_prompts = [ engine_prompts = [
TokensPrompt(prompt_token_ids=request_prompt["prompt_token_ids"]) TokensPrompt(prompt_token_ids=request_prompt["prompt_token_ids"])
@ -493,7 +498,7 @@ class OpenAIServing:
request=request) request=request)
if isinstance(request_prompt, str): if isinstance(request_prompt, str):
prompt_inputs = self._tokenize_prompt_input( prompt_inputs = await self._tokenize_prompt_input_async(
request, request,
tokenizer, tokenizer,
request_prompt, request_prompt,

View File

@ -15,7 +15,7 @@ from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import EmbeddingRequestOutput from vllm.outputs import EmbeddingRequestOutput
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.utils import merge_async_iterators, random_uuid from vllm.utils import make_async, merge_async_iterators, random_uuid
logger = init_logger(__name__) logger = init_logger(__name__)
@ -145,7 +145,9 @@ class OpenAIServingScores(OpenAIServing):
tokenization_kwargs["truncation"] = True tokenization_kwargs["truncation"] = True
tokenization_kwargs["max_length"] = truncate_prompt_tokens tokenization_kwargs["max_length"] = truncate_prompt_tokens
prompt_inputs = tokenizer(text=q, tokenize_async = make_async(tokenizer.__call__,
executor=self._tokenizer_executor)
prompt_inputs = await tokenize_async(text=q,
text_pair=t, text_pair=t,
**tokenization_kwargs) **tokenization_kwargs)
engine_prompt = TokensPrompt( engine_prompt = TokensPrompt(

View File

@ -81,7 +81,8 @@ class OpenAIServingTokenization(OpenAIServing):
add_special_tokens=request.add_special_tokens, add_special_tokens=request.add_special_tokens,
) )
else: else:
request_prompts, engine_prompts = self._preprocess_completion( (request_prompts,
engine_prompts) = await self._preprocess_completion(
request, request,
tokenizer, tokenizer,
request.prompt, request.prompt,
@ -134,7 +135,7 @@ class OpenAIServingTokenization(OpenAIServing):
# Silently ignore prompt adapter since it does not affect tokenization # Silently ignore prompt adapter since it does not affect tokenization
# (Unlike in Embeddings API where an error is raised) # (Unlike in Embeddings API where an error is raised)
prompt_input = self._tokenize_prompt_input( prompt_input = await self._tokenize_prompt_input_async(
request, request,
tokenizer, tokenizer,
request.tokens, request.tokens,

View File

@ -1,5 +1,6 @@
import argparse import argparse
import asyncio import asyncio
import concurrent
import contextlib import contextlib
import datetime import datetime
import enum import enum
@ -351,7 +352,10 @@ def in_wsl() -> bool:
return "microsoft" in " ".join(uname()).lower() return "microsoft" in " ".join(uname()).lower()
def make_async(func: Callable[P, T]) -> Callable[P, Awaitable[T]]: def make_async(
func: Callable[P, T],
executor: Optional[concurrent.futures.Executor] = None
) -> Callable[P, Awaitable[T]]:
"""Take a blocking function, and run it on in an executor thread. """Take a blocking function, and run it on in an executor thread.
This function prevents the blocking function from blocking the This function prevents the blocking function from blocking the
@ -362,7 +366,7 @@ def make_async(func: Callable[P, T]) -> Callable[P, Awaitable[T]]:
def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> asyncio.Future: def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> asyncio.Future:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
p_func = partial(func, *args, **kwargs) p_func = partial(func, *args, **kwargs)
return loop.run_in_executor(executor=None, func=p_func) return loop.run_in_executor(executor=executor, func=p_func)
return _async_wrapper return _async_wrapper