[Renderer] Clean up renderer code (#26216)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-10-05 01:05:29 +08:00 committed by GitHub
parent a42d2df75f
commit 119f00630b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 93 additions and 133 deletions

View File

@ -54,7 +54,7 @@ async def test_token_in_token_out_and_logprobs(server):
prompt=token_ids, prompt=token_ids,
max_tokens=20, max_tokens=20,
temperature=0, temperature=0,
echo=True, echo=False,
extra_body={ extra_body={
"return_token_ids": True, "return_token_ids": True,
}, },

View File

@ -4,7 +4,7 @@
import pytest import pytest
from vllm.inputs import zip_enc_dec_prompts from vllm.inputs import zip_enc_dec_prompts
from vllm.inputs.parse import parse_and_batch_prompt from vllm.inputs.parse import parse_raw_prompts
pytestmark = pytest.mark.cpu_test pytestmark = pytest.mark.cpu_test
@ -31,30 +31,30 @@ INPUTS_SLICES = [
] ]
def test_parse_single_batch_empty(): def test_parse_raw_single_batch_empty():
with pytest.raises(ValueError, match="at least one prompt"): with pytest.raises(ValueError, match="at least one prompt"):
parse_and_batch_prompt([]) parse_raw_prompts([])
with pytest.raises(ValueError, match="at least one prompt"): with pytest.raises(ValueError, match="at least one prompt"):
parse_and_batch_prompt([[]]) parse_raw_prompts([[]])
@pytest.mark.parametrize('string_input', STRING_INPUTS) @pytest.mark.parametrize('string_input', STRING_INPUTS)
def test_parse_single_batch_string_consistent(string_input: str): def test_parse_raw_single_batch_string_consistent(string_input: str):
assert parse_and_batch_prompt(string_input) \ assert parse_raw_prompts(string_input) \
== parse_and_batch_prompt([string_input]) == parse_raw_prompts([string_input])
@pytest.mark.parametrize('token_input', TOKEN_INPUTS) @pytest.mark.parametrize('token_input', TOKEN_INPUTS)
def test_parse_single_batch_token_consistent(token_input: list[int]): def test_parse_raw_single_batch_token_consistent(token_input: list[int]):
assert parse_and_batch_prompt(token_input) \ assert parse_raw_prompts(token_input) \
== parse_and_batch_prompt([token_input]) == parse_raw_prompts([token_input])
@pytest.mark.parametrize('inputs_slice', INPUTS_SLICES) @pytest.mark.parametrize('inputs_slice', INPUTS_SLICES)
def test_parse_single_batch_string_slice(inputs_slice: slice): def test_parse_raw_single_batch_string_slice(inputs_slice: slice):
assert parse_and_batch_prompt(STRING_INPUTS)[inputs_slice] \ assert parse_raw_prompts(STRING_INPUTS)[inputs_slice] \
== parse_and_batch_prompt(STRING_INPUTS[inputs_slice]) == parse_raw_prompts(STRING_INPUTS[inputs_slice])
# yapf: disable # yapf: disable

View File

@ -691,6 +691,5 @@ class OpenAIServingCompletion(OpenAIServing):
truncate_prompt_tokens=request.truncate_prompt_tokens, truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens, add_special_tokens=request.add_special_tokens,
cache_salt=request.cache_salt, cache_salt=request.cache_salt,
needs_detokenization=bool(request.echo needs_detokenization=bool(request.echo),
and not request.return_token_ids),
) )

View File

