diff --git a/tests/entrypoints/openai/test_token_in_token_out.py b/tests/entrypoints/openai/test_token_in_token_out.py index ed003939c44b..f84605690c53 100644 --- a/tests/entrypoints/openai/test_token_in_token_out.py +++ b/tests/entrypoints/openai/test_token_in_token_out.py @@ -54,7 +54,7 @@ async def test_token_in_token_out_and_logprobs(server): prompt=token_ids, max_tokens=20, temperature=0, - echo=True, + echo=False, extra_body={ "return_token_ids": True, }, diff --git a/tests/test_inputs.py b/tests/test_inputs.py index b61b95bc4333..10a18e2d871f 100644 --- a/tests/test_inputs.py +++ b/tests/test_inputs.py @@ -4,7 +4,7 @@ import pytest from vllm.inputs import zip_enc_dec_prompts -from vllm.inputs.parse import parse_and_batch_prompt +from vllm.inputs.parse import parse_raw_prompts pytestmark = pytest.mark.cpu_test @@ -31,30 +31,30 @@ INPUTS_SLICES = [ ] -def test_parse_single_batch_empty(): +def test_parse_raw_single_batch_empty(): with pytest.raises(ValueError, match="at least one prompt"): - parse_and_batch_prompt([]) + parse_raw_prompts([]) with pytest.raises(ValueError, match="at least one prompt"): - parse_and_batch_prompt([[]]) + parse_raw_prompts([[]]) @pytest.mark.parametrize('string_input', STRING_INPUTS) -def test_parse_single_batch_string_consistent(string_input: str): - assert parse_and_batch_prompt(string_input) \ - == parse_and_batch_prompt([string_input]) +def test_parse_raw_single_batch_string_consistent(string_input: str): + assert parse_raw_prompts(string_input) \ + == parse_raw_prompts([string_input]) @pytest.mark.parametrize('token_input', TOKEN_INPUTS) -def test_parse_single_batch_token_consistent(token_input: list[int]): - assert parse_and_batch_prompt(token_input) \ - == parse_and_batch_prompt([token_input]) +def test_parse_raw_single_batch_token_consistent(token_input: list[int]): + assert parse_raw_prompts(token_input) \ + == parse_raw_prompts([token_input]) @pytest.mark.parametrize('inputs_slice', INPUTS_SLICES) -def test_parse_single_batch_string_slice(inputs_slice: slice): - assert parse_and_batch_prompt(STRING_INPUTS)[inputs_slice] \ - == parse_and_batch_prompt(STRING_INPUTS[inputs_slice]) +def test_parse_raw_single_batch_string_slice(inputs_slice: slice): + assert parse_raw_prompts(STRING_INPUTS)[inputs_slice] \ + == parse_raw_prompts(STRING_INPUTS[inputs_slice]) # yapf: disable diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index d0756e42b796..6e4113e6cf1e 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -691,6 +691,5 @@ class OpenAIServingCompletion(OpenAIServing): truncate_prompt_tokens=request.truncate_prompt_tokens, add_special_tokens=request.add_special_tokens, cache_salt=request.cache_salt, - needs_detokenization=bool(request.echo - and not request.return_token_ids), + needs_detokenization=bool(request.echo), ) diff --git a/vllm/entrypoints/renderer.py b/vllm/entrypoints/renderer.py index d7ce57c728ba..f6fc045a1877 100644 --- a/vllm/entrypoints/renderer.py +++ b/vllm/entrypoints/renderer.py @@ -13,8 +13,9 @@ from pydantic import Field from vllm.config import ModelConfig from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt +from vllm.inputs.data import TextPrompt as EngineTextPrompt from vllm.inputs.data import TokensPrompt as EngineTokensPrompt -from vllm.inputs.parse import parse_and_batch_prompt +from vllm.inputs.parse import get_prompt_components, parse_raw_prompts from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import AsyncMicrobatchTokenizer @@ -41,6 +42,27 @@ class RenderConfig: needs_detokenization: Optional[bool] = False """If True, detokenize IDs back to text for inclusion in outputs.""" + def verify_truncate_prompt_tokens( + self, model_config: ModelConfig) -> Optional[int]: + """Validate and normalize `truncate_prompt_tokens` parameter.""" + truncate_prompt_tokens = self.truncate_prompt_tokens + if truncate_prompt_tokens is None: + return None + + if truncate_prompt_tokens == 0: + return 0 + + if truncate_prompt_tokens < 0: + truncate_prompt_tokens = model_config.max_model_len + + max_length = self.max_length + if max_length is not None and truncate_prompt_tokens > max_length: # type: ignore[operator] + raise ValueError( + f"{truncate_prompt_tokens=} cannot be greater than " + f"{max_length=}. Please select a smaller truncation size.") + + return truncate_prompt_tokens + class BaseRenderer(ABC): """ @@ -74,7 +96,7 @@ class BaseRenderer(ABC): self, *, prompt_or_prompts: Union[str, list[str], list[int], list[list[int]]], - config: "RenderConfig", + config: RenderConfig, ) -> list[EngineTokensPrompt]: """ Convert text or token inputs into engine-ready TokensPrompt objects. @@ -107,7 +129,7 @@ class BaseRenderer(ABC): prompt_or_prompts: Optional[Union[str, list[str], list[int], list[list[int]]]] = None, prompt_embeds: Optional[Union[bytes, list[bytes]]] = None, - config: "RenderConfig", + config: RenderConfig, ) -> list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]: """ Convert text/token and/or base64-encoded embeddings inputs into @@ -189,47 +211,25 @@ class CompletionRenderer(BaseRenderer): self, *, prompt_or_prompts: Union[str, list[str], list[int], list[list[int]]], - config: "RenderConfig", + config: RenderConfig, ) -> list[EngineTokensPrompt]: """Implementation of prompt rendering for completion-style requests. Uses async tokenizer pooling for improved performance. See base class for detailed parameter documentation. """ - truncate_prompt_tokens = self._validate_and_normalize_truncate_tokens( - config.truncate_prompt_tokens, config.max_length) + truncate_prompt_tokens = config.verify_truncate_prompt_tokens( + self.model_config) if truncate_prompt_tokens == 0: return [] - # Parse and batch the input prompts - batch_inputs = parse_and_batch_prompt(prompt_or_prompts) + tasks = (self._create_prompt( + prompt_input, + config=config, + truncate_prompt_tokens=truncate_prompt_tokens, + ) for prompt_input in parse_raw_prompts(prompt_or_prompts)) - tasks = [] - for prompt_input in batch_inputs: - if prompt_input["is_tokens"] is True: - # Token input - # Note: detokenization is needed when echo is enabled, - # where the input token IDs are decoded back to text. - task = self._maybe_detokenize(prompt_input["content"], - config.max_length, - truncate_prompt_tokens, - config.cache_salt, - config.needs_detokenization) - else: - # Text input - task = self._tokenize(prompt_input["content"], - config.max_length, - truncate_prompt_tokens, - config.add_special_tokens, - config.cache_salt) - tasks.append(task) - - # Wait for all text tokenization to finish - if tasks: - tokenized_text_prompts = await asyncio.gather(*tasks) - return tokenized_text_prompts - - return [] + return await asyncio.gather(*tasks) async def render_prompt_and_embeds( self, @@ -237,14 +237,14 @@ class CompletionRenderer(BaseRenderer): prompt_or_prompts: Optional[Union[str, list[str], list[int], list[list[int]]]] = None, prompt_embeds: Optional[Union[bytes, list[bytes]]] = None, - config: "RenderConfig", + config: RenderConfig, ) -> list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]: """ Render text/token prompts and/or precomputed embedding prompts. At least one of `prompt_or_prompts` or `prompt_embeds` must be provided. """ - truncate_prompt_tokens = self._validate_and_normalize_truncate_tokens( - config.truncate_prompt_tokens, config.max_length) + truncate_prompt_tokens = config.verify_truncate_prompt_tokens( + self.model_config) if truncate_prompt_tokens == 0: return [] @@ -265,29 +265,6 @@ class CompletionRenderer(BaseRenderer): return rendered - def _validate_and_normalize_truncate_tokens( - self, - truncate_prompt_tokens: Optional[int], - max_length: Optional[int], - ) -> Optional[int]: - """Validate and normalize truncate_prompt_tokens parameter.""" - if truncate_prompt_tokens is None: - return None - - if truncate_prompt_tokens == 0: - return 0 - - if truncate_prompt_tokens < 0: - truncate_prompt_tokens = self.model_config.max_model_len - - if max_length is not None and truncate_prompt_tokens > max_length: # type: ignore[operator] - raise ValueError( - f"truncate_prompt_tokens ({truncate_prompt_tokens}) " - f"cannot be greater than max_length ({max_length}). " - f"Please select a smaller truncation size.") - - return truncate_prompt_tokens - def _maybe_apply_truncation( self, token_ids: list[int], truncate_prompt_tokens: Optional[int]) -> list[int]: @@ -299,7 +276,38 @@ class CompletionRenderer(BaseRenderer): return token_ids[-truncate_prompt_tokens:] - async def _tokenize( + async def _create_prompt( + self, + prompt_input: Union[EngineTextPrompt, EngineTokensPrompt], + config: RenderConfig, + truncate_prompt_tokens: Optional[int], + ) -> EngineTokensPrompt: + prompt, prompt_token_ids, _ = get_prompt_components(prompt_input) + + if prompt_token_ids is not None: + # NOTE: detokenization is needed when echo is enabled, + # where the input token IDs are decoded back to text. + return await self._create_prompt_from_token_ids( + prompt_token_ids, + config.max_length, + truncate_prompt_tokens, + config.cache_salt, + config.needs_detokenization, + ) + + if prompt is not None: + return await self._create_prompt_from_text( + prompt, + config.max_length, + truncate_prompt_tokens, + config.add_special_tokens, + config.cache_salt, + ) + + # TODO: Also handle embeds prompt using this method + raise NotImplementedError + + async def _create_prompt_from_text( self, text: str, max_length: Optional[int], @@ -330,7 +338,7 @@ class CompletionRenderer(BaseRenderer): return self._create_tokens_prompt(encoded.input_ids, max_length, cache_salt, text) - async def _maybe_detokenize( + async def _create_prompt_from_token_ids( self, token_ids: list[int], max_length: Optional[int], @@ -343,7 +351,7 @@ class CompletionRenderer(BaseRenderer): truncate_prompt_tokens) prompt = None - if needs_detokenization is True: + if needs_detokenization: async_tokenizer = self._get_async_tokenizer() prompt = await async_tokenizer.decode(token_ids) diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py index 123c81173120..f93817bd463d 100644 --- a/vllm/inputs/parse.py +++ b/vllm/inputs/parse.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Sequence from typing import (TYPE_CHECKING, Literal, NamedTuple, Optional, TypedDict, - Union, cast, overload) + Union, cast) from typing_extensions import TypeIs @@ -16,34 +16,12 @@ if TYPE_CHECKING: import torch -class ParsedText(TypedDict): - content: str - is_tokens: Literal[False] - - -class ParsedTokens(TypedDict): - content: list[int] - is_tokens: Literal[True] - - -@overload -def parse_and_batch_prompt( - prompt: Union[str, list[str]], ) -> Sequence[ParsedText]: - ... - - -@overload -def parse_and_batch_prompt( - prompt: Union[list[int], list[list[int]]], ) -> Sequence[ParsedTokens]: - ... - - -def parse_and_batch_prompt( +def parse_raw_prompts( prompt: Union[str, list[str], list[int], list[list[int]]], -) -> Union[Sequence[ParsedText], Sequence[ParsedTokens]]: +) -> Union[Sequence[TextPrompt], Sequence[TokensPrompt]]: if isinstance(prompt, str): # case 1: a string - return [ParsedText(content=prompt, is_tokens=False)] + return [TextPrompt(prompt=prompt)] if isinstance(prompt, list): if len(prompt) == 0: @@ -52,13 +30,11 @@ def parse_and_batch_prompt( if is_list_of(prompt, str): # case 2: array of strings prompt = cast(list[str], prompt) - return [ - ParsedText(content=elem, is_tokens=False) for elem in prompt - ] + return [TextPrompt(prompt=elem) for elem in prompt] if is_list_of(prompt, int): # case 3: array of tokens prompt = cast(list[int], prompt) - return [ParsedTokens(content=prompt, is_tokens=True)] + return [TokensPrompt(prompt_token_ids=prompt)] if is_list_of(prompt, list): prompt = cast(list[list[int]], prompt) if len(prompt[0]) == 0: @@ -66,10 +42,7 @@ def parse_and_batch_prompt( if is_list_of(prompt[0], int): # case 4: array of token arrays - return [ - ParsedTokens(content=elem, is_tokens=True) - for elem in prompt - ] + return [TokensPrompt(prompt_token_ids=elem) for elem in prompt] raise TypeError("prompt must be a string, array of strings, " "array of tokens, or array of token arrays") @@ -99,26 +72,6 @@ ParsedSingletonPrompt = Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt, ParsedEmbedsPrompt] -@overload -def parse_singleton_prompt(prompt: str) -> ParsedStrPrompt: - ... - - -@overload -def parse_singleton_prompt(prompt: TextPrompt) -> ParsedTextPrompt: - ... - - -@overload -def parse_singleton_prompt(prompt: TokensPrompt) -> ParsedTokensPrompt: - ... - - -@overload -def parse_singleton_prompt(prompt: EmbedsPrompt) -> ParsedEmbedsPrompt: - ... - - def parse_singleton_prompt(prompt: SingletonPrompt) -> ParsedSingletonPrompt: if isinstance(prompt, str): return ParsedStrPrompt(type="str", content=prompt)