Consolidate rendering parameters into RenderConfig dataclass (#24543)

Signed-off-by: sfeng33 <4florafeng@gmail.com>
This commit is contained in:
Flora Feng 2025-09-10 01:44:47 -07:00 committed by GitHub
parent feaf202e93
commit 77f62613f9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 167 additions and 108 deletions

View File

@ -10,7 +10,7 @@ import pybase64
import pytest import pytest
import torch import torch
from vllm.entrypoints.renderer import CompletionRenderer from vllm.entrypoints.renderer import CompletionRenderer, RenderConfig
from vllm.inputs.data import is_embeds_prompt from vllm.inputs.data import is_embeds_prompt
@ -56,8 +56,8 @@ class TestRenderPrompt:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_token_input(self, renderer): async def test_token_input(self, renderer):
tokens = [101, 7592, 2088] tokens = [101, 7592, 2088]
results = await renderer.render_prompt(prompt_or_prompts=tokens, results = await renderer.render_prompt(
max_length=100) prompt_or_prompts=tokens, config=RenderConfig(max_length=100))
assert len(results) == 1 assert len(results) == 1
assert results[0]["prompt_token_ids"] == tokens assert results[0]["prompt_token_ids"] == tokens
@ -65,8 +65,8 @@ class TestRenderPrompt:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_token_list_input(self, renderer): async def test_token_list_input(self, renderer):
token_lists = [[101, 7592, 2088], [102, 1234, 5678, 9012], [103, 4567]] token_lists = [[101, 7592, 2088], [102, 1234, 5678, 9012], [103, 4567]]
results = await renderer.render_prompt(prompt_or_prompts=token_lists, results = await renderer.render_prompt(
max_length=100) prompt_or_prompts=token_lists, config=RenderConfig(max_length=100))
assert len(results) == 3 assert len(results) == 3
assert results[0]["prompt_token_ids"] == [101, 7592, 2088] assert results[0]["prompt_token_ids"] == [101, 7592, 2088]
@ -80,8 +80,9 @@ class TestRenderPrompt:
renderer.async_tokenizer_pool[ renderer.async_tokenizer_pool[
renderer.tokenizer] = mock_async_tokenizer renderer.tokenizer] = mock_async_tokenizer
results = await renderer.render_prompt(prompt_or_prompts="Hello world", results = await renderer.render_prompt(
max_length=100) prompt_or_prompts="Hello world",
config=RenderConfig(max_length=100))
assert len(results) == 1 assert len(results) == 1
assert results[0]["prompt_token_ids"] == [101, 7592, 2088] assert results[0]["prompt_token_ids"] == [101, 7592, 2088]
@ -96,7 +97,8 @@ class TestRenderPrompt:
text_list_input = ["Hello world", "How are you?", "Good morning"] text_list_input = ["Hello world", "How are you?", "Good morning"]
results = await renderer.render_prompt( results = await renderer.render_prompt(
prompt_or_prompts=text_list_input, max_length=100) prompt_or_prompts=text_list_input,
config=RenderConfig(max_length=100))
assert len(results) == 3 assert len(results) == 3
for result in results: for result in results:
@ -110,8 +112,9 @@ class TestRenderPrompt:
renderer.async_tokenizer_pool[ renderer.async_tokenizer_pool[
renderer.tokenizer] = mock_async_tokenizer renderer.tokenizer] = mock_async_tokenizer
results = await renderer.render_prompt(prompt_or_prompts="Hello world", results = await renderer.render_prompt(
max_length=100) prompt_or_prompts="Hello world",
config=RenderConfig(max_length=100))
assert len(results) == 1 assert len(results) == 1
call_args = mock_async_tokenizer.call_args call_args = mock_async_tokenizer.call_args
@ -126,8 +129,9 @@ class TestRenderPrompt:
renderer.tokenizer] = mock_async_tokenizer renderer.tokenizer] = mock_async_tokenizer
results = await renderer.render_prompt(prompt_or_prompts="Hello world", results = await renderer.render_prompt(prompt_or_prompts="Hello world",
config=RenderConfig(
max_length=100, max_length=100,
truncate_prompt_tokens=50) truncate_prompt_tokens=50))
assert len(results) == 1 assert len(results) == 1
call_args = mock_async_tokenizer.call_args call_args = mock_async_tokenizer.call_args
@ -143,8 +147,9 @@ class TestRenderPrompt:
renderer.tokenizer] = mock_async_tokenizer renderer.tokenizer] = mock_async_tokenizer
results = await renderer.render_prompt(prompt_or_prompts="Hello world", results = await renderer.render_prompt(prompt_or_prompts="Hello world",
config=RenderConfig(
max_length=200, max_length=200,
truncate_prompt_tokens=-1) truncate_prompt_tokens=-1))
assert len(results) == 1 assert len(results) == 1
call_args = mock_async_tokenizer.call_args call_args = mock_async_tokenizer.call_args
@ -157,8 +162,9 @@ class TestRenderPrompt:
long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108,
109] # 10 tokens 109] # 10 tokens
results = await renderer.render_prompt(prompt_or_prompts=long_tokens, results = await renderer.render_prompt(prompt_or_prompts=long_tokens,
config=RenderConfig(
max_length=100, max_length=100,
truncate_prompt_tokens=5) truncate_prompt_tokens=5))
assert len(results) == 1 assert len(results) == 1
# Should keep the last 5 tokens: [105, 106, 107, 108, 109] # Should keep the last 5 tokens: [105, 106, 107, 108, 109]
@ -170,7 +176,7 @@ class TestRenderPrompt:
with pytest.raises(ValueError, match="maximum context length"): with pytest.raises(ValueError, match="maximum context length"):
await renderer.render_prompt(prompt_or_prompts=long_tokens, await renderer.render_prompt(prompt_or_prompts=long_tokens,
max_length=100) config=RenderConfig(max_length=100))
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_no_tokenizer_for_text(self, mock_model_config): async def test_no_tokenizer_for_text(self, mock_model_config):
@ -181,7 +187,8 @@ class TestRenderPrompt:
with pytest.raises(ValueError, match="No tokenizer available"): with pytest.raises(ValueError, match="No tokenizer available"):
await renderer_no_tokenizer.render_prompt( await renderer_no_tokenizer.render_prompt(
prompt_or_prompts="Hello world", max_length=100) prompt_or_prompts="Hello world",
config=RenderConfig(max_length=100))
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_token_input_with_needs_detokenization( async def test_token_input_with_needs_detokenization(
@ -196,7 +203,7 @@ class TestRenderPrompt:
tokens = [1, 2, 3, 4] tokens = [1, 2, 3, 4]
results = await renderer.render_prompt( results = await renderer.render_prompt(
prompt_or_prompts=tokens, prompt_or_prompts=tokens,
needs_detokenization=True, config=RenderConfig(needs_detokenization=True),
) )
assert len(results) == 1 assert len(results) == 1
@ -221,7 +228,9 @@ class TestRenderEmbedPrompt:
embed_bytes = self._create_test_embed_bytes(test_tensor) embed_bytes = self._create_test_embed_bytes(test_tensor)
results = await renderer.render_prompt_and_embeds( results = await renderer.render_prompt_and_embeds(
prompt_embeds=embed_bytes, cache_salt="test_salt") prompt_embeds=embed_bytes,
config=RenderConfig(cache_salt="test_salt"),
)
assert len(results) == 1 assert len(results) == 1
assert is_embeds_prompt(results[0]) assert is_embeds_prompt(results[0])
@ -240,7 +249,9 @@ class TestRenderEmbedPrompt:
] ]
results = await renderer.render_prompt_and_embeds( results = await renderer.render_prompt_and_embeds(
prompt_embeds=embed_bytes_list) prompt_embeds=embed_bytes_list,
config=RenderConfig(),
)
assert len(results) == 2 assert len(results) == 2
for i, result in enumerate(results): for i, result in enumerate(results):
@ -254,7 +265,9 @@ class TestRenderEmbedPrompt:
embed_bytes = self._create_test_embed_bytes(test_tensor) embed_bytes = self._create_test_embed_bytes(test_tensor)
results = await renderer.render_prompt_and_embeds( results = await renderer.render_prompt_and_embeds(
prompt_embeds=embed_bytes, truncate_prompt_tokens=10) prompt_embeds=embed_bytes,
config=RenderConfig(truncate_prompt_tokens=10),
)
assert len(results) == 1 assert len(results) == 1
# Should keep last 10 tokens # Should keep last 10 tokens
@ -271,7 +284,9 @@ class TestRenderEmbedPrompt:
embed_bytes = self._create_test_embed_bytes(test_tensor) embed_bytes = self._create_test_embed_bytes(test_tensor)
results = await renderer.render_prompt_and_embeds( results = await renderer.render_prompt_and_embeds(
prompt_embeds=embed_bytes) prompt_embeds=embed_bytes,
config=RenderConfig(),
)
assert len(results) == 1 assert len(results) == 1
assert results[0]["prompt_embeds"].dtype == dtype assert results[0]["prompt_embeds"].dtype == dtype
@ -283,7 +298,9 @@ class TestRenderEmbedPrompt:
embed_bytes = self._create_test_embed_bytes(test_tensor) embed_bytes = self._create_test_embed_bytes(test_tensor)
results = await renderer.render_prompt_and_embeds( results = await renderer.render_prompt_and_embeds(
prompt_embeds=embed_bytes) prompt_embeds=embed_bytes,
config=RenderConfig(),
)
assert len(results) == 1 assert len(results) == 1
# Should be squeezed to 2D # Should be squeezed to 2D
@ -303,7 +320,10 @@ class TestRenderEmbedPrompt:
embed_bytes = self._create_test_embed_bytes(test_tensor) embed_bytes = self._create_test_embed_bytes(test_tensor)
results = await renderer.render_prompt_and_embeds( results = await renderer.render_prompt_and_embeds(
prompt_or_prompts="Hello world", prompt_embeds=embed_bytes) prompt_or_prompts="Hello world",
prompt_embeds=embed_bytes,
config=RenderConfig(),
)
assert len(results) == 2 assert len(results) == 2
# First should be embed prompt # First should be embed prompt

