[Front-end] microbatch tokenization (#19334)

Signed-off-by: zt2370 <ztang2370@gmail.com>
This commit is contained in:
ztang2370 2025-07-08 00:54:10 +08:00 committed by GitHub
parent edd270bc78
commit a37d75bbec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 288 additions and 64 deletions

View File

@ -7,6 +7,8 @@ from dataclasses import dataclass, field
from typing import Any, Optional
from unittest.mock import MagicMock
import pytest
from vllm.config import MultiModalConfig
from vllm.engine.multiprocessing.client import MQLLMEngineClient
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
@ -73,7 +75,8 @@ def test_async_serving_chat_init():
assert serving_completion.chat_template == CHAT_TEMPLATE
def test_serving_chat_should_set_correct_max_tokens():
@pytest.mark.asyncio
async def test_serving_chat_should_set_correct_max_tokens():
mock_engine = MagicMock(spec=MQLLMEngineClient)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
@ -88,6 +91,7 @@ def test_serving_chat_should_set_correct_max_tokens():
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
request_logger=None)
req = ChatCompletionRequest(
model=MODEL_NAME,
messages=[{
@ -98,13 +102,13 @@ def test_serving_chat_should_set_correct_max_tokens():
)
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].max_tokens == 93
req.max_tokens = 10
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].max_tokens == 10
@ -143,7 +147,7 @@ def test_serving_chat_should_set_correct_max_tokens():
)
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].max_tokens == 10
@ -151,7 +155,7 @@ def test_serving_chat_should_set_correct_max_tokens():
req.max_tokens = 15
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].max_tokens == 10
@ -159,7 +163,7 @@ def test_serving_chat_should_set_correct_max_tokens():
req.max_tokens = 5
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].max_tokens == 5
@ -198,7 +202,7 @@ def test_serving_chat_should_set_correct_max_tokens():
)
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].max_tokens == 93
@ -206,7 +210,7 @@ def test_serving_chat_should_set_correct_max_tokens():
req.max_tokens = 100
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].max_tokens == 93
@ -214,12 +218,13 @@ def test_serving_chat_should_set_correct_max_tokens():
req.max_tokens = 5
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].max_tokens == 5
def test_serving_chat_could_load_correct_generation_config():
@pytest.mark.asyncio
async def test_serving_chat_could_load_correct_generation_config():
mock_model_config = MockModelConfig()
mock_model_config.diff_sampling_param = {
@ -242,6 +247,7 @@ def test_serving_chat_could_load_correct_generation_config():
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
request_logger=None)
req = ChatCompletionRequest(
model=MODEL_NAME,
messages=[{
@ -252,7 +258,7 @@ def test_serving_chat_could_load_correct_generation_config():
)
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].temperature == 0.5
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
@ -261,7 +267,7 @@ def test_serving_chat_could_load_correct_generation_config():
req.temperature = 0.1
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].temperature == 0.1
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
@ -270,13 +276,14 @@ def test_serving_chat_could_load_correct_generation_config():
req.temperature = 0.0
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].temperature == 0.0
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
def test_serving_chat_did_set_correct_cache_salt():
@pytest.mark.asyncio
async def test_serving_chat_did_set_correct_cache_salt():
mock_model_config = MockModelConfig()
mock_engine = MagicMock(spec=MQLLMEngineClient)
@ -306,11 +313,11 @@ def test_serving_chat_did_set_correct_cache_salt():
# By default cache_salt in the engine prompt is not set
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
await serving_chat.create_chat_completion(req)
assert "cache_salt" not in mock_engine.generate.call_args.args[0]
# Test with certain cache_salt
req.cache_salt = "test_salt"
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[0]["cache_salt"] == "test_salt"

View File

