mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 22:05:44 +08:00
[Front-end] microbatch tokenization (#19334)
Signed-off-by: zt2370 <ztang2370@gmail.com>
This commit is contained in:
parent
edd270bc78
commit
a37d75bbec
@ -7,6 +7,8 @@ from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config import MultiModalConfig
|
||||
from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
||||
@ -73,7 +75,8 @@ def test_async_serving_chat_init():
|
||||
assert serving_completion.chat_template == CHAT_TEMPLATE
|
||||
|
||||
|
||||
def test_serving_chat_should_set_correct_max_tokens():
|
||||
@pytest.mark.asyncio
|
||||
async def test_serving_chat_should_set_correct_max_tokens():
|
||||
mock_engine = MagicMock(spec=MQLLMEngineClient)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
@ -88,6 +91,7 @@ def test_serving_chat_should_set_correct_max_tokens():
|
||||
chat_template=CHAT_TEMPLATE,
|
||||
chat_template_content_format="auto",
|
||||
request_logger=None)
|
||||
|
||||
req = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[{
|
||||
@ -98,13 +102,13 @@ def test_serving_chat_should_set_correct_max_tokens():
|
||||
)
|
||||
|
||||
with suppress(Exception):
|
||||
asyncio.run(serving_chat.create_chat_completion(req))
|
||||
await serving_chat.create_chat_completion(req)
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].max_tokens == 93
|
||||
|
||||
req.max_tokens = 10
|
||||
with suppress(Exception):
|
||||
asyncio.run(serving_chat.create_chat_completion(req))
|
||||
await serving_chat.create_chat_completion(req)
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].max_tokens == 10
|
||||
|
||||
@ -143,7 +147,7 @@ def test_serving_chat_should_set_correct_max_tokens():
|
||||
)
|
||||
|
||||
with suppress(Exception):
|
||||
asyncio.run(serving_chat.create_chat_completion(req))
|
||||
await serving_chat.create_chat_completion(req)
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].max_tokens == 10
|
||||
|
||||
@ -151,7 +155,7 @@ def test_serving_chat_should_set_correct_max_tokens():
|
||||
req.max_tokens = 15
|
||||
|
||||
with suppress(Exception):
|
||||
asyncio.run(serving_chat.create_chat_completion(req))
|
||||
await serving_chat.create_chat_completion(req)
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].max_tokens == 10
|
||||
|
||||
@ -159,7 +163,7 @@ def test_serving_chat_should_set_correct_max_tokens():
|
||||
req.max_tokens = 5
|
||||
|
||||
with suppress(Exception):
|
||||
asyncio.run(serving_chat.create_chat_completion(req))
|
||||
await serving_chat.create_chat_completion(req)
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].max_tokens == 5
|
||||
|
||||
@ -198,7 +202,7 @@ def test_serving_chat_should_set_correct_max_tokens():
|
||||
)
|
||||
|
||||
with suppress(Exception):
|
||||
asyncio.run(serving_chat.create_chat_completion(req))
|
||||
await serving_chat.create_chat_completion(req)
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].max_tokens == 93
|
||||
|
||||
@ -206,7 +210,7 @@ def test_serving_chat_should_set_correct_max_tokens():
|
||||
req.max_tokens = 100
|
||||
|
||||
with suppress(Exception):
|
||||
asyncio.run(serving_chat.create_chat_completion(req))
|
||||
await serving_chat.create_chat_completion(req)
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].max_tokens == 93
|
||||
|
||||
@ -214,12 +218,13 @@ def test_serving_chat_should_set_correct_max_tokens():
|
||||
req.max_tokens = 5
|
||||
|
||||
with suppress(Exception):
|
||||
asyncio.run(serving_chat.create_chat_completion(req))
|
||||
await serving_chat.create_chat_completion(req)
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].max_tokens == 5
|
||||
|
||||
|
||||
def test_serving_chat_could_load_correct_generation_config():
|
||||
@pytest.mark.asyncio
|
||||
async def test_serving_chat_could_load_correct_generation_config():
|
||||
|
||||
mock_model_config = MockModelConfig()
|
||||
mock_model_config.diff_sampling_param = {
|
||||
@ -242,6 +247,7 @@ def test_serving_chat_could_load_correct_generation_config():
|
||||
chat_template=CHAT_TEMPLATE,
|
||||
chat_template_content_format="auto",
|
||||
request_logger=None)
|
||||
|
||||
req = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[{
|
||||
@ -252,7 +258,7 @@ def test_serving_chat_could_load_correct_generation_config():
|
||||
)
|
||||
|
||||
with suppress(Exception):
|
||||
asyncio.run(serving_chat.create_chat_completion(req))
|
||||
await serving_chat.create_chat_completion(req)
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].temperature == 0.5
|
||||
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
|
||||
@ -261,7 +267,7 @@ def test_serving_chat_could_load_correct_generation_config():
|
||||
req.temperature = 0.1
|
||||
|
||||
with suppress(Exception):
|
||||
asyncio.run(serving_chat.create_chat_completion(req))
|
||||
await serving_chat.create_chat_completion(req)
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].temperature == 0.1
|
||||
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
|
||||
@ -270,13 +276,14 @@ def test_serving_chat_could_load_correct_generation_config():
|
||||
req.temperature = 0.0
|
||||
|
||||
with suppress(Exception):
|
||||
asyncio.run(serving_chat.create_chat_completion(req))
|
||||
await serving_chat.create_chat_completion(req)
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].temperature == 0.0
|
||||
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
|
||||
|
||||
|
||||
def test_serving_chat_did_set_correct_cache_salt():
|
||||
@pytest.mark.asyncio
|
||||
async def test_serving_chat_did_set_correct_cache_salt():
|
||||
mock_model_config = MockModelConfig()
|
||||
|
||||
mock_engine = MagicMock(spec=MQLLMEngineClient)
|
||||
@ -306,11 +313,11 @@ def test_serving_chat_did_set_correct_cache_salt():
|
||||
|
||||
# By default cache_salt in the engine prompt is not set
|
||||
with suppress(Exception):
|
||||
asyncio.run(serving_chat.create_chat_completion(req))
|
||||
await serving_chat.create_chat_completion(req)
|
||||
assert "cache_salt" not in mock_engine.generate.call_args.args[0]
|
||||
|
||||
# Test with certain cache_salt
|
||||
req.cache_salt = "test_salt"
|
||||
with suppress(Exception):
|
||||
asyncio.run(serving_chat.create_chat_completion(req))
|
||||
await serving_chat.create_chat_completion(req)
|
||||
assert mock_engine.generate.call_args.args[0]["cache_salt"] == "test_salt"
|
||||
|
||||
@ -1,13 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import (AsyncGenerator, Iterable, Iterator, Mapping,
|
||||
Sequence)
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from http import HTTPStatus
|
||||
from typing import (Annotated, Any, Callable, ClassVar, Generic, Optional,
|
||||
TypeVar, Union, cast, overload)
|
||||
@ -79,8 +79,8 @@ from vllm.sequence import Logprob, PromptLogprobs
|
||||
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
||||
log_tracing_disabled_warning)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.utils import (is_list_of, make_async, merge_async_iterators,
|
||||
random_uuid)
|
||||
from vllm.utils import (AsyncMicrobatchTokenizer, is_list_of,
|
||||
merge_async_iterators, random_uuid)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -226,11 +226,19 @@ class OpenAIServing:
|
||||
|
||||
self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
self._tokenize_prompt_input_async = make_async(
|
||||
self._tokenize_prompt_input, executor=self._tokenizer_executor)
|
||||
self._tokenize_prompt_input_or_inputs_async = make_async(
|
||||
self._tokenize_prompt_input_or_inputs,
|
||||
executor=self._tokenizer_executor)
|
||||
self._async_tokenizer_pool: dict[AnyTokenizer,
|
||||
AsyncMicrobatchTokenizer] = {}
|
||||
|
||||
def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer:
|
||||
"""
|
||||
Return (and cache) an `AsyncMicrobatchTokenizer` bound to the
|
||||
given tokenizer.
|
||||
"""
|
||||
async_tokenizer = self._async_tokenizer_pool.get(tokenizer)
|
||||
if async_tokenizer is None:
|
||||
async_tokenizer = AsyncMicrobatchTokenizer(tokenizer)
|
||||
self._async_tokenizer_pool[tokenizer] = async_tokenizer
|
||||
return async_tokenizer
|
||||
|
||||
async def _preprocess(
|
||||
self,
|
||||
@ -467,7 +475,7 @@ class OpenAIServing:
|
||||
# if _check_model has been called earlier, this will be unreachable
|
||||
raise ValueError(f"The model `{request.model}` does not exist.")
|
||||
|
||||
def _normalize_prompt_text_to_input(
|
||||
async def _normalize_prompt_text_to_input(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
@ -475,38 +483,44 @@ class OpenAIServing:
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]],
|
||||
add_special_tokens: bool,
|
||||
) -> TextTokensPrompt:
|
||||
async_tokenizer = self._get_async_tokenizer(tokenizer)
|
||||
|
||||
if (self.model_config.encoder_config is not None
|
||||
and self.model_config.encoder_config.get(
|
||||
"do_lower_case", False)):
|
||||
prompt = prompt.lower()
|
||||
|
||||
if truncate_prompt_tokens is None:
|
||||
encoded = tokenizer(prompt, add_special_tokens=add_special_tokens)
|
||||
encoded = await async_tokenizer(
|
||||
prompt, add_special_tokens=add_special_tokens)
|
||||
elif truncate_prompt_tokens < 0:
|
||||
# Negative means we cap at the model's max length
|
||||
encoded = tokenizer(prompt,
|
||||
add_special_tokens=add_special_tokens,
|
||||
truncation=True,
|
||||
max_length=self.max_model_len)
|
||||
encoded = await async_tokenizer(
|
||||
prompt,
|
||||
add_special_tokens=add_special_tokens,
|
||||
truncation=True,
|
||||
max_length=self.max_model_len)
|
||||
else:
|
||||
encoded = tokenizer(prompt,
|
||||
add_special_tokens=add_special_tokens,
|
||||
truncation=True,
|
||||
max_length=truncate_prompt_tokens)
|
||||
encoded = await async_tokenizer(
|
||||
prompt,
|
||||
add_special_tokens=add_special_tokens,
|
||||
truncation=True,
|
||||
max_length=truncate_prompt_tokens)
|
||||
|
||||
input_ids = encoded.input_ids
|
||||
|
||||
input_text = prompt
|
||||
|
||||
return self._validate_input(request, input_ids, input_text)
|
||||
|
||||
def _normalize_prompt_tokens_to_input(
|
||||
async def _normalize_prompt_tokens_to_input(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
prompt_ids: list[int],
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
|
||||
) -> TextTokensPrompt:
|
||||
async_tokenizer = self._get_async_tokenizer(tokenizer)
|
||||
|
||||
if truncate_prompt_tokens is None:
|
||||
input_ids = prompt_ids
|
||||
elif truncate_prompt_tokens < 0:
|
||||
@ -514,7 +528,7 @@ class OpenAIServing:
|
||||
else:
|
||||
input_ids = prompt_ids[-truncate_prompt_tokens:]
|
||||
|
||||
input_text = tokenizer.decode(input_ids)
|
||||
input_text = await async_tokenizer.decode(input_ids)
|
||||
|
||||
return self._validate_input(request, input_ids, input_text)
|
||||
|
||||
@ -578,7 +592,7 @@ class OpenAIServing:
|
||||
|
||||
return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
|
||||
|
||||
def _tokenize_prompt_input(
|
||||
async def _tokenize_prompt_input_async(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
@ -591,23 +605,24 @@ class OpenAIServing:
|
||||
[`_tokenize_prompt_input_or_inputs`][vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs]
|
||||
that assumes single input.
|
||||
"""
|
||||
return next(
|
||||
self._tokenize_prompt_inputs(
|
||||
async for result in self._tokenize_prompt_inputs_async(
|
||||
request,
|
||||
tokenizer,
|
||||
[prompt_input],
|
||||
[prompt_input],
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
add_special_tokens=add_special_tokens,
|
||||
))
|
||||
):
|
||||
return result
|
||||
raise ValueError("No results yielded from tokenization")
|
||||
|
||||
def _tokenize_prompt_inputs(
|
||||
async def _tokenize_prompt_inputs_async(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
prompt_inputs: Iterable[Union[str, list[int]]],
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
|
||||
add_special_tokens: bool = True,
|
||||
) -> Iterator[TextTokensPrompt]:
|
||||
) -> AsyncGenerator[TextTokensPrompt, None]:
|
||||
"""
|
||||
A simpler implementation of
|
||||
[`_tokenize_prompt_input_or_inputs`][vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs]
|
||||
@ -615,7 +630,7 @@ class OpenAIServing:
|
||||
"""
|
||||
for text in prompt_inputs:
|
||||
if isinstance(text, str):
|
||||
yield self._normalize_prompt_text_to_input(
|
||||
yield await self._normalize_prompt_text_to_input(
|
||||
request,
|
||||
tokenizer,
|
||||
prompt=text,
|
||||
@ -623,14 +638,14 @@ class OpenAIServing:
|
||||
add_special_tokens=add_special_tokens,
|
||||
)
|
||||
else:
|
||||
yield self._normalize_prompt_tokens_to_input(
|
||||
yield await self._normalize_prompt_tokens_to_input(
|
||||
request,
|
||||
tokenizer,
|
||||
prompt_ids=text,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
)
|
||||
|
||||
def _tokenize_prompt_input_or_inputs(
|
||||
async def _tokenize_prompt_input_or_inputs_async(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
@ -664,21 +679,31 @@ class OpenAIServing:
|
||||
# VSCode Pyright extension should still work properly
|
||||
# "is False" is required for Pyright to perform type narrowing
|
||||
# See: https://github.com/microsoft/pyright/issues/7672
|
||||
inputs_text.extend([
|
||||
self._normalize_prompt_text_to_input(
|
||||
request,
|
||||
tokenizer,
|
||||
prompt=prompt_input["content"],
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
add_special_tokens=add_special_tokens)
|
||||
if prompt_input["is_tokens"] is False else
|
||||
self._normalize_prompt_tokens_to_input(
|
||||
request,
|
||||
tokenizer,
|
||||
prompt_ids=prompt_input["content"],
|
||||
truncate_prompt_tokens=truncate_prompt_tokens)
|
||||
for prompt_input in parse_and_batch_prompt(input_or_inputs)
|
||||
])
|
||||
|
||||
# Parse and batch the input prompts
|
||||
batch_inputs = parse_and_batch_prompt(input_or_inputs)
|
||||
|
||||
# Process each input in the batch concurrently
|
||||
tasks = []
|
||||
for prompt_input in batch_inputs:
|
||||
if prompt_input["is_tokens"] is False:
|
||||
task = self._normalize_prompt_text_to_input(
|
||||
request,
|
||||
tokenizer,
|
||||
prompt_input["content"],
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
add_special_tokens=add_special_tokens)
|
||||
else:
|
||||
task = self._normalize_prompt_tokens_to_input(
|
||||
request,
|
||||
tokenizer,
|
||||
prompt_input["content"],
|
||||
truncate_prompt_tokens=truncate_prompt_tokens)
|
||||
tasks.append(task)
|
||||
|
||||
# Wait for all tokenization tasks to complete
|
||||
results = await asyncio.gather(*tasks)
|
||||
inputs_text.extend(results)
|
||||
|
||||
return inputs_text, inputs_embeds
|
||||
|
||||
|
||||
@ -41,6 +41,7 @@ from collections import UserDict, defaultdict
|
||||
from collections.abc import (AsyncGenerator, Awaitable, Collection, Generator,
|
||||
Hashable, Iterable, Iterator, KeysView, Mapping,
|
||||
Sequence)
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from concurrent.futures.process import ProcessPoolExecutor
|
||||
from dataclasses import dataclass, field
|
||||
from functools import cache, lru_cache, partial, wraps
|
||||
@ -64,6 +65,7 @@ import zmq.asyncio
|
||||
from packaging import version
|
||||
from packaging.version import Version
|
||||
from torch.library import Library
|
||||
from transformers.tokenization_utils_base import BatchEncoding
|
||||
from typing_extensions import Never, ParamSpec, TypeIs, assert_never
|
||||
|
||||
import vllm.envs as envs
|
||||
@ -507,6 +509,196 @@ def random_uuid() -> str:
|
||||
return str(uuid.uuid4().hex)
|
||||
|
||||
|
||||
class AsyncMicrobatchTokenizer:
|
||||
"""Asynchronous tokenizer with micro-batching.
|
||||
|
||||
Pulls pending encode/decode requests from a queue and batches them
|
||||
up to reduce overhead. A single-thread ThreadPoolExecutor is used
|
||||
so the event loop stays responsive.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer,
|
||||
max_batch_size: int = 32,
|
||||
batch_wait_timeout_s: float = 0.002,
|
||||
) -> None:
|
||||
self.tokenizer = tokenizer
|
||||
self.max_batch_size = max_batch_size
|
||||
self.batch_wait_timeout_s = batch_wait_timeout_s
|
||||
|
||||
self._loop = asyncio.get_running_loop()
|
||||
self._queues: dict[tuple,
|
||||
asyncio.Queue[Union[tuple[str, dict,
|
||||
asyncio.Future],
|
||||
tuple[list[int],
|
||||
asyncio.Future]]]] = {}
|
||||
self._batcher_tasks: list[asyncio.Task] = []
|
||||
|
||||
# Single-thread executor for blocking tokenizer calls.
|
||||
self._executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
# === Public async API ===
|
||||
async def __call__(self, prompt, **kwargs):
|
||||
result_future: asyncio.Future = self._loop.create_future()
|
||||
key = self._queue_key("encode", kwargs)
|
||||
queue = self._get_queue(self._loop, key)
|
||||
await queue.put((prompt, kwargs, result_future))
|
||||
return await result_future
|
||||
|
||||
async def decode(self, token_ids, **kwargs):
|
||||
result_future: asyncio.Future = self._loop.create_future()
|
||||
key = self._queue_key("decode", kwargs)
|
||||
queue = self._get_queue(self._loop, key)
|
||||
await queue.put((token_ids, result_future))
|
||||
return await result_future
|
||||
|
||||
# === Internal helpers ===
|
||||
def _get_queue(
|
||||
self, loop: asyncio.AbstractEventLoop, key: tuple
|
||||
) -> asyncio.Queue[Union[tuple[str, dict, asyncio.Future], tuple[
|
||||
list[int], asyncio.Future]]]:
|
||||
"""Get the request queue for the given operation key, creating a new
|
||||
queue and batcher task if needed."""
|
||||
queue = self._queues.get(key)
|
||||
if queue is None:
|
||||
self._queues[key] = queue = asyncio.Queue()
|
||||
if key[0] == "encode":
|
||||
can_batch = key[1] != "other"
|
||||
coro = self._batch_encode_loop(queue, can_batch)
|
||||
else:
|
||||
assert key[0] == "decode", \
|
||||
f"Unknown operation type: {key[0]}."
|
||||
coro = self._batch_decode_loop(queue)
|
||||
self._batcher_tasks.append(loop.create_task(coro))
|
||||
return queue
|
||||
|
||||
async def _batch_encode_loop(self, queue: asyncio.Queue, can_batch: bool):
|
||||
"""Batch incoming encode requests for efficiency."""
|
||||
while True:
|
||||
prompt, kwargs, result_future = await queue.get()
|
||||
prompts = [prompt]
|
||||
kwargs_list = [kwargs]
|
||||
result_futures = [result_future]
|
||||
deadline = self._loop.time() + self.batch_wait_timeout_s
|
||||
|
||||
while len(prompts) < self.max_batch_size:
|
||||
timeout = deadline - self._loop.time()
|
||||
if timeout <= 0:
|
||||
break
|
||||
try:
|
||||
prompt, kwargs, result_future = await asyncio.wait_for(
|
||||
queue.get(), timeout)
|
||||
prompts.append(prompt)
|
||||
result_futures.append(result_future)
|
||||
if not can_batch:
|
||||
kwargs_list.append(kwargs)
|
||||
except asyncio.TimeoutError:
|
||||
break
|
||||
|
||||
try:
|
||||
# If every request uses identical kwargs we can run a single
|
||||
# batched tokenizer call for a big speed-up.
|
||||
if can_batch and len(prompts) > 1:
|
||||
encode_fn = partial(self.tokenizer, prompts, **kwargs)
|
||||
results = await self._loop.run_in_executor(
|
||||
self._executor, encode_fn)
|
||||
|
||||
for i, fut in enumerate(result_futures):
|
||||
if not fut.done():
|
||||
data = {k: v[i] for k, v in results.items()}
|
||||
fut.set_result(BatchEncoding(data))
|
||||
else:
|
||||
encode_fn = lambda prompts=prompts, kwargs=kwargs_list: [
|
||||
self.tokenizer(p, **kw)
|
||||
for p, kw in zip(prompts, kwargs)
|
||||
]
|
||||
results = await self._loop.run_in_executor(
|
||||
self._executor, encode_fn)
|
||||
|
||||
for fut, res in zip(result_futures, results):
|
||||
if not fut.done():
|
||||
fut.set_result(res)
|
||||
except Exception as e:
|
||||
for fut in result_futures:
|
||||
if not fut.done():
|
||||
fut.set_exception(e)
|
||||
|
||||
async def _batch_decode_loop(self, queue: asyncio.Queue):
|
||||
"""Batch incoming decode requests for efficiency."""
|
||||
while True:
|
||||
token_ids, result_future = await queue.get()
|
||||
token_ids_list = [token_ids]
|
||||
result_futures = [result_future]
|
||||
deadline = self._loop.time() + self.batch_wait_timeout_s
|
||||
|
||||
while len(token_ids_list) < self.max_batch_size:
|
||||
timeout = deadline - self._loop.time()
|
||||
if timeout <= 0:
|
||||
break
|
||||
try:
|
||||
token_ids, result_future = await asyncio.wait_for(
|
||||
queue.get(), timeout)
|
||||
token_ids_list.append(token_ids)
|
||||
result_futures.append(result_future)
|
||||
except asyncio.TimeoutError:
|
||||
break
|
||||
|
||||
try:
|
||||
# Perform a single batched decode call for all requests
|
||||
results = await self._loop.run_in_executor(
|
||||
self._executor, self.tokenizer.batch_decode,
|
||||
token_ids_list)
|
||||
for fut, res in zip(result_futures, results):
|
||||
if not fut.done():
|
||||
fut.set_result(res)
|
||||
except Exception as e:
|
||||
for fut in result_futures:
|
||||
if not fut.done():
|
||||
fut.set_exception(e)
|
||||
|
||||
def _queue_key(self, op: str, kwargs: dict) -> tuple:
|
||||
"""
|
||||
Return a normalized key describing operation + kwargs.
|
||||
|
||||
- `add_special_tokens`: {True/False}
|
||||
- `truncation`: {True/False}
|
||||
- If `truncation` is False (`max_length` is None),
|
||||
returns a key for a can_batch queue.
|
||||
- If `truncation` is True and `max_length` is None or equals
|
||||
`tokenizer.model_max_length`, returns a key for a can_batch queue.
|
||||
- Otherwise, returns a key for a cannot_batch queue.
|
||||
|
||||
Examples:
|
||||
- Decode: ("decode",)
|
||||
- Encode typical:
|
||||
("encode", add_special_tokens, bool_truncation, max_length_label)
|
||||
- Fallback: ("encode", "other")
|
||||
"""
|
||||
|
||||
if op == "decode":
|
||||
return ("decode", )
|
||||
|
||||
add_special_tokens = kwargs.get("add_special_tokens", True)
|
||||
truncation = kwargs.get("truncation", False)
|
||||
max_length = kwargs.get("max_length")
|
||||
|
||||
if not truncation:
|
||||
return ("encode", add_special_tokens, False, None)
|
||||
|
||||
model_max = getattr(self.tokenizer, "model_max_length", None)
|
||||
if max_length is None or (model_max is not None
|
||||
and max_length == model_max):
|
||||
return ("encode", add_special_tokens, True, "model_max")
|
||||
|
||||
return ("encode", "other")
|
||||
|
||||
def __del__(self):
|
||||
for task in self._batcher_tasks:
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
|
||||
|
||||
def make_async(
|
||||
func: Callable[P, T],
|
||||
executor: Optional[concurrent.futures.Executor] = None
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user