View File

@ -20,6 +20,7 @@ from vllm.entrypoints.openai.serving_engine import (ClassificationServeContext,
OpenAIServing, OpenAIServing,
ServeContext) ServeContext)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.renderer import RenderConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import ClassificationOutput, PoolingRequestOutput from vllm.outputs import ClassificationOutput, PoolingRequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
@ -57,8 +58,7 @@ class ClassificationMixin(OpenAIServing):
renderer = self._get_renderer(ctx.tokenizer) renderer = self._get_renderer(ctx.tokenizer)
ctx.engine_prompts = await renderer.render_prompt( ctx.engine_prompts = await renderer.render_prompt(
prompt_or_prompts=ctx.request.input, prompt_or_prompts=ctx.request.input,
max_length=self.max_model_len, config=self._build_render_config(ctx.request))
truncate_prompt_tokens=ctx.request.truncate_prompt_tokens)
return None return None
@ -114,6 +114,12 @@ class ClassificationMixin(OpenAIServing):
usage=usage, usage=usage,
) )
def _build_render_config(self,
request: ClassificationRequest) -> RenderConfig:
return RenderConfig(
max_length=self.max_model_len,
truncate_prompt_tokens=request.truncate_prompt_tokens)
class ServingClassification(ClassificationMixin): class ServingClassification(ClassificationMixin):
request_id_prefix = "classify" request_id_prefix = "classify"

