mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:25:00 +08:00
[Renderer] Clean up renderer code (#26216)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
a42d2df75f
commit
119f00630b
@ -54,7 +54,7 @@ async def test_token_in_token_out_and_logprobs(server):
|
|||||||
prompt=token_ids,
|
prompt=token_ids,
|
||||||
max_tokens=20,
|
max_tokens=20,
|
||||||
temperature=0,
|
temperature=0,
|
||||||
echo=True,
|
echo=False,
|
||||||
extra_body={
|
extra_body={
|
||||||
"return_token_ids": True,
|
"return_token_ids": True,
|
||||||
},
|
},
|
||||||
|
|||||||
@ -4,7 +4,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm.inputs import zip_enc_dec_prompts
|
from vllm.inputs import zip_enc_dec_prompts
|
||||||
from vllm.inputs.parse import parse_and_batch_prompt
|
from vllm.inputs.parse import parse_raw_prompts
|
||||||
|
|
||||||
pytestmark = pytest.mark.cpu_test
|
pytestmark = pytest.mark.cpu_test
|
||||||
|
|
||||||
@ -31,30 +31,30 @@ INPUTS_SLICES = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_parse_single_batch_empty():
|
def test_parse_raw_single_batch_empty():
|
||||||
with pytest.raises(ValueError, match="at least one prompt"):
|
with pytest.raises(ValueError, match="at least one prompt"):
|
||||||
parse_and_batch_prompt([])
|
parse_raw_prompts([])
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="at least one prompt"):
|
with pytest.raises(ValueError, match="at least one prompt"):
|
||||||
parse_and_batch_prompt([[]])
|
parse_raw_prompts([[]])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('string_input', STRING_INPUTS)
|
@pytest.mark.parametrize('string_input', STRING_INPUTS)
|
||||||
def test_parse_single_batch_string_consistent(string_input: str):
|
def test_parse_raw_single_batch_string_consistent(string_input: str):
|
||||||
assert parse_and_batch_prompt(string_input) \
|
assert parse_raw_prompts(string_input) \
|
||||||
== parse_and_batch_prompt([string_input])
|
== parse_raw_prompts([string_input])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('token_input', TOKEN_INPUTS)
|
@pytest.mark.parametrize('token_input', TOKEN_INPUTS)
|
||||||
def test_parse_single_batch_token_consistent(token_input: list[int]):
|
def test_parse_raw_single_batch_token_consistent(token_input: list[int]):
|
||||||
assert parse_and_batch_prompt(token_input) \
|
assert parse_raw_prompts(token_input) \
|
||||||
== parse_and_batch_prompt([token_input])
|
== parse_raw_prompts([token_input])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('inputs_slice', INPUTS_SLICES)
|
@pytest.mark.parametrize('inputs_slice', INPUTS_SLICES)
|
||||||
def test_parse_single_batch_string_slice(inputs_slice: slice):
|
def test_parse_raw_single_batch_string_slice(inputs_slice: slice):
|
||||||
assert parse_and_batch_prompt(STRING_INPUTS)[inputs_slice] \
|
assert parse_raw_prompts(STRING_INPUTS)[inputs_slice] \
|
||||||
== parse_and_batch_prompt(STRING_INPUTS[inputs_slice])
|
== parse_raw_prompts(STRING_INPUTS[inputs_slice])
|
||||||
|
|
||||||
|
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
|
|||||||
@ -691,6 +691,5 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
||||||
add_special_tokens=request.add_special_tokens,
|
add_special_tokens=request.add_special_tokens,
|
||||||
cache_salt=request.cache_salt,
|
cache_salt=request.cache_salt,
|
||||||
needs_detokenization=bool(request.echo
|
needs_detokenization=bool(request.echo),
|
||||||
and not request.return_token_ids),
|
|
||||||
)
|
)
|
||||||
|
|||||||
@ -13,8 +13,9 @@ from pydantic import Field
|
|||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt
|
from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt
|
||||||
|
from vllm.inputs.data import TextPrompt as EngineTextPrompt
|
||||||
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
||||||
from vllm.inputs.parse import parse_and_batch_prompt
|
from vllm.inputs.parse import get_prompt_components, parse_raw_prompts
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
from vllm.utils import AsyncMicrobatchTokenizer
|
from vllm.utils import AsyncMicrobatchTokenizer
|
||||||
|
|
||||||
@ -41,6 +42,27 @@ class RenderConfig:
|
|||||||
needs_detokenization: Optional[bool] = False
|
needs_detokenization: Optional[bool] = False
|
||||||
"""If True, detokenize IDs back to text for inclusion in outputs."""
|
"""If True, detokenize IDs back to text for inclusion in outputs."""
|
||||||
|
|
||||||
|
def verify_truncate_prompt_tokens(
|
||||||
|
self, model_config: ModelConfig) -> Optional[int]:
|
||||||
|
"""Validate and normalize `truncate_prompt_tokens` parameter."""
|
||||||
|
truncate_prompt_tokens = self.truncate_prompt_tokens
|
||||||
|
if truncate_prompt_tokens is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if truncate_prompt_tokens == 0:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
if truncate_prompt_tokens < 0:
|
||||||
|
truncate_prompt_tokens = model_config.max_model_len
|
||||||
|
|
||||||
|
max_length = self.max_length
|
||||||
|
if max_length is not None and truncate_prompt_tokens > max_length: # type: ignore[operator]
|
||||||
|
raise ValueError(
|
||||||
|
f"{truncate_prompt_tokens=} cannot be greater than "
|
||||||
|
f"{max_length=}. Please select a smaller truncation size.")
|
||||||
|
|
||||||
|
return truncate_prompt_tokens
|
||||||
|
|
||||||
|
|
||||||
class BaseRenderer(ABC):
|
class BaseRenderer(ABC):
|
||||||
"""
|
"""
|
||||||
@ -74,7 +96,7 @@ class BaseRenderer(ABC):
|
|||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
prompt_or_prompts: Union[str, list[str], list[int], list[list[int]]],
|
prompt_or_prompts: Union[str, list[str], list[int], list[list[int]]],
|
||||||
config: "RenderConfig",
|
config: RenderConfig,
|
||||||
) -> list[EngineTokensPrompt]:
|
) -> list[EngineTokensPrompt]:
|
||||||
"""
|
"""
|
||||||
Convert text or token inputs into engine-ready TokensPrompt objects.
|
Convert text or token inputs into engine-ready TokensPrompt objects.
|
||||||
@ -107,7 +129,7 @@ class BaseRenderer(ABC):
|
|||||||
prompt_or_prompts: Optional[Union[str, list[str], list[int],
|
prompt_or_prompts: Optional[Union[str, list[str], list[int],
|
||||||
list[list[int]]]] = None,
|
list[list[int]]]] = None,
|
||||||
prompt_embeds: Optional[Union[bytes, list[bytes]]] = None,
|
prompt_embeds: Optional[Union[bytes, list[bytes]]] = None,
|
||||||
config: "RenderConfig",
|
config: RenderConfig,
|
||||||
) -> list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]:
|
) -> list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]:
|
||||||
"""
|
"""
|
||||||
Convert text/token and/or base64-encoded embeddings inputs into
|
Convert text/token and/or base64-encoded embeddings inputs into
|
||||||
@ -189,47 +211,25 @@ class CompletionRenderer(BaseRenderer):
|
|||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
prompt_or_prompts: Union[str, list[str], list[int], list[list[int]]],
|
prompt_or_prompts: Union[str, list[str], list[int], list[list[int]]],
|
||||||
config: "RenderConfig",
|
config: RenderConfig,
|
||||||
) -> list[EngineTokensPrompt]:
|
) -> list[EngineTokensPrompt]:
|
||||||
"""Implementation of prompt rendering for completion-style requests.
|
"""Implementation of prompt rendering for completion-style requests.
|
||||||
|
|
||||||
Uses async tokenizer pooling for improved performance. See base class
|
Uses async tokenizer pooling for improved performance. See base class
|
||||||
for detailed parameter documentation.
|
for detailed parameter documentation.
|
||||||
"""
|
"""
|
||||||
truncate_prompt_tokens = self._validate_and_normalize_truncate_tokens(
|
truncate_prompt_tokens = config.verify_truncate_prompt_tokens(
|
||||||
config.truncate_prompt_tokens, config.max_length)
|
self.model_config)
|
||||||
if truncate_prompt_tokens == 0:
|
if truncate_prompt_tokens == 0:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# Parse and batch the input prompts
|
tasks = (self._create_prompt(
|
||||||
batch_inputs = parse_and_batch_prompt(prompt_or_prompts)
|
prompt_input,
|
||||||
|
config=config,
|
||||||
|
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||||
|
) for prompt_input in parse_raw_prompts(prompt_or_prompts))
|
||||||
|
|
||||||
tasks = []
|
return await asyncio.gather(*tasks)
|
||||||
for prompt_input in batch_inputs:
|
|
||||||
if prompt_input["is_tokens"] is True:
|
|
||||||
# Token input
|
|
||||||
# Note: detokenization is needed when echo is enabled,
|
|
||||||
# where the input token IDs are decoded back to text.
|
|
||||||
task = self._maybe_detokenize(prompt_input["content"],
|
|
||||||
config.max_length,
|
|
||||||
truncate_prompt_tokens,
|
|
||||||
config.cache_salt,
|
|
||||||
config.needs_detokenization)
|
|
||||||
else:
|
|
||||||
# Text input
|
|
||||||
task = self._tokenize(prompt_input["content"],
|
|
||||||
config.max_length,
|
|
||||||
truncate_prompt_tokens,
|
|
||||||
config.add_special_tokens,
|
|
||||||
config.cache_salt)
|
|
||||||
tasks.append(task)
|
|
||||||
|
|
||||||
# Wait for all text tokenization to finish
|
|
||||||
if tasks:
|
|
||||||
tokenized_text_prompts = await asyncio.gather(*tasks)
|
|
||||||
return tokenized_text_prompts
|
|
||||||
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def render_prompt_and_embeds(
|
async def render_prompt_and_embeds(
|
||||||
self,
|
self,
|
||||||
@ -237,14 +237,14 @@ class CompletionRenderer(BaseRenderer):
|
|||||||
prompt_or_prompts: Optional[Union[str, list[str], list[int],
|
prompt_or_prompts: Optional[Union[str, list[str], list[int],
|
||||||
list[list[int]]]] = None,
|
list[list[int]]]] = None,
|
||||||
prompt_embeds: Optional[Union[bytes, list[bytes]]] = None,
|
prompt_embeds: Optional[Union[bytes, list[bytes]]] = None,
|
||||||
config: "RenderConfig",
|
config: RenderConfig,
|
||||||
) -> list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]:
|
) -> list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]:
|
||||||
"""
|
"""
|
||||||
Render text/token prompts and/or precomputed embedding prompts. At
|
Render text/token prompts and/or precomputed embedding prompts. At
|
||||||
least one of `prompt_or_prompts` or `prompt_embeds` must be provided.
|
least one of `prompt_or_prompts` or `prompt_embeds` must be provided.
|
||||||
"""
|
"""
|
||||||
truncate_prompt_tokens = self._validate_and_normalize_truncate_tokens(
|
truncate_prompt_tokens = config.verify_truncate_prompt_tokens(
|
||||||
config.truncate_prompt_tokens, config.max_length)
|
self.model_config)
|
||||||
if truncate_prompt_tokens == 0:
|
if truncate_prompt_tokens == 0:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@ -265,29 +265,6 @@ class CompletionRenderer(BaseRenderer):
|
|||||||
|
|
||||||
return rendered
|
return rendered
|
||||||
|
|
||||||
def _validate_and_normalize_truncate_tokens(
|
|
||||||
self,
|
|
||||||
truncate_prompt_tokens: Optional[int],
|
|
||||||
max_length: Optional[int],
|
|
||||||
) -> Optional[int]:
|
|
||||||
"""Validate and normalize truncate_prompt_tokens parameter."""
|
|
||||||
if truncate_prompt_tokens is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if truncate_prompt_tokens == 0:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
if truncate_prompt_tokens < 0:
|
|
||||||
truncate_prompt_tokens = self.model_config.max_model_len
|
|
||||||
|
|
||||||
if max_length is not None and truncate_prompt_tokens > max_length: # type: ignore[operator]
|
|
||||||
raise ValueError(
|
|
||||||
f"truncate_prompt_tokens ({truncate_prompt_tokens}) "
|
|
||||||
f"cannot be greater than max_length ({max_length}). "
|
|
||||||
f"Please select a smaller truncation size.")
|
|
||||||
|
|
||||||
return truncate_prompt_tokens
|
|
||||||
|
|
||||||
def _maybe_apply_truncation(
|
def _maybe_apply_truncation(
|
||||||
self, token_ids: list[int],
|
self, token_ids: list[int],
|
||||||
truncate_prompt_tokens: Optional[int]) -> list[int]:
|
truncate_prompt_tokens: Optional[int]) -> list[int]:
|
||||||
@ -299,7 +276,38 @@ class CompletionRenderer(BaseRenderer):
|
|||||||
|
|
||||||
return token_ids[-truncate_prompt_tokens:]
|
return token_ids[-truncate_prompt_tokens:]
|
||||||
|
|
||||||
async def _tokenize(
|
async def _create_prompt(
|
||||||
|
self,
|
||||||
|
prompt_input: Union[EngineTextPrompt, EngineTokensPrompt],
|
||||||
|
config: RenderConfig,
|
||||||
|
truncate_prompt_tokens: Optional[int],
|
||||||
|
) -> EngineTokensPrompt:
|
||||||
|
prompt, prompt_token_ids, _ = get_prompt_components(prompt_input)
|
||||||
|
|
||||||
|
if prompt_token_ids is not None:
|
||||||
|
# NOTE: detokenization is needed when echo is enabled,
|
||||||
|
# where the input token IDs are decoded back to text.
|
||||||
|
return await self._create_prompt_from_token_ids(
|
||||||
|
prompt_token_ids,
|
||||||
|
config.max_length,
|
||||||
|
truncate_prompt_tokens,
|
||||||
|
config.cache_salt,
|
||||||
|
config.needs_detokenization,
|
||||||
|
)
|
||||||
|
|
||||||
|
if prompt is not None:
|
||||||
|
return await self._create_prompt_from_text(
|
||||||
|
prompt,
|
||||||
|
config.max_length,
|
||||||
|
truncate_prompt_tokens,
|
||||||
|
config.add_special_tokens,
|
||||||
|
config.cache_salt,
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: Also handle embeds prompt using this method
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def _create_prompt_from_text(
|
||||||
self,
|
self,
|
||||||
text: str,
|
text: str,
|
||||||
max_length: Optional[int],
|
max_length: Optional[int],
|
||||||
@ -330,7 +338,7 @@ class CompletionRenderer(BaseRenderer):
|
|||||||
return self._create_tokens_prompt(encoded.input_ids, max_length,
|
return self._create_tokens_prompt(encoded.input_ids, max_length,
|
||||||
cache_salt, text)
|
cache_salt, text)
|
||||||
|
|
||||||
async def _maybe_detokenize(
|
async def _create_prompt_from_token_ids(
|
||||||
self,
|
self,
|
||||||
token_ids: list[int],
|
token_ids: list[int],
|
||||||
max_length: Optional[int],
|
max_length: Optional[int],
|
||||||
@ -343,7 +351,7 @@ class CompletionRenderer(BaseRenderer):
|
|||||||
truncate_prompt_tokens)
|
truncate_prompt_tokens)
|
||||||
|
|
||||||
prompt = None
|
prompt = None
|
||||||
if needs_detokenization is True:
|
if needs_detokenization:
|
||||||
async_tokenizer = self._get_async_tokenizer()
|
async_tokenizer = self._get_async_tokenizer()
|
||||||
prompt = await async_tokenizer.decode(token_ids)
|
prompt = await async_tokenizer.decode(token_ids)
|
||||||
|
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import (TYPE_CHECKING, Literal, NamedTuple, Optional, TypedDict,
|
from typing import (TYPE_CHECKING, Literal, NamedTuple, Optional, TypedDict,
|
||||||
Union, cast, overload)
|
Union, cast)
|
||||||
|
|
||||||
from typing_extensions import TypeIs
|
from typing_extensions import TypeIs
|
||||||
|
|
||||||
@ -16,34 +16,12 @@ if TYPE_CHECKING:
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
class ParsedText(TypedDict):
|
def parse_raw_prompts(
|
||||||
content: str
|
|
||||||
is_tokens: Literal[False]
|
|
||||||
|
|
||||||
|
|
||||||
class ParsedTokens(TypedDict):
|
|
||||||
content: list[int]
|
|
||||||
is_tokens: Literal[True]
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def parse_and_batch_prompt(
|
|
||||||
prompt: Union[str, list[str]], ) -> Sequence[ParsedText]:
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def parse_and_batch_prompt(
|
|
||||||
prompt: Union[list[int], list[list[int]]], ) -> Sequence[ParsedTokens]:
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
def parse_and_batch_prompt(
|
|
||||||
prompt: Union[str, list[str], list[int], list[list[int]]],
|
prompt: Union[str, list[str], list[int], list[list[int]]],
|
||||||
) -> Union[Sequence[ParsedText], Sequence[ParsedTokens]]:
|
) -> Union[Sequence[TextPrompt], Sequence[TokensPrompt]]:
|
||||||
if isinstance(prompt, str):
|
if isinstance(prompt, str):
|
||||||
# case 1: a string
|
# case 1: a string
|
||||||
return [ParsedText(content=prompt, is_tokens=False)]
|
return [TextPrompt(prompt=prompt)]
|
||||||
|
|
||||||
if isinstance(prompt, list):
|
if isinstance(prompt, list):
|
||||||
if len(prompt) == 0:
|
if len(prompt) == 0:
|
||||||
@ -52,13 +30,11 @@ def parse_and_batch_prompt(
|
|||||||
if is_list_of(prompt, str):
|
if is_list_of(prompt, str):
|
||||||
# case 2: array of strings
|
# case 2: array of strings
|
||||||
prompt = cast(list[str], prompt)
|
prompt = cast(list[str], prompt)
|
||||||
return [
|
return [TextPrompt(prompt=elem) for elem in prompt]
|
||||||
ParsedText(content=elem, is_tokens=False) for elem in prompt
|
|
||||||
]
|
|
||||||
if is_list_of(prompt, int):
|
if is_list_of(prompt, int):
|
||||||
# case 3: array of tokens
|
# case 3: array of tokens
|
||||||
prompt = cast(list[int], prompt)
|
prompt = cast(list[int], prompt)
|
||||||
return [ParsedTokens(content=prompt, is_tokens=True)]
|
return [TokensPrompt(prompt_token_ids=prompt)]
|
||||||
if is_list_of(prompt, list):
|
if is_list_of(prompt, list):
|
||||||
prompt = cast(list[list[int]], prompt)
|
prompt = cast(list[list[int]], prompt)
|
||||||
if len(prompt[0]) == 0:
|
if len(prompt[0]) == 0:
|
||||||
@ -66,10 +42,7 @@ def parse_and_batch_prompt(
|
|||||||
|
|
||||||
if is_list_of(prompt[0], int):
|
if is_list_of(prompt[0], int):
|
||||||
# case 4: array of token arrays
|
# case 4: array of token arrays
|
||||||
return [
|
return [TokensPrompt(prompt_token_ids=elem) for elem in prompt]
|
||||||
ParsedTokens(content=elem, is_tokens=True)
|
|
||||||
for elem in prompt
|
|
||||||
]
|
|
||||||
|
|
||||||
raise TypeError("prompt must be a string, array of strings, "
|
raise TypeError("prompt must be a string, array of strings, "
|
||||||
"array of tokens, or array of token arrays")
|
"array of tokens, or array of token arrays")
|
||||||
@ -99,26 +72,6 @@ ParsedSingletonPrompt = Union[ParsedStrPrompt, ParsedTextPrompt,
|
|||||||
ParsedTokensPrompt, ParsedEmbedsPrompt]
|
ParsedTokensPrompt, ParsedEmbedsPrompt]
|
||||||
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def parse_singleton_prompt(prompt: str) -> ParsedStrPrompt:
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def parse_singleton_prompt(prompt: TextPrompt) -> ParsedTextPrompt:
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def parse_singleton_prompt(prompt: TokensPrompt) -> ParsedTokensPrompt:
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def parse_singleton_prompt(prompt: EmbedsPrompt) -> ParsedEmbedsPrompt:
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
def parse_singleton_prompt(prompt: SingletonPrompt) -> ParsedSingletonPrompt:
|
def parse_singleton_prompt(prompt: SingletonPrompt) -> ParsedSingletonPrompt:
|
||||||
if isinstance(prompt, str):
|
if isinstance(prompt, str):
|
||||||
return ParsedStrPrompt(type="str", content=prompt)
|
return ParsedStrPrompt(type="str", content=prompt)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user