mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:15:51 +08:00
Truncation control for embedding models (#14776)
Signed-off-by: Gabriel Marinho <gmarinho@ibm.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Co-authored-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
parent
4055130a85
commit
1c2bc7ead0
103
tests/entrypoints/openai/test_truncation.py
Normal file
103
tests/entrypoints/openai/test_truncation.py
Normal file
@ -0,0 +1,103 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Any
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "sentence-transformers/all-MiniLM-L12-v2"
|
||||
max_model_len = 128
|
||||
|
||||
input = """Immerse yourself in the enchanting chronicle of calculus, a
|
||||
mathematical domain that has radically transformed our comprehension of
|
||||
change and motion. Despite its roots in ancient civilizations, the
|
||||
formal birth of calculus predominantly occurred in the 17th century,
|
||||
primarily under the influential guidance of Sir Isaac Newton and Gottfried
|
||||
Wilhelm Leibniz. The earliest traces of calculus concepts are found in
|
||||
ancient Greek mathematics,most notably in the works of Eudoxus and
|
||||
Archimedes, around 300 BCE. They utilized the 'method of exhaustion'—a
|
||||
technique for computing areas and volumes through the use of finite sums.
|
||||
This methodology laid crucial foundational work for integral calculus.
|
||||
In the 17th century, both Newton and Leibniz independently pioneered
|
||||
calculus, each contributing unique perspectives that would shape this new
|
||||
field."""
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args = [
|
||||
"--task",
|
||||
"embed",
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--enforce-eager",
|
||||
"--max-model-len",
|
||||
str(max_model_len),
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server):
|
||||
async with server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smaller_truncation_size(client: openai.AsyncOpenAI):
|
||||
truncation_size = 10
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": MODEL_NAME,
|
||||
"input": input,
|
||||
"truncate_prompt_tokens": truncation_size
|
||||
}
|
||||
|
||||
response = await client.post(path="embeddings",
|
||||
cast_to=object,
|
||||
body={**kwargs})
|
||||
|
||||
assert response["usage"]["prompt_tokens"] == truncation_size
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bigger_truncation_size(client: openai.AsyncOpenAI):
|
||||
truncation_size = max_model_len + 1
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": MODEL_NAME,
|
||||
"input": input,
|
||||
"truncate_prompt_tokens": truncation_size
|
||||
}
|
||||
|
||||
with pytest.raises(openai.BadRequestError) as err:
|
||||
err = await client.post(path="embeddings",
|
||||
cast_to=object,
|
||||
body={**kwargs})
|
||||
|
||||
assert str(err) == f"""openai.BadRequestError:
|
||||
Error code: 400 - {{'object': 'error',
|
||||
'message': 'truncate_prompt_tokens value
|
||||
({truncation_size})
|
||||
is greater than max_model_len ({max_model_len}).
|
||||
Please, select a smaller truncation size.',
|
||||
'type': 'BadRequestError',
|
||||
'param': None, 'code': 400}}"""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_truncation_size(client: openai.AsyncOpenAI):
|
||||
truncation_size = -1
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": MODEL_NAME,
|
||||
"input": input,
|
||||
"truncate_prompt_tokens": truncation_size
|
||||
}
|
||||
|
||||
response = await client.post(path="embeddings",
|
||||
cast_to=object,
|
||||
body={**kwargs})
|
||||
|
||||
assert response["usage"]["prompt_tokens"] == max_model_len
|
||||
69
tests/models/embedding/language/test_truncation_control.py
Normal file
69
tests/models/embedding/language/test_truncation_control.py
Normal file
@ -0,0 +1,69 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import pytest
|
||||
|
||||
MODEL_NAME = "sentence-transformers/all-MiniLM-L12-v2"
|
||||
max_model_len = 128
|
||||
|
||||
input_str = """Immerse yourself in the enchanting chronicle of calculus, a
|
||||
mathematical domain that has radically transformed our comprehension of
|
||||
change and motion. Despite its roots in ancient civilizations, the
|
||||
formal birth of calculus predominantly occurred in the 17th century,
|
||||
primarily under the influential guidance of Sir Isaac Newton and Gottfried
|
||||
Wilhelm Leibniz. The earliest traces of calculus concepts are found in
|
||||
ancient Greek mathematics,most notably in the works of Eudoxus and
|
||||
Archimedes, around 300 BCE. They utilized the 'method of exhaustion'—a
|
||||
technique for computing areas and volumes through the use of finite sums.
|
||||
This methodology laid crucial foundational work for integral calculus.
|
||||
In the 17th century, both Newton and Leibniz independently pioneered
|
||||
calculus, each contributing unique perspectives that would shape this new
|
||||
field."""
|
||||
|
||||
|
||||
def test_smaller_truncation_size(vllm_runner,
|
||||
model_name=MODEL_NAME,
|
||||
input_str=input_str):
|
||||
|
||||
truncate_prompt_tokens = 10
|
||||
|
||||
with vllm_runner(model_name, task="embed",
|
||||
max_model_len=max_model_len) as vllm_model:
|
||||
vllm_output = vllm_model.model.encode(
|
||||
input_str, truncate_prompt_tokens=truncate_prompt_tokens)
|
||||
|
||||
prompt_tokens = vllm_output[0].prompt_token_ids
|
||||
|
||||
assert len(prompt_tokens) == truncate_prompt_tokens
|
||||
|
||||
|
||||
def test_max_truncation_size(vllm_runner,
|
||||
model_name=MODEL_NAME,
|
||||
input_str=input_str):
|
||||
truncate_prompt_tokens = -1
|
||||
|
||||
with vllm_runner(model_name, task="embed",
|
||||
max_model_len=max_model_len) as vllm_model:
|
||||
vllm_output = vllm_model.model.encode(
|
||||
input_str, truncate_prompt_tokens=truncate_prompt_tokens)
|
||||
|
||||
prompt_tokens = vllm_output[0].prompt_token_ids
|
||||
|
||||
assert len(prompt_tokens) == max_model_len
|
||||
|
||||
|
||||
def test_bigger_truncation_size(vllm_runner,
|
||||
model_name=MODEL_NAME,
|
||||
input_str=input_str):
|
||||
|
||||
truncate_prompt_tokens = max_model_len + 1
|
||||
|
||||
with pytest.raises(ValueError), vllm_runner(
|
||||
model_name, task="embed",
|
||||
max_model_len=max_model_len) as vllm_model:
|
||||
|
||||
llm_output = vllm_model.model.encode(
|
||||
input_str, truncate_prompt_tokens=truncate_prompt_tokens)
|
||||
|
||||
assert llm_output == f"""truncate_prompt_tokens value
|
||||
({truncate_prompt_tokens}) is greater than
|
||||
max_model_len ({max_model_len}). Please, select
|
||||
a smaller truncation size."""
|
||||
@ -645,6 +645,7 @@ class LLMEngine:
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
@ -678,6 +679,7 @@ class LLMEngine:
|
||||
params: Optional[Union[SamplingParams, PoolingParams]] = None,
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
@ -758,6 +760,7 @@ class LLMEngine:
|
||||
|
||||
processed_inputs = self.input_preprocessor.preprocess(
|
||||
prompt,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import AsyncGenerator, List, Mapping, Optional
|
||||
from typing import AsyncGenerator, Mapping, Optional
|
||||
|
||||
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
|
||||
from vllm.config import DecodingConfig, ModelConfig, VllmConfig
|
||||
@ -256,7 +256,7 @@ class EngineClient(ABC):
|
||||
async def do_log_stats(
|
||||
self,
|
||||
scheduler_outputs: Optional[SchedulerOutputs] = None,
|
||||
model_output: Optional[List[SamplerOutput]] = None,
|
||||
model_output: Optional[list[SamplerOutput]] = None,
|
||||
) -> None:
|
||||
...
|
||||
|
||||
|
||||
@ -25,6 +25,7 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
|
||||
resolve_chat_template_content_format)
|
||||
from vllm.entrypoints.score_utils import (_cosine_similarity,
|
||||
_validate_score_input_lens)
|
||||
from vllm.entrypoints.utils import _validate_truncation_size
|
||||
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
|
||||
from vllm.inputs.parse import is_token_prompt, parse_and_batch_prompt
|
||||
from vllm.logger import init_logger
|
||||
@ -793,6 +794,7 @@ class LLM:
|
||||
pooling_params: Optional[Union[PoolingParams,
|
||||
Sequence[PoolingParams]]] = None,
|
||||
*,
|
||||
truncate_prompt_tokens: Optional[int] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
@ -807,6 +809,7 @@ class LLM:
|
||||
pooling_params: Optional[Union[PoolingParams,
|
||||
Sequence[PoolingParams]]] = None,
|
||||
prompt_token_ids: Optional[list[int]] = None,
|
||||
truncate_prompt_tokens: Optional[int] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
@ -821,6 +824,7 @@ class LLM:
|
||||
pooling_params: Optional[Union[PoolingParams,
|
||||
Sequence[PoolingParams]]] = None,
|
||||
prompt_token_ids: Optional[list[list[int]]] = None,
|
||||
truncate_prompt_tokens: Optional[int] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
@ -836,6 +840,7 @@ class LLM:
|
||||
Sequence[PoolingParams]]] = None,
|
||||
*,
|
||||
prompt_token_ids: list[int],
|
||||
truncate_prompt_tokens: Optional[int] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
@ -851,6 +856,7 @@ class LLM:
|
||||
Sequence[PoolingParams]]] = None,
|
||||
*,
|
||||
prompt_token_ids: list[list[int]],
|
||||
truncate_prompt_tokens: Optional[int] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
@ -864,6 +870,7 @@ class LLM:
|
||||
prompts: None,
|
||||
pooling_params: None,
|
||||
prompt_token_ids: Union[list[int], list[list[int]]],
|
||||
truncate_prompt_tokens: Optional[int] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
@ -882,6 +889,7 @@ class LLM:
|
||||
pooling_params: Optional[Union[PoolingParams,
|
||||
Sequence[PoolingParams]]] = None,
|
||||
prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
|
||||
truncate_prompt_tokens: Optional[int] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
@ -946,10 +954,15 @@ class LLM:
|
||||
for pooling_param in pooling_params:
|
||||
pooling_param.verify(self.llm_engine.model_config)
|
||||
|
||||
tokenization_kwargs: dict[str, Any] = {}
|
||||
_validate_truncation_size(self.llm_engine.model_config.max_model_len,
|
||||
truncate_prompt_tokens, tokenization_kwargs)
|
||||
|
||||
self._validate_and_add_requests(
|
||||
prompts=parsed_prompts,
|
||||
params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
|
||||
@ -962,6 +975,7 @@ class LLM:
|
||||
prompts: Union[PromptType, Sequence[PromptType]],
|
||||
/,
|
||||
*,
|
||||
truncate_prompt_tokens: Optional[int] = None,
|
||||
use_tqdm: bool = True,
|
||||
pooling_params: Optional[Union[PoolingParams,
|
||||
Sequence[PoolingParams]]] = None,
|
||||
@ -995,6 +1009,7 @@ class LLM:
|
||||
"Embedding API is only enabled for `--task embed`")
|
||||
|
||||
items = self.encode(prompts,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
use_tqdm=use_tqdm,
|
||||
pooling_params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
@ -1055,6 +1070,7 @@ class LLM:
|
||||
|
||||
encoded_output: list[PoolingRequestOutput] = self.encode(
|
||||
text_1 + text_2,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
@ -1098,9 +1114,8 @@ class LLM:
|
||||
pooling_params = PoolingParams()
|
||||
|
||||
tokenization_kwargs: dict[str, Any] = {}
|
||||
if truncate_prompt_tokens is not None:
|
||||
tokenization_kwargs["truncation"] = True
|
||||
tokenization_kwargs["max_length"] = truncate_prompt_tokens
|
||||
_validate_truncation_size(self.llm_engine.model_config.max_model_len,
|
||||
truncate_prompt_tokens, tokenization_kwargs)
|
||||
|
||||
parsed_prompts = []
|
||||
|
||||
@ -1323,6 +1338,7 @@ class LLM:
|
||||
Sequence[PoolingParams]],
|
||||
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest],
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
guided_options: Optional[GuidedDecodingRequest] = None,
|
||||
priority: Optional[list[int]] = None,
|
||||
) -> None:
|
||||
@ -1359,6 +1375,7 @@ class LLM:
|
||||
self._add_request(
|
||||
prompt,
|
||||
params[i] if isinstance(params, Sequence) else params,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request[i] if isinstance(
|
||||
lora_request, Sequence) else lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
@ -1369,6 +1386,7 @@ class LLM:
|
||||
self,
|
||||
prompt: PromptType,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
@ -1379,6 +1397,7 @@ class LLM:
|
||||
prompt,
|
||||
params,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=priority,
|
||||
)
|
||||
|
||||
@ -1014,7 +1014,7 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
|
||||
encoding_format: Literal["float", "base64"] = "float"
|
||||
dimensions: Optional[int] = None
|
||||
user: Optional[str] = None
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
|
||||
|
||||
# doc: begin-embedding-pooling-params
|
||||
additional_data: Optional[Any] = None
|
||||
@ -1049,7 +1049,7 @@ class EmbeddingChatRequest(OpenAIBaseModel):
|
||||
encoding_format: Literal["float", "base64"] = "float"
|
||||
dimensions: Optional[int] = None
|
||||
user: Optional[str] = None
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
|
||||
|
||||
# doc: begin-chat-embedding-pooling-params
|
||||
additional_data: Optional[Any] = None
|
||||
@ -1116,7 +1116,7 @@ class ScoreRequest(OpenAIBaseModel):
|
||||
model: Optional[str] = None
|
||||
text_1: Union[list[str], str]
|
||||
text_2: Union[list[str], str]
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
|
||||
|
||||
# doc: begin-score-pooling-params
|
||||
additional_data: Optional[Any] = None
|
||||
@ -1142,7 +1142,7 @@ class RerankRequest(OpenAIBaseModel):
|
||||
query: str
|
||||
documents: list[str]
|
||||
top_n: int = Field(default_factory=lambda: 0)
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
|
||||
|
||||
# doc: begin-rerank-pooling-params
|
||||
additional_data: Optional[Any] = None
|
||||
|
||||
@ -21,6 +21,7 @@ from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest,
|
||||
ErrorResponse, UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.utils import _validate_truncation_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput,
|
||||
PoolingRequestOutput)
|
||||
@ -85,16 +86,7 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
request_id = f"embd-{self._base_request_id(raw_request)}"
|
||||
created_time = int(time.time())
|
||||
|
||||
truncate_prompt_tokens = None
|
||||
|
||||
if request.truncate_prompt_tokens is not None:
|
||||
if request.truncate_prompt_tokens <= self.max_model_len:
|
||||
truncate_prompt_tokens = request.truncate_prompt_tokens
|
||||
else:
|
||||
return self.create_error_response(
|
||||
"truncate_prompt_tokens value is "
|
||||
"greater than max_model_len."
|
||||
" Please, select a smaller truncation size.")
|
||||
truncate_prompt_tokens = request.truncate_prompt_tokens
|
||||
|
||||
pooling_params = request.to_pooling_params()
|
||||
|
||||
@ -104,6 +96,8 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
try:
|
||||
truncate_prompt_tokens = _validate_truncation_size(
|
||||
self.max_model_len, truncate_prompt_tokens)
|
||||
(
|
||||
lora_request,
|
||||
prompt_adapter_request,
|
||||
|
||||
@ -173,7 +173,7 @@ class OpenAIServing:
|
||||
request: AnyRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
prompt: str,
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]],
|
||||
add_special_tokens: bool,
|
||||
) -> TextTokensPrompt:
|
||||
if (self.model_config.encoder_config is not None
|
||||
@ -271,7 +271,7 @@ class OpenAIServing:
|
||||
request: AnyRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
prompt_input: Union[str, list[int]],
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
|
||||
add_special_tokens: bool = True,
|
||||
) -> TextTokensPrompt:
|
||||
"""
|
||||
@ -292,7 +292,7 @@ class OpenAIServing:
|
||||
request: AnyRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
prompt_inputs: Iterable[Union[str, list[int]]],
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
|
||||
add_special_tokens: bool = True,
|
||||
) -> Iterator[TextTokensPrompt]:
|
||||
"""
|
||||
@ -321,7 +321,7 @@ class OpenAIServing:
|
||||
request: AnyRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
input_or_inputs: Union[str, list[str], list[int], list[list[int]]],
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
|
||||
add_special_tokens: bool = True,
|
||||
) -> list[TextTokensPrompt]:
|
||||
"""
|
||||
@ -356,7 +356,7 @@ class OpenAIServing:
|
||||
request: CompletionLikeRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
input_or_inputs: Union[str, list[str], list[int], list[list[int]]],
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
|
||||
add_special_tokens: bool = True,
|
||||
) -> tuple[list[TextTokensPrompt], list[TokensPrompt]]:
|
||||
request_prompts = await self._tokenize_prompt_input_or_inputs_async(
|
||||
|
||||
@ -21,6 +21,7 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse,
|
||||
PoolingResponseData, UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.utils import _validate_truncation_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import PoolingOutput, PoolingRequestOutput
|
||||
from vllm.utils import merge_async_iterators
|
||||
@ -85,18 +86,11 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
request_id = f"pool-{self._base_request_id(raw_request)}"
|
||||
created_time = int(time.time())
|
||||
|
||||
truncate_prompt_tokens = None
|
||||
|
||||
if request.truncate_prompt_tokens is not None:
|
||||
if request.truncate_prompt_tokens <= self.max_model_len:
|
||||
truncate_prompt_tokens = request.truncate_prompt_tokens
|
||||
else:
|
||||
return self.create_error_response(
|
||||
"truncate_prompt_tokens value is "
|
||||
"greater than max_model_len."
|
||||
" Please, select a smaller truncation size.")
|
||||
truncate_prompt_tokens = request.truncate_prompt_tokens
|
||||
|
||||
try:
|
||||
truncate_prompt_tokens = _validate_truncation_size(
|
||||
self.max_model_len, truncate_prompt_tokens)
|
||||
(
|
||||
lora_request,
|
||||
prompt_adapter_request,
|
||||
|
||||
@ -18,6 +18,7 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.score_utils import (_cosine_similarity,
|
||||
_validate_score_input_lens)
|
||||
from vllm.entrypoints.utils import _validate_truncation_size
|
||||
from vllm.inputs.data import TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
@ -231,11 +232,6 @@ class ServingScores(OpenAIServing):
|
||||
truncate_prompt_tokens: Optional[int] = None,
|
||||
) -> list[PoolingRequestOutput]:
|
||||
|
||||
tokenization_kwargs: dict[str, Any] = {}
|
||||
if truncate_prompt_tokens is not None:
|
||||
tokenization_kwargs["truncation"] = True
|
||||
tokenization_kwargs["max_length"] = truncate_prompt_tokens
|
||||
|
||||
(
|
||||
lora_request,
|
||||
prompt_adapter_request,
|
||||
@ -247,12 +243,9 @@ class ServingScores(OpenAIServing):
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||
|
||||
if truncate_prompt_tokens is not None and \
|
||||
truncate_prompt_tokens > self.max_model_len:
|
||||
raise ValueError(
|
||||
f"truncate_prompt_tokens value ({truncate_prompt_tokens}) "
|
||||
f"is greater than max_model_len ({self.max_model_len})."
|
||||
f" Please, select a smaller truncation size.")
|
||||
tokenization_kwargs: dict[str, Any] = {}
|
||||
_validate_truncation_size(self.max_model_len, truncate_prompt_tokens,
|
||||
tokenization_kwargs)
|
||||
|
||||
trace_headers = (None if raw_request is None else await
|
||||
self._get_trace_headers(raw_request.headers))
|
||||
|
||||
@ -46,4 +46,4 @@ def _validate_score_input_lens(
|
||||
if len(texts_1) == 0:
|
||||
raise ValueError("At least one text element must be given")
|
||||
if len(texts_2) == 0:
|
||||
raise ValueError("At least one text_pair element must be given")
|
||||
raise ValueError("At least one text_pair element must be given")
|
||||
@ -3,6 +3,7 @@
|
||||
import asyncio
|
||||
import functools
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
@ -134,3 +135,26 @@ def cli_env_setup():
|
||||
if "VLLM_WORKER_MULTIPROC_METHOD" not in os.environ:
|
||||
logger.debug("Setting VLLM_WORKER_MULTIPROC_METHOD to 'spawn'")
|
||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||
|
||||
|
||||
def _validate_truncation_size(
|
||||
max_model_len: int,
|
||||
truncate_prompt_tokens: Optional[int],
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> Optional[int]:
|
||||
|
||||
if truncate_prompt_tokens is not None:
|
||||
if truncate_prompt_tokens <= -1:
|
||||
truncate_prompt_tokens = max_model_len
|
||||
|
||||
if truncate_prompt_tokens > max_model_len:
|
||||
raise ValueError(
|
||||
f"truncate_prompt_tokens value ({truncate_prompt_tokens}) "
|
||||
f"is greater than max_model_len ({max_model_len})."
|
||||
f" Please, select a smaller truncation size.")
|
||||
|
||||
if tokenization_kwargs is not None:
|
||||
tokenization_kwargs["truncation"] = True
|
||||
tokenization_kwargs["max_length"] = truncate_prompt_tokens
|
||||
|
||||
return truncate_prompt_tokens
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Mapping
|
||||
from typing import Optional, Union, cast
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from typing_extensions import assert_never
|
||||
|
||||
@ -183,18 +183,21 @@ class InputPreprocessor:
|
||||
self,
|
||||
prompt: str,
|
||||
lora_request: Optional[LoRARequest],
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> list[int]:
|
||||
"""
|
||||
Apply the model's tokenizer to a text prompt, returning the
|
||||
corresponding token IDs.
|
||||
"""
|
||||
tokenizer = self.get_tokenizer_group()
|
||||
add_special_tokens = None
|
||||
if tokenization_kwargs is None:
|
||||
tokenization_kwargs = {}
|
||||
|
||||
if self.model_config.hf_config.model_type == "whisper":
|
||||
# For Whisper, special tokens should be provided by the user based
|
||||
# on the task and language of their request. Also needed to avoid
|
||||
# appending an EOS token to the prompt which disrupts generation.
|
||||
add_special_tokens = False
|
||||
tokenization_kwargs["add_special_tokens"] = False
|
||||
|
||||
if (self.model_config.encoder_config is not None
|
||||
and self.model_config.encoder_config.get(
|
||||
@ -203,25 +206,27 @@ class InputPreprocessor:
|
||||
|
||||
return tokenizer.encode(prompt=prompt,
|
||||
lora_request=lora_request,
|
||||
add_special_tokens=add_special_tokens)
|
||||
**tokenization_kwargs)
|
||||
|
||||
async def _tokenize_prompt_async(
|
||||
self,
|
||||
prompt: str,
|
||||
lora_request: Optional[LoRARequest],
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> list[int]:
|
||||
"""Async version of :meth:`_tokenize_prompt`."""
|
||||
tokenizer = self.get_tokenizer_group()
|
||||
add_special_tokens = None
|
||||
if tokenization_kwargs is None:
|
||||
tokenization_kwargs = {}
|
||||
|
||||
if self.model_config.hf_config.model_type == "whisper":
|
||||
# For Whisper, special tokens should be provided by the user based
|
||||
# on the task and language of their request. Also needed to avoid
|
||||
# appending an EOS token to the prompt which disrupts generation.
|
||||
add_special_tokens = False
|
||||
return await tokenizer.encode_async(
|
||||
prompt=prompt,
|
||||
lora_request=lora_request,
|
||||
add_special_tokens=add_special_tokens)
|
||||
tokenization_kwargs["add_special_tokens"] = False
|
||||
return await tokenizer.encode_async(prompt=prompt,
|
||||
lora_request=lora_request,
|
||||
**tokenization_kwargs)
|
||||
|
||||
def _process_multimodal(
|
||||
self,
|
||||
@ -281,6 +286,7 @@ class InputPreprocessor:
|
||||
def _prompt_to_llm_inputs(
|
||||
self,
|
||||
prompt: SingletonPrompt,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
return_mm_hashes: bool = False,
|
||||
) -> SingletonInputs:
|
||||
@ -304,6 +310,7 @@ class InputPreprocessor:
|
||||
prompt_token_ids = self._tokenize_prompt(
|
||||
prompt_text,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
return token_inputs(
|
||||
@ -352,6 +359,7 @@ class InputPreprocessor:
|
||||
prompt_token_ids = self._tokenize_prompt(
|
||||
prompt_text,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
return token_inputs(
|
||||
@ -364,6 +372,7 @@ class InputPreprocessor:
|
||||
async def _prompt_to_llm_inputs_async(
|
||||
self,
|
||||
prompt: SingletonPrompt,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
return_mm_hashes: bool = False,
|
||||
) -> SingletonInputs:
|
||||
@ -375,6 +384,7 @@ class InputPreprocessor:
|
||||
prompt_token_ids = await self._tokenize_prompt_async(
|
||||
prompt_text,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
return token_inputs(
|
||||
@ -517,6 +527,7 @@ class InputPreprocessor:
|
||||
def _process_encoder_decoder_prompt(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> EncoderDecoderInputs:
|
||||
"""
|
||||
For encoder/decoder models only:
|
||||
@ -553,7 +564,9 @@ class InputPreprocessor:
|
||||
|
||||
if is_explicit_encoder_decoder_prompt(prompt):
|
||||
encoder_inputs = self._prompt_to_llm_inputs(
|
||||
prompt["encoder_prompt"])
|
||||
prompt["encoder_prompt"],
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
if (decoder_input := prompt["decoder_prompt"]) is None:
|
||||
decoder_inputs = None
|
||||
else:
|
||||
@ -565,7 +578,10 @@ class InputPreprocessor:
|
||||
self._separate_enc_dec_inputs_from_mm_processor_outputs(
|
||||
encoder_inputs, decoder_inputs))
|
||||
else:
|
||||
inputs = self._prompt_to_llm_inputs(prompt)
|
||||
inputs = self._prompt_to_llm_inputs(
|
||||
prompt,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
if self.model_config.is_multimodal_model:
|
||||
# Encoder-Decoder Multimodal model
|
||||
encoder_inputs, decoder_inputs = (
|
||||
@ -581,6 +597,7 @@ class InputPreprocessor:
|
||||
async def _process_encoder_decoder_prompt_async(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> EncoderDecoderInputs:
|
||||
"""Async version of :meth:`_process_encoder_decoder_prompt`."""
|
||||
encoder_inputs: SingletonInputs
|
||||
@ -588,13 +605,18 @@ class InputPreprocessor:
|
||||
|
||||
if is_explicit_encoder_decoder_prompt(prompt):
|
||||
encoder_task = self._prompt_to_llm_inputs_async(
|
||||
prompt["encoder_prompt"])
|
||||
prompt["encoder_prompt"],
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
if (decoder_input := prompt["decoder_prompt"]) is None:
|
||||
encoder_inputs = await encoder_task
|
||||
decoder_inputs = None
|
||||
else:
|
||||
decoder_task = self._prompt_to_llm_inputs_async(decoder_input)
|
||||
decoder_task = self._prompt_to_llm_inputs_async(
|
||||
decoder_input,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
encoder_inputs, decoder_inputs = await asyncio.gather(
|
||||
encoder_task, decoder_task)
|
||||
@ -606,7 +628,10 @@ class InputPreprocessor:
|
||||
self._separate_enc_dec_inputs_from_mm_processor_outputs(
|
||||
encoder_inputs, decoder_inputs))
|
||||
else:
|
||||
inputs = await self._prompt_to_llm_inputs_async(prompt)
|
||||
inputs = await self._prompt_to_llm_inputs_async(
|
||||
prompt,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
if self.model_config.is_multimodal_model:
|
||||
# Encoder-Decoder Multimodal model
|
||||
encoder_inputs, decoder_inputs = (
|
||||
@ -638,6 +663,7 @@ class InputPreprocessor:
|
||||
def _process_decoder_only_prompt(
|
||||
self,
|
||||
prompt: SingletonPrompt,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
return_mm_hashes: bool = False,
|
||||
@ -660,6 +686,7 @@ class InputPreprocessor:
|
||||
|
||||
prompt_comps = self._prompt_to_llm_inputs(
|
||||
prompt,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
return_mm_hashes=return_mm_hashes,
|
||||
)
|
||||
@ -672,6 +699,7 @@ class InputPreprocessor:
|
||||
async def _process_decoder_only_prompt_async(
|
||||
self,
|
||||
prompt: SingletonPrompt,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
return_mm_hashes: bool = False,
|
||||
@ -679,6 +707,7 @@ class InputPreprocessor:
|
||||
"""Async version of :meth:`_process_decoder_only_prompt`."""
|
||||
prompt_comps = await self._prompt_to_llm_inputs_async(
|
||||
prompt,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
return_mm_hashes=return_mm_hashes,
|
||||
)
|
||||
@ -691,6 +720,7 @@ class InputPreprocessor:
|
||||
def preprocess(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
return_mm_hashes: bool = False,
|
||||
@ -711,6 +741,7 @@ class InputPreprocessor:
|
||||
# Decoder-only operation
|
||||
return self._process_decoder_only_prompt(
|
||||
prompt,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
return_mm_hashes=return_mm_hashes,
|
||||
@ -719,6 +750,7 @@ class InputPreprocessor:
|
||||
async def preprocess_async(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
return_mm_hashes: bool = False,
|
||||
@ -739,6 +771,7 @@ class InputPreprocessor:
|
||||
# Decoder-only operation
|
||||
return await self._process_decoder_only_prompt_async(
|
||||
prompt,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
return_mm_hashes=return_mm_hashes,
|
||||
|
||||
@ -186,9 +186,10 @@ class SamplingParams(
|
||||
logits_processors: list of functions that modify logits based on
|
||||
previously generated tokens, and optionally prompt tokens as
|
||||
a first argument.
|
||||
truncate_prompt_tokens: If set to an integer k, will use only the last k
|
||||
tokens from the prompt (i.e., left truncation). Defaults to None
|
||||
(i.e., no truncation).
|
||||
truncate_prompt_tokens: If set to -1, will use the truncation size
|
||||
supported by the model. If set to an integer k, will use only
|
||||
the last k tokens from the prompt (i.e., left truncation).
|
||||
Defaults to None (i.e., no truncation).
|
||||
guided_decoding: If provided, the engine will construct a guided
|
||||
decoding logits processor from these parameters. Defaults to None.
|
||||
logit_bias: If provided, the engine will construct a logits processor
|
||||
|
||||
@ -55,6 +55,8 @@ def encode_tokens(
|
||||
tokenizer: AnyTokenizer,
|
||||
text: str,
|
||||
*,
|
||||
truncation: Optional[bool] = None,
|
||||
max_length: Optional[int] = None,
|
||||
add_special_tokens: Optional[bool] = None,
|
||||
) -> list[int]:
|
||||
"""
|
||||
@ -64,10 +66,18 @@ def encode_tokens(
|
||||
:code:`add_special_tokens=None` means to use the backend's default
|
||||
settings.
|
||||
"""
|
||||
if add_special_tokens is not None:
|
||||
return tokenizer.encode(text, add_special_tokens=add_special_tokens)
|
||||
|
||||
return tokenizer.encode(text)
|
||||
kw_args: dict[str, Any] = {}
|
||||
if max_length is not None:
|
||||
kw_args["max_length"] = max_length
|
||||
|
||||
if truncation is not None:
|
||||
kw_args["truncation"] = truncation
|
||||
|
||||
if add_special_tokens is not None:
|
||||
kw_args["add_special_tokens"] = add_special_tokens
|
||||
|
||||
return tokenizer.encode(text, **kw_args)
|
||||
|
||||
|
||||
def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
|
||||
|
||||
@ -94,6 +94,8 @@ class TokenizerBase(ABC):
|
||||
@abstractmethod
|
||||
def encode(self,
|
||||
text: str,
|
||||
truncation: Optional[bool] = None,
|
||||
max_length: Optional[int] = None,
|
||||
add_special_tokens: Optional[bool] = None) -> list[int]:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@ -45,11 +45,16 @@ class TokenizerGroup:
|
||||
|
||||
def encode(self,
|
||||
prompt: str,
|
||||
max_length: Optional[int] = None,
|
||||
truncation: Optional[bool] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
add_special_tokens: Optional[bool] = None) -> List[int]:
|
||||
|
||||
tokenizer = self.get_lora_tokenizer(lora_request)
|
||||
ret = encode_tokens(tokenizer,
|
||||
prompt,
|
||||
max_length=max_length,
|
||||
truncation=truncation,
|
||||
add_special_tokens=add_special_tokens)
|
||||
self._raise_if_input_too_long(ret, lora_request)
|
||||
return ret
|
||||
@ -57,11 +62,15 @@ class TokenizerGroup:
|
||||
async def encode_async(
|
||||
self,
|
||||
prompt: str,
|
||||
max_length: Optional[int] = None,
|
||||
truncation: Optional[bool] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
add_special_tokens: Optional[bool] = None) -> List[int]:
|
||||
tokenizer = await self.get_lora_tokenizer_async(lora_request)
|
||||
ret = encode_tokens(tokenizer,
|
||||
prompt,
|
||||
max_length=max_length,
|
||||
truncation=truncation,
|
||||
add_special_tokens=add_special_tokens)
|
||||
self._raise_if_input_too_long(ret, lora_request)
|
||||
return ret
|
||||
|
||||
@ -359,6 +359,8 @@ class MistralTokenizer(TokenizerBase):
|
||||
|
||||
def encode(self,
|
||||
text: str,
|
||||
truncation: Optional[bool] = None,
|
||||
max_length: Optional[int] = None,
|
||||
add_special_tokens: Optional[bool] = None) -> List[int]:
|
||||
# `encode` should only be used for prompt completion
|
||||
# it should never be used for chat_completion.
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator, Mapping
|
||||
from copy import copy
|
||||
from typing import Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -201,6 +201,7 @@ class AsyncLLM(EngineClient):
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
@ -219,7 +220,8 @@ class AsyncLLM(EngineClient):
|
||||
# Convert Input --> Request.
|
||||
prompt_str, request = self.processor.process_inputs(
|
||||
request_id, prompt, params, arrival_time, lora_request,
|
||||
trace_headers, prompt_adapter_request, priority)
|
||||
tokenization_kwargs, trace_headers, prompt_adapter_request,
|
||||
priority)
|
||||
|
||||
if params.n == 1:
|
||||
await self._add_request(request, prompt_str, None, 0, queue)
|
||||
|
||||
@ -175,6 +175,7 @@ class LLMEngine:
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
@ -182,7 +183,8 @@ class LLMEngine:
|
||||
# Process raw inputs into the request.
|
||||
prompt_str, request = self.processor.process_inputs(
|
||||
request_id, prompt, params, arrival_time, lora_request,
|
||||
trace_headers, prompt_adapter_request, priority)
|
||||
tokenization_kwargs, trace_headers, prompt_adapter_request,
|
||||
priority)
|
||||
|
||||
n = params.n if isinstance(params, SamplingParams) else 1
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
import time
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Literal, Optional, Union
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
|
||||
@ -198,6 +198,7 @@ class Processor:
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
@ -224,6 +225,7 @@ class Processor:
|
||||
# 3. Apply prompt adapter to prompt token ids if one exists.
|
||||
processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
|
||||
prompt,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
return_mm_hashes=self.use_hash,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user