View File

@ -30,6 +30,7 @@ from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
clamp_prompt_logprobs) clamp_prompt_logprobs)
# yapf: enable # yapf: enable
from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.renderer import RenderConfig
from vllm.entrypoints.utils import get_max_tokens from vllm.entrypoints.utils import get_max_tokens
from vllm.inputs.data import (EmbedsPrompt, TokensPrompt, is_embeds_prompt, from vllm.inputs.data import (EmbedsPrompt, TokensPrompt, is_embeds_prompt,
is_tokens_prompt) is_tokens_prompt)
@ -129,18 +130,11 @@ class OpenAIServingCompletion(OpenAIServing):
tokenizer = await self.engine_client.get_tokenizer(lora_request tokenizer = await self.engine_client.get_tokenizer(lora_request
) )
renderer = self._get_renderer(tokenizer) renderer = self._get_renderer(tokenizer)
max_input_tokens_len = self.max_model_len - (request.max_tokens
or 0)
engine_prompts = await renderer.render_prompt_and_embeds( engine_prompts = await renderer.render_prompt_and_embeds(
prompt_or_prompts=request.prompt, prompt_or_prompts=request.prompt,
prompt_embeds=request.prompt_embeds, prompt_embeds=request.prompt_embeds,
max_length=max_input_tokens_len, config=self._build_render_config(request),
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
cache_salt=request.cache_salt,
needs_detokenization=bool(request.echo
and not request.return_token_ids),
) )
except ValueError as e: except ValueError as e:
logger.exception("Error in preprocessing prompt inputs") logger.exception("Error in preprocessing prompt inputs")
@ -677,3 +671,18 @@ class OpenAIServingCompletion(OpenAIServing):
tokens=out_tokens, tokens=out_tokens,
top_logprobs=out_top_logprobs, top_logprobs=out_top_logprobs,
) )
def _build_render_config(
self,
request: CompletionRequest,
max_input_length: Optional[int] = None,
) -> RenderConfig:
max_input_tokens_len = self.max_model_len - (request.max_tokens or 0)
return RenderConfig(
max_length=max_input_tokens_len,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
cache_salt=request.cache_salt,
needs_detokenization=bool(request.echo
and not request.return_token_ids),
)

