mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:04:53 +08:00
Add option to completion API to truncate prompt tokens (#3144)
This commit is contained in:
parent
cfaf49a167
commit
1d7c940d74
@ -4,7 +4,7 @@ import time
|
||||
from typing import Dict, List, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from pydantic import BaseModel, Field, conint, model_validator
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import random_uuid
|
||||
@ -229,6 +229,7 @@ class CompletionRequest(BaseModel):
|
||||
min_tokens: Optional[int] = 0
|
||||
skip_special_tokens: Optional[bool] = True
|
||||
spaces_between_special_tokens: Optional[bool] = True
|
||||
truncate_prompt_tokens: Optional[conint(ge=1)] = None
|
||||
# doc: end-completion-sampling-params
|
||||
|
||||
# doc: begin-completion-extra-params
|
||||
@ -309,6 +310,7 @@ class CompletionRequest(BaseModel):
|
||||
include_stop_str_in_output=self.include_stop_str_in_output,
|
||||
length_penalty=self.length_penalty,
|
||||
logits_processors=logits_processors,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
|
||||
@ -137,10 +137,16 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
for i, prompt in enumerate(prompts):
|
||||
if prompt_is_tokens:
|
||||
input_ids = self._validate_prompt_and_tokenize(
|
||||
request, prompt_ids=prompt)
|
||||
request,
|
||||
prompt_ids=prompt,
|
||||
truncate_prompt_tokens=sampling_params.
|
||||
truncate_prompt_tokens)
|
||||
else:
|
||||
input_ids = self._validate_prompt_and_tokenize(
|
||||
request, prompt=prompt)
|
||||
request,
|
||||
prompt=prompt,
|
||||
truncate_prompt_tokens=sampling_params.
|
||||
truncate_prompt_tokens)
|
||||
|
||||
generators.append(
|
||||
self.engine.generate(prompt,
|
||||
|
||||
@ -4,6 +4,8 @@ from dataclasses import dataclass
|
||||
from http import HTTPStatus
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from pydantic import conint
|
||||
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
CompletionRequest, ErrorResponse,
|
||||
@ -66,7 +68,8 @@ class OpenAIServing:
|
||||
self.tokenizer = get_tokenizer(
|
||||
engine_model_config.tokenizer,
|
||||
tokenizer_mode=engine_model_config.tokenizer_mode,
|
||||
trust_remote_code=engine_model_config.trust_remote_code)
|
||||
trust_remote_code=engine_model_config.trust_remote_code,
|
||||
truncation_side="left")
|
||||
|
||||
async def show_available_models(self) -> ModelList:
|
||||
"""Show available models. Right now we only have one model."""
|
||||
@ -164,15 +167,26 @@ class OpenAIServing:
|
||||
self,
|
||||
request: Union[ChatCompletionRequest, CompletionRequest],
|
||||
prompt: Optional[str] = None,
|
||||
prompt_ids: Optional[List[int]] = None) -> List[int]:
|
||||
prompt_ids: Optional[List[int]] = None,
|
||||
truncate_prompt_tokens: Optional[conint(ge=1)] = None
|
||||
) -> List[int]:
|
||||
if not (prompt or prompt_ids):
|
||||
raise ValueError("Either prompt or prompt_ids should be provided.")
|
||||
if (prompt and prompt_ids):
|
||||
raise ValueError(
|
||||
"Only one of prompt or prompt_ids should be provided.")
|
||||
|
||||
input_ids = prompt_ids if prompt_ids is not None else self.tokenizer(
|
||||
prompt).input_ids
|
||||
if prompt_ids is None:
|
||||
tokenizer_kwargs = {} if truncate_prompt_tokens is None else {
|
||||
"truncation": True,
|
||||
"max_length": truncate_prompt_tokens,
|
||||
}
|
||||
input_ids = self.tokenizer(prompt, **tokenizer_kwargs).input_ids
|
||||
elif truncate_prompt_tokens is not None:
|
||||
input_ids = prompt_ids[-truncate_prompt_tokens:]
|
||||
else:
|
||||
input_ids = prompt_ids
|
||||
|
||||
token_num = len(input_ids)
|
||||
|
||||
if request.max_tokens is None:
|
||||
|
||||
@ -5,6 +5,7 @@ from functools import cached_property
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from pydantic import conint
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
@ -94,6 +95,9 @@ class SamplingParams:
|
||||
tokens in the output. Defaults to True.
|
||||
logits_processors: List of functions that modify logits based on
|
||||
previously generated tokens.
|
||||
truncate_prompt_tokens: If set to an integer k, will use only the last k
|
||||
tokens from the prompt (i.e., left truncation). Defaults to None
|
||||
(i.e., no truncation).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -123,6 +127,7 @@ class SamplingParams:
|
||||
skip_special_tokens: bool = True,
|
||||
spaces_between_special_tokens: bool = True,
|
||||
logits_processors: Optional[List[LogitsProcessor]] = None,
|
||||
truncate_prompt_tokens: Optional[conint(ge=1)] = None,
|
||||
) -> None:
|
||||
self.n = n
|
||||
self.best_of = best_of if best_of is not None else n
|
||||
@ -160,6 +165,7 @@ class SamplingParams:
|
||||
self.spaces_between_special_tokens = spaces_between_special_tokens
|
||||
self.logits_processors = logits_processors
|
||||
self.include_stop_str_in_output = include_stop_str_in_output
|
||||
self.truncate_prompt_tokens = truncate_prompt_tokens
|
||||
self._verify_args()
|
||||
if self.use_beam_search:
|
||||
self._verify_beam_search()
|
||||
@ -216,6 +222,10 @@ class SamplingParams:
|
||||
if self.prompt_logprobs is not None and self.prompt_logprobs < 0:
|
||||
raise ValueError(f"prompt_logprobs must be non-negative, got "
|
||||
f"{self.prompt_logprobs}.")
|
||||
if (self.truncate_prompt_tokens is not None
|
||||
and self.truncate_prompt_tokens < 1):
|
||||
raise ValueError(f"truncate_prompt_tokens must be >= 1, "
|
||||
f"got {self.truncate_prompt_tokens}")
|
||||
if self.stop and not self.detokenize:
|
||||
raise ValueError(
|
||||
"stop strings are only supported when detokenize is True. "
|
||||
@ -300,4 +310,5 @@ class SamplingParams:
|
||||
f"prompt_logprobs={self.prompt_logprobs}, "
|
||||
f"skip_special_tokens={self.skip_special_tokens}, "
|
||||
"spaces_between_special_tokens="
|
||||
f"{self.spaces_between_special_tokens})")
|
||||
f"{self.spaces_between_special_tokens}, "
|
||||
f"truncate_prompt_tokens={self.truncate_prompt_tokens})")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user