mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 01:45:01 +08:00
[Front-end] microbatch tokenization (#19334)
Signed-off-by: zt2370 <ztang2370@gmail.com>
This commit is contained in:
parent
edd270bc78
commit
a37d75bbec
@ -7,6 +7,8 @@ from dataclasses import dataclass, field
|
|||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from vllm.config import MultiModalConfig
|
from vllm.config import MultiModalConfig
|
||||||
from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
||||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
||||||
@ -73,7 +75,8 @@ def test_async_serving_chat_init():
|
|||||||
assert serving_completion.chat_template == CHAT_TEMPLATE
|
assert serving_completion.chat_template == CHAT_TEMPLATE
|
||||||
|
|
||||||
|
|
||||||
def test_serving_chat_should_set_correct_max_tokens():
|
@pytest.mark.asyncio
|
||||||
|
async def test_serving_chat_should_set_correct_max_tokens():
|
||||||
mock_engine = MagicMock(spec=MQLLMEngineClient)
|
mock_engine = MagicMock(spec=MQLLMEngineClient)
|
||||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||||
mock_engine.errored = False
|
mock_engine.errored = False
|
||||||
@ -88,6 +91,7 @@ def test_serving_chat_should_set_correct_max_tokens():
|
|||||||
chat_template=CHAT_TEMPLATE,
|
chat_template=CHAT_TEMPLATE,
|
||||||
chat_template_content_format="auto",
|
chat_template_content_format="auto",
|
||||||
request_logger=None)
|
request_logger=None)
|
||||||
|
|
||||||
req = ChatCompletionRequest(
|
req = ChatCompletionRequest(
|
||||||
model=MODEL_NAME,
|
model=MODEL_NAME,
|
||||||
messages=[{
|
messages=[{
|
||||||
@ -98,13 +102,13 @@ def test_serving_chat_should_set_correct_max_tokens():
|
|||||||
)
|
)
|
||||||
|
|
||||||
with suppress(Exception):
|
with suppress(Exception):
|
||||||
asyncio.run(serving_chat.create_chat_completion(req))
|
await serving_chat.create_chat_completion(req)
|
||||||
|
|
||||||
assert mock_engine.generate.call_args.args[1].max_tokens == 93
|
assert mock_engine.generate.call_args.args[1].max_tokens == 93
|
||||||
|
|
||||||
req.max_tokens = 10
|
req.max_tokens = 10
|
||||||
with suppress(Exception):
|
with suppress(Exception):
|
||||||
asyncio.run(serving_chat.create_chat_completion(req))
|
await serving_chat.create_chat_completion(req)
|
||||||
|
|
||||||
assert mock_engine.generate.call_args.args[1].max_tokens == 10
|
assert mock_engine.generate.call_args.args[1].max_tokens == 10
|
||||||
|
|
||||||
@ -143,7 +147,7 @@ def test_serving_chat_should_set_correct_max_tokens():
|
|||||||
)
|
)
|
||||||
|
|
||||||
with suppress(Exception):
|
with suppress(Exception):
|
||||||
asyncio.run(serving_chat.create_chat_completion(req))
|
await serving_chat.create_chat_completion(req)
|
||||||
|
|
||||||
assert mock_engine.generate.call_args.args[1].max_tokens == 10
|
assert mock_engine.generate.call_args.args[1].max_tokens == 10
|
||||||
|
|
||||||
@ -151,7 +155,7 @@ def test_serving_chat_should_set_correct_max_tokens():
|
|||||||
req.max_tokens = 15
|
req.max_tokens = 15
|
||||||
|
|
||||||
with suppress(Exception):
|
with suppress(Exception):
|
||||||
asyncio.run(serving_chat.create_chat_completion(req))
|
await serving_chat.create_chat_completion(req)
|
||||||
|
|
||||||
assert mock_engine.generate.call_args.args[1].max_tokens == 10
|
assert mock_engine.generate.call_args.args[1].max_tokens == 10
|
||||||
|
|
||||||
@ -159,7 +163,7 @@ def test_serving_chat_should_set_correct_max_tokens():
|
|||||||
req.max_tokens = 5
|
req.max_tokens = 5
|
||||||
|
|
||||||
with suppress(Exception):
|
with suppress(Exception):
|
||||||
asyncio.run(serving_chat.create_chat_completion(req))
|
await serving_chat.create_chat_completion(req)
|
||||||
|
|
||||||
assert mock_engine.generate.call_args.args[1].max_tokens == 5
|
assert mock_engine.generate.call_args.args[1].max_tokens == 5
|
||||||
|
|
||||||
@ -198,7 +202,7 @@ def test_serving_chat_should_set_correct_max_tokens():
|
|||||||
)
|
)
|
||||||
|
|
||||||
with suppress(Exception):
|
with suppress(Exception):
|
||||||
asyncio.run(serving_chat.create_chat_completion(req))
|
await serving_chat.create_chat_completion(req)
|
||||||
|
|
||||||
assert mock_engine.generate.call_args.args[1].max_tokens == 93
|
assert mock_engine.generate.call_args.args[1].max_tokens == 93
|
||||||
|
|
||||||
@ -206,7 +210,7 @@ def test_serving_chat_should_set_correct_max_tokens():
|
|||||||
req.max_tokens = 100
|
req.max_tokens = 100
|
||||||
|
|
||||||
with suppress(Exception):
|
with suppress(Exception):
|
||||||
asyncio.run(serving_chat.create_chat_completion(req))
|
await serving_chat.create_chat_completion(req)
|
||||||
|
|
||||||
assert mock_engine.generate.call_args.args[1].max_tokens == 93
|
assert mock_engine.generate.call_args.args[1].max_tokens == 93
|
||||||
|
|
||||||
@ -214,12 +218,13 @@ def test_serving_chat_should_set_correct_max_tokens():
|
|||||||
req.max_tokens = 5
|
req.max_tokens = 5
|
||||||
|
|
||||||
with suppress(Exception):
|
with suppress(Exception):
|
||||||
asyncio.run(serving_chat.create_chat_completion(req))
|
await serving_chat.create_chat_completion(req)
|
||||||
|
|
||||||
assert mock_engine.generate.call_args.args[1].max_tokens == 5
|
assert mock_engine.generate.call_args.args[1].max_tokens == 5
|
||||||
|
|
||||||
|
|
||||||
def test_serving_chat_could_load_correct_generation_config():
|
@pytest.mark.asyncio
|
||||||
|
async def test_serving_chat_could_load_correct_generation_config():
|
||||||
|
|
||||||
mock_model_config = MockModelConfig()
|
mock_model_config = MockModelConfig()
|
||||||
mock_model_config.diff_sampling_param = {
|
mock_model_config.diff_sampling_param = {
|
||||||
@ -242,6 +247,7 @@ def test_serving_chat_could_load_correct_generation_config():
|
|||||||
chat_template=CHAT_TEMPLATE,
|
chat_template=CHAT_TEMPLATE,
|
||||||
chat_template_content_format="auto",
|
chat_template_content_format="auto",
|
||||||
request_logger=None)
|
request_logger=None)
|
||||||
|
|
||||||
req = ChatCompletionRequest(
|
req = ChatCompletionRequest(
|
||||||
model=MODEL_NAME,
|
model=MODEL_NAME,
|
||||||
messages=[{
|
messages=[{
|
||||||
@ -252,7 +258,7 @@ def test_serving_chat_could_load_correct_generation_config():
|
|||||||
)
|
)
|
||||||
|
|
||||||
with suppress(Exception):
|
with suppress(Exception):
|
||||||
asyncio.run(serving_chat.create_chat_completion(req))
|
await serving_chat.create_chat_completion(req)
|
||||||
|
|
||||||
assert mock_engine.generate.call_args.args[1].temperature == 0.5
|
assert mock_engine.generate.call_args.args[1].temperature == 0.5
|
||||||
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
|
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
|
||||||
@ -261,7 +267,7 @@ def test_serving_chat_could_load_correct_generation_config():
|
|||||||
req.temperature = 0.1
|
req.temperature = 0.1
|
||||||
|
|
||||||
with suppress(Exception):
|
with suppress(Exception):
|
||||||
asyncio.run(serving_chat.create_chat_completion(req))
|
await serving_chat.create_chat_completion(req)
|
||||||
|
|
||||||
assert mock_engine.generate.call_args.args[1].temperature == 0.1
|
assert mock_engine.generate.call_args.args[1].temperature == 0.1
|
||||||
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
|
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
|
||||||
@ -270,13 +276,14 @@ def test_serving_chat_could_load_correct_generation_config():
|
|||||||
req.temperature = 0.0
|
req.temperature = 0.0
|
||||||
|
|
||||||
with suppress(Exception):
|
with suppress(Exception):
|
||||||
asyncio.run(serving_chat.create_chat_completion(req))
|
await serving_chat.create_chat_completion(req)
|
||||||
|
|
||||||
assert mock_engine.generate.call_args.args[1].temperature == 0.0
|
assert mock_engine.generate.call_args.args[1].temperature == 0.0
|
||||||
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
|
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
|
||||||
|
|
||||||
|
|
||||||
def test_serving_chat_did_set_correct_cache_salt():
|
@pytest.mark.asyncio
|
||||||
|
async def test_serving_chat_did_set_correct_cache_salt():
|
||||||
mock_model_config = MockModelConfig()
|
mock_model_config = MockModelConfig()
|
||||||
|
|
||||||
mock_engine = MagicMock(spec=MQLLMEngineClient)
|
mock_engine = MagicMock(spec=MQLLMEngineClient)
|
||||||
@ -306,11 +313,11 @@ def test_serving_chat_did_set_correct_cache_salt():
|
|||||||
|
|
||||||
# By default cache_salt in the engine prompt is not set
|
# By default cache_salt in the engine prompt is not set
|
||||||
with suppress(Exception):
|
with suppress(Exception):
|
||||||
asyncio.run(serving_chat.create_chat_completion(req))
|
await serving_chat.create_chat_completion(req)
|
||||||
assert "cache_salt" not in mock_engine.generate.call_args.args[0]
|
assert "cache_salt" not in mock_engine.generate.call_args.args[0]
|
||||||
|
|
||||||
# Test with certain cache_salt
|
# Test with certain cache_salt
|
||||||
req.cache_salt = "test_salt"
|
req.cache_salt = "test_salt"
|
||||||
with suppress(Exception):
|
with suppress(Exception):
|
||||||
asyncio.run(serving_chat.create_chat_completion(req))
|
await serving_chat.create_chat_completion(req)
|
||||||
assert mock_engine.generate.call_args.args[0]["cache_salt"] == "test_salt"
|
assert mock_engine.generate.call_args.args[0]["cache_salt"] == "test_salt"
|
||||||
|
|||||||
@ -1,13 +1,13 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from collections.abc import (AsyncGenerator, Iterable, Iterator, Mapping,
|
from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence
|
||||||
Sequence)
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from concurrent.futures.thread import ThreadPoolExecutor
|
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import (Annotated, Any, Callable, ClassVar, Generic, Optional,
|
from typing import (Annotated, Any, Callable, ClassVar, Generic, Optional,
|
||||||
TypeVar, Union, cast, overload)
|
TypeVar, Union, cast, overload)
|
||||||
@ -79,8 +79,8 @@ from vllm.sequence import Logprob, PromptLogprobs
|
|||||||
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
||||||
log_tracing_disabled_warning)
|
log_tracing_disabled_warning)
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||||
from vllm.utils import (is_list_of, make_async, merge_async_iterators,
|
from vllm.utils import (AsyncMicrobatchTokenizer, is_list_of,
|
||||||
random_uuid)
|
merge_async_iterators, random_uuid)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -226,11 +226,19 @@ class OpenAIServing:
|
|||||||
|
|
||||||
self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
|
self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
|
||||||
|
|
||||||
self._tokenize_prompt_input_async = make_async(
|
self._async_tokenizer_pool: dict[AnyTokenizer,
|
||||||
self._tokenize_prompt_input, executor=self._tokenizer_executor)
|
AsyncMicrobatchTokenizer] = {}
|
||||||
self._tokenize_prompt_input_or_inputs_async = make_async(
|
|
||||||
self._tokenize_prompt_input_or_inputs,
|
def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer:
|
||||||
executor=self._tokenizer_executor)
|
"""
|
||||||
|
Return (and cache) an `AsyncMicrobatchTokenizer` bound to the
|
||||||
|
given tokenizer.
|
||||||
|
"""
|
||||||
|
async_tokenizer = self._async_tokenizer_pool.get(tokenizer)
|
||||||
|
if async_tokenizer is None:
|
||||||
|
async_tokenizer = AsyncMicrobatchTokenizer(tokenizer)
|
||||||
|
self._async_tokenizer_pool[tokenizer] = async_tokenizer
|
||||||
|
return async_tokenizer
|
||||||
|
|
||||||
async def _preprocess(
|
async def _preprocess(
|
||||||
self,
|
self,
|
||||||
@ -467,7 +475,7 @@ class OpenAIServing:
|
|||||||
# if _check_model has been called earlier, this will be unreachable
|
# if _check_model has been called earlier, this will be unreachable
|
||||||
raise ValueError(f"The model `{request.model}` does not exist.")
|
raise ValueError(f"The model `{request.model}` does not exist.")
|
||||||
|
|
||||||
def _normalize_prompt_text_to_input(
|
async def _normalize_prompt_text_to_input(
|
||||||
self,
|
self,
|
||||||
request: AnyRequest,
|
request: AnyRequest,
|
||||||
tokenizer: AnyTokenizer,
|
tokenizer: AnyTokenizer,
|
||||||
@ -475,38 +483,44 @@ class OpenAIServing:
|
|||||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]],
|
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]],
|
||||||
add_special_tokens: bool,
|
add_special_tokens: bool,
|
||||||
) -> TextTokensPrompt:
|
) -> TextTokensPrompt:
|
||||||
|
async_tokenizer = self._get_async_tokenizer(tokenizer)
|
||||||
|
|
||||||
if (self.model_config.encoder_config is not None
|
if (self.model_config.encoder_config is not None
|
||||||
and self.model_config.encoder_config.get(
|
and self.model_config.encoder_config.get(
|
||||||
"do_lower_case", False)):
|
"do_lower_case", False)):
|
||||||
prompt = prompt.lower()
|
prompt = prompt.lower()
|
||||||
|
|
||||||
if truncate_prompt_tokens is None:
|
if truncate_prompt_tokens is None:
|
||||||
encoded = tokenizer(prompt, add_special_tokens=add_special_tokens)
|
encoded = await async_tokenizer(
|
||||||
|
prompt, add_special_tokens=add_special_tokens)
|
||||||
elif truncate_prompt_tokens < 0:
|
elif truncate_prompt_tokens < 0:
|
||||||
# Negative means we cap at the model's max length
|
# Negative means we cap at the model's max length
|
||||||
encoded = tokenizer(prompt,
|
encoded = await async_tokenizer(
|
||||||
add_special_tokens=add_special_tokens,
|
prompt,
|
||||||
truncation=True,
|
add_special_tokens=add_special_tokens,
|
||||||
max_length=self.max_model_len)
|
truncation=True,
|
||||||
|
max_length=self.max_model_len)
|
||||||
else:
|
else:
|
||||||
encoded = tokenizer(prompt,
|
encoded = await async_tokenizer(
|
||||||
add_special_tokens=add_special_tokens,
|
prompt,
|
||||||
truncation=True,
|
add_special_tokens=add_special_tokens,
|
||||||
max_length=truncate_prompt_tokens)
|
truncation=True,
|
||||||
|
max_length=truncate_prompt_tokens)
|
||||||
|
|
||||||
input_ids = encoded.input_ids
|
input_ids = encoded.input_ids
|
||||||
|
|
||||||
input_text = prompt
|
input_text = prompt
|
||||||
|
|
||||||
return self._validate_input(request, input_ids, input_text)
|
return self._validate_input(request, input_ids, input_text)
|
||||||
|
|
||||||
def _normalize_prompt_tokens_to_input(
|
async def _normalize_prompt_tokens_to_input(
|
||||||
self,
|
self,
|
||||||
request: AnyRequest,
|
request: AnyRequest,
|
||||||
tokenizer: AnyTokenizer,
|
tokenizer: AnyTokenizer,
|
||||||
prompt_ids: list[int],
|
prompt_ids: list[int],
|
||||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
|
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
|
||||||
) -> TextTokensPrompt:
|
) -> TextTokensPrompt:
|
||||||
|
async_tokenizer = self._get_async_tokenizer(tokenizer)
|
||||||
|
|
||||||
if truncate_prompt_tokens is None:
|
if truncate_prompt_tokens is None:
|
||||||
input_ids = prompt_ids
|
input_ids = prompt_ids
|
||||||
elif truncate_prompt_tokens < 0:
|
elif truncate_prompt_tokens < 0:
|
||||||
@ -514,7 +528,7 @@ class OpenAIServing:
|
|||||||
else:
|
else:
|
||||||
input_ids = prompt_ids[-truncate_prompt_tokens:]
|
input_ids = prompt_ids[-truncate_prompt_tokens:]
|
||||||
|
|
||||||
input_text = tokenizer.decode(input_ids)
|
input_text = await async_tokenizer.decode(input_ids)
|
||||||
|
|
||||||
return self._validate_input(request, input_ids, input_text)
|
return self._validate_input(request, input_ids, input_text)
|
||||||
|
|
||||||
@ -578,7 +592,7 @@ class OpenAIServing:
|
|||||||
|
|
||||||
return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
|
return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
|
||||||
|
|
||||||
def _tokenize_prompt_input(
|
async def _tokenize_prompt_input_async(
|
||||||
self,
|
self,
|
||||||
request: AnyRequest,
|
request: AnyRequest,
|
||||||
tokenizer: AnyTokenizer,
|
tokenizer: AnyTokenizer,
|
||||||
@ -591,23 +605,24 @@ class OpenAIServing:
|
|||||||
[`_tokenize_prompt_input_or_inputs`][vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs]
|
[`_tokenize_prompt_input_or_inputs`][vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs]
|
||||||
that assumes single input.
|
that assumes single input.
|
||||||
"""
|
"""
|
||||||
return next(
|
async for result in self._tokenize_prompt_inputs_async(
|
||||||
self._tokenize_prompt_inputs(
|
|
||||||
request,
|
request,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
[prompt_input],
|
[prompt_input],
|
||||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||||
add_special_tokens=add_special_tokens,
|
add_special_tokens=add_special_tokens,
|
||||||
))
|
):
|
||||||
|
return result
|
||||||
|
raise ValueError("No results yielded from tokenization")
|
||||||
|
|
||||||
def _tokenize_prompt_inputs(
|
async def _tokenize_prompt_inputs_async(
|
||||||
self,
|
self,
|
||||||
request: AnyRequest,
|
request: AnyRequest,
|
||||||
tokenizer: AnyTokenizer,
|
tokenizer: AnyTokenizer,
|
||||||
prompt_inputs: Iterable[Union[str, list[int]]],
|
prompt_inputs: Iterable[Union[str, list[int]]],
|
||||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
|
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
|
||||||
add_special_tokens: bool = True,
|
add_special_tokens: bool = True,
|
||||||
) -> Iterator[TextTokensPrompt]:
|
) -> AsyncGenerator[TextTokensPrompt, None]:
|
||||||
"""
|
"""
|
||||||
A simpler implementation of
|
A simpler implementation of
|
||||||
[`_tokenize_prompt_input_or_inputs`][vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs]
|
[`_tokenize_prompt_input_or_inputs`][vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs]
|
||||||
@ -615,7 +630,7 @@ class OpenAIServing:
|
|||||||
"""
|
"""
|
||||||
for text in prompt_inputs:
|
for text in prompt_inputs:
|
||||||
if isinstance(text, str):
|
if isinstance(text, str):
|
||||||
yield self._normalize_prompt_text_to_input(
|
yield await self._normalize_prompt_text_to_input(
|
||||||
request,
|
request,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
prompt=text,
|
prompt=text,
|
||||||
@ -623,14 +638,14 @@ class OpenAIServing:
|
|||||||
add_special_tokens=add_special_tokens,
|
add_special_tokens=add_special_tokens,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
yield self._normalize_prompt_tokens_to_input(
|
yield await self._normalize_prompt_tokens_to_input(
|
||||||
request,
|
request,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
prompt_ids=text,
|
prompt_ids=text,
|
||||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _tokenize_prompt_input_or_inputs(
|
async def _tokenize_prompt_input_or_inputs_async(
|
||||||
self,
|
self,
|
||||||
request: AnyRequest,
|
request: AnyRequest,
|
||||||
tokenizer: AnyTokenizer,
|
tokenizer: AnyTokenizer,
|
||||||
@ -664,21 +679,31 @@ class OpenAIServing:
|
|||||||
# VSCode Pyright extension should still work properly
|
# VSCode Pyright extension should still work properly
|
||||||
# "is False" is required for Pyright to perform type narrowing
|
# "is False" is required for Pyright to perform type narrowing
|
||||||
# See: https://github.com/microsoft/pyright/issues/7672
|
# See: https://github.com/microsoft/pyright/issues/7672
|
||||||
inputs_text.extend([
|
|
||||||
self._normalize_prompt_text_to_input(
|
# Parse and batch the input prompts
|
||||||
request,
|
batch_inputs = parse_and_batch_prompt(input_or_inputs)
|
||||||
tokenizer,
|
|
||||||
prompt=prompt_input["content"],
|
# Process each input in the batch concurrently
|
||||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
tasks = []
|
||||||
add_special_tokens=add_special_tokens)
|
for prompt_input in batch_inputs:
|
||||||
if prompt_input["is_tokens"] is False else
|
if prompt_input["is_tokens"] is False:
|
||||||
self._normalize_prompt_tokens_to_input(
|
task = self._normalize_prompt_text_to_input(
|
||||||
request,
|
request,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
prompt_ids=prompt_input["content"],
|
prompt_input["content"],
|
||||||
truncate_prompt_tokens=truncate_prompt_tokens)
|
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||||
for prompt_input in parse_and_batch_prompt(input_or_inputs)
|
add_special_tokens=add_special_tokens)
|
||||||
])
|
else:
|
||||||
|
task = self._normalize_prompt_tokens_to_input(
|
||||||
|
request,
|
||||||
|
tokenizer,
|
||||||
|
prompt_input["content"],
|
||||||
|
truncate_prompt_tokens=truncate_prompt_tokens)
|
||||||
|
tasks.append(task)
|
||||||
|
|
||||||
|
# Wait for all tokenization tasks to complete
|
||||||
|
results = await asyncio.gather(*tasks)
|
||||||
|
inputs_text.extend(results)
|
||||||
|
|
||||||
return inputs_text, inputs_embeds
|
return inputs_text, inputs_embeds
|
||||||
|
|
||||||
|
|||||||
@ -41,6 +41,7 @@ from collections import UserDict, defaultdict
|
|||||||
from collections.abc import (AsyncGenerator, Awaitable, Collection, Generator,
|
from collections.abc import (AsyncGenerator, Awaitable, Collection, Generator,
|
||||||
Hashable, Iterable, Iterator, KeysView, Mapping,
|
Hashable, Iterable, Iterator, KeysView, Mapping,
|
||||||
Sequence)
|
Sequence)
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from concurrent.futures.process import ProcessPoolExecutor
|
from concurrent.futures.process import ProcessPoolExecutor
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from functools import cache, lru_cache, partial, wraps
|
from functools import cache, lru_cache, partial, wraps
|
||||||
@ -64,6 +65,7 @@ import zmq.asyncio
|
|||||||
from packaging import version
|
from packaging import version
|
||||||
from packaging.version import Version
|
from packaging.version import Version
|
||||||
from torch.library import Library
|
from torch.library import Library
|
||||||
|
from transformers.tokenization_utils_base import BatchEncoding
|
||||||
from typing_extensions import Never, ParamSpec, TypeIs, assert_never
|
from typing_extensions import Never, ParamSpec, TypeIs, assert_never
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
@ -507,6 +509,196 @@ def random_uuid() -> str:
|
|||||||
return str(uuid.uuid4().hex)
|
return str(uuid.uuid4().hex)
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncMicrobatchTokenizer:
|
||||||
|
"""Asynchronous tokenizer with micro-batching.
|
||||||
|
|
||||||
|
Pulls pending encode/decode requests from a queue and batches them
|
||||||
|
up to reduce overhead. A single-thread ThreadPoolExecutor is used
|
||||||
|
so the event loop stays responsive.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tokenizer,
|
||||||
|
max_batch_size: int = 32,
|
||||||
|
batch_wait_timeout_s: float = 0.002,
|
||||||
|
) -> None:
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.max_batch_size = max_batch_size
|
||||||
|
self.batch_wait_timeout_s = batch_wait_timeout_s
|
||||||
|
|
||||||
|
self._loop = asyncio.get_running_loop()
|
||||||
|
self._queues: dict[tuple,
|
||||||
|
asyncio.Queue[Union[tuple[str, dict,
|
||||||
|
asyncio.Future],
|
||||||
|
tuple[list[int],
|
||||||
|
asyncio.Future]]]] = {}
|
||||||
|
self._batcher_tasks: list[asyncio.Task] = []
|
||||||
|
|
||||||
|
# Single-thread executor for blocking tokenizer calls.
|
||||||
|
self._executor = ThreadPoolExecutor(max_workers=1)
|
||||||
|
|
||||||
|
# === Public async API ===
|
||||||
|
async def __call__(self, prompt, **kwargs):
|
||||||
|
result_future: asyncio.Future = self._loop.create_future()
|
||||||
|
key = self._queue_key("encode", kwargs)
|
||||||
|
queue = self._get_queue(self._loop, key)
|
||||||
|
await queue.put((prompt, kwargs, result_future))
|
||||||
|
return await result_future
|
||||||
|
|
||||||
|
async def decode(self, token_ids, **kwargs):
|
||||||
|
result_future: asyncio.Future = self._loop.create_future()
|
||||||
|
key = self._queue_key("decode", kwargs)
|
||||||
|
queue = self._get_queue(self._loop, key)
|
||||||
|
await queue.put((token_ids, result_future))
|
||||||
|
return await result_future
|
||||||
|
|
||||||
|
# === Internal helpers ===
|
||||||
|
def _get_queue(
|
||||||
|
self, loop: asyncio.AbstractEventLoop, key: tuple
|
||||||
|
) -> asyncio.Queue[Union[tuple[str, dict, asyncio.Future], tuple[
|
||||||
|
list[int], asyncio.Future]]]:
|
||||||
|
"""Get the request queue for the given operation key, creating a new
|
||||||
|
queue and batcher task if needed."""
|
||||||
|
queue = self._queues.get(key)
|
||||||
|
if queue is None:
|
||||||
|
self._queues[key] = queue = asyncio.Queue()
|
||||||
|
if key[0] == "encode":
|
||||||
|
can_batch = key[1] != "other"
|
||||||
|
coro = self._batch_encode_loop(queue, can_batch)
|
||||||
|
else:
|
||||||
|
assert key[0] == "decode", \
|
||||||
|
f"Unknown operation type: {key[0]}."
|
||||||
|
coro = self._batch_decode_loop(queue)
|
||||||
|
self._batcher_tasks.append(loop.create_task(coro))
|
||||||
|
return queue
|
||||||
|
|
||||||
|
async def _batch_encode_loop(self, queue: asyncio.Queue, can_batch: bool):
|
||||||
|
"""Batch incoming encode requests for efficiency."""
|
||||||
|
while True:
|
||||||
|
prompt, kwargs, result_future = await queue.get()
|
||||||
|
prompts = [prompt]
|
||||||
|
kwargs_list = [kwargs]
|
||||||
|
result_futures = [result_future]
|
||||||
|
deadline = self._loop.time() + self.batch_wait_timeout_s
|
||||||
|
|
||||||
|
while len(prompts) < self.max_batch_size:
|
||||||
|
timeout = deadline - self._loop.time()
|
||||||
|
if timeout <= 0:
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
prompt, kwargs, result_future = await asyncio.wait_for(
|
||||||
|
queue.get(), timeout)
|
||||||
|
prompts.append(prompt)
|
||||||
|
result_futures.append(result_future)
|
||||||
|
if not can_batch:
|
||||||
|
kwargs_list.append(kwargs)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
# If every request uses identical kwargs we can run a single
|
||||||
|
# batched tokenizer call for a big speed-up.
|
||||||
|
if can_batch and len(prompts) > 1:
|
||||||
|
encode_fn = partial(self.tokenizer, prompts, **kwargs)
|
||||||
|
results = await self._loop.run_in_executor(
|
||||||
|
self._executor, encode_fn)
|
||||||
|
|
||||||
|
for i, fut in enumerate(result_futures):
|
||||||
|
if not fut.done():
|
||||||
|
data = {k: v[i] for k, v in results.items()}
|
||||||
|
fut.set_result(BatchEncoding(data))
|
||||||
|
else:
|
||||||
|
encode_fn = lambda prompts=prompts, kwargs=kwargs_list: [
|
||||||
|
self.tokenizer(p, **kw)
|
||||||
|
for p, kw in zip(prompts, kwargs)
|
||||||
|
]
|
||||||
|
results = await self._loop.run_in_executor(
|
||||||
|
self._executor, encode_fn)
|
||||||
|
|
||||||
|
for fut, res in zip(result_futures, results):
|
||||||
|
if not fut.done():
|
||||||
|
fut.set_result(res)
|
||||||
|
except Exception as e:
|
||||||
|
for fut in result_futures:
|
||||||
|
if not fut.done():
|
||||||
|
fut.set_exception(e)
|
||||||
|
|
||||||
|
async def _batch_decode_loop(self, queue: asyncio.Queue):
|
||||||
|
"""Batch incoming decode requests for efficiency."""
|
||||||
|
while True:
|
||||||
|
token_ids, result_future = await queue.get()
|
||||||
|
token_ids_list = [token_ids]
|
||||||
|
result_futures = [result_future]
|
||||||
|
deadline = self._loop.time() + self.batch_wait_timeout_s
|
||||||
|
|
||||||
|
while len(token_ids_list) < self.max_batch_size:
|
||||||
|
timeout = deadline - self._loop.time()
|
||||||
|
if timeout <= 0:
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
token_ids, result_future = await asyncio.wait_for(
|
||||||
|
queue.get(), timeout)
|
||||||
|
token_ids_list.append(token_ids)
|
||||||
|
result_futures.append(result_future)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Perform a single batched decode call for all requests
|
||||||
|
results = await self._loop.run_in_executor(
|
||||||
|
self._executor, self.tokenizer.batch_decode,
|
||||||
|
token_ids_list)
|
||||||
|
for fut, res in zip(result_futures, results):
|
||||||
|
if not fut.done():
|
||||||
|
fut.set_result(res)
|
||||||
|
except Exception as e:
|
||||||
|
for fut in result_futures:
|
||||||
|
if not fut.done():
|
||||||
|
fut.set_exception(e)
|
||||||
|
|
||||||
|
def _queue_key(self, op: str, kwargs: dict) -> tuple:
|
||||||
|
"""
|
||||||
|
Return a normalized key describing operation + kwargs.
|
||||||
|
|
||||||
|
- `add_special_tokens`: {True/False}
|
||||||
|
- `truncation`: {True/False}
|
||||||
|
- If `truncation` is False (`max_length` is None),
|
||||||
|
returns a key for a can_batch queue.
|
||||||
|
- If `truncation` is True and `max_length` is None or equals
|
||||||
|
`tokenizer.model_max_length`, returns a key for a can_batch queue.
|
||||||
|
- Otherwise, returns a key for a cannot_batch queue.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
- Decode: ("decode",)
|
||||||
|
- Encode typical:
|
||||||
|
("encode", add_special_tokens, bool_truncation, max_length_label)
|
||||||
|
- Fallback: ("encode", "other")
|
||||||
|
"""
|
||||||
|
|
||||||
|
if op == "decode":
|
||||||
|
return ("decode", )
|
||||||
|
|
||||||
|
add_special_tokens = kwargs.get("add_special_tokens", True)
|
||||||
|
truncation = kwargs.get("truncation", False)
|
||||||
|
max_length = kwargs.get("max_length")
|
||||||
|
|
||||||
|
if not truncation:
|
||||||
|
return ("encode", add_special_tokens, False, None)
|
||||||
|
|
||||||
|
model_max = getattr(self.tokenizer, "model_max_length", None)
|
||||||
|
if max_length is None or (model_max is not None
|
||||||
|
and max_length == model_max):
|
||||||
|
return ("encode", add_special_tokens, True, "model_max")
|
||||||
|
|
||||||
|
return ("encode", "other")
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
for task in self._batcher_tasks:
|
||||||
|
if not task.done():
|
||||||
|
task.cancel()
|
||||||
|
|
||||||
|
|
||||||
def make_async(
|
def make_async(
|
||||||
func: Callable[P, T],
|
func: Callable[P, T],
|
||||||
executor: Optional[concurrent.futures.Executor] = None
|
executor: Optional[concurrent.futures.Executor] = None
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user