[Refactor] Introduce basic Renderer for completion-style request (#24010)

Signed-off-by: sfeng33 <4florafeng@gmail.com>
This commit is contained in:
Flora Feng 2025-09-03 22:21:12 -07:00 committed by GitHub
parent e919d6f549
commit 712b273f65
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 416 additions and 27 deletions

View File

@ -0,0 +1,163 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Optional
from unittest.mock import AsyncMock, MagicMock
import pytest
from vllm.entrypoints.renderer import CompletionRenderer
@dataclass
class MockModelConfig:
max_model_len: int = 100
encoder_config: Optional[dict] = None
class MockTokenizerResult:
def __init__(self, input_ids):
self.input_ids = input_ids
@pytest.fixture
def mock_model_config():
return MockModelConfig()
@pytest.fixture
def mock_tokenizer():
tokenizer = MagicMock()
return tokenizer
@pytest.fixture
def mock_async_tokenizer():
async_tokenizer = AsyncMock()
return async_tokenizer
@pytest.fixture
def renderer(mock_model_config, mock_tokenizer):
return CompletionRenderer(model_config=mock_model_config,
tokenizer=mock_tokenizer,
async_tokenizer_pool={})
class TestRenderPrompt:
"""Test Category A: Basic Functionality Tests"""
@pytest.mark.asyncio
async def test_token_input(self, renderer):
tokens = [101, 7592, 2088]
results = await renderer.render_prompt(prompt_or_prompts=tokens,
max_length=100)
assert len(results) == 1
assert results[0]["prompt_token_ids"] == tokens
@pytest.mark.asyncio
async def test_token_list_input(self, renderer):
token_lists = [[101, 7592, 2088], [102, 1234, 5678, 9012], [103, 4567]]
results = await renderer.render_prompt(prompt_or_prompts=token_lists,
max_length=100)
assert len(results) == 3
assert results[0]["prompt_token_ids"] == [101, 7592, 2088]
assert results[1]["prompt_token_ids"] == [102, 1234, 5678, 9012]
assert results[2]["prompt_token_ids"] == [103, 4567]
@pytest.mark.asyncio
async def test_text_input(self, renderer, mock_async_tokenizer):
mock_async_tokenizer.return_value = MockTokenizerResult(
[101, 7592, 2088])
renderer.async_tokenizer_pool[
renderer.tokenizer] = mock_async_tokenizer
results = await renderer.render_prompt(prompt_or_prompts="Hello world",
max_length=100)
assert len(results) == 1
assert results[0]["prompt_token_ids"] == [101, 7592, 2088]
mock_async_tokenizer.assert_called_once()
@pytest.mark.asyncio
async def test_text_list_input(self, renderer, mock_async_tokenizer):
mock_async_tokenizer.return_value = MockTokenizerResult(
[101, 7592, 2088])
renderer.async_tokenizer_pool[
renderer.tokenizer] = mock_async_tokenizer
text_list_input = ["Hello world", "How are you?", "Good morning"]
results = await renderer.render_prompt(
prompt_or_prompts=text_list_input, max_length=100)
assert len(results) == 3
for result in results:
assert result["prompt_token_ids"] == [101, 7592, 2088]
assert mock_async_tokenizer.call_count == 3
@pytest.mark.asyncio
async def test_no_truncation(self, renderer, mock_async_tokenizer):
mock_async_tokenizer.return_value = MockTokenizerResult(
[101, 7592, 2088])
renderer.async_tokenizer_pool[
renderer.tokenizer] = mock_async_tokenizer
results = await renderer.render_prompt(prompt_or_prompts="Hello world",
max_length=100)
assert len(results) == 1
call_args = mock_async_tokenizer.call_args
assert "truncation" not in call_args.kwargs or call_args.kwargs[
"truncation"] is False
@pytest.mark.asyncio
async def test_truncation_positive(self, renderer, mock_async_tokenizer):
mock_async_tokenizer.return_value = MockTokenizerResult(
[101, 7592, 2088]) # Truncated
renderer.async_tokenizer_pool[
renderer.tokenizer] = mock_async_tokenizer
results = await renderer.render_prompt(prompt_or_prompts="Hello world",
max_length=100,
truncate_prompt_tokens=50)
assert len(results) == 1
call_args = mock_async_tokenizer.call_args
assert call_args.kwargs["truncation"] is True
assert call_args.kwargs["max_length"] == 50
@pytest.mark.asyncio
async def test_token_truncation_last_elements(self, renderer):
# Test that token truncation keeps the last N elements
long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108,
109] # 10 tokens
results = await renderer.render_prompt(prompt_or_prompts=long_tokens,
max_length=100,
truncate_prompt_tokens=5)
assert len(results) == 1
# Should keep the last 5 tokens: [105, 106, 107, 108, 109]
assert results[0]["prompt_token_ids"] == [105, 106, 107, 108, 109]
@pytest.mark.asyncio
async def test_max_length_exceeded(self, renderer):
long_tokens = list(range(150)) # Exceeds max_model_len=100
with pytest.raises(ValueError, match="maximum context length"):
await renderer.render_prompt(prompt_or_prompts=long_tokens,
max_length=100)
@pytest.mark.asyncio
async def test_no_tokenizer_for_text(self, mock_model_config):
renderer_no_tokenizer = CompletionRenderer(
model_config=mock_model_config,
tokenizer=None,
async_tokenizer_pool={})
with pytest.raises(ValueError, match="No tokenizer available"):
await renderer_no_tokenizer.render_prompt(
prompt_or_prompts="Hello world", max_length=100)

