[Renderer] Clean up renderer code (#26216)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-10-05 01:05:29 +08:00 committed by GitHub
parent a42d2df75f
commit 119f00630b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 93 additions and 133 deletions

View File

@ -54,7 +54,7 @@ async def test_token_in_token_out_and_logprobs(server):
prompt=token_ids,
max_tokens=20,
temperature=0,
echo=True,
echo=False,
extra_body={
"return_token_ids": True,
},

View File

@ -4,7 +4,7 @@
import pytest
from vllm.inputs import zip_enc_dec_prompts
from vllm.inputs.parse import parse_and_batch_prompt
from vllm.inputs.parse import parse_raw_prompts
pytestmark = pytest.mark.cpu_test
@ -31,30 +31,30 @@ INPUTS_SLICES = [
]
def test_parse_single_batch_empty():
def test_parse_raw_single_batch_empty():
with pytest.raises(ValueError, match="at least one prompt"):
parse_and_batch_prompt([])
parse_raw_prompts([])
with pytest.raises(ValueError, match="at least one prompt"):
parse_and_batch_prompt([[]])
parse_raw_prompts([[]])
@pytest.mark.parametrize('string_input', STRING_INPUTS)
def test_parse_single_batch_string_consistent(string_input: str):
assert parse_and_batch_prompt(string_input) \
== parse_and_batch_prompt([string_input])
def test_parse_raw_single_batch_string_consistent(string_input: str):
assert parse_raw_prompts(string_input) \
== parse_raw_prompts([string_input])
@pytest.mark.parametrize('token_input', TOKEN_INPUTS)
def test_parse_single_batch_token_consistent(token_input: list[int]):
assert parse_and_batch_prompt(token_input) \
== parse_and_batch_prompt([token_input])
def test_parse_raw_single_batch_token_consistent(token_input: list[int]):
assert parse_raw_prompts(token_input) \
== parse_raw_prompts([token_input])
@pytest.mark.parametrize('inputs_slice', INPUTS_SLICES)
def test_parse_single_batch_string_slice(inputs_slice: slice):
assert parse_and_batch_prompt(STRING_INPUTS)[inputs_slice] \
== parse_and_batch_prompt(STRING_INPUTS[inputs_slice])
def test_parse_raw_single_batch_string_slice(inputs_slice: slice):
assert parse_raw_prompts(STRING_INPUTS)[inputs_slice] \
== parse_raw_prompts(STRING_INPUTS[inputs_slice])
# yapf: disable

View File

@ -691,6 +691,5 @@ class OpenAIServingCompletion(OpenAIServing):
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
cache_salt=request.cache_salt,
needs_detokenization=bool(request.echo
and not request.return_token_ids),
needs_detokenization=bool(request.echo),
)

View File

