Truncation control for embedding models (#14776)

Signed-off-by: Gabriel Marinho <gmarinho@ibm.com>
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Co-authored-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
Gabriel Marinho 2025-04-29 22:24:57 -03:00 committed by GitHub
parent 4055130a85
commit 1c2bc7ead0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 333 additions and 71 deletions

View File

@ -0,0 +1,103 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Any
import openai
import pytest
import pytest_asyncio
from tests.utils import RemoteOpenAIServer
MODEL_NAME = "sentence-transformers/all-MiniLM-L12-v2"
max_model_len = 128
input = """Immerse yourself in the enchanting chronicle of calculus, a
mathematical domain that has radically transformed our comprehension of
change and motion. Despite its roots in ancient civilizations, the
formal birth of calculus predominantly occurred in the 17th century,
primarily under the influential guidance of Sir Isaac Newton and Gottfried
Wilhelm Leibniz. The earliest traces of calculus concepts are found in
ancient Greek mathematics,most notably in the works of Eudoxus and
Archimedes, around 300 BCE. They utilized the 'method of exhaustion'a
technique for computing areas and volumes through the use of finite sums.
This methodology laid crucial foundational work for integral calculus.
In the 17th century, both Newton and Leibniz independently pioneered
calculus, each contributing unique perspectives that would shape this new
field."""
@pytest.fixture(scope="module")
def server():
args = [
"--task",
"embed",
"--dtype",
"bfloat16",
"--enforce-eager",
"--max-model-len",
str(max_model_len),
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def client(server):
async with server.get_async_client() as async_client:
yield async_client
@pytest.mark.asyncio
async def test_smaller_truncation_size(client: openai.AsyncOpenAI):
truncation_size = 10
kwargs: dict[str, Any] = {
"model": MODEL_NAME,
"input": input,
"truncate_prompt_tokens": truncation_size
}
response = await client.post(path="embeddings",
cast_to=object,
body={**kwargs})
assert response["usage"]["prompt_tokens"] == truncation_size
@pytest.mark.asyncio
async def test_bigger_truncation_size(client: openai.AsyncOpenAI):
truncation_size = max_model_len + 1
kwargs: dict[str, Any] = {
"model": MODEL_NAME,
"input": input,
"truncate_prompt_tokens": truncation_size
}
with pytest.raises(openai.BadRequestError) as err:
err = await client.post(path="embeddings",
cast_to=object,
body={**kwargs})
assert str(err) == f"""openai.BadRequestError:
Error code: 400 - {{'object': 'error',
'message': 'truncate_prompt_tokens value
({truncation_size})
is greater than max_model_len ({max_model_len}).
Please, select a smaller truncation size.',
'type': 'BadRequestError',
'param': None, 'code': 400}}"""
@pytest.mark.asyncio
async def test_max_truncation_size(client: openai.AsyncOpenAI):
truncation_size = -1
kwargs: dict[str, Any] = {
"model": MODEL_NAME,
"input": input,
"truncate_prompt_tokens": truncation_size
}
response = await client.post(path="embeddings",
cast_to=object,
body={**kwargs})
assert response["usage"]["prompt_tokens"] == max_model_len

View File

@ -0,0 +1,69 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
MODEL_NAME = "sentence-transformers/all-MiniLM-L12-v2"
max_model_len = 128
input_str = """Immerse yourself in the enchanting chronicle of calculus, a
mathematical domain that has radically transformed our comprehension of
change and motion. Despite its roots in ancient civilizations, the
formal birth of calculus predominantly occurred in the 17th century,
primarily under the influential guidance of Sir Isaac Newton and Gottfried
Wilhelm Leibniz. The earliest traces of calculus concepts are found in
ancient Greek mathematics,most notably in the works of Eudoxus and
Archimedes, around 300 BCE. They utilized the 'method of exhaustion'a
technique for computing areas and volumes through the use of finite sums.
This methodology laid crucial foundational work for integral calculus.
In the 17th century, both Newton and Leibniz independently pioneered
calculus, each contributing unique perspectives that would shape this new
field."""
def test_smaller_truncation_size(vllm_runner,
model_name=MODEL_NAME,
input_str=input_str):
truncate_prompt_tokens = 10
with vllm_runner(model_name, task="embed",
max_model_len=max_model_len) as vllm_model:
vllm_output = vllm_model.model.encode(
input_str, truncate_prompt_tokens=truncate_prompt_tokens)
prompt_tokens = vllm_output[0].prompt_token_ids
assert len(prompt_tokens) == truncate_prompt_tokens
def test_max_truncation_size(vllm_runner,
model_name=MODEL_NAME,
input_str=input_str):
truncate_prompt_tokens = -1
with vllm_runner(model_name, task="embed",
max_model_len=max_model_len) as vllm_model:
vllm_output = vllm_model.model.encode(
input_str, truncate_prompt_tokens=truncate_prompt_tokens)
prompt_tokens = vllm_output[0].prompt_token_ids
assert len(prompt_tokens) == max_model_len
def test_bigger_truncation_size(vllm_runner,
model_name=MODEL_NAME,
input_str=input_str):
truncate_prompt_tokens = max_model_len + 1
with pytest.raises(ValueError), vllm_runner(
model_name, task="embed",
max_model_len=max_model_len) as vllm_model:
llm_output = vllm_model.model.encode(
input_str, truncate_prompt_tokens=truncate_prompt_tokens)
assert llm_output == f"""truncate_prompt_tokens value
({truncate_prompt_tokens}) is greater than
max_model_len ({max_model_len}). Please, select
a smaller truncation size."""

View File

@ -645,6 +645,7 @@ class LLMEngine:
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
@ -678,6 +679,7 @@ class LLMEngine:
params: Optional[Union[SamplingParams, PoolingParams]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
@ -758,6 +760,7 @@ class LLMEngine:
processed_inputs = self.input_preprocessor.preprocess(
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)

View File

@ -2,7 +2,7 @@
import asyncio
from abc import ABC, abstractmethod
from typing import AsyncGenerator, List, Mapping, Optional
from typing import AsyncGenerator, Mapping, Optional
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
from vllm.config import DecodingConfig, ModelConfig, VllmConfig
@ -256,7 +256,7 @@ class EngineClient(ABC):
async def do_log_stats(
self,
scheduler_outputs: Optional[SchedulerOutputs] = None,
model_output: Optional[List[SamplerOutput]] = None,
model_output: Optional[list[SamplerOutput]] = None,
) -> None:
...

View File

@ -25,6 +25,7 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
resolve_chat_template_content_format)
from vllm.entrypoints.score_utils import (_cosine_similarity,
_validate_score_input_lens)
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
from vllm.inputs.parse import is_token_prompt, parse_and_batch_prompt
from vllm.logger import init_logger
@ -793,6 +794,7 @@ class LLM:
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
*,
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
@ -807,6 +809,7 @@ class LLM:
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
prompt_token_ids: Optional[list[int]] = None,
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
@ -821,6 +824,7 @@ class LLM:
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
prompt_token_ids: Optional[list[list[int]]] = None,
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
@ -836,6 +840,7 @@ class LLM:
Sequence[PoolingParams]]] = None,
*,
prompt_token_ids: list[int],
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
@ -851,6 +856,7 @@ class LLM:
Sequence[PoolingParams]]] = None,
*,
prompt_token_ids: list[list[int]],
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
@ -864,6 +870,7 @@ class LLM:
prompts: None,
pooling_params: None,
prompt_token_ids: Union[list[int], list[list[int]]],
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
@ -882,6 +889,7 @@ class LLM:
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
@ -946,10 +954,15 @@ class LLM:
for pooling_param in pooling_params:
pooling_param.verify(self.llm_engine.model_config)
tokenization_kwargs: dict[str, Any] = {}
_validate_truncation_size(self.llm_engine.model_config.max_model_len,
truncate_prompt_tokens, tokenization_kwargs)
self._validate_and_add_requests(
prompts=parsed_prompts,
params=pooling_params,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
prompt_adapter_request=prompt_adapter_request,
)
@ -962,6 +975,7 @@ class LLM:
prompts: Union[PromptType, Sequence[PromptType]],
/,
*,
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: bool = True,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
@ -995,6 +1009,7 @@ class LLM:
"Embedding API is only enabled for `--task embed`")
items = self.encode(prompts,
truncate_prompt_tokens=truncate_prompt_tokens,
use_tqdm=use_tqdm,
pooling_params=pooling_params,
lora_request=lora_request,
@ -1055,6 +1070,7 @@ class LLM:
encoded_output: list[PoolingRequestOutput] = self.encode(
text_1 + text_2,
truncate_prompt_tokens=truncate_prompt_tokens,
use_tqdm=use_tqdm,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
@ -1098,9 +1114,8 @@ class LLM:
pooling_params = PoolingParams()
tokenization_kwargs: dict[str, Any] = {}
if truncate_prompt_tokens is not None:
tokenization_kwargs["truncation"] = True
tokenization_kwargs["max_length"] = truncate_prompt_tokens
_validate_truncation_size(self.llm_engine.model_config.max_model_len,
truncate_prompt_tokens, tokenization_kwargs)
parsed_prompts = []
@ -1323,6 +1338,7 @@ class LLM:
Sequence[PoolingParams]],
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
prompt_adapter_request: Optional[PromptAdapterRequest],
tokenization_kwargs: Optional[dict[str, Any]] = None,
guided_options: Optional[GuidedDecodingRequest] = None,
priority: Optional[list[int]] = None,
) -> None:
@ -1359,6 +1375,7 @@ class LLM:
self._add_request(
prompt,
params[i] if isinstance(params, Sequence) else params,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request[i] if isinstance(
lora_request, Sequence) else lora_request,
prompt_adapter_request=prompt_adapter_request,
@ -1369,6 +1386,7 @@ class LLM:
self,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
@ -1379,6 +1397,7 @@ class LLM:
prompt,
params,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
)

