Add option to completion API to truncate prompt tokens (#3144)

This commit is contained in:
Thomas Parnell 2024-04-05 19:15:42 +02:00 committed by GitHub
parent cfaf49a167
commit 1d7c940d74
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 41 additions and 8 deletions

View File

@ -4,7 +4,7 @@ import time
from typing import Dict, List, Literal, Optional, Union from typing import Dict, List, Literal, Optional, Union
import torch import torch
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, conint, model_validator
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid from vllm.utils import random_uuid
@ -229,6 +229,7 @@ class CompletionRequest(BaseModel):
min_tokens: Optional[int] = 0 min_tokens: Optional[int] = 0
skip_special_tokens: Optional[bool] = True skip_special_tokens: Optional[bool] = True
spaces_between_special_tokens: Optional[bool] = True spaces_between_special_tokens: Optional[bool] = True
truncate_prompt_tokens: Optional[conint(ge=1)] = None
# doc: end-completion-sampling-params # doc: end-completion-sampling-params
# doc: begin-completion-extra-params # doc: begin-completion-extra-params
@ -309,6 +310,7 @@ class CompletionRequest(BaseModel):
include_stop_str_in_output=self.include_stop_str_in_output, include_stop_str_in_output=self.include_stop_str_in_output,
length_penalty=self.length_penalty, length_penalty=self.length_penalty,
logits_processors=logits_processors, logits_processors=logits_processors,
truncate_prompt_tokens=self.truncate_prompt_tokens,
) )
@model_validator(mode="before") @model_validator(mode="before")

View File