View File

@ -28,6 +28,7 @@ from vllm.entrypoints.openai.serving_engine import (EmbeddingServeContext,
TextTokensPrompt) TextTokensPrompt)
# yapf: enable # yapf: enable
from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.renderer import RenderConfig
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput, from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput,
@ -97,23 +98,28 @@ class EmbeddingMixin(OpenAIServing):
add_special_tokens=ctx.request.add_special_tokens, add_special_tokens=ctx.request.add_special_tokens,
) )
else: else:
# Set max_length based on chunked processing capability
if self._should_use_chunked_processing(ctx.request):
max_length = None
else:
max_length = self.max_embed_len or self.max_model_len
ctx.engine_prompts = await renderer.render_prompt( ctx.engine_prompts = await renderer.render_prompt(
prompt_or_prompts=ctx.request.input, prompt_or_prompts=ctx.request.input,
max_length=max_length, config=self._build_render_config(ctx.request),
truncate_prompt_tokens=ctx.request.truncate_prompt_tokens,
add_special_tokens=ctx.request.add_special_tokens,
) )
return None return None
except (ValueError, TypeError) as e: except (ValueError, TypeError) as e:
logger.exception("Error in preprocessing prompt inputs") logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e)) return self.create_error_response(str(e))
def _build_render_config(
self, request: EmbeddingCompletionRequest) -> RenderConfig:
# Set max_length based on chunked processing capability
if self._should_use_chunked_processing(request):
max_length = None
else:
max_length = self.max_embed_len or self.max_model_len
return RenderConfig(
max_length=max_length,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens)
@override @override
def _build_response( def _build_response(
self, self,

View File

@ -58,7 +58,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
TranslationRequest) TranslationRequest)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.tool_parsers import ToolParser from vllm.entrypoints.openai.tool_parsers import ToolParser
from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer from vllm.entrypoints.renderer import (BaseRenderer, CompletionRenderer,
RenderConfig)
# yapf: enable # yapf: enable
from vllm.inputs.data import PromptType from vllm.inputs.data import PromptType
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
@ -248,6 +249,19 @@ class OpenAIServing:
tokenizer=tokenizer, tokenizer=tokenizer,
async_tokenizer_pool=self._async_tokenizer_pool) async_tokenizer_pool=self._async_tokenizer_pool)
def _build_render_config(
self,
request: Any,
) -> RenderConfig:
"""
Build and return a `RenderConfig` for an endpoint.
Used by the renderer to control how prompts are prepared
(e.g., tokenization and length handling). Endpoints should
implement this with logic appropriate to their request type.
"""
raise NotImplementedError
def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer: def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer:
""" """
Return (and cache) an `AsyncMicrobatchTokenizer` bound to the Return (and cache) an `AsyncMicrobatchTokenizer` bound to the

View File