View File

@ -62,8 +62,10 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
TranslationRequest)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.tool_parsers import ToolParser
from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer
# yapf: enable
from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt
from vllm.inputs.data import PromptType
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger
@ -243,6 +245,16 @@ class OpenAIServing:
AsyncMicrobatchTokenizer] = {}
self.log_error_stack = log_error_stack
def _get_renderer(self, tokenizer: Optional[AnyTokenizer]) -> BaseRenderer:
"""
Get a Renderer instance with the provided tokenizer.
Uses shared async tokenizer pool for efficiency.
"""
return CompletionRenderer(
model_config=self.model_config,
tokenizer=tokenizer,
async_tokenizer_pool=self._async_tokenizer_pool)
def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer:
"""
Return (and cache) an `AsyncMicrobatchTokenizer` bound to the
@ -1098,7 +1110,7 @@ class OpenAIServing:
def _log_inputs(
self,
request_id: str,
inputs: RequestPrompt,
inputs: Union[RequestPrompt, PromptType],
params: Optional[Union[SamplingParams, PoolingParams,
BeamSearchParams]],
lora_request: Optional[LoRARequest],
@ -1110,11 +1122,9 @@ class OpenAIServing:
prompt = inputs
elif isinstance(inputs, list):
prompt_token_ids = inputs
elif "prompt_embeds" in inputs:
prompt_embeds = inputs.get("prompt_embeds")
else:
prompt = inputs["prompt"]
prompt_token_ids = inputs["prompt_token_ids"]
prompt = getattr(inputs, 'prompt', None)
prompt_token_ids = getattr(inputs, 'prompt_token_ids', None)
self.request_logger.log_inputs(
request_id,

View File

@ -4,7 +4,7 @@
import asyncio
import base64
import time
from collections.abc import AsyncGenerator, Sequence
from collections.abc import AsyncGenerator
from typing import Final, Literal, Optional, Union, cast
import jinja2
@ -26,7 +26,7 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse,
PoolingRequest, PoolingResponse,
PoolingResponseData, UsageInfo)
# yapf: enable
from vllm.entrypoints.openai.serving_engine import OpenAIServing, RequestPrompt
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.logger import init_logger
@ -104,6 +104,7 @@ class OpenAIServingPooling(OpenAIServing):
else:
tokenizer = await self.engine_client.get_tokenizer(lora_request
)
renderer = self._get_renderer(tokenizer)
if getattr(request, "dimensions", None) is not None:
return self.create_error_response(
@ -126,14 +127,11 @@ class OpenAIServingPooling(OpenAIServing):
engine_prompts = await self.io_processor.pre_process_async(
prompt=validated_prompt, request_id=request_id)
request_prompts: Sequence[RequestPrompt] = [
""
] * len(engine_prompts)
elif isinstance(request, PoolingChatRequest):
(
_,
request_prompts,
_,
engine_prompts,
) = await self._preprocess_chat(
request,
@ -149,13 +147,13 @@ class OpenAIServingPooling(OpenAIServing):
add_special_tokens=request.add_special_tokens,
)
elif isinstance(request, PoolingCompletionRequest):
(request_prompts,
engine_prompts) = await self._preprocess_completion(
request,
tokenizer,
request.input,
add_special_tokens=request.add_special_tokens,
)
engine_prompts = await renderer.render_prompt(
prompt_or_prompts=request.input,
max_length=self.max_model_len,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
cache_salt=getattr(request, 'cache_salt', None),
)
else:
raise ValueError(
f"Unsupported request of type {type(request)}")
@ -177,7 +175,7 @@ class OpenAIServingPooling(OpenAIServing):
request_id_item = f"{request_id}-{i}"
self._log_inputs(request_id_item,
request_prompts[i],
engine_prompt,
params=pooling_params,
lora_request=lora_request)

View File

@ -65,6 +65,7 @@ class OpenAIServingTokenization(OpenAIServing):
lora_request = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer(lora_request)
renderer = self._get_renderer(tokenizer)
if isinstance(request, TokenizeChatRequest):
tool_dicts = (None if request.tools is None else
@ -87,13 +88,11 @@ class OpenAIServingTokenization(OpenAIServing):
add_special_tokens=request.add_special_tokens,
)
else:
(request_prompts,
engine_prompts) = await self._preprocess_completion(
request,
tokenizer,
request.prompt,
add_special_tokens=request.add_special_tokens,
)
engine_prompts = await renderer.render_prompt(
prompt_or_prompts=request.prompt,
add_special_tokens=request.add_special_tokens,
cache_salt=getattr(request, 'cache_salt', None),
)
except (ValueError, TypeError, jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(f"{e} {e.__cause__}")
@ -101,7 +100,7 @@ class OpenAIServingTokenization(OpenAIServing):
input_ids: list[int] = []
for i, engine_prompt in enumerate(engine_prompts):
self._log_inputs(request_id,
request_prompts[i],
engine_prompt,
params=None,
lora_request=lora_request)

View File

@ -0,0 +1,219 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
from abc import ABC, abstractmethod
from typing import Annotated, Optional, Union
from pydantic import Field
from vllm.config import ModelConfig
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.inputs.parse import parse_and_batch_prompt
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import AsyncMicrobatchTokenizer
class BaseRenderer(ABC):
"""
Base class for unified input processing and rendering.
The Renderer serves as a unified input processor that consolidates
tokenization, chat template formatting, and multimodal input handling
into a single component.
It converts high-level API requests (OpenAI-style JSON) into token IDs and
multimodal features ready for engine consumption.
Key responsibilities:
- Convert text prompts to token sequences with proper special tokens
- Apply chat templates and format conversations
- Handle multimodal inputs (images, audio, etc.) when applicable
- Manage prompt truncation and length validation
- Provide clean separation between API layer and engine core
"""
def __init__(
self,
model_config: ModelConfig,
tokenizer: Optional[AnyTokenizer] = None,
):
super().__init__()
self.model_config = model_config
self.tokenizer = tokenizer
@abstractmethod
async def render_prompt(
self,
prompt_or_prompts: Union[str, list[str], list[int], list[list[int]]],
max_length: Optional[int] = None,
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
add_special_tokens: Optional[bool] = True,
cache_salt: Optional[str] = None,
) -> list[EngineTokensPrompt]:
"""
Convert input prompts into tokenized format for engine processing.
This is the core method that transforms various input formats into
standardized TokensPrompt objects. Implementations should handle
tokenization, special token insertion, truncation, and validation
according to model requirements.
Args:
prompt_or_prompts: Input data in various formats:
- str: Single text prompt
- list[str]: Batch of text prompts
- list[int]: Pre-tokenized sequence
- list[list[int]]: Batch of pre-tokenized sequences
max_length: Maximum sequence length (endpoint-specific behavior)
truncate_prompt_tokens: Truncate to last N tokens
(None=no truncation, 0=empty)
add_special_tokens: Add model-specific tokens (e.g., [CLS], [SEP])
to text inputs
cache_salt: Optional string to disambiguate cached prompts
Returns:
list[EngineTokensPrompt]: Tokenized prompts ready for engine
consumption
Raises:
ValueError: If input format is invalid or length limits exceeded
"""
raise NotImplementedError
class CompletionRenderer(BaseRenderer):
def __init__(
self,
model_config: ModelConfig,
tokenizer: Optional[AnyTokenizer] = None,
async_tokenizer_pool: Optional[dict[AnyTokenizer,
AsyncMicrobatchTokenizer]] = None,
):
super().__init__(model_config, tokenizer)
self.async_tokenizer_pool = async_tokenizer_pool or {}
self.async_tokenizer: Optional[AsyncMicrobatchTokenizer] = None
async def render_prompt(
self,
prompt_or_prompts: Union[str, list[str], list[int], list[list[int]]],
max_length: Optional[int] = None,
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
add_special_tokens: Optional[bool] = True,
cache_salt: Optional[str] = None,
) -> list[EngineTokensPrompt]:
"""Implementation of prompt rendering for completion-style requests.
Uses async tokenizer pooling for improved performance. See base class
for detailed parameter documentation.
"""
if truncate_prompt_tokens is not None:
if max_length is not None:
assert 0 <= truncate_prompt_tokens <= max_length
if truncate_prompt_tokens == 0:
return []
# Parse and batch the input prompts
batch_inputs = parse_and_batch_prompt(prompt_or_prompts)
rendered_prompts: list[EngineTokensPrompt] = []
tokenize_tasks = []
for prompt_input in batch_inputs:
if prompt_input["is_tokens"] is True:
# Token input
token_ids = self._maybe_apply_truncation(
prompt_input["content"], truncate_prompt_tokens)
rendered_prompts.append(
self._create_tokens_prompt(token_ids, max_length,
cache_salt))
else:
# Text input
tokenize_task = asyncio.create_task(
self._tokenize(prompt_input["content"], max_length,
truncate_prompt_tokens, add_special_tokens,
cache_salt))
tokenize_tasks.append(tokenize_task)
# Wait for all text tokenization to finish
if tokenize_tasks:
tokenized_text_prompts = await asyncio.gather(*tokenize_tasks)
rendered_prompts.extend(tokenized_text_prompts)
return rendered_prompts
def _maybe_apply_truncation(
self, token_ids: list[int],
truncate_prompt_tokens: Optional[int]) -> list[int]:
"""Apply truncation to token sequence."""
if truncate_prompt_tokens is None:
return token_ids
if truncate_prompt_tokens >= len(token_ids):
return token_ids
return token_ids[-truncate_prompt_tokens:]
async def _tokenize(
self,
text: str,
max_length: Optional[int],
truncate_prompt_tokens: Optional[int],
add_special_tokens: Optional[bool],
cache_salt: Optional[str],
) -> EngineTokensPrompt:
"""Tokenize text input asynchronously."""
async_tokenizer = self._get_async_tokenizer()
# Handle encoder-specific preprocessing
if (self.model_config.encoder_config is not None
and self.model_config.encoder_config.get(
"do_lower_case", False)):
text = text.lower()
# Tokenize texts
if truncate_prompt_tokens is None:
encoded = await async_tokenizer(
text, add_special_tokens=add_special_tokens)
else:
encoded = await async_tokenizer(
text,
add_special_tokens=add_special_tokens,
truncation=True,
max_length=truncate_prompt_tokens)
return self._create_tokens_prompt(encoded.input_ids, max_length,
cache_salt)
def _get_async_tokenizer(self) -> AsyncMicrobatchTokenizer:
"""Get or create async tokenizer using shared pool."""
if self.async_tokenizer is not None:
return self.async_tokenizer
if self.tokenizer is None:
raise ValueError(
"No tokenizer available for text input processing")
# Check shared pool first
if self.tokenizer in self.async_tokenizer_pool:
return self.async_tokenizer_pool[self.tokenizer]
# Create new async tokenizer and add to pool
self.async_tokenizer = AsyncMicrobatchTokenizer(self.tokenizer)
self.async_tokenizer_pool[self.tokenizer] = self.async_tokenizer
return self.async_tokenizer
def _create_tokens_prompt(
self,
token_ids: list[int],
max_length: Optional[int] = None,
cache_salt: Optional[str] = None,
) -> EngineTokensPrompt:
"""Create validated EngineTokensPrompt."""
if max_length is not None and len(token_ids) > max_length:
raise ValueError(
f"This maximum context length is {max_length} tokens. "
f"However, your request has {len(token_ids)} input tokens. "
"Please reduce the length of the input messages.")
tokens_prompt = EngineTokensPrompt(prompt_token_ids=token_ids)
if cache_salt is not None:
tokens_prompt["cache_salt"] = cache_salt
return tokens_prompt