@ -137,10 +137,16 @@ class OpenAIServingCompletion(OpenAIServing):
for i, prompt in enumerate(prompts): for i, prompt in enumerate(prompts):
if prompt_is_tokens: if prompt_is_tokens:
input_ids = self._validate_prompt_and_tokenize( input_ids = self._validate_prompt_and_tokenize(
request, prompt_ids=prompt) request,
prompt_ids=prompt,
truncate_prompt_tokens=sampling_params.
truncate_prompt_tokens)
else: else:
input_ids = self._validate_prompt_and_tokenize( input_ids = self._validate_prompt_and_tokenize(
request, prompt=prompt) request,
prompt=prompt,
truncate_prompt_tokens=sampling_params.
truncate_prompt_tokens)
generators.append( generators.append(
self.engine.generate(prompt, self.engine.generate(prompt,

View File

@ -4,6 +4,8 @@ from dataclasses import dataclass
from http import HTTPStatus from http import HTTPStatus
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from pydantic import conint
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest, ErrorResponse, CompletionRequest, ErrorResponse,
@ -66,7 +68,8 @@ class OpenAIServing:
self.tokenizer = get_tokenizer( self.tokenizer = get_tokenizer(
engine_model_config.tokenizer, engine_model_config.tokenizer,
tokenizer_mode=engine_model_config.tokenizer_mode, tokenizer_mode=engine_model_config.tokenizer_mode,
trust_remote_code=engine_model_config.trust_remote_code) trust_remote_code=engine_model_config.trust_remote_code,
truncation_side="left")
async def show_available_models(self) -> ModelList: async def show_available_models(self) -> ModelList:
"""Show available models. Right now we only have one model.""" """Show available models. Right now we only have one model."""
@ -164,15 +167,26 @@ class OpenAIServing:
self, self,
request: Union[ChatCompletionRequest, CompletionRequest], request: Union[ChatCompletionRequest, CompletionRequest],
prompt: Optional[str] = None, prompt: Optional[str] = None,
prompt_ids: Optional[List[int]] = None) -> List[int]: prompt_ids: Optional[List[int]] = None,
truncate_prompt_tokens: Optional[conint(ge=1)] = None
) -> List[int]:
if not (prompt or prompt_ids): if not (prompt or prompt_ids):
raise ValueError("Either prompt or prompt_ids should be provided.") raise ValueError("Either prompt or prompt_ids should be provided.")
if (prompt and prompt_ids): if (prompt and prompt_ids):
raise ValueError( raise ValueError(
"Only one of prompt or prompt_ids should be provided.") "Only one of prompt or prompt_ids should be provided.")
input_ids = prompt_ids if prompt_ids is not None else self.tokenizer( if prompt_ids is None:
prompt).input_ids tokenizer_kwargs = {} if truncate_prompt_tokens is None else {
"truncation": True,
"max_length": truncate_prompt_tokens,
}
input_ids = self.tokenizer(prompt, **tokenizer_kwargs).input_ids
elif truncate_prompt_tokens is not None:
input_ids = prompt_ids[-truncate_prompt_tokens:]
else:
input_ids = prompt_ids
token_num = len(input_ids) token_num = len(input_ids)
if request.max_tokens is None: if request.max_tokens is None:

View File

@ -5,6 +5,7 @@ from functools import cached_property
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
import torch import torch
from pydantic import conint
_SAMPLING_EPS = 1e-5 _SAMPLING_EPS = 1e-5
@ -94,6 +95,9 @@ class SamplingParams:
tokens in the output. Defaults to True. tokens in the output. Defaults to True.
logits_processors: List of functions that modify logits based on logits_processors: List of functions that modify logits based on
previously generated tokens. previously generated tokens.
truncate_prompt_tokens: If set to an integer k, will use only the last k
tokens from the prompt (i.e., left truncation). Defaults to None
(i.e., no truncation).
""" """
def __init__( def __init__(
@ -123,6 +127,7 @@ class SamplingParams:
skip_special_tokens: bool = True, skip_special_tokens: bool = True,
spaces_between_special_tokens: bool = True, spaces_between_special_tokens: bool = True,
logits_processors: Optional[List[LogitsProcessor]] = None, logits_processors: Optional[List[LogitsProcessor]] = None,
truncate_prompt_tokens: Optional[conint(ge=1)] = None,
) -> None: ) -> None:
self.n = n self.n = n
self.best_of = best_of if best_of is not None else n self.best_of = best_of if best_of is not None else n
@ -160,6 +165,7 @@ class SamplingParams:
self.spaces_between_special_tokens = spaces_between_special_tokens self.spaces_between_special_tokens = spaces_between_special_tokens
self.logits_processors = logits_processors self.logits_processors = logits_processors
self.include_stop_str_in_output = include_stop_str_in_output self.include_stop_str_in_output = include_stop_str_in_output
self.truncate_prompt_tokens = truncate_prompt_tokens
self._verify_args() self._verify_args()
if self.use_beam_search: if self.use_beam_search:
self._verify_beam_search() self._verify_beam_search()
@ -216,6 +222,10 @@ class SamplingParams:
if self.prompt_logprobs is not None and self.prompt_logprobs < 0: if self.prompt_logprobs is not None and self.prompt_logprobs < 0:
raise ValueError(f"prompt_logprobs must be non-negative, got " raise ValueError(f"prompt_logprobs must be non-negative, got "
f"{self.prompt_logprobs}.") f"{self.prompt_logprobs}.")
if (self.truncate_prompt_tokens is not None
and self.truncate_prompt_tokens < 1):
raise ValueError(f"truncate_prompt_tokens must be >= 1, "
f"got {self.truncate_prompt_tokens}")
if self.stop and not self.detokenize: if self.stop and not self.detokenize:
raise ValueError( raise ValueError(
"stop strings are only supported when detokenize is True. " "stop strings are only supported when detokenize is True. "
@ -300,4 +310,5 @@ class SamplingParams:
f"prompt_logprobs={self.prompt_logprobs}, " f"prompt_logprobs={self.prompt_logprobs}, "
f"skip_special_tokens={self.skip_special_tokens}, " f"skip_special_tokens={self.skip_special_tokens}, "
"spaces_between_special_tokens=" "spaces_between_special_tokens="
f"{self.spaces_between_special_tokens})") f"{self.spaces_between_special_tokens}, "
f"truncate_prompt_tokens={self.truncate_prompt_tokens})")