@ -13,8 +13,9 @@ from pydantic import Field
from vllm.config import ModelConfig
from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt
from vllm.inputs.data import TextPrompt as EngineTextPrompt
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.inputs.parse import parse_and_batch_prompt
from vllm.inputs.parse import get_prompt_components, parse_raw_prompts
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import AsyncMicrobatchTokenizer
@ -41,6 +42,27 @@ class RenderConfig:
needs_detokenization: Optional[bool] = False
"""If True, detokenize IDs back to text for inclusion in outputs."""
def verify_truncate_prompt_tokens(
self, model_config: ModelConfig) -> Optional[int]:
"""Validate and normalize `truncate_prompt_tokens` parameter."""
truncate_prompt_tokens = self.truncate_prompt_tokens
if truncate_prompt_tokens is None:
return None
if truncate_prompt_tokens == 0:
return 0
if truncate_prompt_tokens < 0:
truncate_prompt_tokens = model_config.max_model_len
max_length = self.max_length
if max_length is not None and truncate_prompt_tokens > max_length: # type: ignore[operator]
raise ValueError(
f"{truncate_prompt_tokens=} cannot be greater than "
f"{max_length=}. Please select a smaller truncation size.")
return truncate_prompt_tokens
class BaseRenderer(ABC):
"""
@ -74,7 +96,7 @@ class BaseRenderer(ABC):
self,
*,
prompt_or_prompts: Union[str, list[str], list[int], list[list[int]]],
config: "RenderConfig",
config: RenderConfig,
) -> list[EngineTokensPrompt]:
"""
Convert text or token inputs into engine-ready TokensPrompt objects.
@ -107,7 +129,7 @@ class BaseRenderer(ABC):
prompt_or_prompts: Optional[Union[str, list[str], list[int],
list[list[int]]]] = None,
prompt_embeds: Optional[Union[bytes, list[bytes]]] = None,
config: "RenderConfig",
config: RenderConfig,
) -> list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]:
"""
Convert text/token and/or base64-encoded embeddings inputs into
@ -189,47 +211,25 @@ class CompletionRenderer(BaseRenderer):
self,
*,
prompt_or_prompts: Union[str, list[str], list[int], list[list[int]]],
config: "RenderConfig",
config: RenderConfig,
) -> list[EngineTokensPrompt]:
"""Implementation of prompt rendering for completion-style requests.
Uses async tokenizer pooling for improved performance. See base class
for detailed parameter documentation.
"""
truncate_prompt_tokens = self._validate_and_normalize_truncate_tokens(
config.truncate_prompt_tokens, config.max_length)
truncate_prompt_tokens = config.verify_truncate_prompt_tokens(
self.model_config)
if truncate_prompt_tokens == 0:
return []
# Parse and batch the input prompts
batch_inputs = parse_and_batch_prompt(prompt_or_prompts)
tasks = (self._create_prompt(
prompt_input,
config=config,
truncate_prompt_tokens=truncate_prompt_tokens,
) for prompt_input in parse_raw_prompts(prompt_or_prompts))
tasks = []
for prompt_input in batch_inputs:
if prompt_input["is_tokens"] is True:
# Token input
# Note: detokenization is needed when echo is enabled,
# where the input token IDs are decoded back to text.
task = self._maybe_detokenize(prompt_input["content"],
config.max_length,
truncate_prompt_tokens,
config.cache_salt,
config.needs_detokenization)
else:
# Text input
task = self._tokenize(prompt_input["content"],
config.max_length,
truncate_prompt_tokens,
config.add_special_tokens,
config.cache_salt)
tasks.append(task)
# Wait for all text tokenization to finish
if tasks:
tokenized_text_prompts = await asyncio.gather(*tasks)
return tokenized_text_prompts
return []
return await asyncio.gather(*tasks)
async def render_prompt_and_embeds(
self,
@ -237,14 +237,14 @@ class CompletionRenderer(BaseRenderer):
prompt_or_prompts: Optional[Union[str, list[str], list[int],
list[list[int]]]] = None,
prompt_embeds: Optional[Union[bytes, list[bytes]]] = None,
config: "RenderConfig",
config: RenderConfig,
) -> list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]:
"""
Render text/token prompts and/or precomputed embedding prompts. At
least one of `prompt_or_prompts` or `prompt_embeds` must be provided.
"""
truncate_prompt_tokens = self._validate_and_normalize_truncate_tokens(
config.truncate_prompt_tokens, config.max_length)
truncate_prompt_tokens = config.verify_truncate_prompt_tokens(
self.model_config)
if truncate_prompt_tokens == 0:
return []
@ -265,29 +265,6 @@ class CompletionRenderer(BaseRenderer):
return rendered
def _validate_and_normalize_truncate_tokens(
self,
truncate_prompt_tokens: Optional[int],
max_length: Optional[int],
) -> Optional[int]:
"""Validate and normalize truncate_prompt_tokens parameter."""
if truncate_prompt_tokens is None:
return None
if truncate_prompt_tokens == 0:
return 0
if truncate_prompt_tokens < 0:
truncate_prompt_tokens = self.model_config.max_model_len
if max_length is not None and truncate_prompt_tokens > max_length: # type: ignore[operator]
raise ValueError(
f"truncate_prompt_tokens ({truncate_prompt_tokens}) "
f"cannot be greater than max_length ({max_length}). "
f"Please select a smaller truncation size.")
return truncate_prompt_tokens
def _maybe_apply_truncation(
self, token_ids: list[int],
truncate_prompt_tokens: Optional[int]) -> list[int]:
@ -299,7 +276,38 @@ class CompletionRenderer(BaseRenderer):
return token_ids[-truncate_prompt_tokens:]
async def _tokenize(
async def _create_prompt(
self,
prompt_input: Union[EngineTextPrompt, EngineTokensPrompt],
config: RenderConfig,
truncate_prompt_tokens: Optional[int],
) -> EngineTokensPrompt:
prompt, prompt_token_ids, _ = get_prompt_components(prompt_input)
if prompt_token_ids is not None:
# NOTE: detokenization is needed when echo is enabled,
# where the input token IDs are decoded back to text.
return await self._create_prompt_from_token_ids(
prompt_token_ids,
config.max_length,
truncate_prompt_tokens,
config.cache_salt,
config.needs_detokenization,
)
if prompt is not None:
return await self._create_prompt_from_text(
prompt,
config.max_length,
truncate_prompt_tokens,
config.add_special_tokens,
config.cache_salt,
)
# TODO: Also handle embeds prompt using this method
raise NotImplementedError
async def _create_prompt_from_text(
self,
text: str,
max_length: Optional[int],
@ -330,7 +338,7 @@ class CompletionRenderer(BaseRenderer):
return self._create_tokens_prompt(encoded.input_ids, max_length,
cache_salt, text)
async def _maybe_detokenize(
async def _create_prompt_from_token_ids(
self,
token_ids: list[int],
max_length: Optional[int],
@ -343,7 +351,7 @@ class CompletionRenderer(BaseRenderer):
truncate_prompt_tokens)
prompt = None
if needs_detokenization is True:
if needs_detokenization:
async_tokenizer = self._get_async_tokenizer()
prompt = await async_tokenizer.decode(token_ids)

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from typing import (TYPE_CHECKING, Literal, NamedTuple, Optional, TypedDict,
Union, cast, overload)
Union, cast)
from typing_extensions import TypeIs
@ -16,34 +16,12 @@ if TYPE_CHECKING:
import torch
class ParsedText(TypedDict):
content: str
is_tokens: Literal[False]
class ParsedTokens(TypedDict):
content: list[int]
is_tokens: Literal[True]
@overload
def parse_and_batch_prompt(
prompt: Union[str, list[str]], ) -> Sequence[ParsedText]:
...
@overload
def parse_and_batch_prompt(
prompt: Union[list[int], list[list[int]]], ) -> Sequence[ParsedTokens]:
...
def parse_and_batch_prompt(
def parse_raw_prompts(
prompt: Union[str, list[str], list[int], list[list[int]]],
) -> Union[Sequence[ParsedText], Sequence[ParsedTokens]]:
) -> Union[Sequence[TextPrompt], Sequence[TokensPrompt]]:
if isinstance(prompt, str):
# case 1: a string
return [ParsedText(content=prompt, is_tokens=False)]
return [TextPrompt(prompt=prompt)]
if isinstance(prompt, list):
if len(prompt) == 0:
@ -52,13 +30,11 @@ def parse_and_batch_prompt(
if is_list_of(prompt, str):
# case 2: array of strings
prompt = cast(list[str], prompt)
return [
ParsedText(content=elem, is_tokens=False) for elem in prompt
]
return [TextPrompt(prompt=elem) for elem in prompt]
if is_list_of(prompt, int):
# case 3: array of tokens
prompt = cast(list[int], prompt)
return [ParsedTokens(content=prompt, is_tokens=True)]
return [TokensPrompt(prompt_token_ids=prompt)]
if is_list_of(prompt, list):
prompt = cast(list[list[int]], prompt)
if len(prompt[0]) == 0:
@ -66,10 +42,7 @@ def parse_and_batch_prompt(
if is_list_of(prompt[0], int):
# case 4: array of token arrays
return [
ParsedTokens(content=elem, is_tokens=True)
for elem in prompt
]
return [TokensPrompt(prompt_token_ids=elem) for elem in prompt]
raise TypeError("prompt must be a string, array of strings, "
"array of tokens, or array of token arrays")
@ -99,26 +72,6 @@ ParsedSingletonPrompt = Union[ParsedStrPrompt, ParsedTextPrompt,
ParsedTokensPrompt, ParsedEmbedsPrompt]
@overload
def parse_singleton_prompt(prompt: str) -> ParsedStrPrompt:
...
@overload
def parse_singleton_prompt(prompt: TextPrompt) -> ParsedTextPrompt:
...
@overload
def parse_singleton_prompt(prompt: TokensPrompt) -> ParsedTokensPrompt:
...
@overload
def parse_singleton_prompt(prompt: EmbedsPrompt) -> ParsedEmbedsPrompt:
...
def parse_singleton_prompt(prompt: SingletonPrompt) -> ParsedSingletonPrompt:
if isinstance(prompt, str):
return ParsedStrPrompt(type="str", content=prompt)