diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index ad80946b5671..8a7892cf6d6a 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -7,6 +7,8 @@ from dataclasses import dataclass, field from typing import Any, Optional from unittest.mock import MagicMock +import pytest + from vllm.config import MultiModalConfig from vllm.engine.multiprocessing.client import MQLLMEngineClient from vllm.entrypoints.openai.protocol import ChatCompletionRequest @@ -73,7 +75,8 @@ def test_async_serving_chat_init(): assert serving_completion.chat_template == CHAT_TEMPLATE -def test_serving_chat_should_set_correct_max_tokens(): +@pytest.mark.asyncio +async def test_serving_chat_should_set_correct_max_tokens(): mock_engine = MagicMock(spec=MQLLMEngineClient) mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False @@ -88,6 +91,7 @@ def test_serving_chat_should_set_correct_max_tokens(): chat_template=CHAT_TEMPLATE, chat_template_content_format="auto", request_logger=None) + req = ChatCompletionRequest( model=MODEL_NAME, messages=[{ @@ -98,13 +102,13 @@ def test_serving_chat_should_set_correct_max_tokens(): ) with suppress(Exception): - asyncio.run(serving_chat.create_chat_completion(req)) + await serving_chat.create_chat_completion(req) assert mock_engine.generate.call_args.args[1].max_tokens == 93 req.max_tokens = 10 with suppress(Exception): - asyncio.run(serving_chat.create_chat_completion(req)) + await serving_chat.create_chat_completion(req) assert mock_engine.generate.call_args.args[1].max_tokens == 10 @@ -143,7 +147,7 @@ def test_serving_chat_should_set_correct_max_tokens(): ) with suppress(Exception): - asyncio.run(serving_chat.create_chat_completion(req)) + await serving_chat.create_chat_completion(req) assert mock_engine.generate.call_args.args[1].max_tokens == 10 @@ -151,7 +155,7 @@ def test_serving_chat_should_set_correct_max_tokens(): req.max_tokens = 15 with suppress(Exception): - asyncio.run(serving_chat.create_chat_completion(req)) + await serving_chat.create_chat_completion(req) assert mock_engine.generate.call_args.args[1].max_tokens == 10 @@ -159,7 +163,7 @@ def test_serving_chat_should_set_correct_max_tokens(): req.max_tokens = 5 with suppress(Exception): - asyncio.run(serving_chat.create_chat_completion(req)) + await serving_chat.create_chat_completion(req) assert mock_engine.generate.call_args.args[1].max_tokens == 5 @@ -198,7 +202,7 @@ def test_serving_chat_should_set_correct_max_tokens(): ) with suppress(Exception): - asyncio.run(serving_chat.create_chat_completion(req)) + await serving_chat.create_chat_completion(req) assert mock_engine.generate.call_args.args[1].max_tokens == 93 @@ -206,7 +210,7 @@ def test_serving_chat_should_set_correct_max_tokens(): req.max_tokens = 100 with suppress(Exception): - asyncio.run(serving_chat.create_chat_completion(req)) + await serving_chat.create_chat_completion(req) assert mock_engine.generate.call_args.args[1].max_tokens == 93 @@ -214,12 +218,13 @@ def test_serving_chat_should_set_correct_max_tokens(): req.max_tokens = 5 with suppress(Exception): - asyncio.run(serving_chat.create_chat_completion(req)) + await serving_chat.create_chat_completion(req) assert mock_engine.generate.call_args.args[1].max_tokens == 5 -def test_serving_chat_could_load_correct_generation_config(): +@pytest.mark.asyncio +async def test_serving_chat_could_load_correct_generation_config(): mock_model_config = MockModelConfig() mock_model_config.diff_sampling_param = { @@ -242,6 +247,7 @@ def test_serving_chat_could_load_correct_generation_config(): chat_template=CHAT_TEMPLATE, chat_template_content_format="auto", request_logger=None) + req = ChatCompletionRequest( model=MODEL_NAME, messages=[{ @@ -252,7 +258,7 @@ def test_serving_chat_could_load_correct_generation_config(): ) with suppress(Exception): - asyncio.run(serving_chat.create_chat_completion(req)) + await serving_chat.create_chat_completion(req) assert mock_engine.generate.call_args.args[1].temperature == 0.5 assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05 @@ -261,7 +267,7 @@ def test_serving_chat_could_load_correct_generation_config(): req.temperature = 0.1 with suppress(Exception): - asyncio.run(serving_chat.create_chat_completion(req)) + await serving_chat.create_chat_completion(req) assert mock_engine.generate.call_args.args[1].temperature == 0.1 assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05 @@ -270,13 +276,14 @@ def test_serving_chat_could_load_correct_generation_config(): req.temperature = 0.0 with suppress(Exception): - asyncio.run(serving_chat.create_chat_completion(req)) + await serving_chat.create_chat_completion(req) assert mock_engine.generate.call_args.args[1].temperature == 0.0 assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05 -def test_serving_chat_did_set_correct_cache_salt(): +@pytest.mark.asyncio +async def test_serving_chat_did_set_correct_cache_salt(): mock_model_config = MockModelConfig() mock_engine = MagicMock(spec=MQLLMEngineClient) @@ -306,11 +313,11 @@ def test_serving_chat_did_set_correct_cache_salt(): # By default cache_salt in the engine prompt is not set with suppress(Exception): - asyncio.run(serving_chat.create_chat_completion(req)) + await serving_chat.create_chat_completion(req) assert "cache_salt" not in mock_engine.generate.call_args.args[0] # Test with certain cache_salt req.cache_salt = "test_salt" with suppress(Exception): - asyncio.run(serving_chat.create_chat_completion(req)) + await serving_chat.create_chat_completion(req) assert mock_engine.generate.call_args.args[0]["cache_salt"] == "test_salt" diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index c4ebb7141d09..bec2e1254795 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -1,13 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import asyncio import base64 import io import json import sys import time -from collections.abc import (AsyncGenerator, Iterable, Iterator, Mapping, - Sequence) -from concurrent.futures.thread import ThreadPoolExecutor +from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence +from concurrent.futures import ThreadPoolExecutor from http import HTTPStatus from typing import (Annotated, Any, Callable, ClassVar, Generic, Optional, TypeVar, Union, cast, overload) @@ -79,8 +79,8 @@ from vllm.sequence import Logprob, PromptLogprobs from vllm.tracing import (contains_trace_headers, extract_trace_headers, log_tracing_disabled_warning) from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer -from vllm.utils import (is_list_of, make_async, merge_async_iterators, - random_uuid) +from vllm.utils import (AsyncMicrobatchTokenizer, is_list_of, + merge_async_iterators, random_uuid) logger = init_logger(__name__) @@ -226,11 +226,19 @@ class OpenAIServing: self._tokenizer_executor = ThreadPoolExecutor(max_workers=1) - self._tokenize_prompt_input_async = make_async( - self._tokenize_prompt_input, executor=self._tokenizer_executor) - self._tokenize_prompt_input_or_inputs_async = make_async( - self._tokenize_prompt_input_or_inputs, - executor=self._tokenizer_executor) + self._async_tokenizer_pool: dict[AnyTokenizer, + AsyncMicrobatchTokenizer] = {} + + def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer: + """ + Return (and cache) an `AsyncMicrobatchTokenizer` bound to the + given tokenizer. + """ + async_tokenizer = self._async_tokenizer_pool.get(tokenizer) + if async_tokenizer is None: + async_tokenizer = AsyncMicrobatchTokenizer(tokenizer) + self._async_tokenizer_pool[tokenizer] = async_tokenizer + return async_tokenizer async def _preprocess( self, @@ -467,7 +475,7 @@ class OpenAIServing: # if _check_model has been called earlier, this will be unreachable raise ValueError(f"The model `{request.model}` does not exist.") - def _normalize_prompt_text_to_input( + async def _normalize_prompt_text_to_input( self, request: AnyRequest, tokenizer: AnyTokenizer, @@ -475,38 +483,44 @@ class OpenAIServing: truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]], add_special_tokens: bool, ) -> TextTokensPrompt: + async_tokenizer = self._get_async_tokenizer(tokenizer) + if (self.model_config.encoder_config is not None and self.model_config.encoder_config.get( "do_lower_case", False)): prompt = prompt.lower() if truncate_prompt_tokens is None: - encoded = tokenizer(prompt, add_special_tokens=add_special_tokens) + encoded = await async_tokenizer( + prompt, add_special_tokens=add_special_tokens) elif truncate_prompt_tokens < 0: # Negative means we cap at the model's max length - encoded = tokenizer(prompt, - add_special_tokens=add_special_tokens, - truncation=True, - max_length=self.max_model_len) + encoded = await async_tokenizer( + prompt, + add_special_tokens=add_special_tokens, + truncation=True, + max_length=self.max_model_len) else: - encoded = tokenizer(prompt, - add_special_tokens=add_special_tokens, - truncation=True, - max_length=truncate_prompt_tokens) + encoded = await async_tokenizer( + prompt, + add_special_tokens=add_special_tokens, + truncation=True, + max_length=truncate_prompt_tokens) input_ids = encoded.input_ids - input_text = prompt return self._validate_input(request, input_ids, input_text) - def _normalize_prompt_tokens_to_input( + async def _normalize_prompt_tokens_to_input( self, request: AnyRequest, tokenizer: AnyTokenizer, prompt_ids: list[int], truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]], ) -> TextTokensPrompt: + async_tokenizer = self._get_async_tokenizer(tokenizer) + if truncate_prompt_tokens is None: input_ids = prompt_ids elif truncate_prompt_tokens < 0: @@ -514,7 +528,7 @@ class OpenAIServing: else: input_ids = prompt_ids[-truncate_prompt_tokens:] - input_text = tokenizer.decode(input_ids) + input_text = await async_tokenizer.decode(input_ids) return self._validate_input(request, input_ids, input_text) @@ -578,7 +592,7 @@ class OpenAIServing: return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids) - def _tokenize_prompt_input( + async def _tokenize_prompt_input_async( self, request: AnyRequest, tokenizer: AnyTokenizer, @@ -591,23 +605,24 @@ class OpenAIServing: [`_tokenize_prompt_input_or_inputs`][vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs] that assumes single input. """ - return next( - self._tokenize_prompt_inputs( + async for result in self._tokenize_prompt_inputs_async( request, tokenizer, - [prompt_input], + [prompt_input], truncate_prompt_tokens=truncate_prompt_tokens, add_special_tokens=add_special_tokens, - )) + ): + return result + raise ValueError("No results yielded from tokenization") - def _tokenize_prompt_inputs( + async def _tokenize_prompt_inputs_async( self, request: AnyRequest, tokenizer: AnyTokenizer, prompt_inputs: Iterable[Union[str, list[int]]], truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None, add_special_tokens: bool = True, - ) -> Iterator[TextTokensPrompt]: + ) -> AsyncGenerator[TextTokensPrompt, None]: """ A simpler implementation of [`_tokenize_prompt_input_or_inputs`][vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs] @@ -615,7 +630,7 @@ class OpenAIServing: """ for text in prompt_inputs: if isinstance(text, str): - yield self._normalize_prompt_text_to_input( + yield await self._normalize_prompt_text_to_input( request, tokenizer, prompt=text, @@ -623,14 +638,14 @@ class OpenAIServing: add_special_tokens=add_special_tokens, ) else: - yield self._normalize_prompt_tokens_to_input( + yield await self._normalize_prompt_tokens_to_input( request, tokenizer, prompt_ids=text, truncate_prompt_tokens=truncate_prompt_tokens, ) - def _tokenize_prompt_input_or_inputs( + async def _tokenize_prompt_input_or_inputs_async( self, request: AnyRequest, tokenizer: AnyTokenizer, @@ -664,21 +679,31 @@ class OpenAIServing: # VSCode Pyright extension should still work properly # "is False" is required for Pyright to perform type narrowing # See: https://github.com/microsoft/pyright/issues/7672 - inputs_text.extend([ - self._normalize_prompt_text_to_input( - request, - tokenizer, - prompt=prompt_input["content"], - truncate_prompt_tokens=truncate_prompt_tokens, - add_special_tokens=add_special_tokens) - if prompt_input["is_tokens"] is False else - self._normalize_prompt_tokens_to_input( - request, - tokenizer, - prompt_ids=prompt_input["content"], - truncate_prompt_tokens=truncate_prompt_tokens) - for prompt_input in parse_and_batch_prompt(input_or_inputs) - ]) + + # Parse and batch the input prompts + batch_inputs = parse_and_batch_prompt(input_or_inputs) + + # Process each input in the batch concurrently + tasks = [] + for prompt_input in batch_inputs: + if prompt_input["is_tokens"] is False: + task = self._normalize_prompt_text_to_input( + request, + tokenizer, + prompt_input["content"], + truncate_prompt_tokens=truncate_prompt_tokens, + add_special_tokens=add_special_tokens) + else: + task = self._normalize_prompt_tokens_to_input( + request, + tokenizer, + prompt_input["content"], + truncate_prompt_tokens=truncate_prompt_tokens) + tasks.append(task) + + # Wait for all tokenization tasks to complete + results = await asyncio.gather(*tasks) + inputs_text.extend(results) return inputs_text, inputs_embeds diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 9322e3cc477a..bfdbd682464a 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -41,6 +41,7 @@ from collections import UserDict, defaultdict from collections.abc import (AsyncGenerator, Awaitable, Collection, Generator, Hashable, Iterable, Iterator, KeysView, Mapping, Sequence) +from concurrent.futures import ThreadPoolExecutor from concurrent.futures.process import ProcessPoolExecutor from dataclasses import dataclass, field from functools import cache, lru_cache, partial, wraps @@ -64,6 +65,7 @@ import zmq.asyncio from packaging import version from packaging.version import Version from torch.library import Library +from transformers.tokenization_utils_base import BatchEncoding from typing_extensions import Never, ParamSpec, TypeIs, assert_never import vllm.envs as envs @@ -507,6 +509,196 @@ def random_uuid() -> str: return str(uuid.uuid4().hex) +class AsyncMicrobatchTokenizer: + """Asynchronous tokenizer with micro-batching. + + Pulls pending encode/decode requests from a queue and batches them + up to reduce overhead. A single-thread ThreadPoolExecutor is used + so the event loop stays responsive. + """ + + def __init__( + self, + tokenizer, + max_batch_size: int = 32, + batch_wait_timeout_s: float = 0.002, + ) -> None: + self.tokenizer = tokenizer + self.max_batch_size = max_batch_size + self.batch_wait_timeout_s = batch_wait_timeout_s + + self._loop = asyncio.get_running_loop() + self._queues: dict[tuple, + asyncio.Queue[Union[tuple[str, dict, + asyncio.Future], + tuple[list[int], + asyncio.Future]]]] = {} + self._batcher_tasks: list[asyncio.Task] = [] + + # Single-thread executor for blocking tokenizer calls. + self._executor = ThreadPoolExecutor(max_workers=1) + + # === Public async API === + async def __call__(self, prompt, **kwargs): + result_future: asyncio.Future = self._loop.create_future() + key = self._queue_key("encode", kwargs) + queue = self._get_queue(self._loop, key) + await queue.put((prompt, kwargs, result_future)) + return await result_future + + async def decode(self, token_ids, **kwargs): + result_future: asyncio.Future = self._loop.create_future() + key = self._queue_key("decode", kwargs) + queue = self._get_queue(self._loop, key) + await queue.put((token_ids, result_future)) + return await result_future + + # === Internal helpers === + def _get_queue( + self, loop: asyncio.AbstractEventLoop, key: tuple + ) -> asyncio.Queue[Union[tuple[str, dict, asyncio.Future], tuple[ + list[int], asyncio.Future]]]: + """Get the request queue for the given operation key, creating a new + queue and batcher task if needed.""" + queue = self._queues.get(key) + if queue is None: + self._queues[key] = queue = asyncio.Queue() + if key[0] == "encode": + can_batch = key[1] != "other" + coro = self._batch_encode_loop(queue, can_batch) + else: + assert key[0] == "decode", \ + f"Unknown operation type: {key[0]}." + coro = self._batch_decode_loop(queue) + self._batcher_tasks.append(loop.create_task(coro)) + return queue + + async def _batch_encode_loop(self, queue: asyncio.Queue, can_batch: bool): + """Batch incoming encode requests for efficiency.""" + while True: + prompt, kwargs, result_future = await queue.get() + prompts = [prompt] + kwargs_list = [kwargs] + result_futures = [result_future] + deadline = self._loop.time() + self.batch_wait_timeout_s + + while len(prompts) < self.max_batch_size: + timeout = deadline - self._loop.time() + if timeout <= 0: + break + try: + prompt, kwargs, result_future = await asyncio.wait_for( + queue.get(), timeout) + prompts.append(prompt) + result_futures.append(result_future) + if not can_batch: + kwargs_list.append(kwargs) + except asyncio.TimeoutError: + break + + try: + # If every request uses identical kwargs we can run a single + # batched tokenizer call for a big speed-up. + if can_batch and len(prompts) > 1: + encode_fn = partial(self.tokenizer, prompts, **kwargs) + results = await self._loop.run_in_executor( + self._executor, encode_fn) + + for i, fut in enumerate(result_futures): + if not fut.done(): + data = {k: v[i] for k, v in results.items()} + fut.set_result(BatchEncoding(data)) + else: + encode_fn = lambda prompts=prompts, kwargs=kwargs_list: [ + self.tokenizer(p, **kw) + for p, kw in zip(prompts, kwargs) + ] + results = await self._loop.run_in_executor( + self._executor, encode_fn) + + for fut, res in zip(result_futures, results): + if not fut.done(): + fut.set_result(res) + except Exception as e: + for fut in result_futures: + if not fut.done(): + fut.set_exception(e) + + async def _batch_decode_loop(self, queue: asyncio.Queue): + """Batch incoming decode requests for efficiency.""" + while True: + token_ids, result_future = await queue.get() + token_ids_list = [token_ids] + result_futures = [result_future] + deadline = self._loop.time() + self.batch_wait_timeout_s + + while len(token_ids_list) < self.max_batch_size: + timeout = deadline - self._loop.time() + if timeout <= 0: + break + try: + token_ids, result_future = await asyncio.wait_for( + queue.get(), timeout) + token_ids_list.append(token_ids) + result_futures.append(result_future) + except asyncio.TimeoutError: + break + + try: + # Perform a single batched decode call for all requests + results = await self._loop.run_in_executor( + self._executor, self.tokenizer.batch_decode, + token_ids_list) + for fut, res in zip(result_futures, results): + if not fut.done(): + fut.set_result(res) + except Exception as e: + for fut in result_futures: + if not fut.done(): + fut.set_exception(e) + + def _queue_key(self, op: str, kwargs: dict) -> tuple: + """ + Return a normalized key describing operation + kwargs. + + - `add_special_tokens`: {True/False} + - `truncation`: {True/False} + - If `truncation` is False (`max_length` is None), + returns a key for a can_batch queue. + - If `truncation` is True and `max_length` is None or equals + `tokenizer.model_max_length`, returns a key for a can_batch queue. + - Otherwise, returns a key for a cannot_batch queue. + + Examples: + - Decode: ("decode",) + - Encode typical: + ("encode", add_special_tokens, bool_truncation, max_length_label) + - Fallback: ("encode", "other") + """ + + if op == "decode": + return ("decode", ) + + add_special_tokens = kwargs.get("add_special_tokens", True) + truncation = kwargs.get("truncation", False) + max_length = kwargs.get("max_length") + + if not truncation: + return ("encode", add_special_tokens, False, None) + + model_max = getattr(self.tokenizer, "model_max_length", None) + if max_length is None or (model_max is not None + and max_length == model_max): + return ("encode", add_special_tokens, True, "model_max") + + return ("encode", "other") + + def __del__(self): + for task in self._batcher_tasks: + if not task.done(): + task.cancel() + + def make_async( func: Callable[P, T], executor: Optional[concurrent.futures.Executor] = None