@ -13,8 +13,9 @@ from pydantic import Field
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt
from vllm.inputs.data import TextPrompt as EngineTextPrompt
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.inputs.parse import parse_and_batch_prompt from vllm.inputs.parse import get_prompt_components, parse_raw_prompts
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import AsyncMicrobatchTokenizer from vllm.utils import AsyncMicrobatchTokenizer
@ -41,6 +42,27 @@ class RenderConfig:
needs_detokenization: Optional[bool] = False needs_detokenization: Optional[bool] = False
"""If True, detokenize IDs back to text for inclusion in outputs.""" """If True, detokenize IDs back to text for inclusion in outputs."""
def verify_truncate_prompt_tokens(
self, model_config: ModelConfig) -> Optional[int]:
"""Validate and normalize `truncate_prompt_tokens` parameter."""
truncate_prompt_tokens = self.truncate_prompt_tokens
if truncate_prompt_tokens is None:
return None
if truncate_prompt_tokens == 0:
return 0
if truncate_prompt_tokens < 0:
truncate_prompt_tokens = model_config.max_model_len
max_length = self.max_length
if max_length is not None and truncate_prompt_tokens > max_length: # type: ignore[operator]
raise ValueError(
f"{truncate_prompt_tokens=} cannot be greater than "
f"{max_length=}. Please select a smaller truncation size.")
return truncate_prompt_tokens
class BaseRenderer(ABC): class BaseRenderer(ABC):
""" """
@ -74,7 +96,7 @@ class BaseRenderer(ABC):
self, self,
*, *,
prompt_or_prompts: Union[str, list[str], list[int], list[list[int]]], prompt_or_prompts: Union[str, list[str], list[int], list[list[int]]],
config: "RenderConfig", config: RenderConfig,
) -> list[EngineTokensPrompt]: ) -> list[EngineTokensPrompt]:
""" """
Convert text or token inputs into engine-ready TokensPrompt objects. Convert text or token inputs into engine-ready TokensPrompt objects.
@ -107,7 +129,7 @@ class BaseRenderer(ABC):
prompt_or_prompts: Optional[Union[str, list[str], list[int], prompt_or_prompts: Optional[Union[str, list[str], list[int],
list[list[int]]]] = None, list[list[int]]]] = None,
prompt_embeds: Optional[Union[bytes, list[bytes]]] = None, prompt_embeds: Optional[Union[bytes, list[bytes]]] = None,
config: "RenderConfig", config: RenderConfig,
) -> list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]: ) -> list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]:
""" """
Convert text/token and/or base64-encoded embeddings inputs into Convert text/token and/or base64-encoded embeddings inputs into
@ -189,47 +211,25 @@ class CompletionRenderer(BaseRenderer):
self, self,
*, *,
prompt_or_prompts: Union[str, list[str], list[int], list[list[int]]], prompt_or_prompts: Union[str, list[str], list[int], list[list[int]]],
config: "RenderConfig", config: RenderConfig,
) -> list[EngineTokensPrompt]: ) -> list[EngineTokensPrompt]:
"""Implementation of prompt rendering for completion-style requests. """Implementation of prompt rendering for completion-style requests.
Uses async tokenizer pooling for improved performance. See base class Uses async tokenizer pooling for improved performance. See base class
for detailed parameter documentation. for detailed parameter documentation.
""" """
truncate_prompt_tokens = self._validate_and_normalize_truncate_tokens( truncate_prompt_tokens = config.verify_truncate_prompt_tokens(
config.truncate_prompt_tokens, config.max_length) self.model_config)
if truncate_prompt_tokens == 0: if truncate_prompt_tokens == 0:
return [] return []
# Parse and batch the input prompts tasks = (self._create_prompt(
batch_inputs = parse_and_batch_prompt(prompt_or_prompts) prompt_input,
config=config,
truncate_prompt_tokens=truncate_prompt_tokens,
) for prompt_input in parse_raw_prompts(prompt_or_prompts))
tasks = [] return await asyncio.gather(*tasks)
for prompt_input in batch_inputs:
if prompt_input["is_tokens"] is True:
# Token input
# Note: detokenization is needed when echo is enabled,
# where the input token IDs are decoded back to text.
task = self._maybe_detokenize(prompt_input["content"],
config.max_length,
truncate_prompt_tokens,
config.cache_salt,
config.needs_detokenization)
else:
# Text input
task = self._tokenize(prompt_input["content"],
config.max_length,
truncate_prompt_tokens,
config.add_special_tokens,
config.cache_salt)
tasks.append(task)
# Wait for all text tokenization to finish
if tasks:
tokenized_text_prompts = await asyncio.gather(*tasks)
return tokenized_text_prompts
return []
async def render_prompt_and_embeds( async def render_prompt_and_embeds(
self, self,
@ -237,14 +237,14 @@ class CompletionRenderer(BaseRenderer):
prompt_or_prompts: Optional[Union[str, list[str], list[int], prompt_or_prompts: Optional[Union[str, list[str], list[int],
list[list[int]]]] = None, list[list[int]]]] = None,
prompt_embeds: Optional[Union[bytes, list[bytes]]] = None, prompt_embeds: Optional[Union[bytes, list[bytes]]] = None,
config: "RenderConfig", config: RenderConfig,
) -> list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]: ) -> list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]:
""" """
Render text/token prompts and/or precomputed embedding prompts. At Render text/token prompts and/or precomputed embedding prompts. At
least one of `prompt_or_prompts` or `prompt_embeds` must be provided. least one of `prompt_or_prompts` or `prompt_embeds` must be provided.
""" """
truncate_prompt_tokens = self._validate_and_normalize_truncate_tokens( truncate_prompt_tokens = config.verify_truncate_prompt_tokens(
config.truncate_prompt_tokens, config.max_length) self.model_config)
if truncate_prompt_tokens == 0: if truncate_prompt_tokens == 0:
return [] return []
@ -265,29 +265,6 @@ class CompletionRenderer(BaseRenderer):
return rendered return rendered
def _validate_and_normalize_truncate_tokens(
self,
truncate_prompt_tokens: Optional[int],
max_length: Optional[int],
) -> Optional[int]:
"""Validate and normalize truncate_prompt_tokens parameter."""
if truncate_prompt_tokens is None:
return None
if truncate_prompt_tokens == 0:
return 0
if truncate_prompt_tokens < 0:
truncate_prompt_tokens = self.model_config.max_model_len
if max_length is not None and truncate_prompt_tokens > max_length: # type: ignore[operator]
raise ValueError(
f"truncate_prompt_tokens ({truncate_prompt_tokens}) "
f"cannot be greater than max_length ({max_length}). "
f"Please select a smaller truncation size.")
return truncate_prompt_tokens
def _maybe_apply_truncation( def _maybe_apply_truncation(
self, token_ids: list[int], self, token_ids: list[int],
truncate_prompt_tokens: Optional[int]) -> list[int]: truncate_prompt_tokens: Optional[int]) -> list[int]:
@ -299,7 +276,38 @@ class CompletionRenderer(BaseRenderer):
return token_ids[-truncate_prompt_tokens:] return token_ids[-truncate_prompt_tokens:]
async def _tokenize( async def _create_prompt(
self,
prompt_input: Union[EngineTextPrompt, EngineTokensPrompt],
config: RenderConfig,
truncate_prompt_tokens: Optional[int],
) -> EngineTokensPrompt:
prompt, prompt_token_ids, _ = get_prompt_components(prompt_input)
if prompt_token_ids is not None:
# NOTE: detokenization is needed when echo is enabled,
# where the input token IDs are decoded back to text.
return await self._create_prompt_from_token_ids(
prompt_token_ids,
config.max_length,
truncate_prompt_tokens,
config.cache_salt,
config.needs_detokenization,
)
if prompt is not None:
return await self._create_prompt_from_text(
prompt,
config.max_length,
truncate_prompt_tokens,
config.add_special_tokens,
config.cache_salt,
)
# TODO: Also handle embeds prompt using this method
raise NotImplementedError
async def _create_prompt_from_text(
self, self,
text: str, text: str,
max_length: Optional[int], max_length: Optional[int],
@ -330,7 +338,7 @@ class CompletionRenderer(BaseRenderer):
return self._create_tokens_prompt(encoded.input_ids, max_length, return self._create_tokens_prompt(encoded.input_ids, max_length,
cache_salt, text) cache_salt, text)
async def _maybe_detokenize( async def _create_prompt_from_token_ids(
self, self,
token_ids: list[int], token_ids: list[int],
max_length: Optional[int], max_length: Optional[int],
@ -343,7 +351,7 @@ class CompletionRenderer(BaseRenderer):
truncate_prompt_tokens) truncate_prompt_tokens)
prompt = None prompt = None
if needs_detokenization is True: if needs_detokenization:
async_tokenizer = self._get_async_tokenizer() async_tokenizer = self._get_async_tokenizer()
prompt = await async_tokenizer.decode(token_ids) prompt = await async_tokenizer.decode(token_ids)

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence from collections.abc import Sequence
from typing import (TYPE_CHECKING, Literal, NamedTuple, Optional, TypedDict, from typing import (TYPE_CHECKING, Literal, NamedTuple, Optional, TypedDict,
Union, cast, overload) Union, cast)
from typing_extensions import TypeIs from typing_extensions import TypeIs
@ -16,34 +16,12 @@ if TYPE_CHECKING:
import torch import torch
class ParsedText(TypedDict): def parse_raw_prompts(
content: str
is_tokens: Literal[False]
class ParsedTokens(TypedDict):
content: list[int]
is_tokens: Literal[True]
@overload
def parse_and_batch_prompt(
prompt: Union[str, list[str]], ) -> Sequence[ParsedText]:
...
@overload
def parse_and_batch_prompt(
prompt: Union[list[int], list[list[int]]], ) -> Sequence[ParsedTokens]:
...
def parse_and_batch_prompt(
prompt: Union[str, list[str], list[int], list[list[int]]], prompt: Union[str, list[str], list[int], list[list[int]]],
) -> Union[Sequence[ParsedText], Sequence[ParsedTokens]]: ) -> Union[Sequence[TextPrompt], Sequence[TokensPrompt]]:
if isinstance(prompt, str): if isinstance(prompt, str):
# case 1: a string # case 1: a string
return [ParsedText(content=prompt, is_tokens=False)] return [TextPrompt(prompt=prompt)]
if isinstance(prompt, list): if isinstance(prompt, list):
if len(prompt) == 0: if len(prompt) == 0:
@ -52,13 +30,11 @@ def parse_and_batch_prompt(
if is_list_of(prompt, str): if is_list_of(prompt, str):
# case 2: array of strings # case 2: array of strings
prompt = cast(list[str], prompt) prompt = cast(list[str], prompt)
return [ return [TextPrompt(prompt=elem) for elem in prompt]
ParsedText(content=elem, is_tokens=False) for elem in prompt
]
if is_list_of(prompt, int): if is_list_of(prompt, int):
# case 3: array of tokens # case 3: array of tokens
prompt = cast(list[int], prompt) prompt = cast(list[int], prompt)
return [ParsedTokens(content=prompt, is_tokens=True)] return [TokensPrompt(prompt_token_ids=prompt)]
if is_list_of(prompt, list): if is_list_of(prompt, list):
prompt = cast(list[list[int]], prompt) prompt = cast(list[list[int]], prompt)
if len(prompt[0]) == 0: if len(prompt[0]) == 0:
@ -66,10 +42,7 @@ def parse_and_batch_prompt(
if is_list_of(prompt[0], int): if is_list_of(prompt[0], int):
# case 4: array of token arrays # case 4: array of token arrays
return [ return [TokensPrompt(prompt_token_ids=elem) for elem in prompt]
ParsedTokens(content=elem, is_tokens=True)
for elem in prompt
]
raise TypeError("prompt must be a string, array of strings, " raise TypeError("prompt must be a string, array of strings, "
"array of tokens, or array of token arrays") "array of tokens, or array of token arrays")
@ -99,26 +72,6 @@ ParsedSingletonPrompt = Union[ParsedStrPrompt, ParsedTextPrompt,
ParsedTokensPrompt, ParsedEmbedsPrompt] ParsedTokensPrompt, ParsedEmbedsPrompt]
@overload
def parse_singleton_prompt(prompt: str) -> ParsedStrPrompt:
...
@overload
def parse_singleton_prompt(prompt: TextPrompt) -> ParsedTextPrompt:
...
@overload
def parse_singleton_prompt(prompt: TokensPrompt) -> ParsedTokensPrompt:
...
@overload
def parse_singleton_prompt(prompt: EmbedsPrompt) -> ParsedEmbedsPrompt:
...
def parse_singleton_prompt(prompt: SingletonPrompt) -> ParsedSingletonPrompt: def parse_singleton_prompt(prompt: SingletonPrompt) -> ParsedSingletonPrompt:
if isinstance(prompt, str): if isinstance(prompt, str):
return ParsedStrPrompt(type="str", content=prompt) return ParsedStrPrompt(type="str", content=prompt)