@ -28,6 +28,7 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse,
# yapf: enable # yapf: enable
from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.renderer import RenderConfig
from vllm.entrypoints.utils import _validate_truncation_size from vllm.entrypoints.utils import _validate_truncation_size
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import PoolingOutput, PoolingRequestOutput from vllm.outputs import PoolingOutput, PoolingRequestOutput
@ -149,10 +150,7 @@ class OpenAIServingPooling(OpenAIServing):
elif isinstance(request, PoolingCompletionRequest): elif isinstance(request, PoolingCompletionRequest):
engine_prompts = await renderer.render_prompt( engine_prompts = await renderer.render_prompt(
prompt_or_prompts=request.input, prompt_or_prompts=request.input,
max_length=self.max_model_len, config=self._build_render_config(request),
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
cache_salt=getattr(request, 'cache_salt', None),
) )
else: else:
raise ValueError( raise ValueError(
@ -270,3 +268,10 @@ class OpenAIServingPooling(OpenAIServing):
data=items, data=items,
usage=usage, usage=usage,
) )
def _build_render_config(
self, request: PoolingCompletionRequest) -> RenderConfig:
return RenderConfig(
max_length=self.max_model_len,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens)

View File

@ -22,6 +22,7 @@ from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
# yapf: enable # yapf: enable
from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.renderer import RenderConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
@ -72,7 +73,7 @@ class OpenAIServingTokenization(OpenAIServing):
[tool.model_dump() for tool in request.tools]) [tool.model_dump() for tool in request.tools])
( (
_, _,
request_prompts, _,
engine_prompts, engine_prompts,
) = await self._preprocess_chat( ) = await self._preprocess_chat(
request, request,
@ -90,15 +91,14 @@ class OpenAIServingTokenization(OpenAIServing):
else: else:
engine_prompts = await renderer.render_prompt( engine_prompts = await renderer.render_prompt(
prompt_or_prompts=request.prompt, prompt_or_prompts=request.prompt,
add_special_tokens=request.add_special_tokens, config=self._build_render_config(request),
cache_salt=getattr(request, 'cache_salt', None),
) )
except (ValueError, TypeError, jinja2.TemplateError) as e: except (ValueError, TypeError, jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs") logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(f"{e} {e.__cause__}") return self.create_error_response(f"{e} {e.__cause__}")
input_ids: list[int] = [] input_ids: list[int] = []
for i, engine_prompt in enumerate(engine_prompts): for engine_prompt in engine_prompts:
self._log_inputs(request_id, self._log_inputs(request_id,
engine_prompt, engine_prompt,
params=None, params=None,
@ -157,6 +157,9 @@ class OpenAIServingTokenization(OpenAIServing):
return self.create_error_response( return self.create_error_response(
f"Failed to get tokenizer info: {str(e)}") f"Failed to get tokenizer info: {str(e)}")
def _build_render_config(self, request: TokenizeRequest) -> RenderConfig:
return RenderConfig(add_special_tokens=request.add_special_tokens)
@dataclass @dataclass
class TokenizerInfo: class TokenizerInfo:

View File

@ -4,6 +4,7 @@
import asyncio import asyncio
import io import io
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Annotated, Optional, Union from typing import Annotated, Optional, Union
import pybase64 import pybase64
@ -18,6 +19,29 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import AsyncMicrobatchTokenizer from vllm.utils import AsyncMicrobatchTokenizer
@dataclass(frozen=True)
class RenderConfig:
"""Configuration to control how prompts are prepared."""
max_length: Optional[int] = None
"""Maximum allowable total input token length. If provided,
token inputs longer than this raise ``ValueError``."""
truncate_prompt_tokens: Optional[int] = None
"""Number of tokens to keep. ``None`` means no truncation.
``0`` yields an empty list (and skips embeds).
``-1`` maps to ``model_config.max_model_len``."""
add_special_tokens: Optional[bool] = True
"""Whether to add model-specific special tokens during tokenization."""
cache_salt: Optional[str] = None
"""String to disambiguate prefix cache entries."""
needs_detokenization: Optional[bool] = False
"""If True, detokenize IDs back to text for inclusion in outputs."""
class BaseRenderer(ABC): class BaseRenderer(ABC):
""" """
Base class for unified input processing and rendering. Base class for unified input processing and rendering.
@ -48,12 +72,9 @@ class BaseRenderer(ABC):
@abstractmethod @abstractmethod
async def render_prompt( async def render_prompt(
self, self,
*,
prompt_or_prompts: Union[str, list[str], list[int], list[list[int]]], prompt_or_prompts: Union[str, list[str], list[int], list[list[int]]],
max_length: Optional[int] = None, config: "RenderConfig",
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
add_special_tokens: Optional[bool] = True,
cache_salt: Optional[str] = None,
needs_detokenization: Optional[bool] = False,
) -> list[EngineTokensPrompt]: ) -> list[EngineTokensPrompt]:
""" """
Convert text or token inputs into engine-ready TokensPrompt objects. Convert text or token inputs into engine-ready TokensPrompt objects.
@ -68,16 +89,8 @@ class BaseRenderer(ABC):
- ``list[str]``: Batch of text prompts. - ``list[str]``: Batch of text prompts.
- ``list[int]``: Single pre-tokenized sequence. - ``list[int]``: Single pre-tokenized sequence.
- ``list[list[int]]``: Batch of pre-tokenized sequences. - ``list[list[int]]``: Batch of pre-tokenized sequences.
max_length: Maximum allowable total input token length. If provided, config: Render configuration controlling how prompts are prepared
token inputs longer than this raise ``ValueError``. (e.g., tokenization and length handling).
truncate_prompt_tokens: Number of tokens to keep. ``None`` means no
truncation. ``0`` yields an empty list (and skips embeds).
``-1`` maps to ``model_config.max_model_len``.
add_special_tokens: Whether to add model-specific special tokens
during text tokenization.
cache_salt: Optional string to disambiguate prefix cache entries.
needs_detokenization: If True and ``prompt_or_prompts`` is token
input, detokenize IDs back to text for inclusion in outputs.
Returns: Returns:
list[EngineTokensPrompt]: Engine-ready token prompts. list[EngineTokensPrompt]: Engine-ready token prompts.
@ -90,18 +103,15 @@ class BaseRenderer(ABC):
@abstractmethod @abstractmethod
async def render_prompt_and_embeds( async def render_prompt_and_embeds(
self, self,
*,
prompt_or_prompts: Optional[Union[str, list[str], list[int], prompt_or_prompts: Optional[Union[str, list[str], list[int],
list[list[int]]]] = None, list[list[int]]]] = None,
prompt_embeds: Optional[Union[bytes, list[bytes]]] = None, prompt_embeds: Optional[Union[bytes, list[bytes]]] = None,
max_length: Optional[int] = None, config: "RenderConfig",
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
add_special_tokens: Optional[bool] = True,
cache_salt: Optional[str] = None,
needs_detokenization: Optional[bool] = False,
) -> list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]: ) -> list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]:
""" """
Convert text/token and/or base64-encoded embeddings inputs into Convert text/token and/or base64-encoded embeddings inputs into
engine-ready prompt objects. engine-ready prompt objects using a unified RenderConfig.
At least one of ``prompt_or_prompts`` or ``prompt_embeds`` must be At least one of ``prompt_or_prompts`` or ``prompt_embeds`` must be
provided and non-empty. If both are omitted or empty (e.g., empty provided and non-empty. If both are omitted or empty (e.g., empty
@ -111,15 +121,8 @@ class BaseRenderer(ABC):
prompt_or_prompts: Text or token inputs to include. prompt_or_prompts: Text or token inputs to include.
prompt_embeds: Base64-encoded bytes (or list thereof) containing a prompt_embeds: Base64-encoded bytes (or list thereof) containing a
torch-saved tensor to be used as prompt embeddings. torch-saved tensor to be used as prompt embeddings.
max_length: Maximum allowable total input token length. If provided, config: Render configuration controlling how prompts are prepared
inputs longer than this raise ``ValueError``. (e.g., tokenization and length handling).
truncate_prompt_tokens: Number of tokens/rows to keep from the end
of the sequence. ``-1`` maps to ``model_config.max_model_len``.
add_special_tokens: Whether to add model-specific special tokens
during text tokenization.
cache_salt: Optional string to disambiguate prefix cache entries.
needs_detokenization: If True and ``prompt_or_prompts`` is token
input, detokenize IDs back to text for inclusion in outputs.
Returns: Returns:
list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]: list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]:
@ -184,12 +187,9 @@ class CompletionRenderer(BaseRenderer):
async def render_prompt( async def render_prompt(
self, self,
*,
prompt_or_prompts: Union[str, list[str], list[int], list[list[int]]], prompt_or_prompts: Union[str, list[str], list[int], list[list[int]]],
max_length: Optional[int] = None, config: "RenderConfig",
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
add_special_tokens: Optional[bool] = True,
cache_salt: Optional[str] = None,
needs_detokenization: Optional[bool] = False,
) -> list[EngineTokensPrompt]: ) -> list[EngineTokensPrompt]:
"""Implementation of prompt rendering for completion-style requests. """Implementation of prompt rendering for completion-style requests.
@ -197,7 +197,7 @@ class CompletionRenderer(BaseRenderer):
for detailed parameter documentation. for detailed parameter documentation.
""" """
truncate_prompt_tokens = self._validate_and_normalize_truncate_tokens( truncate_prompt_tokens = self._validate_and_normalize_truncate_tokens(
truncate_prompt_tokens, max_length) config.truncate_prompt_tokens, config.max_length)
if truncate_prompt_tokens == 0: if truncate_prompt_tokens == 0:
return [] return []
@ -211,16 +211,19 @@ class CompletionRenderer(BaseRenderer):
detokenize_task = asyncio.create_task( detokenize_task = asyncio.create_task(
# Note: detokenization is needed when echo is enabled, # Note: detokenization is needed when echo is enabled,
# where the input token IDs are decoded back to text. # where the input token IDs are decoded back to text.
self._maybe_detokenize(prompt_input["content"], max_length, self._maybe_detokenize(prompt_input["content"],
truncate_prompt_tokens, cache_salt, config.max_length,
needs_detokenization)) truncate_prompt_tokens,
config.cache_salt,
config.needs_detokenization))
tasks.append(detokenize_task) tasks.append(detokenize_task)
else: else:
# Text input # Text input
tokenize_task = asyncio.create_task( tokenize_task = asyncio.create_task(
self._tokenize(prompt_input["content"], max_length, self._tokenize(prompt_input["content"], config.max_length,
truncate_prompt_tokens, add_special_tokens, truncate_prompt_tokens,
cache_salt)) config.add_special_tokens,
config.cache_salt))
tasks.append(tokenize_task) tasks.append(tokenize_task)
# Wait for all text tokenization to finish # Wait for all text tokenization to finish
@ -232,21 +235,18 @@ class CompletionRenderer(BaseRenderer):
async def render_prompt_and_embeds( async def render_prompt_and_embeds(
self, self,
*,
prompt_or_prompts: Optional[Union[str, list[str], list[int], prompt_or_prompts: Optional[Union[str, list[str], list[int],
list[list[int]]]] = None, list[list[int]]]] = None,
prompt_embeds: Optional[Union[bytes, list[bytes]]] = None, prompt_embeds: Optional[Union[bytes, list[bytes]]] = None,
max_length: Optional[int] = None, config: "RenderConfig",
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
add_special_tokens: Optional[bool] = True,
cache_salt: Optional[str] = None,
needs_detokenization: Optional[bool] = False,
) -> list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]: ) -> list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]:
""" """
Render text/token prompts and/or precomputed embedding prompts. At Render text/token prompts and/or precomputed embedding prompts. At
least one of `prompt_or_prompts` or `prompt_embeds` must be provided. least one of `prompt_or_prompts` or `prompt_embeds` must be provided.
""" """
truncate_prompt_tokens = self._validate_and_normalize_truncate_tokens( truncate_prompt_tokens = self._validate_and_normalize_truncate_tokens(
truncate_prompt_tokens, max_length) config.truncate_prompt_tokens, config.max_length)
if truncate_prompt_tokens == 0: if truncate_prompt_tokens == 0:
return [] return []
@ -255,17 +255,13 @@ class CompletionRenderer(BaseRenderer):
if prompt_embeds is not None: if prompt_embeds is not None:
rendered.extend( rendered.extend(
self.load_prompt_embeds(prompt_embeds, truncate_prompt_tokens, self.load_prompt_embeds(prompt_embeds, truncate_prompt_tokens,
cache_salt)) config.cache_salt))
if prompt_or_prompts is None or prompt_or_prompts == "": if prompt_or_prompts is None or prompt_or_prompts == "":
return rendered return rendered
token_prompts = await self.render_prompt( token_prompts = await self.render_prompt(
prompt_or_prompts=prompt_or_prompts, prompt_or_prompts=prompt_or_prompts,
max_length=max_length, config=config,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens,
cache_salt=cache_salt,
needs_detokenization=needs_detokenization,
) )
rendered.extend(token_prompts) rendered.extend(token_prompts)