mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 10:45:01 +08:00
[Refactor] Introduce basic Renderer for completion-style request (#24010)
Signed-off-by: sfeng33 <4florafeng@gmail.com>
This commit is contained in:
parent
e919d6f549
commit
712b273f65
163
tests/entrypoints/test_renderer.py
Normal file
163
tests/entrypoints/test_renderer.py
Normal file
@ -0,0 +1,163 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.entrypoints.renderer import CompletionRenderer
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MockModelConfig:
|
||||||
|
max_model_len: int = 100
|
||||||
|
encoder_config: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
|
class MockTokenizerResult:
|
||||||
|
|
||||||
|
def __init__(self, input_ids):
|
||||||
|
self.input_ids = input_ids
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_model_config():
|
||||||
|
return MockModelConfig()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_tokenizer():
|
||||||
|
tokenizer = MagicMock()
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_async_tokenizer():
|
||||||
|
async_tokenizer = AsyncMock()
|
||||||
|
return async_tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def renderer(mock_model_config, mock_tokenizer):
|
||||||
|
return CompletionRenderer(model_config=mock_model_config,
|
||||||
|
tokenizer=mock_tokenizer,
|
||||||
|
async_tokenizer_pool={})
|
||||||
|
|
||||||
|
|
||||||
|
class TestRenderPrompt:
|
||||||
|
"""Test Category A: Basic Functionality Tests"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_token_input(self, renderer):
|
||||||
|
tokens = [101, 7592, 2088]
|
||||||
|
results = await renderer.render_prompt(prompt_or_prompts=tokens,
|
||||||
|
max_length=100)
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0]["prompt_token_ids"] == tokens
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_token_list_input(self, renderer):
|
||||||
|
token_lists = [[101, 7592, 2088], [102, 1234, 5678, 9012], [103, 4567]]
|
||||||
|
results = await renderer.render_prompt(prompt_or_prompts=token_lists,
|
||||||
|
max_length=100)
|
||||||
|
|
||||||
|
assert len(results) == 3
|
||||||
|
assert results[0]["prompt_token_ids"] == [101, 7592, 2088]
|
||||||
|
assert results[1]["prompt_token_ids"] == [102, 1234, 5678, 9012]
|
||||||
|
assert results[2]["prompt_token_ids"] == [103, 4567]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_text_input(self, renderer, mock_async_tokenizer):
|
||||||
|
mock_async_tokenizer.return_value = MockTokenizerResult(
|
||||||
|
[101, 7592, 2088])
|
||||||
|
renderer.async_tokenizer_pool[
|
||||||
|
renderer.tokenizer] = mock_async_tokenizer
|
||||||
|
|
||||||
|
results = await renderer.render_prompt(prompt_or_prompts="Hello world",
|
||||||
|
max_length=100)
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0]["prompt_token_ids"] == [101, 7592, 2088]
|
||||||
|
mock_async_tokenizer.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_text_list_input(self, renderer, mock_async_tokenizer):
|
||||||
|
mock_async_tokenizer.return_value = MockTokenizerResult(
|
||||||
|
[101, 7592, 2088])
|
||||||
|
renderer.async_tokenizer_pool[
|
||||||
|
renderer.tokenizer] = mock_async_tokenizer
|
||||||
|
|
||||||
|
text_list_input = ["Hello world", "How are you?", "Good morning"]
|
||||||
|
results = await renderer.render_prompt(
|
||||||
|
prompt_or_prompts=text_list_input, max_length=100)
|
||||||
|
|
||||||
|
assert len(results) == 3
|
||||||
|
for result in results:
|
||||||
|
assert result["prompt_token_ids"] == [101, 7592, 2088]
|
||||||
|
assert mock_async_tokenizer.call_count == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_truncation(self, renderer, mock_async_tokenizer):
|
||||||
|
mock_async_tokenizer.return_value = MockTokenizerResult(
|
||||||
|
[101, 7592, 2088])
|
||||||
|
renderer.async_tokenizer_pool[
|
||||||
|
renderer.tokenizer] = mock_async_tokenizer
|
||||||
|
|
||||||
|
results = await renderer.render_prompt(prompt_or_prompts="Hello world",
|
||||||
|
max_length=100)
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
call_args = mock_async_tokenizer.call_args
|
||||||
|
assert "truncation" not in call_args.kwargs or call_args.kwargs[
|
||||||
|
"truncation"] is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_truncation_positive(self, renderer, mock_async_tokenizer):
|
||||||
|
mock_async_tokenizer.return_value = MockTokenizerResult(
|
||||||
|
[101, 7592, 2088]) # Truncated
|
||||||
|
renderer.async_tokenizer_pool[
|
||||||
|
renderer.tokenizer] = mock_async_tokenizer
|
||||||
|
|
||||||
|
results = await renderer.render_prompt(prompt_or_prompts="Hello world",
|
||||||
|
max_length=100,
|
||||||
|
truncate_prompt_tokens=50)
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
call_args = mock_async_tokenizer.call_args
|
||||||
|
assert call_args.kwargs["truncation"] is True
|
||||||
|
assert call_args.kwargs["max_length"] == 50
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_token_truncation_last_elements(self, renderer):
|
||||||
|
# Test that token truncation keeps the last N elements
|
||||||
|
long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108,
|
||||||
|
109] # 10 tokens
|
||||||
|
results = await renderer.render_prompt(prompt_or_prompts=long_tokens,
|
||||||
|
max_length=100,
|
||||||
|
truncate_prompt_tokens=5)
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
# Should keep the last 5 tokens: [105, 106, 107, 108, 109]
|
||||||
|
assert results[0]["prompt_token_ids"] == [105, 106, 107, 108, 109]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_max_length_exceeded(self, renderer):
|
||||||
|
long_tokens = list(range(150)) # Exceeds max_model_len=100
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="maximum context length"):
|
||||||
|
await renderer.render_prompt(prompt_or_prompts=long_tokens,
|
||||||
|
max_length=100)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_tokenizer_for_text(self, mock_model_config):
|
||||||
|
renderer_no_tokenizer = CompletionRenderer(
|
||||||
|
model_config=mock_model_config,
|
||||||
|
tokenizer=None,
|
||||||
|
async_tokenizer_pool={})
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="No tokenizer available"):
|
||||||
|
await renderer_no_tokenizer.render_prompt(
|
||||||
|
prompt_or_prompts="Hello world", max_length=100)
|
||||||
@ -62,8 +62,10 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
|||||||
TranslationRequest)
|
TranslationRequest)
|
||||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||||
from vllm.entrypoints.openai.tool_parsers import ToolParser
|
from vllm.entrypoints.openai.tool_parsers import ToolParser
|
||||||
|
from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt
|
from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt
|
||||||
|
from vllm.inputs.data import PromptType
|
||||||
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
||||||
from vllm.inputs.parse import parse_and_batch_prompt
|
from vllm.inputs.parse import parse_and_batch_prompt
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -243,6 +245,16 @@ class OpenAIServing:
|
|||||||
AsyncMicrobatchTokenizer] = {}
|
AsyncMicrobatchTokenizer] = {}
|
||||||
self.log_error_stack = log_error_stack
|
self.log_error_stack = log_error_stack
|
||||||
|
|
||||||
|
def _get_renderer(self, tokenizer: Optional[AnyTokenizer]) -> BaseRenderer:
|
||||||
|
"""
|
||||||
|
Get a Renderer instance with the provided tokenizer.
|
||||||
|
Uses shared async tokenizer pool for efficiency.
|
||||||
|
"""
|
||||||
|
return CompletionRenderer(
|
||||||
|
model_config=self.model_config,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
async_tokenizer_pool=self._async_tokenizer_pool)
|
||||||
|
|
||||||
def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer:
|
def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer:
|
||||||
"""
|
"""
|
||||||
Return (and cache) an `AsyncMicrobatchTokenizer` bound to the
|
Return (and cache) an `AsyncMicrobatchTokenizer` bound to the
|
||||||
@ -1098,7 +1110,7 @@ class OpenAIServing:
|
|||||||
def _log_inputs(
|
def _log_inputs(
|
||||||
self,
|
self,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
inputs: RequestPrompt,
|
inputs: Union[RequestPrompt, PromptType],
|
||||||
params: Optional[Union[SamplingParams, PoolingParams,
|
params: Optional[Union[SamplingParams, PoolingParams,
|
||||||
BeamSearchParams]],
|
BeamSearchParams]],
|
||||||
lora_request: Optional[LoRARequest],
|
lora_request: Optional[LoRARequest],
|
||||||
@ -1110,11 +1122,9 @@ class OpenAIServing:
|
|||||||
prompt = inputs
|
prompt = inputs
|
||||||
elif isinstance(inputs, list):
|
elif isinstance(inputs, list):
|
||||||
prompt_token_ids = inputs
|
prompt_token_ids = inputs
|
||||||
elif "prompt_embeds" in inputs:
|
|
||||||
prompt_embeds = inputs.get("prompt_embeds")
|
|
||||||
else:
|
else:
|
||||||
prompt = inputs["prompt"]
|
prompt = getattr(inputs, 'prompt', None)
|
||||||
prompt_token_ids = inputs["prompt_token_ids"]
|
prompt_token_ids = getattr(inputs, 'prompt_token_ids', None)
|
||||||
|
|
||||||
self.request_logger.log_inputs(
|
self.request_logger.log_inputs(
|
||||||
request_id,
|
request_id,
|
||||||
|
|||||||
@ -4,7 +4,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
import time
|
import time
|
||||||
from collections.abc import AsyncGenerator, Sequence
|
from collections.abc import AsyncGenerator
|
||||||
from typing import Final, Literal, Optional, Union, cast
|
from typing import Final, Literal, Optional, Union, cast
|
||||||
|
|
||||||
import jinja2
|
import jinja2
|
||||||
@ -26,7 +26,7 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse,
|
|||||||
PoolingRequest, PoolingResponse,
|
PoolingRequest, PoolingResponse,
|
||||||
PoolingResponseData, UsageInfo)
|
PoolingResponseData, UsageInfo)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing, RequestPrompt
|
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||||
from vllm.entrypoints.utils import _validate_truncation_size
|
from vllm.entrypoints.utils import _validate_truncation_size
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -104,6 +104,7 @@ class OpenAIServingPooling(OpenAIServing):
|
|||||||
else:
|
else:
|
||||||
tokenizer = await self.engine_client.get_tokenizer(lora_request
|
tokenizer = await self.engine_client.get_tokenizer(lora_request
|
||||||
)
|
)
|
||||||
|
renderer = self._get_renderer(tokenizer)
|
||||||
|
|
||||||
if getattr(request, "dimensions", None) is not None:
|
if getattr(request, "dimensions", None) is not None:
|
||||||
return self.create_error_response(
|
return self.create_error_response(
|
||||||
@ -126,14 +127,11 @@ class OpenAIServingPooling(OpenAIServing):
|
|||||||
|
|
||||||
engine_prompts = await self.io_processor.pre_process_async(
|
engine_prompts = await self.io_processor.pre_process_async(
|
||||||
prompt=validated_prompt, request_id=request_id)
|
prompt=validated_prompt, request_id=request_id)
|
||||||
request_prompts: Sequence[RequestPrompt] = [
|
|
||||||
""
|
|
||||||
] * len(engine_prompts)
|
|
||||||
|
|
||||||
elif isinstance(request, PoolingChatRequest):
|
elif isinstance(request, PoolingChatRequest):
|
||||||
(
|
(
|
||||||
_,
|
_,
|
||||||
request_prompts,
|
_,
|
||||||
engine_prompts,
|
engine_prompts,
|
||||||
) = await self._preprocess_chat(
|
) = await self._preprocess_chat(
|
||||||
request,
|
request,
|
||||||
@ -149,12 +147,12 @@ class OpenAIServingPooling(OpenAIServing):
|
|||||||
add_special_tokens=request.add_special_tokens,
|
add_special_tokens=request.add_special_tokens,
|
||||||
)
|
)
|
||||||
elif isinstance(request, PoolingCompletionRequest):
|
elif isinstance(request, PoolingCompletionRequest):
|
||||||
(request_prompts,
|
engine_prompts = await renderer.render_prompt(
|
||||||
engine_prompts) = await self._preprocess_completion(
|
prompt_or_prompts=request.input,
|
||||||
request,
|
max_length=self.max_model_len,
|
||||||
tokenizer,
|
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||||
request.input,
|
|
||||||
add_special_tokens=request.add_special_tokens,
|
add_special_tokens=request.add_special_tokens,
|
||||||
|
cache_salt=getattr(request, 'cache_salt', None),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -177,7 +175,7 @@ class OpenAIServingPooling(OpenAIServing):
|
|||||||
request_id_item = f"{request_id}-{i}"
|
request_id_item = f"{request_id}-{i}"
|
||||||
|
|
||||||
self._log_inputs(request_id_item,
|
self._log_inputs(request_id_item,
|
||||||
request_prompts[i],
|
engine_prompt,
|
||||||
params=pooling_params,
|
params=pooling_params,
|
||||||
lora_request=lora_request)
|
lora_request=lora_request)
|
||||||
|
|
||||||
|
|||||||
@ -65,6 +65,7 @@ class OpenAIServingTokenization(OpenAIServing):
|
|||||||
lora_request = self._maybe_get_adapters(request)
|
lora_request = self._maybe_get_adapters(request)
|
||||||
|
|
||||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||||
|
renderer = self._get_renderer(tokenizer)
|
||||||
|
|
||||||
if isinstance(request, TokenizeChatRequest):
|
if isinstance(request, TokenizeChatRequest):
|
||||||
tool_dicts = (None if request.tools is None else
|
tool_dicts = (None if request.tools is None else
|
||||||
@ -87,12 +88,10 @@ class OpenAIServingTokenization(OpenAIServing):
|
|||||||
add_special_tokens=request.add_special_tokens,
|
add_special_tokens=request.add_special_tokens,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
(request_prompts,
|
engine_prompts = await renderer.render_prompt(
|
||||||
engine_prompts) = await self._preprocess_completion(
|
prompt_or_prompts=request.prompt,
|
||||||
request,
|
|
||||||
tokenizer,
|
|
||||||
request.prompt,
|
|
||||||
add_special_tokens=request.add_special_tokens,
|
add_special_tokens=request.add_special_tokens,
|
||||||
|
cache_salt=getattr(request, 'cache_salt', None),
|
||||||
)
|
)
|
||||||
except (ValueError, TypeError, jinja2.TemplateError) as e:
|
except (ValueError, TypeError, jinja2.TemplateError) as e:
|
||||||
logger.exception("Error in preprocessing prompt inputs")
|
logger.exception("Error in preprocessing prompt inputs")
|
||||||
@ -101,7 +100,7 @@ class OpenAIServingTokenization(OpenAIServing):
|
|||||||
input_ids: list[int] = []
|
input_ids: list[int] = []
|
||||||
for i, engine_prompt in enumerate(engine_prompts):
|
for i, engine_prompt in enumerate(engine_prompts):
|
||||||
self._log_inputs(request_id,
|
self._log_inputs(request_id,
|
||||||
request_prompts[i],
|
engine_prompt,
|
||||||
params=None,
|
params=None,
|
||||||
lora_request=lora_request)
|
lora_request=lora_request)
|
||||||
|
|
||||||
|
|||||||
219
vllm/entrypoints/renderer.py
Normal file
219
vllm/entrypoints/renderer.py
Normal file
@ -0,0 +1,219 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Annotated, Optional, Union
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from vllm.config import ModelConfig
|
||||||
|
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
||||||
|
from vllm.inputs.parse import parse_and_batch_prompt
|
||||||
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
|
from vllm.utils import AsyncMicrobatchTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class BaseRenderer(ABC):
|
||||||
|
"""
|
||||||
|
Base class for unified input processing and rendering.
|
||||||
|
|
||||||
|
The Renderer serves as a unified input processor that consolidates
|
||||||
|
tokenization, chat template formatting, and multimodal input handling
|
||||||
|
into a single component.
|
||||||
|
It converts high-level API requests (OpenAI-style JSON) into token IDs and
|
||||||
|
multimodal features ready for engine consumption.
|
||||||
|
|
||||||
|
Key responsibilities:
|
||||||
|
- Convert text prompts to token sequences with proper special tokens
|
||||||
|
- Apply chat templates and format conversations
|
||||||
|
- Handle multimodal inputs (images, audio, etc.) when applicable
|
||||||
|
- Manage prompt truncation and length validation
|
||||||
|
- Provide clean separation between API layer and engine core
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
tokenizer: Optional[AnyTokenizer] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.model_config = model_config
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def render_prompt(
|
||||||
|
self,
|
||||||
|
prompt_or_prompts: Union[str, list[str], list[int], list[list[int]]],
|
||||||
|
max_length: Optional[int] = None,
|
||||||
|
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
|
||||||
|
add_special_tokens: Optional[bool] = True,
|
||||||
|
cache_salt: Optional[str] = None,
|
||||||
|
) -> list[EngineTokensPrompt]:
|
||||||
|
"""
|
||||||
|
Convert input prompts into tokenized format for engine processing.
|
||||||
|
|
||||||
|
This is the core method that transforms various input formats into
|
||||||
|
standardized TokensPrompt objects. Implementations should handle
|
||||||
|
tokenization, special token insertion, truncation, and validation
|
||||||
|
according to model requirements.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt_or_prompts: Input data in various formats:
|
||||||
|
- str: Single text prompt
|
||||||
|
- list[str]: Batch of text prompts
|
||||||
|
- list[int]: Pre-tokenized sequence
|
||||||
|
- list[list[int]]: Batch of pre-tokenized sequences
|
||||||
|
max_length: Maximum sequence length (endpoint-specific behavior)
|
||||||
|
truncate_prompt_tokens: Truncate to last N tokens
|
||||||
|
(None=no truncation, 0=empty)
|
||||||
|
add_special_tokens: Add model-specific tokens (e.g., [CLS], [SEP])
|
||||||
|
to text inputs
|
||||||
|
cache_salt: Optional string to disambiguate cached prompts
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[EngineTokensPrompt]: Tokenized prompts ready for engine
|
||||||
|
consumption
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If input format is invalid or length limits exceeded
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionRenderer(BaseRenderer):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
tokenizer: Optional[AnyTokenizer] = None,
|
||||||
|
async_tokenizer_pool: Optional[dict[AnyTokenizer,
|
||||||
|
AsyncMicrobatchTokenizer]] = None,
|
||||||
|
):
|
||||||
|
super().__init__(model_config, tokenizer)
|
||||||
|
self.async_tokenizer_pool = async_tokenizer_pool or {}
|
||||||
|
self.async_tokenizer: Optional[AsyncMicrobatchTokenizer] = None
|
||||||
|
|
||||||
|
async def render_prompt(
|
||||||
|
self,
|
||||||
|
prompt_or_prompts: Union[str, list[str], list[int], list[list[int]]],
|
||||||
|
max_length: Optional[int] = None,
|
||||||
|
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
|
||||||
|
add_special_tokens: Optional[bool] = True,
|
||||||
|
cache_salt: Optional[str] = None,
|
||||||
|
) -> list[EngineTokensPrompt]:
|
||||||
|
"""Implementation of prompt rendering for completion-style requests.
|
||||||
|
|
||||||
|
Uses async tokenizer pooling for improved performance. See base class
|
||||||
|
for detailed parameter documentation.
|
||||||
|
"""
|
||||||
|
if truncate_prompt_tokens is not None:
|
||||||
|
if max_length is not None:
|
||||||
|
assert 0 <= truncate_prompt_tokens <= max_length
|
||||||
|
if truncate_prompt_tokens == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Parse and batch the input prompts
|
||||||
|
batch_inputs = parse_and_batch_prompt(prompt_or_prompts)
|
||||||
|
|
||||||
|
rendered_prompts: list[EngineTokensPrompt] = []
|
||||||
|
tokenize_tasks = []
|
||||||
|
for prompt_input in batch_inputs:
|
||||||
|
if prompt_input["is_tokens"] is True:
|
||||||
|
# Token input
|
||||||
|
token_ids = self._maybe_apply_truncation(
|
||||||
|
prompt_input["content"], truncate_prompt_tokens)
|
||||||
|
rendered_prompts.append(
|
||||||
|
self._create_tokens_prompt(token_ids, max_length,
|
||||||
|
cache_salt))
|
||||||
|
else:
|
||||||
|
# Text input
|
||||||
|
tokenize_task = asyncio.create_task(
|
||||||
|
self._tokenize(prompt_input["content"], max_length,
|
||||||
|
truncate_prompt_tokens, add_special_tokens,
|
||||||
|
cache_salt))
|
||||||
|
tokenize_tasks.append(tokenize_task)
|
||||||
|
|
||||||
|
# Wait for all text tokenization to finish
|
||||||
|
if tokenize_tasks:
|
||||||
|
tokenized_text_prompts = await asyncio.gather(*tokenize_tasks)
|
||||||
|
rendered_prompts.extend(tokenized_text_prompts)
|
||||||
|
|
||||||
|
return rendered_prompts
|
||||||
|
|
||||||
|
def _maybe_apply_truncation(
|
||||||
|
self, token_ids: list[int],
|
||||||
|
truncate_prompt_tokens: Optional[int]) -> list[int]:
|
||||||
|
"""Apply truncation to token sequence."""
|
||||||
|
if truncate_prompt_tokens is None:
|
||||||
|
return token_ids
|
||||||
|
if truncate_prompt_tokens >= len(token_ids):
|
||||||
|
return token_ids
|
||||||
|
|
||||||
|
return token_ids[-truncate_prompt_tokens:]
|
||||||
|
|
||||||
|
async def _tokenize(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
max_length: Optional[int],
|
||||||
|
truncate_prompt_tokens: Optional[int],
|
||||||
|
add_special_tokens: Optional[bool],
|
||||||
|
cache_salt: Optional[str],
|
||||||
|
) -> EngineTokensPrompt:
|
||||||
|
"""Tokenize text input asynchronously."""
|
||||||
|
async_tokenizer = self._get_async_tokenizer()
|
||||||
|
|
||||||
|
# Handle encoder-specific preprocessing
|
||||||
|
if (self.model_config.encoder_config is not None
|
||||||
|
and self.model_config.encoder_config.get(
|
||||||
|
"do_lower_case", False)):
|
||||||
|
text = text.lower()
|
||||||
|
|
||||||
|
# Tokenize texts
|
||||||
|
if truncate_prompt_tokens is None:
|
||||||
|
encoded = await async_tokenizer(
|
||||||
|
text, add_special_tokens=add_special_tokens)
|
||||||
|
else:
|
||||||
|
encoded = await async_tokenizer(
|
||||||
|
text,
|
||||||
|
add_special_tokens=add_special_tokens,
|
||||||
|
truncation=True,
|
||||||
|
max_length=truncate_prompt_tokens)
|
||||||
|
|
||||||
|
return self._create_tokens_prompt(encoded.input_ids, max_length,
|
||||||
|
cache_salt)
|
||||||
|
|
||||||
|
def _get_async_tokenizer(self) -> AsyncMicrobatchTokenizer:
|
||||||
|
"""Get or create async tokenizer using shared pool."""
|
||||||
|
if self.async_tokenizer is not None:
|
||||||
|
return self.async_tokenizer
|
||||||
|
if self.tokenizer is None:
|
||||||
|
raise ValueError(
|
||||||
|
"No tokenizer available for text input processing")
|
||||||
|
|
||||||
|
# Check shared pool first
|
||||||
|
if self.tokenizer in self.async_tokenizer_pool:
|
||||||
|
return self.async_tokenizer_pool[self.tokenizer]
|
||||||
|
|
||||||
|
# Create new async tokenizer and add to pool
|
||||||
|
self.async_tokenizer = AsyncMicrobatchTokenizer(self.tokenizer)
|
||||||
|
self.async_tokenizer_pool[self.tokenizer] = self.async_tokenizer
|
||||||
|
return self.async_tokenizer
|
||||||
|
|
||||||
|
def _create_tokens_prompt(
|
||||||
|
self,
|
||||||
|
token_ids: list[int],
|
||||||
|
max_length: Optional[int] = None,
|
||||||
|
cache_salt: Optional[str] = None,
|
||||||
|
) -> EngineTokensPrompt:
|
||||||
|
"""Create validated EngineTokensPrompt."""
|
||||||
|
if max_length is not None and len(token_ids) > max_length:
|
||||||
|
raise ValueError(
|
||||||
|
f"This maximum context length is {max_length} tokens. "
|
||||||
|
f"However, your request has {len(token_ids)} input tokens. "
|
||||||
|
"Please reduce the length of the input messages.")
|
||||||
|
|
||||||
|
tokens_prompt = EngineTokensPrompt(prompt_token_ids=token_ids)
|
||||||
|
if cache_salt is not None:
|
||||||
|
tokens_prompt["cache_salt"] = cache_salt
|
||||||
|
return tokens_prompt
|
||||||
Loading…
x
Reference in New Issue
Block a user