mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 03:08:49 +08:00
Add option to completion API to truncate prompt tokens (#3144)
This commit is contained in:
parent
cfaf49a167
commit
1d7c940d74
@ -4,7 +4,7 @@ import time
|
|||||||
from typing import Dict, List, Literal, Optional, Union
|
from typing import Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, conint, model_validator
|
||||||
|
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.utils import random_uuid
|
from vllm.utils import random_uuid
|
||||||
@ -229,6 +229,7 @@ class CompletionRequest(BaseModel):
|
|||||||
min_tokens: Optional[int] = 0
|
min_tokens: Optional[int] = 0
|
||||||
skip_special_tokens: Optional[bool] = True
|
skip_special_tokens: Optional[bool] = True
|
||||||
spaces_between_special_tokens: Optional[bool] = True
|
spaces_between_special_tokens: Optional[bool] = True
|
||||||
|
truncate_prompt_tokens: Optional[conint(ge=1)] = None
|
||||||
# doc: end-completion-sampling-params
|
# doc: end-completion-sampling-params
|
||||||
|
|
||||||
# doc: begin-completion-extra-params
|
# doc: begin-completion-extra-params
|
||||||
@ -309,6 +310,7 @@ class CompletionRequest(BaseModel):
|
|||||||
include_stop_str_in_output=self.include_stop_str_in_output,
|
include_stop_str_in_output=self.include_stop_str_in_output,
|
||||||
length_penalty=self.length_penalty,
|
length_penalty=self.length_penalty,
|
||||||
logits_processors=logits_processors,
|
logits_processors=logits_processors,
|
||||||
|
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
|
|||||||
@ -137,10 +137,16 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
for i, prompt in enumerate(prompts):
|
for i, prompt in enumerate(prompts):
|
||||||
if prompt_is_tokens:
|
if prompt_is_tokens:
|
||||||
input_ids = self._validate_prompt_and_tokenize(
|
input_ids = self._validate_prompt_and_tokenize(
|
||||||
request, prompt_ids=prompt)
|
request,
|
||||||
|
prompt_ids=prompt,
|
||||||
|
truncate_prompt_tokens=sampling_params.
|
||||||
|
truncate_prompt_tokens)
|
||||||
else:
|
else:
|
||||||
input_ids = self._validate_prompt_and_tokenize(
|
input_ids = self._validate_prompt_and_tokenize(
|
||||||
request, prompt=prompt)
|
request,
|
||||||
|
prompt=prompt,
|
||||||
|
truncate_prompt_tokens=sampling_params.
|
||||||
|
truncate_prompt_tokens)
|
||||||
|
|
||||||
generators.append(
|
generators.append(
|
||||||
self.engine.generate(prompt,
|
self.engine.generate(prompt,
|
||||||
|
|||||||
@ -4,6 +4,8 @@ from dataclasses import dataclass
|
|||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from pydantic import conint
|
||||||
|
|
||||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||||
CompletionRequest, ErrorResponse,
|
CompletionRequest, ErrorResponse,
|
||||||
@ -66,7 +68,8 @@ class OpenAIServing:
|
|||||||
self.tokenizer = get_tokenizer(
|
self.tokenizer = get_tokenizer(
|
||||||
engine_model_config.tokenizer,
|
engine_model_config.tokenizer,
|
||||||
tokenizer_mode=engine_model_config.tokenizer_mode,
|
tokenizer_mode=engine_model_config.tokenizer_mode,
|
||||||
trust_remote_code=engine_model_config.trust_remote_code)
|
trust_remote_code=engine_model_config.trust_remote_code,
|
||||||
|
truncation_side="left")
|
||||||
|
|
||||||
async def show_available_models(self) -> ModelList:
|
async def show_available_models(self) -> ModelList:
|
||||||
"""Show available models. Right now we only have one model."""
|
"""Show available models. Right now we only have one model."""
|
||||||
@ -164,15 +167,26 @@ class OpenAIServing:
|
|||||||
self,
|
self,
|
||||||
request: Union[ChatCompletionRequest, CompletionRequest],
|
request: Union[ChatCompletionRequest, CompletionRequest],
|
||||||
prompt: Optional[str] = None,
|
prompt: Optional[str] = None,
|
||||||
prompt_ids: Optional[List[int]] = None) -> List[int]:
|
prompt_ids: Optional[List[int]] = None,
|
||||||
|
truncate_prompt_tokens: Optional[conint(ge=1)] = None
|
||||||
|
) -> List[int]:
|
||||||
if not (prompt or prompt_ids):
|
if not (prompt or prompt_ids):
|
||||||
raise ValueError("Either prompt or prompt_ids should be provided.")
|
raise ValueError("Either prompt or prompt_ids should be provided.")
|
||||||
if (prompt and prompt_ids):
|
if (prompt and prompt_ids):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Only one of prompt or prompt_ids should be provided.")
|
"Only one of prompt or prompt_ids should be provided.")
|
||||||
|
|
||||||
input_ids = prompt_ids if prompt_ids is not None else self.tokenizer(
|
if prompt_ids is None:
|
||||||
prompt).input_ids
|
tokenizer_kwargs = {} if truncate_prompt_tokens is None else {
|
||||||
|
"truncation": True,
|
||||||
|
"max_length": truncate_prompt_tokens,
|
||||||
|
}
|
||||||
|
input_ids = self.tokenizer(prompt, **tokenizer_kwargs).input_ids
|
||||||
|
elif truncate_prompt_tokens is not None:
|
||||||
|
input_ids = prompt_ids[-truncate_prompt_tokens:]
|
||||||
|
else:
|
||||||
|
input_ids = prompt_ids
|
||||||
|
|
||||||
token_num = len(input_ids)
|
token_num = len(input_ids)
|
||||||
|
|
||||||
if request.max_tokens is None:
|
if request.max_tokens is None:
|
||||||
|
|||||||
@ -5,6 +5,7 @@ from functools import cached_property
|
|||||||
from typing import Callable, List, Optional, Union
|
from typing import Callable, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from pydantic import conint
|
||||||
|
|
||||||
_SAMPLING_EPS = 1e-5
|
_SAMPLING_EPS = 1e-5
|
||||||
|
|
||||||
@ -94,6 +95,9 @@ class SamplingParams:
|
|||||||
tokens in the output. Defaults to True.
|
tokens in the output. Defaults to True.
|
||||||
logits_processors: List of functions that modify logits based on
|
logits_processors: List of functions that modify logits based on
|
||||||
previously generated tokens.
|
previously generated tokens.
|
||||||
|
truncate_prompt_tokens: If set to an integer k, will use only the last k
|
||||||
|
tokens from the prompt (i.e., left truncation). Defaults to None
|
||||||
|
(i.e., no truncation).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -123,6 +127,7 @@ class SamplingParams:
|
|||||||
skip_special_tokens: bool = True,
|
skip_special_tokens: bool = True,
|
||||||
spaces_between_special_tokens: bool = True,
|
spaces_between_special_tokens: bool = True,
|
||||||
logits_processors: Optional[List[LogitsProcessor]] = None,
|
logits_processors: Optional[List[LogitsProcessor]] = None,
|
||||||
|
truncate_prompt_tokens: Optional[conint(ge=1)] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.n = n
|
self.n = n
|
||||||
self.best_of = best_of if best_of is not None else n
|
self.best_of = best_of if best_of is not None else n
|
||||||
@ -160,6 +165,7 @@ class SamplingParams:
|
|||||||
self.spaces_between_special_tokens = spaces_between_special_tokens
|
self.spaces_between_special_tokens = spaces_between_special_tokens
|
||||||
self.logits_processors = logits_processors
|
self.logits_processors = logits_processors
|
||||||
self.include_stop_str_in_output = include_stop_str_in_output
|
self.include_stop_str_in_output = include_stop_str_in_output
|
||||||
|
self.truncate_prompt_tokens = truncate_prompt_tokens
|
||||||
self._verify_args()
|
self._verify_args()
|
||||||
if self.use_beam_search:
|
if self.use_beam_search:
|
||||||
self._verify_beam_search()
|
self._verify_beam_search()
|
||||||
@ -216,6 +222,10 @@ class SamplingParams:
|
|||||||
if self.prompt_logprobs is not None and self.prompt_logprobs < 0:
|
if self.prompt_logprobs is not None and self.prompt_logprobs < 0:
|
||||||
raise ValueError(f"prompt_logprobs must be non-negative, got "
|
raise ValueError(f"prompt_logprobs must be non-negative, got "
|
||||||
f"{self.prompt_logprobs}.")
|
f"{self.prompt_logprobs}.")
|
||||||
|
if (self.truncate_prompt_tokens is not None
|
||||||
|
and self.truncate_prompt_tokens < 1):
|
||||||
|
raise ValueError(f"truncate_prompt_tokens must be >= 1, "
|
||||||
|
f"got {self.truncate_prompt_tokens}")
|
||||||
if self.stop and not self.detokenize:
|
if self.stop and not self.detokenize:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"stop strings are only supported when detokenize is True. "
|
"stop strings are only supported when detokenize is True. "
|
||||||
@ -300,4 +310,5 @@ class SamplingParams:
|
|||||||
f"prompt_logprobs={self.prompt_logprobs}, "
|
f"prompt_logprobs={self.prompt_logprobs}, "
|
||||||
f"skip_special_tokens={self.skip_special_tokens}, "
|
f"skip_special_tokens={self.skip_special_tokens}, "
|
||||||
"spaces_between_special_tokens="
|
"spaces_between_special_tokens="
|
||||||
f"{self.spaces_between_special_tokens})")
|
f"{self.spaces_between_special_tokens}, "
|
||||||
|
f"truncate_prompt_tokens={self.truncate_prompt_tokens})")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user