mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 06:34:58 +08:00
[Renderer] Clean up renderer code (#26216)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
a42d2df75f
commit
119f00630b
@ -54,7 +54,7 @@ async def test_token_in_token_out_and_logprobs(server):
|
||||
prompt=token_ids,
|
||||
max_tokens=20,
|
||||
temperature=0,
|
||||
echo=True,
|
||||
echo=False,
|
||||
extra_body={
|
||||
"return_token_ids": True,
|
||||
},
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
import pytest
|
||||
|
||||
from vllm.inputs import zip_enc_dec_prompts
|
||||
from vllm.inputs.parse import parse_and_batch_prompt
|
||||
from vllm.inputs.parse import parse_raw_prompts
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
@ -31,30 +31,30 @@ INPUTS_SLICES = [
|
||||
]
|
||||
|
||||
|
||||
def test_parse_single_batch_empty():
|
||||
def test_parse_raw_single_batch_empty():
|
||||
with pytest.raises(ValueError, match="at least one prompt"):
|
||||
parse_and_batch_prompt([])
|
||||
parse_raw_prompts([])
|
||||
|
||||
with pytest.raises(ValueError, match="at least one prompt"):
|
||||
parse_and_batch_prompt([[]])
|
||||
parse_raw_prompts([[]])
|
||||
|
||||
|
||||
@pytest.mark.parametrize('string_input', STRING_INPUTS)
|
||||
def test_parse_single_batch_string_consistent(string_input: str):
|
||||
assert parse_and_batch_prompt(string_input) \
|
||||
== parse_and_batch_prompt([string_input])
|
||||
def test_parse_raw_single_batch_string_consistent(string_input: str):
|
||||
assert parse_raw_prompts(string_input) \
|
||||
== parse_raw_prompts([string_input])
|
||||
|
||||
|
||||
@pytest.mark.parametrize('token_input', TOKEN_INPUTS)
|
||||
def test_parse_single_batch_token_consistent(token_input: list[int]):
|
||||
assert parse_and_batch_prompt(token_input) \
|
||||
== parse_and_batch_prompt([token_input])
|
||||
def test_parse_raw_single_batch_token_consistent(token_input: list[int]):
|
||||
assert parse_raw_prompts(token_input) \
|
||||
== parse_raw_prompts([token_input])
|
||||
|
||||
|
||||
@pytest.mark.parametrize('inputs_slice', INPUTS_SLICES)
|
||||
def test_parse_single_batch_string_slice(inputs_slice: slice):
|
||||
assert parse_and_batch_prompt(STRING_INPUTS)[inputs_slice] \
|
||||
== parse_and_batch_prompt(STRING_INPUTS[inputs_slice])
|
||||
def test_parse_raw_single_batch_string_slice(inputs_slice: slice):
|
||||
assert parse_raw_prompts(STRING_INPUTS)[inputs_slice] \
|
||||
== parse_raw_prompts(STRING_INPUTS[inputs_slice])
|
||||
|
||||
|
||||
# yapf: disable
|
||||
|
||||
@ -691,6 +691,5 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
cache_salt=request.cache_salt,
|
||||
needs_detokenization=bool(request.echo
|
||||
and not request.return_token_ids),
|
||||
needs_detokenization=bool(request.echo),
|
||||
)
|
||||
|
||||
@ -13,8 +13,9 @@ from pydantic import Field
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt
|
||||
from vllm.inputs.data import TextPrompt as EngineTextPrompt
|
||||
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
||||
from vllm.inputs.parse import parse_and_batch_prompt
|
||||
from vllm.inputs.parse import get_prompt_components, parse_raw_prompts
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import AsyncMicrobatchTokenizer
|
||||
|
||||
@ -41,6 +42,27 @@ class RenderConfig:
|
||||
needs_detokenization: Optional[bool] = False
|
||||
"""If True, detokenize IDs back to text for inclusion in outputs."""
|
||||
|
||||
def verify_truncate_prompt_tokens(
|
||||
self, model_config: ModelConfig) -> Optional[int]:
|
||||
"""Validate and normalize `truncate_prompt_tokens` parameter."""
|
||||
truncate_prompt_tokens = self.truncate_prompt_tokens
|
||||
if truncate_prompt_tokens is None:
|
||||
return None
|
||||
|
||||
if truncate_prompt_tokens == 0:
|
||||
return 0
|
||||
|
||||
if truncate_prompt_tokens < 0:
|
||||
truncate_prompt_tokens = model_config.max_model_len
|
||||
|
||||
max_length = self.max_length
|
||||
if max_length is not None and truncate_prompt_tokens > max_length: # type: ignore[operator]
|
||||
raise ValueError(
|
||||
f"{truncate_prompt_tokens=} cannot be greater than "
|
||||
f"{max_length=}. Please select a smaller truncation size.")
|
||||
|
||||
return truncate_prompt_tokens
|
||||
|
||||
|
||||
class BaseRenderer(ABC):
|
||||
"""
|
||||
@ -74,7 +96,7 @@ class BaseRenderer(ABC):
|
||||
self,
|
||||
*,
|
||||
prompt_or_prompts: Union[str, list[str], list[int], list[list[int]]],
|
||||
config: "RenderConfig",
|
||||
config: RenderConfig,
|
||||
) -> list[EngineTokensPrompt]:
|
||||
"""
|
||||
Convert text or token inputs into engine-ready TokensPrompt objects.
|
||||
@ -107,7 +129,7 @@ class BaseRenderer(ABC):
|
||||
prompt_or_prompts: Optional[Union[str, list[str], list[int],
|
||||
list[list[int]]]] = None,
|
||||
prompt_embeds: Optional[Union[bytes, list[bytes]]] = None,
|
||||
config: "RenderConfig",
|
||||
config: RenderConfig,
|
||||
) -> list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]:
|
||||
"""
|
||||
Convert text/token and/or base64-encoded embeddings inputs into
|
||||
@ -189,47 +211,25 @@ class CompletionRenderer(BaseRenderer):
|
||||
self,
|
||||
*,
|
||||
prompt_or_prompts: Union[str, list[str], list[int], list[list[int]]],
|
||||
config: "RenderConfig",
|
||||
config: RenderConfig,
|
||||
) -> list[EngineTokensPrompt]:
|
||||
"""Implementation of prompt rendering for completion-style requests.
|
||||
|
||||
Uses async tokenizer pooling for improved performance. See base class
|
||||
for detailed parameter documentation.
|
||||
"""
|
||||
truncate_prompt_tokens = self._validate_and_normalize_truncate_tokens(
|
||||
config.truncate_prompt_tokens, config.max_length)
|
||||
truncate_prompt_tokens = config.verify_truncate_prompt_tokens(
|
||||
self.model_config)
|
||||
if truncate_prompt_tokens == 0:
|
||||
return []
|
||||
|
||||
# Parse and batch the input prompts
|
||||
batch_inputs = parse_and_batch_prompt(prompt_or_prompts)
|
||||
tasks = (self._create_prompt(
|
||||
prompt_input,
|
||||
config=config,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
) for prompt_input in parse_raw_prompts(prompt_or_prompts))
|
||||
|
||||
tasks = []
|
||||
for prompt_input in batch_inputs:
|
||||
if prompt_input["is_tokens"] is True:
|
||||
# Token input
|
||||
# Note: detokenization is needed when echo is enabled,
|
||||
# where the input token IDs are decoded back to text.
|
||||
task = self._maybe_detokenize(prompt_input["content"],
|
||||
config.max_length,
|
||||
truncate_prompt_tokens,
|
||||
config.cache_salt,
|
||||
config.needs_detokenization)
|
||||
else:
|
||||
# Text input
|
||||
task = self._tokenize(prompt_input["content"],
|
||||
config.max_length,
|
||||
truncate_prompt_tokens,
|
||||
config.add_special_tokens,
|
||||
config.cache_salt)
|
||||
tasks.append(task)
|
||||
|
||||
# Wait for all text tokenization to finish
|
||||
if tasks:
|
||||
tokenized_text_prompts = await asyncio.gather(*tasks)
|
||||
return tokenized_text_prompts
|
||||
|
||||
return []
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
async def render_prompt_and_embeds(
|
||||
self,
|
||||
@ -237,14 +237,14 @@ class CompletionRenderer(BaseRenderer):
|
||||
prompt_or_prompts: Optional[Union[str, list[str], list[int],
|
||||
list[list[int]]]] = None,
|
||||
prompt_embeds: Optional[Union[bytes, list[bytes]]] = None,
|
||||
config: "RenderConfig",
|
||||
config: RenderConfig,
|
||||
) -> list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]:
|
||||
"""
|
||||
Render text/token prompts and/or precomputed embedding prompts. At
|
||||
least one of `prompt_or_prompts` or `prompt_embeds` must be provided.
|
||||
"""
|
||||
truncate_prompt_tokens = self._validate_and_normalize_truncate_tokens(
|
||||
config.truncate_prompt_tokens, config.max_length)
|
||||
truncate_prompt_tokens = config.verify_truncate_prompt_tokens(
|
||||
self.model_config)
|
||||
if truncate_prompt_tokens == 0:
|
||||
return []
|
||||
|
||||
@ -265,29 +265,6 @@ class CompletionRenderer(BaseRenderer):
|
||||
|
||||
return rendered
|
||||
|
||||
def _validate_and_normalize_truncate_tokens(
|
||||
self,
|
||||
truncate_prompt_tokens: Optional[int],
|
||||
max_length: Optional[int],
|
||||
) -> Optional[int]:
|
||||
"""Validate and normalize truncate_prompt_tokens parameter."""
|
||||
if truncate_prompt_tokens is None:
|
||||
return None
|
||||
|
||||
if truncate_prompt_tokens == 0:
|
||||
return 0
|
||||
|
||||
if truncate_prompt_tokens < 0:
|
||||
truncate_prompt_tokens = self.model_config.max_model_len
|
||||
|
||||
if max_length is not None and truncate_prompt_tokens > max_length: # type: ignore[operator]
|
||||
raise ValueError(
|
||||
f"truncate_prompt_tokens ({truncate_prompt_tokens}) "
|
||||
f"cannot be greater than max_length ({max_length}). "
|
||||
f"Please select a smaller truncation size.")
|
||||
|
||||
return truncate_prompt_tokens
|
||||
|
||||
def _maybe_apply_truncation(
|
||||
self, token_ids: list[int],
|
||||
truncate_prompt_tokens: Optional[int]) -> list[int]:
|
||||
@ -299,7 +276,38 @@ class CompletionRenderer(BaseRenderer):
|
||||
|
||||
return token_ids[-truncate_prompt_tokens:]
|
||||
|
||||
async def _tokenize(
|
||||
async def _create_prompt(
|
||||
self,
|
||||
prompt_input: Union[EngineTextPrompt, EngineTokensPrompt],
|
||||
config: RenderConfig,
|
||||
truncate_prompt_tokens: Optional[int],
|
||||
) -> EngineTokensPrompt:
|
||||
prompt, prompt_token_ids, _ = get_prompt_components(prompt_input)
|
||||
|
||||
if prompt_token_ids is not None:
|
||||
# NOTE: detokenization is needed when echo is enabled,
|
||||
# where the input token IDs are decoded back to text.
|
||||
return await self._create_prompt_from_token_ids(
|
||||
prompt_token_ids,
|
||||
config.max_length,
|
||||
truncate_prompt_tokens,
|
||||
config.cache_salt,
|
||||
config.needs_detokenization,
|
||||
)
|
||||
|
||||
if prompt is not None:
|
||||
return await self._create_prompt_from_text(
|
||||
prompt,
|
||||
config.max_length,
|
||||
truncate_prompt_tokens,
|
||||
config.add_special_tokens,
|
||||
config.cache_salt,
|
||||
)
|
||||
|
||||
# TODO: Also handle embeds prompt using this method
|
||||
raise NotImplementedError
|
||||
|
||||
async def _create_prompt_from_text(
|
||||
self,
|
||||
text: str,
|
||||
max_length: Optional[int],
|
||||
@ -330,7 +338,7 @@ class CompletionRenderer(BaseRenderer):
|
||||
return self._create_tokens_prompt(encoded.input_ids, max_length,
|
||||
cache_salt, text)
|
||||
|
||||
async def _maybe_detokenize(
|
||||
async def _create_prompt_from_token_ids(
|
||||
self,
|
||||
token_ids: list[int],
|
||||
max_length: Optional[int],
|
||||
@ -343,7 +351,7 @@ class CompletionRenderer(BaseRenderer):
|
||||
truncate_prompt_tokens)
|
||||
|
||||
prompt = None
|
||||
if needs_detokenization is True:
|
||||
if needs_detokenization:
|
||||
async_tokenizer = self._get_async_tokenizer()
|
||||
prompt = await async_tokenizer.decode(token_ids)
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Sequence
|
||||
from typing import (TYPE_CHECKING, Literal, NamedTuple, Optional, TypedDict,
|
||||
Union, cast, overload)
|
||||
Union, cast)
|
||||
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
@ -16,34 +16,12 @@ if TYPE_CHECKING:
|
||||
import torch
|
||||
|
||||
|
||||
class ParsedText(TypedDict):
|
||||
content: str
|
||||
is_tokens: Literal[False]
|
||||
|
||||
|
||||
class ParsedTokens(TypedDict):
|
||||
content: list[int]
|
||||
is_tokens: Literal[True]
|
||||
|
||||
|
||||
@overload
|
||||
def parse_and_batch_prompt(
|
||||
prompt: Union[str, list[str]], ) -> Sequence[ParsedText]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def parse_and_batch_prompt(
|
||||
prompt: Union[list[int], list[list[int]]], ) -> Sequence[ParsedTokens]:
|
||||
...
|
||||
|
||||
|
||||
def parse_and_batch_prompt(
|
||||
def parse_raw_prompts(
|
||||
prompt: Union[str, list[str], list[int], list[list[int]]],
|
||||
) -> Union[Sequence[ParsedText], Sequence[ParsedTokens]]:
|
||||
) -> Union[Sequence[TextPrompt], Sequence[TokensPrompt]]:
|
||||
if isinstance(prompt, str):
|
||||
# case 1: a string
|
||||
return [ParsedText(content=prompt, is_tokens=False)]
|
||||
return [TextPrompt(prompt=prompt)]
|
||||
|
||||
if isinstance(prompt, list):
|
||||
if len(prompt) == 0:
|
||||
@ -52,13 +30,11 @@ def parse_and_batch_prompt(
|
||||
if is_list_of(prompt, str):
|
||||
# case 2: array of strings
|
||||
prompt = cast(list[str], prompt)
|
||||
return [
|
||||
ParsedText(content=elem, is_tokens=False) for elem in prompt
|
||||
]
|
||||
return [TextPrompt(prompt=elem) for elem in prompt]
|
||||
if is_list_of(prompt, int):
|
||||
# case 3: array of tokens
|
||||
prompt = cast(list[int], prompt)
|
||||
return [ParsedTokens(content=prompt, is_tokens=True)]
|
||||
return [TokensPrompt(prompt_token_ids=prompt)]
|
||||
if is_list_of(prompt, list):
|
||||
prompt = cast(list[list[int]], prompt)
|
||||
if len(prompt[0]) == 0:
|
||||
@ -66,10 +42,7 @@ def parse_and_batch_prompt(
|
||||
|
||||
if is_list_of(prompt[0], int):
|
||||
# case 4: array of token arrays
|
||||
return [
|
||||
ParsedTokens(content=elem, is_tokens=True)
|
||||
for elem in prompt
|
||||
]
|
||||
return [TokensPrompt(prompt_token_ids=elem) for elem in prompt]
|
||||
|
||||
raise TypeError("prompt must be a string, array of strings, "
|
||||
"array of tokens, or array of token arrays")
|
||||
@ -99,26 +72,6 @@ ParsedSingletonPrompt = Union[ParsedStrPrompt, ParsedTextPrompt,
|
||||
ParsedTokensPrompt, ParsedEmbedsPrompt]
|
||||
|
||||
|
||||
@overload
|
||||
def parse_singleton_prompt(prompt: str) -> ParsedStrPrompt:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def parse_singleton_prompt(prompt: TextPrompt) -> ParsedTextPrompt:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def parse_singleton_prompt(prompt: TokensPrompt) -> ParsedTokensPrompt:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def parse_singleton_prompt(prompt: EmbedsPrompt) -> ParsedEmbedsPrompt:
|
||||
...
|
||||
|
||||
|
||||
def parse_singleton_prompt(prompt: SingletonPrompt) -> ParsedSingletonPrompt:
|
||||
if isinstance(prompt, str):
|
||||
return ParsedStrPrompt(type="str", content=prompt)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user