@ -1,13 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import base64
import io
import json
import sys
import time
from collections.abc import (AsyncGenerator, Iterable, Iterator, Mapping,
Sequence)
from concurrent.futures.thread import ThreadPoolExecutor
from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence
from concurrent.futures import ThreadPoolExecutor
from http import HTTPStatus
from typing import (Annotated, Any, Callable, ClassVar, Generic, Optional,
TypeVar, Union, cast, overload)
@ -79,8 +79,8 @@ from vllm.sequence import Logprob, PromptLogprobs
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import (is_list_of, make_async, merge_async_iterators,
random_uuid)
from vllm.utils import (AsyncMicrobatchTokenizer, is_list_of,
merge_async_iterators, random_uuid)
logger = init_logger(__name__)
@ -226,11 +226,19 @@ class OpenAIServing:
self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
self._tokenize_prompt_input_async = make_async(
self._tokenize_prompt_input, executor=self._tokenizer_executor)
self._tokenize_prompt_input_or_inputs_async = make_async(
self._tokenize_prompt_input_or_inputs,
executor=self._tokenizer_executor)
self._async_tokenizer_pool: dict[AnyTokenizer,
AsyncMicrobatchTokenizer] = {}
def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer:
"""
Return (and cache) an `AsyncMicrobatchTokenizer` bound to the
given tokenizer.
"""
async_tokenizer = self._async_tokenizer_pool.get(tokenizer)
if async_tokenizer is None:
async_tokenizer = AsyncMicrobatchTokenizer(tokenizer)
self._async_tokenizer_pool[tokenizer] = async_tokenizer
return async_tokenizer
async def _preprocess(
self,
@ -467,7 +475,7 @@ class OpenAIServing:
# if _check_model has been called earlier, this will be unreachable
raise ValueError(f"The model `{request.model}` does not exist.")
def _normalize_prompt_text_to_input(
async def _normalize_prompt_text_to_input(
self,
request: AnyRequest,
tokenizer: AnyTokenizer,
@ -475,38 +483,44 @@ class OpenAIServing:
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]],
add_special_tokens: bool,
) -> TextTokensPrompt:
async_tokenizer = self._get_async_tokenizer(tokenizer)
if (self.model_config.encoder_config is not None
and self.model_config.encoder_config.get(
"do_lower_case", False)):
prompt = prompt.lower()
if truncate_prompt_tokens is None:
encoded = tokenizer(prompt, add_special_tokens=add_special_tokens)
encoded = await async_tokenizer(
prompt, add_special_tokens=add_special_tokens)
elif truncate_prompt_tokens < 0:
# Negative means we cap at the model's max length
encoded = tokenizer(prompt,
add_special_tokens=add_special_tokens,
truncation=True,
max_length=self.max_model_len)
encoded = await async_tokenizer(
prompt,
add_special_tokens=add_special_tokens,
truncation=True,
max_length=self.max_model_len)
else:
encoded = tokenizer(prompt,
add_special_tokens=add_special_tokens,
truncation=True,
max_length=truncate_prompt_tokens)
encoded = await async_tokenizer(
prompt,
add_special_tokens=add_special_tokens,
truncation=True,
max_length=truncate_prompt_tokens)
input_ids = encoded.input_ids
input_text = prompt
return self._validate_input(request, input_ids, input_text)
def _normalize_prompt_tokens_to_input(
async def _normalize_prompt_tokens_to_input(
self,
request: AnyRequest,
tokenizer: AnyTokenizer,
prompt_ids: list[int],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
) -> TextTokensPrompt:
async_tokenizer = self._get_async_tokenizer(tokenizer)
if truncate_prompt_tokens is None:
input_ids = prompt_ids
elif truncate_prompt_tokens < 0:
@ -514,7 +528,7 @@ class OpenAIServing:
else:
input_ids = prompt_ids[-truncate_prompt_tokens:]
input_text = tokenizer.decode(input_ids)
input_text = await async_tokenizer.decode(input_ids)
return self._validate_input(request, input_ids, input_text)
@ -578,7 +592,7 @@ class OpenAIServing:
return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
def _tokenize_prompt_input(
async def _tokenize_prompt_input_async(
self,
request: AnyRequest,
tokenizer: AnyTokenizer,
@ -591,23 +605,24 @@ class OpenAIServing:
[`_tokenize_prompt_input_or_inputs`][vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs]
that assumes single input.
"""
return next(
self._tokenize_prompt_inputs(
async for result in self._tokenize_prompt_inputs_async(
request,
tokenizer,
[prompt_input],
[prompt_input],
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens,
))
):
return result
raise ValueError("No results yielded from tokenization")
def _tokenize_prompt_inputs(
async def _tokenize_prompt_inputs_async(
self,
request: AnyRequest,
tokenizer: AnyTokenizer,
prompt_inputs: Iterable[Union[str, list[int]]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
add_special_tokens: bool = True,
) -> Iterator[TextTokensPrompt]:
) -> AsyncGenerator[TextTokensPrompt, None]:
"""
A simpler implementation of
[`_tokenize_prompt_input_or_inputs`][vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs]
@ -615,7 +630,7 @@ class OpenAIServing:
"""
for text in prompt_inputs:
if isinstance(text, str):
yield self._normalize_prompt_text_to_input(
yield await self._normalize_prompt_text_to_input(
request,
tokenizer,
prompt=text,
@ -623,14 +638,14 @@ class OpenAIServing:
add_special_tokens=add_special_tokens,
)
else:
yield self._normalize_prompt_tokens_to_input(
yield await self._normalize_prompt_tokens_to_input(
request,
tokenizer,
prompt_ids=text,
truncate_prompt_tokens=truncate_prompt_tokens,
)
def _tokenize_prompt_input_or_inputs(
async def _tokenize_prompt_input_or_inputs_async(
self,
request: AnyRequest,
tokenizer: AnyTokenizer,
@ -664,21 +679,31 @@ class OpenAIServing:
# VSCode Pyright extension should still work properly
# "is False" is required for Pyright to perform type narrowing
# See: https://github.com/microsoft/pyright/issues/7672
inputs_text.extend([
self._normalize_prompt_text_to_input(
request,
tokenizer,
prompt=prompt_input["content"],
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens)
if prompt_input["is_tokens"] is False else
self._normalize_prompt_tokens_to_input(
request,
tokenizer,
prompt_ids=prompt_input["content"],
truncate_prompt_tokens=truncate_prompt_tokens)
for prompt_input in parse_and_batch_prompt(input_or_inputs)
])
# Parse and batch the input prompts
batch_inputs = parse_and_batch_prompt(input_or_inputs)
# Process each input in the batch concurrently
tasks = []
for prompt_input in batch_inputs:
if prompt_input["is_tokens"] is False:
task = self._normalize_prompt_text_to_input(
request,
tokenizer,
prompt_input["content"],
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens)
else:
task = self._normalize_prompt_tokens_to_input(
request,
tokenizer,
prompt_input["content"],
truncate_prompt_tokens=truncate_prompt_tokens)
tasks.append(task)
# Wait for all tokenization tasks to complete
results = await asyncio.gather(*tasks)
inputs_text.extend(results)
return inputs_text, inputs_embeds

View File

@ -41,6 +41,7 @@ from collections import UserDict, defaultdict
from collections.abc import (AsyncGenerator, Awaitable, Collection, Generator,
Hashable, Iterable, Iterator, KeysView, Mapping,
Sequence)
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures.process import ProcessPoolExecutor
from dataclasses import dataclass, field
from functools import cache, lru_cache, partial, wraps
@ -64,6 +65,7 @@ import zmq.asyncio
from packaging import version
from packaging.version import Version
from torch.library import Library
from transformers.tokenization_utils_base import BatchEncoding
from typing_extensions import Never, ParamSpec, TypeIs, assert_never
import vllm.envs as envs
@ -507,6 +509,196 @@ def random_uuid() -> str:
return str(uuid.uuid4().hex)
class AsyncMicrobatchTokenizer:
"""Asynchronous tokenizer with micro-batching.
Pulls pending encode/decode requests from a queue and batches them
up to reduce overhead. A single-thread ThreadPoolExecutor is used
so the event loop stays responsive.
"""
def __init__(
self,
tokenizer,
max_batch_size: int = 32,
batch_wait_timeout_s: float = 0.002,
) -> None:
self.tokenizer = tokenizer
self.max_batch_size = max_batch_size
self.batch_wait_timeout_s = batch_wait_timeout_s
self._loop = asyncio.get_running_loop()
self._queues: dict[tuple,
asyncio.Queue[Union[tuple[str, dict,
asyncio.Future],
tuple[list[int],
asyncio.Future]]]] = {}
self._batcher_tasks: list[asyncio.Task] = []
# Single-thread executor for blocking tokenizer calls.
self._executor = ThreadPoolExecutor(max_workers=1)
# === Public async API ===
async def __call__(self, prompt, **kwargs):
result_future: asyncio.Future = self._loop.create_future()
key = self._queue_key("encode", kwargs)
queue = self._get_queue(self._loop, key)
await queue.put((prompt, kwargs, result_future))
return await result_future
async def decode(self, token_ids, **kwargs):
result_future: asyncio.Future = self._loop.create_future()
key = self._queue_key("decode", kwargs)
queue = self._get_queue(self._loop, key)
await queue.put((token_ids, result_future))
return await result_future
# === Internal helpers ===
def _get_queue(
self, loop: asyncio.AbstractEventLoop, key: tuple
) -> asyncio.Queue[Union[tuple[str, dict, asyncio.Future], tuple[
list[int], asyncio.Future]]]:
"""Get the request queue for the given operation key, creating a new
queue and batcher task if needed."""
queue = self._queues.get(key)
if queue is None:
self._queues[key] = queue = asyncio.Queue()
if key[0] == "encode":
can_batch = key[1] != "other"
coro = self._batch_encode_loop(queue, can_batch)
else:
assert key[0] == "decode", \
f"Unknown operation type: {key[0]}."
coro = self._batch_decode_loop(queue)
self._batcher_tasks.append(loop.create_task(coro))
return queue
async def _batch_encode_loop(self, queue: asyncio.Queue, can_batch: bool):
"""Batch incoming encode requests for efficiency."""
while True:
prompt, kwargs, result_future = await queue.get()
prompts = [prompt]
kwargs_list = [kwargs]
result_futures = [result_future]
deadline = self._loop.time() + self.batch_wait_timeout_s
while len(prompts) < self.max_batch_size:
timeout = deadline - self._loop.time()
if timeout <= 0:
break
try:
prompt, kwargs, result_future = await asyncio.wait_for(
queue.get(), timeout)
prompts.append(prompt)
result_futures.append(result_future)
if not can_batch:
kwargs_list.append(kwargs)
except asyncio.TimeoutError:
break
try:
# If every request uses identical kwargs we can run a single
# batched tokenizer call for a big speed-up.
if can_batch and len(prompts) > 1:
encode_fn = partial(self.tokenizer, prompts, **kwargs)
results = await self._loop.run_in_executor(
self._executor, encode_fn)
for i, fut in enumerate(result_futures):
if not fut.done():
data = {k: v[i] for k, v in results.items()}
fut.set_result(BatchEncoding(data))
else:
encode_fn = lambda prompts=prompts, kwargs=kwargs_list: [
self.tokenizer(p, **kw)
for p, kw in zip(prompts, kwargs)
]
results = await self._loop.run_in_executor(
self._executor, encode_fn)
for fut, res in zip(result_futures, results):
if not fut.done():
fut.set_result(res)
except Exception as e:
for fut in result_futures:
if not fut.done():
fut.set_exception(e)
async def _batch_decode_loop(self, queue: asyncio.Queue):
"""Batch incoming decode requests for efficiency."""
while True:
token_ids, result_future = await queue.get()
token_ids_list = [token_ids]
result_futures = [result_future]
deadline = self._loop.time() + self.batch_wait_timeout_s
while len(token_ids_list) < self.max_batch_size:
timeout = deadline - self._loop.time()
if timeout <= 0:
break
try:
token_ids, result_future = await asyncio.wait_for(
queue.get(), timeout)
token_ids_list.append(token_ids)
result_futures.append(result_future)
except asyncio.TimeoutError:
break
try:
# Perform a single batched decode call for all requests
results = await self._loop.run_in_executor(
self._executor, self.tokenizer.batch_decode,
token_ids_list)
for fut, res in zip(result_futures, results):
if not fut.done():
fut.set_result(res)
except Exception as e:
for fut in result_futures:
if not fut.done():
fut.set_exception(e)
def _queue_key(self, op: str, kwargs: dict) -> tuple:
"""
Return a normalized key describing operation + kwargs.
- `add_special_tokens`: {True/False}
- `truncation`: {True/False}
- If `truncation` is False (`max_length` is None),
returns a key for a can_batch queue.
- If `truncation` is True and `max_length` is None or equals
`tokenizer.model_max_length`, returns a key for a can_batch queue.
- Otherwise, returns a key for a cannot_batch queue.
Examples:
- Decode: ("decode",)
- Encode typical:
("encode", add_special_tokens, bool_truncation, max_length_label)
- Fallback: ("encode", "other")
"""
if op == "decode":
return ("decode", )
add_special_tokens = kwargs.get("add_special_tokens", True)
truncation = kwargs.get("truncation", False)
max_length = kwargs.get("max_length")
if not truncation:
return ("encode", add_special_tokens, False, None)
model_max = getattr(self.tokenizer, "model_max_length", None)
if max_length is None or (model_max is not None
and max_length == model_max):
return ("encode", add_special_tokens, True, "model_max")
return ("encode", "other")
def __del__(self):
for task in self._batcher_tasks:
if not task.done():
task.cancel()
def make_async(
func: Callable[P, T],
executor: Optional[concurrent.futures.Executor] = None