View File

@ -1014,7 +1014,7 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
encoding_format: Literal["float", "base64"] = "float"
dimensions: Optional[int] = None
user: Optional[str] = None
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
# doc: begin-embedding-pooling-params
additional_data: Optional[Any] = None
@ -1049,7 +1049,7 @@ class EmbeddingChatRequest(OpenAIBaseModel):
encoding_format: Literal["float", "base64"] = "float"
dimensions: Optional[int] = None
user: Optional[str] = None
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
# doc: begin-chat-embedding-pooling-params
additional_data: Optional[Any] = None
@ -1116,7 +1116,7 @@ class ScoreRequest(OpenAIBaseModel):
model: Optional[str] = None
text_1: Union[list[str], str]
text_2: Union[list[str], str]
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
# doc: begin-score-pooling-params
additional_data: Optional[Any] = None
@ -1142,7 +1142,7 @@ class RerankRequest(OpenAIBaseModel):
query: str
documents: list[str]
top_n: int = Field(default_factory=lambda: 0)
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
# doc: begin-rerank-pooling-params
additional_data: Optional[Any] = None

View File

@ -21,6 +21,7 @@ from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest,
ErrorResponse, UsageInfo)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.logger import init_logger
from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput,
PoolingRequestOutput)
@ -85,16 +86,7 @@ class OpenAIServingEmbedding(OpenAIServing):
request_id = f"embd-{self._base_request_id(raw_request)}"
created_time = int(time.time())
truncate_prompt_tokens = None
if request.truncate_prompt_tokens is not None:
if request.truncate_prompt_tokens <= self.max_model_len:
truncate_prompt_tokens = request.truncate_prompt_tokens
else:
return self.create_error_response(
"truncate_prompt_tokens value is "
"greater than max_model_len."
" Please, select a smaller truncation size.")
truncate_prompt_tokens = request.truncate_prompt_tokens
pooling_params = request.to_pooling_params()
@ -104,6 +96,8 @@ class OpenAIServingEmbedding(OpenAIServing):
return self.create_error_response(str(e))
try:
truncate_prompt_tokens = _validate_truncation_size(
self.max_model_len, truncate_prompt_tokens)
(
lora_request,
prompt_adapter_request,

View File

@ -173,7 +173,7 @@ class OpenAIServing:
request: AnyRequest,
tokenizer: AnyTokenizer,
prompt: str,
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]],
add_special_tokens: bool,
) -> TextTokensPrompt:
if (self.model_config.encoder_config is not None
@ -271,7 +271,7 @@ class OpenAIServing:
request: AnyRequest,
tokenizer: AnyTokenizer,
prompt_input: Union[str, list[int]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
add_special_tokens: bool = True,
) -> TextTokensPrompt:
"""
@ -292,7 +292,7 @@ class OpenAIServing:
request: AnyRequest,
tokenizer: AnyTokenizer,
prompt_inputs: Iterable[Union[str, list[int]]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
add_special_tokens: bool = True,
) -> Iterator[TextTokensPrompt]:
"""
@ -321,7 +321,7 @@ class OpenAIServing:
request: AnyRequest,
tokenizer: AnyTokenizer,
input_or_inputs: Union[str, list[str], list[int], list[list[int]]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
add_special_tokens: bool = True,
) -> list[TextTokensPrompt]:
"""
@ -356,7 +356,7 @@ class OpenAIServing:
request: CompletionLikeRequest,
tokenizer: AnyTokenizer,
input_or_inputs: Union[str, list[str], list[int], list[list[int]]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
add_special_tokens: bool = True,
) -> tuple[list[TextTokensPrompt], list[TokensPrompt]]:
request_prompts = await self._tokenize_prompt_input_or_inputs_async(

View File

@ -21,6 +21,7 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse,
PoolingResponseData, UsageInfo)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.logger import init_logger
from vllm.outputs import PoolingOutput, PoolingRequestOutput
from vllm.utils import merge_async_iterators
@ -85,18 +86,11 @@ class OpenAIServingPooling(OpenAIServing):
request_id = f"pool-{self._base_request_id(raw_request)}"
created_time = int(time.time())
truncate_prompt_tokens = None
if request.truncate_prompt_tokens is not None:
if request.truncate_prompt_tokens <= self.max_model_len:
truncate_prompt_tokens = request.truncate_prompt_tokens
else:
return self.create_error_response(
"truncate_prompt_tokens value is "
"greater than max_model_len."
" Please, select a smaller truncation size.")
truncate_prompt_tokens = request.truncate_prompt_tokens
try:
truncate_prompt_tokens = _validate_truncation_size(
self.max_model_len, truncate_prompt_tokens)
(
lora_request,
prompt_adapter_request,

View File

@ -18,6 +18,7 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.score_utils import (_cosine_similarity,
_validate_score_input_lens)
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@ -231,11 +232,6 @@ class ServingScores(OpenAIServing):
truncate_prompt_tokens: Optional[int] = None,
) -> list[PoolingRequestOutput]:
tokenization_kwargs: dict[str, Any] = {}
if truncate_prompt_tokens is not None:
tokenization_kwargs["truncation"] = True
tokenization_kwargs["max_length"] = truncate_prompt_tokens
(
lora_request,
prompt_adapter_request,
@ -247,12 +243,9 @@ class ServingScores(OpenAIServing):
tokenizer = await self.engine_client.get_tokenizer(lora_request)
if truncate_prompt_tokens is not None and \
truncate_prompt_tokens > self.max_model_len:
raise ValueError(
f"truncate_prompt_tokens value ({truncate_prompt_tokens}) "
f"is greater than max_model_len ({self.max_model_len})."
f" Please, select a smaller truncation size.")
tokenization_kwargs: dict[str, Any] = {}
_validate_truncation_size(self.max_model_len, truncate_prompt_tokens,
tokenization_kwargs)
trace_headers = (None if raw_request is None else await
self._get_trace_headers(raw_request.headers))

View File

@ -46,4 +46,4 @@ def _validate_score_input_lens(
if len(texts_1) == 0:
raise ValueError("At least one text element must be given")
if len(texts_2) == 0:
raise ValueError("At least one text_pair element must be given")
raise ValueError("At least one text_pair element must be given")

View File

@ -3,6 +3,7 @@
import asyncio
import functools
import os
from typing import Any, Optional
from fastapi import Request
from fastapi.responses import JSONResponse, StreamingResponse
@ -134,3 +135,26 @@ def cli_env_setup():
if "VLLM_WORKER_MULTIPROC_METHOD" not in os.environ:
logger.debug("Setting VLLM_WORKER_MULTIPROC_METHOD to 'spawn'")
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
def _validate_truncation_size(
max_model_len: int,
truncate_prompt_tokens: Optional[int],
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> Optional[int]:
if truncate_prompt_tokens is not None:
if truncate_prompt_tokens <= -1:
truncate_prompt_tokens = max_model_len
if truncate_prompt_tokens > max_model_len:
raise ValueError(
f"truncate_prompt_tokens value ({truncate_prompt_tokens}) "
f"is greater than max_model_len ({max_model_len})."
f" Please, select a smaller truncation size.")
if tokenization_kwargs is not None:
tokenization_kwargs["truncation"] = True
tokenization_kwargs["max_length"] = truncate_prompt_tokens
return truncate_prompt_tokens

View File

@ -2,7 +2,7 @@
import asyncio
from collections.abc import Mapping
from typing import Optional, Union, cast
from typing import Any, Optional, Union, cast
from typing_extensions import assert_never
@ -183,18 +183,21 @@ class InputPreprocessor:
self,
prompt: str,
lora_request: Optional[LoRARequest],
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[int]:
"""
Apply the model's tokenizer to a text prompt, returning the
corresponding token IDs.
"""
tokenizer = self.get_tokenizer_group()
add_special_tokens = None
if tokenization_kwargs is None:
tokenization_kwargs = {}
if self.model_config.hf_config.model_type == "whisper":
# For Whisper, special tokens should be provided by the user based
# on the task and language of their request. Also needed to avoid
# appending an EOS token to the prompt which disrupts generation.
add_special_tokens = False
tokenization_kwargs["add_special_tokens"] = False
if (self.model_config.encoder_config is not None
and self.model_config.encoder_config.get(
@ -203,25 +206,27 @@ class InputPreprocessor:
return tokenizer.encode(prompt=prompt,
lora_request=lora_request,
add_special_tokens=add_special_tokens)
**tokenization_kwargs)
async def _tokenize_prompt_async(
self,
prompt: str,
lora_request: Optional[LoRARequest],
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[int]:
"""Async version of :meth:`_tokenize_prompt`."""
tokenizer = self.get_tokenizer_group()
add_special_tokens = None
if tokenization_kwargs is None:
tokenization_kwargs = {}
if self.model_config.hf_config.model_type == "whisper":
# For Whisper, special tokens should be provided by the user based
# on the task and language of their request. Also needed to avoid
# appending an EOS token to the prompt which disrupts generation.
add_special_tokens = False
return await tokenizer.encode_async(
prompt=prompt,
lora_request=lora_request,
add_special_tokens=add_special_tokens)
tokenization_kwargs["add_special_tokens"] = False
return await tokenizer.encode_async(prompt=prompt,
lora_request=lora_request,
**tokenization_kwargs)
def _process_multimodal(
self,
@ -281,6 +286,7 @@ class InputPreprocessor:
def _prompt_to_llm_inputs(
self,
prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False,
) -> SingletonInputs:
@ -304,6 +310,7 @@ class InputPreprocessor:
prompt_token_ids = self._tokenize_prompt(
prompt_text,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
)
return token_inputs(
@ -352,6 +359,7 @@ class InputPreprocessor:
prompt_token_ids = self._tokenize_prompt(
prompt_text,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
)
return token_inputs(
@ -364,6 +372,7 @@ class InputPreprocessor:
async def _prompt_to_llm_inputs_async(
self,
prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False,
) -> SingletonInputs:
@ -375,6 +384,7 @@ class InputPreprocessor:
prompt_token_ids = await self._tokenize_prompt_async(
prompt_text,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
)
return token_inputs(
@ -517,6 +527,7 @@ class InputPreprocessor:
def _process_encoder_decoder_prompt(
self,
prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> EncoderDecoderInputs:
"""
For encoder/decoder models only:
@ -553,7 +564,9 @@ class InputPreprocessor:
if is_explicit_encoder_decoder_prompt(prompt):
encoder_inputs = self._prompt_to_llm_inputs(
prompt["encoder_prompt"])
prompt["encoder_prompt"],
tokenization_kwargs=tokenization_kwargs,
)
if (decoder_input := prompt["decoder_prompt"]) is None:
decoder_inputs = None
else:
@ -565,7 +578,10 @@ class InputPreprocessor:
self._separate_enc_dec_inputs_from_mm_processor_outputs(
encoder_inputs, decoder_inputs))
else:
inputs = self._prompt_to_llm_inputs(prompt)
inputs = self._prompt_to_llm_inputs(
prompt,
tokenization_kwargs=tokenization_kwargs,
)
if self.model_config.is_multimodal_model:
# Encoder-Decoder Multimodal model
encoder_inputs, decoder_inputs = (
@ -581,6 +597,7 @@ class InputPreprocessor:
async def _process_encoder_decoder_prompt_async(
self,
prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> EncoderDecoderInputs:
"""Async version of :meth:`_process_encoder_decoder_prompt`."""
encoder_inputs: SingletonInputs
@ -588,13 +605,18 @@ class InputPreprocessor:
if is_explicit_encoder_decoder_prompt(prompt):
encoder_task = self._prompt_to_llm_inputs_async(
prompt["encoder_prompt"])
prompt["encoder_prompt"],
tokenization_kwargs=tokenization_kwargs,
)
if (decoder_input := prompt["decoder_prompt"]) is None:
encoder_inputs = await encoder_task
decoder_inputs = None
else:
decoder_task = self._prompt_to_llm_inputs_async(decoder_input)
decoder_task = self._prompt_to_llm_inputs_async(
decoder_input,
tokenization_kwargs=tokenization_kwargs,
)
encoder_inputs, decoder_inputs = await asyncio.gather(
encoder_task, decoder_task)
@ -606,7 +628,10 @@ class InputPreprocessor:
self._separate_enc_dec_inputs_from_mm_processor_outputs(
encoder_inputs, decoder_inputs))
else:
inputs = await self._prompt_to_llm_inputs_async(prompt)
inputs = await self._prompt_to_llm_inputs_async(
prompt,
tokenization_kwargs=tokenization_kwargs,
)
if self.model_config.is_multimodal_model:
# Encoder-Decoder Multimodal model
encoder_inputs, decoder_inputs = (
@ -638,6 +663,7 @@ class InputPreprocessor:
def _process_decoder_only_prompt(
self,
prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
return_mm_hashes: bool = False,
@ -660,6 +686,7 @@ class InputPreprocessor:
prompt_comps = self._prompt_to_llm_inputs(
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
)
@ -672,6 +699,7 @@ class InputPreprocessor:
async def _process_decoder_only_prompt_async(
self,
prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
return_mm_hashes: bool = False,
@ -679,6 +707,7 @@ class InputPreprocessor:
"""Async version of :meth:`_process_decoder_only_prompt`."""
prompt_comps = await self._prompt_to_llm_inputs_async(
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
)
@ -691,6 +720,7 @@ class InputPreprocessor:
def preprocess(
self,
prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
return_mm_hashes: bool = False,
@ -711,6 +741,7 @@ class InputPreprocessor:
# Decoder-only operation
return self._process_decoder_only_prompt(
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
return_mm_hashes=return_mm_hashes,
@ -719,6 +750,7 @@ class InputPreprocessor:
async def preprocess_async(
self,
prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
return_mm_hashes: bool = False,
@ -739,6 +771,7 @@ class InputPreprocessor:
# Decoder-only operation
return await self._process_decoder_only_prompt_async(
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
return_mm_hashes=return_mm_hashes,

View File

@ -186,9 +186,10 @@ class SamplingParams(
logits_processors: list of functions that modify logits based on
previously generated tokens, and optionally prompt tokens as
a first argument.
truncate_prompt_tokens: If set to an integer k, will use only the last k
tokens from the prompt (i.e., left truncation). Defaults to None
(i.e., no truncation).
truncate_prompt_tokens: If set to -1, will use the truncation size
supported by the model. If set to an integer k, will use only
the last k tokens from the prompt (i.e., left truncation).
Defaults to None (i.e., no truncation).
guided_decoding: If provided, the engine will construct a guided
decoding logits processor from these parameters. Defaults to None.
logit_bias: If provided, the engine will construct a logits processor

View File

@ -55,6 +55,8 @@ def encode_tokens(
tokenizer: AnyTokenizer,
text: str,
*,
truncation: Optional[bool] = None,
max_length: Optional[int] = None,
add_special_tokens: Optional[bool] = None,
) -> list[int]:
"""
@ -64,10 +66,18 @@ def encode_tokens(
:code:`add_special_tokens=None` means to use the backend's default
settings.
"""
if add_special_tokens is not None:
return tokenizer.encode(text, add_special_tokens=add_special_tokens)
return tokenizer.encode(text)
kw_args: dict[str, Any] = {}
if max_length is not None:
kw_args["max_length"] = max_length
if truncation is not None:
kw_args["truncation"] = truncation
if add_special_tokens is not None:
kw_args["add_special_tokens"] = add_special_tokens
return tokenizer.encode(text, **kw_args)
def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:

View File

@ -94,6 +94,8 @@ class TokenizerBase(ABC):
@abstractmethod
def encode(self,
text: str,
truncation: Optional[bool] = None,
max_length: Optional[int] = None,
add_special_tokens: Optional[bool] = None) -> list[int]:
raise NotImplementedError()

View File

@ -45,11 +45,16 @@ class TokenizerGroup:
def encode(self,
prompt: str,
max_length: Optional[int] = None,
truncation: Optional[bool] = None,
lora_request: Optional[LoRARequest] = None,
add_special_tokens: Optional[bool] = None) -> List[int]:
tokenizer = self.get_lora_tokenizer(lora_request)
ret = encode_tokens(tokenizer,
prompt,
max_length=max_length,
truncation=truncation,
add_special_tokens=add_special_tokens)
self._raise_if_input_too_long(ret, lora_request)
return ret
@ -57,11 +62,15 @@ class TokenizerGroup:
async def encode_async(
self,
prompt: str,
max_length: Optional[int] = None,
truncation: Optional[bool] = None,
lora_request: Optional[LoRARequest] = None,
add_special_tokens: Optional[bool] = None) -> List[int]:
tokenizer = await self.get_lora_tokenizer_async(lora_request)
ret = encode_tokens(tokenizer,
prompt,
max_length=max_length,
truncation=truncation,
add_special_tokens=add_special_tokens)
self._raise_if_input_too_long(ret, lora_request)
return ret

View File

@ -359,6 +359,8 @@ class MistralTokenizer(TokenizerBase):
def encode(self,
text: str,
truncation: Optional[bool] = None,
max_length: Optional[int] = None,
add_special_tokens: Optional[bool] = None) -> List[int]:
# `encode` should only be used for prompt completion
# it should never be used for chat_completion.

View File

@ -2,7 +2,7 @@
import asyncio
from collections.abc import AsyncGenerator, Mapping
from copy import copy
from typing import Optional, Union
from typing import Any, Optional, Union
import numpy as np
@ -201,6 +201,7 @@ class AsyncLLM(EngineClient):
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
@ -219,7 +220,8 @@ class AsyncLLM(EngineClient):
# Convert Input --> Request.
prompt_str, request = self.processor.process_inputs(
request_id, prompt, params, arrival_time, lora_request,
trace_headers, prompt_adapter_request, priority)
tokenization_kwargs, trace_headers, prompt_adapter_request,
priority)
if params.n == 1:
await self._add_request(request, prompt_str, None, 0, queue)

View File

@ -175,6 +175,7 @@ class LLMEngine:
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
@ -182,7 +183,8 @@ class LLMEngine:
# Process raw inputs into the request.
prompt_str, request = self.processor.process_inputs(
request_id, prompt, params, arrival_time, lora_request,
trace_headers, prompt_adapter_request, priority)
tokenization_kwargs, trace_headers, prompt_adapter_request,
priority)
n = params.n if isinstance(params, SamplingParams) else 1

View File

@ -2,7 +2,7 @@
import time
from collections.abc import Mapping, Sequence
from typing import Literal, Optional, Union
from typing import Any, Literal, Optional, Union
from vllm.config import VllmConfig
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
@ -198,6 +198,7 @@ class Processor:
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
@ -224,6 +225,7 @@ class Processor:
# 3. Apply prompt adapter to prompt token ids if one exists.
processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
return_mm_hashes=self.use_hash,