[Front-end] microbatch tokenization (#19334)

Signed-off-by: zt2370 <ztang2370@gmail.com>
This commit is contained in:
ztang2370 2025-07-08 00:54:10 +08:00 committed by GitHub
parent edd270bc78
commit a37d75bbec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 288 additions and 64 deletions

View File

@ -7,6 +7,8 @@ from dataclasses import dataclass, field
from typing import Any, Optional from typing import Any, Optional
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest
from vllm.config import MultiModalConfig from vllm.config import MultiModalConfig
from vllm.engine.multiprocessing.client import MQLLMEngineClient from vllm.engine.multiprocessing.client import MQLLMEngineClient
from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.entrypoints.openai.protocol import ChatCompletionRequest
@ -73,7 +75,8 @@ def test_async_serving_chat_init():
assert serving_completion.chat_template == CHAT_TEMPLATE assert serving_completion.chat_template == CHAT_TEMPLATE
def test_serving_chat_should_set_correct_max_tokens(): @pytest.mark.asyncio
async def test_serving_chat_should_set_correct_max_tokens():
mock_engine = MagicMock(spec=MQLLMEngineClient) mock_engine = MagicMock(spec=MQLLMEngineClient)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False mock_engine.errored = False
@ -88,6 +91,7 @@ def test_serving_chat_should_set_correct_max_tokens():
chat_template=CHAT_TEMPLATE, chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto", chat_template_content_format="auto",
request_logger=None) request_logger=None)
req = ChatCompletionRequest( req = ChatCompletionRequest(
model=MODEL_NAME, model=MODEL_NAME,
messages=[{ messages=[{
@ -98,13 +102,13 @@ def test_serving_chat_should_set_correct_max_tokens():
) )
with suppress(Exception): with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req)) await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].max_tokens == 93 assert mock_engine.generate.call_args.args[1].max_tokens == 93
req.max_tokens = 10 req.max_tokens = 10
with suppress(Exception): with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req)) await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].max_tokens == 10 assert mock_engine.generate.call_args.args[1].max_tokens == 10
@ -143,7 +147,7 @@ def test_serving_chat_should_set_correct_max_tokens():
) )
with suppress(Exception): with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req)) await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].max_tokens == 10 assert mock_engine.generate.call_args.args[1].max_tokens == 10
@ -151,7 +155,7 @@ def test_serving_chat_should_set_correct_max_tokens():
req.max_tokens = 15 req.max_tokens = 15
with suppress(Exception): with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req)) await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].max_tokens == 10 assert mock_engine.generate.call_args.args[1].max_tokens == 10
@ -159,7 +163,7 @@ def test_serving_chat_should_set_correct_max_tokens():
req.max_tokens = 5 req.max_tokens = 5
with suppress(Exception): with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req)) await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].max_tokens == 5 assert mock_engine.generate.call_args.args[1].max_tokens == 5
@ -198,7 +202,7 @@ def test_serving_chat_should_set_correct_max_tokens():
) )
with suppress(Exception): with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req)) await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].max_tokens == 93 assert mock_engine.generate.call_args.args[1].max_tokens == 93
@ -206,7 +210,7 @@ def test_serving_chat_should_set_correct_max_tokens():
req.max_tokens = 100 req.max_tokens = 100
with suppress(Exception): with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req)) await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].max_tokens == 93 assert mock_engine.generate.call_args.args[1].max_tokens == 93
@ -214,12 +218,13 @@ def test_serving_chat_should_set_correct_max_tokens():
req.max_tokens = 5 req.max_tokens = 5
with suppress(Exception): with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req)) await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].max_tokens == 5 assert mock_engine.generate.call_args.args[1].max_tokens == 5
def test_serving_chat_could_load_correct_generation_config(): @pytest.mark.asyncio
async def test_serving_chat_could_load_correct_generation_config():
mock_model_config = MockModelConfig() mock_model_config = MockModelConfig()
mock_model_config.diff_sampling_param = { mock_model_config.diff_sampling_param = {
@ -242,6 +247,7 @@ def test_serving_chat_could_load_correct_generation_config():
chat_template=CHAT_TEMPLATE, chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto", chat_template_content_format="auto",
request_logger=None) request_logger=None)
req = ChatCompletionRequest( req = ChatCompletionRequest(
model=MODEL_NAME, model=MODEL_NAME,
messages=[{ messages=[{
@ -252,7 +258,7 @@ def test_serving_chat_could_load_correct_generation_config():
) )
with suppress(Exception): with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req)) await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].temperature == 0.5 assert mock_engine.generate.call_args.args[1].temperature == 0.5
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05 assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
@ -261,7 +267,7 @@ def test_serving_chat_could_load_correct_generation_config():
req.temperature = 0.1 req.temperature = 0.1
with suppress(Exception): with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req)) await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].temperature == 0.1 assert mock_engine.generate.call_args.args[1].temperature == 0.1
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05 assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
@ -270,13 +276,14 @@ def test_serving_chat_could_load_correct_generation_config():
req.temperature = 0.0 req.temperature = 0.0
with suppress(Exception): with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req)) await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].temperature == 0.0 assert mock_engine.generate.call_args.args[1].temperature == 0.0
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05 assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
def test_serving_chat_did_set_correct_cache_salt(): @pytest.mark.asyncio
async def test_serving_chat_did_set_correct_cache_salt():
mock_model_config = MockModelConfig() mock_model_config = MockModelConfig()
mock_engine = MagicMock(spec=MQLLMEngineClient) mock_engine = MagicMock(spec=MQLLMEngineClient)
@ -306,11 +313,11 @@ def test_serving_chat_did_set_correct_cache_salt():
# By default cache_salt in the engine prompt is not set # By default cache_salt in the engine prompt is not set
with suppress(Exception): with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req)) await serving_chat.create_chat_completion(req)
assert "cache_salt" not in mock_engine.generate.call_args.args[0] assert "cache_salt" not in mock_engine.generate.call_args.args[0]
# Test with certain cache_salt # Test with certain cache_salt
req.cache_salt = "test_salt" req.cache_salt = "test_salt"
with suppress(Exception): with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req)) await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[0]["cache_salt"] == "test_salt" assert mock_engine.generate.call_args.args[0]["cache_salt"] == "test_salt"

View File

@ -1,13 +1,13 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import base64 import base64
import io import io
import json import json
import sys import sys
import time import time
from collections.abc import (AsyncGenerator, Iterable, Iterator, Mapping, from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence
Sequence) from concurrent.futures import ThreadPoolExecutor
from concurrent.futures.thread import ThreadPoolExecutor
from http import HTTPStatus from http import HTTPStatus
from typing import (Annotated, Any, Callable, ClassVar, Generic, Optional, from typing import (Annotated, Any, Callable, ClassVar, Generic, Optional,
TypeVar, Union, cast, overload) TypeVar, Union, cast, overload)
@ -79,8 +79,8 @@ from vllm.sequence import Logprob, PromptLogprobs
from vllm.tracing import (contains_trace_headers, extract_trace_headers, from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning) log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import (is_list_of, make_async, merge_async_iterators, from vllm.utils import (AsyncMicrobatchTokenizer, is_list_of,
random_uuid) merge_async_iterators, random_uuid)
logger = init_logger(__name__) logger = init_logger(__name__)
@ -226,11 +226,19 @@ class OpenAIServing:
self._tokenizer_executor = ThreadPoolExecutor(max_workers=1) self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
self._tokenize_prompt_input_async = make_async( self._async_tokenizer_pool: dict[AnyTokenizer,
self._tokenize_prompt_input, executor=self._tokenizer_executor) AsyncMicrobatchTokenizer] = {}
self._tokenize_prompt_input_or_inputs_async = make_async(
self._tokenize_prompt_input_or_inputs, def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer:
executor=self._tokenizer_executor) """
Return (and cache) an `AsyncMicrobatchTokenizer` bound to the
given tokenizer.
"""
async_tokenizer = self._async_tokenizer_pool.get(tokenizer)
if async_tokenizer is None:
async_tokenizer = AsyncMicrobatchTokenizer(tokenizer)
self._async_tokenizer_pool[tokenizer] = async_tokenizer
return async_tokenizer
async def _preprocess( async def _preprocess(
self, self,
@ -467,7 +475,7 @@ class OpenAIServing:
# if _check_model has been called earlier, this will be unreachable # if _check_model has been called earlier, this will be unreachable
raise ValueError(f"The model `{request.model}` does not exist.") raise ValueError(f"The model `{request.model}` does not exist.")
def _normalize_prompt_text_to_input( async def _normalize_prompt_text_to_input(
self, self,
request: AnyRequest, request: AnyRequest,
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
@ -475,38 +483,44 @@ class OpenAIServing:
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]], truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]],
add_special_tokens: bool, add_special_tokens: bool,
) -> TextTokensPrompt: ) -> TextTokensPrompt:
async_tokenizer = self._get_async_tokenizer(tokenizer)
if (self.model_config.encoder_config is not None if (self.model_config.encoder_config is not None
and self.model_config.encoder_config.get( and self.model_config.encoder_config.get(
"do_lower_case", False)): "do_lower_case", False)):
prompt = prompt.lower() prompt = prompt.lower()
if truncate_prompt_tokens is None: if truncate_prompt_tokens is None:
encoded = tokenizer(prompt, add_special_tokens=add_special_tokens) encoded = await async_tokenizer(
prompt, add_special_tokens=add_special_tokens)
elif truncate_prompt_tokens < 0: elif truncate_prompt_tokens < 0:
# Negative means we cap at the model's max length # Negative means we cap at the model's max length
encoded = tokenizer(prompt, encoded = await async_tokenizer(
add_special_tokens=add_special_tokens, prompt,
truncation=True, add_special_tokens=add_special_tokens,
max_length=self.max_model_len) truncation=True,
max_length=self.max_model_len)
else: else:
encoded = tokenizer(prompt, encoded = await async_tokenizer(
add_special_tokens=add_special_tokens, prompt,
truncation=True, add_special_tokens=add_special_tokens,
max_length=truncate_prompt_tokens) truncation=True,
max_length=truncate_prompt_tokens)
input_ids = encoded.input_ids input_ids = encoded.input_ids
input_text = prompt input_text = prompt
return self._validate_input(request, input_ids, input_text) return self._validate_input(request, input_ids, input_text)
def _normalize_prompt_tokens_to_input( async def _normalize_prompt_tokens_to_input(
self, self,
request: AnyRequest, request: AnyRequest,
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
prompt_ids: list[int], prompt_ids: list[int],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]], truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
) -> TextTokensPrompt: ) -> TextTokensPrompt:
async_tokenizer = self._get_async_tokenizer(tokenizer)
if truncate_prompt_tokens is None: if truncate_prompt_tokens is None:
input_ids = prompt_ids input_ids = prompt_ids
elif truncate_prompt_tokens < 0: elif truncate_prompt_tokens < 0:
@ -514,7 +528,7 @@ class OpenAIServing:
else: else:
input_ids = prompt_ids[-truncate_prompt_tokens:] input_ids = prompt_ids[-truncate_prompt_tokens:]
input_text = tokenizer.decode(input_ids) input_text = await async_tokenizer.decode(input_ids)
return self._validate_input(request, input_ids, input_text) return self._validate_input(request, input_ids, input_text)
@ -578,7 +592,7 @@ class OpenAIServing:
return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids) return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
def _tokenize_prompt_input( async def _tokenize_prompt_input_async(
self, self,
request: AnyRequest, request: AnyRequest,
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
@ -591,23 +605,24 @@ class OpenAIServing:
[`_tokenize_prompt_input_or_inputs`][vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs] [`_tokenize_prompt_input_or_inputs`][vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs]
that assumes single input. that assumes single input.
""" """
return next( async for result in self._tokenize_prompt_inputs_async(
self._tokenize_prompt_inputs(
request, request,
tokenizer, tokenizer,
[prompt_input], [prompt_input],
truncate_prompt_tokens=truncate_prompt_tokens, truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens, add_special_tokens=add_special_tokens,
)) ):
return result
raise ValueError("No results yielded from tokenization")
def _tokenize_prompt_inputs( async def _tokenize_prompt_inputs_async(
self, self,
request: AnyRequest, request: AnyRequest,
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
prompt_inputs: Iterable[Union[str, list[int]]], prompt_inputs: Iterable[Union[str, list[int]]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None, truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
add_special_tokens: bool = True, add_special_tokens: bool = True,
) -> Iterator[TextTokensPrompt]: ) -> AsyncGenerator[TextTokensPrompt, None]:
""" """
A simpler implementation of A simpler implementation of
[`_tokenize_prompt_input_or_inputs`][vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs] [`_tokenize_prompt_input_or_inputs`][vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs]
@ -615,7 +630,7 @@ class OpenAIServing:
""" """
for text in prompt_inputs: for text in prompt_inputs:
if isinstance(text, str): if isinstance(text, str):
yield self._normalize_prompt_text_to_input( yield await self._normalize_prompt_text_to_input(
request, request,
tokenizer, tokenizer,
prompt=text, prompt=text,
@ -623,14 +638,14 @@ class OpenAIServing:
add_special_tokens=add_special_tokens, add_special_tokens=add_special_tokens,
) )
else: else:
yield self._normalize_prompt_tokens_to_input( yield await self._normalize_prompt_tokens_to_input(
request, request,
tokenizer, tokenizer,
prompt_ids=text, prompt_ids=text,
truncate_prompt_tokens=truncate_prompt_tokens, truncate_prompt_tokens=truncate_prompt_tokens,
) )
def _tokenize_prompt_input_or_inputs( async def _tokenize_prompt_input_or_inputs_async(
self, self,
request: AnyRequest, request: AnyRequest,
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
@ -664,21 +679,31 @@ class OpenAIServing:
# VSCode Pyright extension should still work properly # VSCode Pyright extension should still work properly
# "is False" is required for Pyright to perform type narrowing # "is False" is required for Pyright to perform type narrowing
# See: https://github.com/microsoft/pyright/issues/7672 # See: https://github.com/microsoft/pyright/issues/7672
inputs_text.extend([
self._normalize_prompt_text_to_input( # Parse and batch the input prompts
request, batch_inputs = parse_and_batch_prompt(input_or_inputs)
tokenizer,
prompt=prompt_input["content"], # Process each input in the batch concurrently
truncate_prompt_tokens=truncate_prompt_tokens, tasks = []
add_special_tokens=add_special_tokens) for prompt_input in batch_inputs:
if prompt_input["is_tokens"] is False else if prompt_input["is_tokens"] is False:
self._normalize_prompt_tokens_to_input( task = self._normalize_prompt_text_to_input(
request, request,
tokenizer, tokenizer,
prompt_ids=prompt_input["content"], prompt_input["content"],
truncate_prompt_tokens=truncate_prompt_tokens) truncate_prompt_tokens=truncate_prompt_tokens,
for prompt_input in parse_and_batch_prompt(input_or_inputs) add_special_tokens=add_special_tokens)
]) else:
task = self._normalize_prompt_tokens_to_input(
request,
tokenizer,
prompt_input["content"],
truncate_prompt_tokens=truncate_prompt_tokens)
tasks.append(task)
# Wait for all tokenization tasks to complete
results = await asyncio.gather(*tasks)
inputs_text.extend(results)
return inputs_text, inputs_embeds return inputs_text, inputs_embeds

View File

@ -41,6 +41,7 @@ from collections import UserDict, defaultdict
from collections.abc import (AsyncGenerator, Awaitable, Collection, Generator, from collections.abc import (AsyncGenerator, Awaitable, Collection, Generator,
Hashable, Iterable, Iterator, KeysView, Mapping, Hashable, Iterable, Iterator, KeysView, Mapping,
Sequence) Sequence)
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures.process import ProcessPoolExecutor from concurrent.futures.process import ProcessPoolExecutor
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import cache, lru_cache, partial, wraps from functools import cache, lru_cache, partial, wraps
@ -64,6 +65,7 @@ import zmq.asyncio
from packaging import version from packaging import version
from packaging.version import Version from packaging.version import Version
from torch.library import Library from torch.library import Library
from transformers.tokenization_utils_base import BatchEncoding
from typing_extensions import Never, ParamSpec, TypeIs, assert_never from typing_extensions import Never, ParamSpec, TypeIs, assert_never
import vllm.envs as envs import vllm.envs as envs
@ -507,6 +509,196 @@ def random_uuid() -> str:
return str(uuid.uuid4().hex) return str(uuid.uuid4().hex)
class AsyncMicrobatchTokenizer:
"""Asynchronous tokenizer with micro-batching.
Pulls pending encode/decode requests from a queue and batches them
up to reduce overhead. A single-thread ThreadPoolExecutor is used
so the event loop stays responsive.
"""
def __init__(
self,
tokenizer,
max_batch_size: int = 32,
batch_wait_timeout_s: float = 0.002,
) -> None:
self.tokenizer = tokenizer
self.max_batch_size = max_batch_size
self.batch_wait_timeout_s = batch_wait_timeout_s
self._loop = asyncio.get_running_loop()
self._queues: dict[tuple,
asyncio.Queue[Union[tuple[str, dict,
asyncio.Future],
tuple[list[int],
asyncio.Future]]]] = {}
self._batcher_tasks: list[asyncio.Task] = []
# Single-thread executor for blocking tokenizer calls.
self._executor = ThreadPoolExecutor(max_workers=1)
# === Public async API ===
async def __call__(self, prompt, **kwargs):
result_future: asyncio.Future = self._loop.create_future()
key = self._queue_key("encode", kwargs)
queue = self._get_queue(self._loop, key)
await queue.put((prompt, kwargs, result_future))
return await result_future
async def decode(self, token_ids, **kwargs):
result_future: asyncio.Future = self._loop.create_future()
key = self._queue_key("decode", kwargs)
queue = self._get_queue(self._loop, key)
await queue.put((token_ids, result_future))
return await result_future
# === Internal helpers ===
def _get_queue(
self, loop: asyncio.AbstractEventLoop, key: tuple
) -> asyncio.Queue[Union[tuple[str, dict, asyncio.Future], tuple[
list[int], asyncio.Future]]]:
"""Get the request queue for the given operation key, creating a new
queue and batcher task if needed."""
queue = self._queues.get(key)
if queue is None:
self._queues[key] = queue = asyncio.Queue()
if key[0] == "encode":
can_batch = key[1] != "other"
coro = self._batch_encode_loop(queue, can_batch)
else:
assert key[0] == "decode", \
f"Unknown operation type: {key[0]}."
coro = self._batch_decode_loop(queue)
self._batcher_tasks.append(loop.create_task(coro))
return queue
async def _batch_encode_loop(self, queue: asyncio.Queue, can_batch: bool):
"""Batch incoming encode requests for efficiency."""
while True:
prompt, kwargs, result_future = await queue.get()
prompts = [prompt]
kwargs_list = [kwargs]
result_futures = [result_future]
deadline = self._loop.time() + self.batch_wait_timeout_s
while len(prompts) < self.max_batch_size:
timeout = deadline - self._loop.time()
if timeout <= 0:
break
try:
prompt, kwargs, result_future = await asyncio.wait_for(
queue.get(), timeout)
prompts.append(prompt)
result_futures.append(result_future)
if not can_batch:
kwargs_list.append(kwargs)
except asyncio.TimeoutError:
break
try:
# If every request uses identical kwargs we can run a single
# batched tokenizer call for a big speed-up.
if can_batch and len(prompts) > 1:
encode_fn = partial(self.tokenizer, prompts, **kwargs)
results = await self._loop.run_in_executor(
self._executor, encode_fn)
for i, fut in enumerate(result_futures):
if not fut.done():
data = {k: v[i] for k, v in results.items()}
fut.set_result(BatchEncoding(data))
else:
encode_fn = lambda prompts=prompts, kwargs=kwargs_list: [
self.tokenizer(p, **kw)
for p, kw in zip(prompts, kwargs)
]
results = await self._loop.run_in_executor(
self._executor, encode_fn)
for fut, res in zip(result_futures, results):
if not fut.done():
fut.set_result(res)
except Exception as e:
for fut in result_futures:
if not fut.done():
fut.set_exception(e)
async def _batch_decode_loop(self, queue: asyncio.Queue):
"""Batch incoming decode requests for efficiency."""
while True:
token_ids, result_future = await queue.get()
token_ids_list = [token_ids]
result_futures = [result_future]
deadline = self._loop.time() + self.batch_wait_timeout_s
while len(token_ids_list) < self.max_batch_size:
timeout = deadline - self._loop.time()
if timeout <= 0:
break
try:
token_ids, result_future = await asyncio.wait_for(
queue.get(), timeout)
token_ids_list.append(token_ids)
result_futures.append(result_future)
except asyncio.TimeoutError:
break
try:
# Perform a single batched decode call for all requests
results = await self._loop.run_in_executor(
self._executor, self.tokenizer.batch_decode,
token_ids_list)
for fut, res in zip(result_futures, results):
if not fut.done():
fut.set_result(res)
except Exception as e:
for fut in result_futures:
if not fut.done():
fut.set_exception(e)
def _queue_key(self, op: str, kwargs: dict) -> tuple:
"""
Return a normalized key describing operation + kwargs.
- `add_special_tokens`: {True/False}
- `truncation`: {True/False}
- If `truncation` is False (`max_length` is None),
returns a key for a can_batch queue.
- If `truncation` is True and `max_length` is None or equals
`tokenizer.model_max_length`, returns a key for a can_batch queue.
- Otherwise, returns a key for a cannot_batch queue.
Examples:
- Decode: ("decode",)
- Encode typical:
("encode", add_special_tokens, bool_truncation, max_length_label)
- Fallback: ("encode", "other")
"""
if op == "decode":
return ("decode", )
add_special_tokens = kwargs.get("add_special_tokens", True)
truncation = kwargs.get("truncation", False)
max_length = kwargs.get("max_length")
if not truncation:
return ("encode", add_special_tokens, False, None)
model_max = getattr(self.tokenizer, "model_max_length", None)
if max_length is None or (model_max is not None
and max_length == model_max):
return ("encode", add_special_tokens, True, "model_max")
return ("encode", "other")
def __del__(self):
for task in self._batcher_tasks:
if not task.done():
task.cancel()
def make_async( def make_async(
func: Callable[P, T], func: Callable[P, T],
executor: Optional[concurrent.futures.Executor] = None executor: Optional[concurrent.futures.Executor] = None