diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 0b691feb8483..f94d22d279cc 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -4,7 +4,7 @@ import time from typing import Dict, List, Literal, Optional, Union import torch -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel, Field, conint, model_validator from vllm.sampling_params import SamplingParams from vllm.utils import random_uuid @@ -229,6 +229,7 @@ class CompletionRequest(BaseModel): min_tokens: Optional[int] = 0 skip_special_tokens: Optional[bool] = True spaces_between_special_tokens: Optional[bool] = True + truncate_prompt_tokens: Optional[conint(ge=1)] = None # doc: end-completion-sampling-params # doc: begin-completion-extra-params @@ -309,6 +310,7 @@ class CompletionRequest(BaseModel): include_stop_str_in_output=self.include_stop_str_in_output, length_penalty=self.length_penalty, logits_processors=logits_processors, + truncate_prompt_tokens=self.truncate_prompt_tokens, ) @model_validator(mode="before") diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 3d1b16f52817..06e7a9225fef 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -137,10 +137,16 @@ class OpenAIServingCompletion(OpenAIServing): for i, prompt in enumerate(prompts): if prompt_is_tokens: input_ids = self._validate_prompt_and_tokenize( - request, prompt_ids=prompt) + request, + prompt_ids=prompt, + truncate_prompt_tokens=sampling_params. + truncate_prompt_tokens) else: input_ids = self._validate_prompt_and_tokenize( - request, prompt=prompt) + request, + prompt=prompt, + truncate_prompt_tokens=sampling_params. + truncate_prompt_tokens) generators.append( self.engine.generate(prompt, diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 9dbd1750e631..8f69388c0251 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -4,6 +4,8 @@ from dataclasses import dataclass from http import HTTPStatus from typing import Dict, List, Optional, Union +from pydantic import conint + from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, CompletionRequest, ErrorResponse, @@ -66,7 +68,8 @@ class OpenAIServing: self.tokenizer = get_tokenizer( engine_model_config.tokenizer, tokenizer_mode=engine_model_config.tokenizer_mode, - trust_remote_code=engine_model_config.trust_remote_code) + trust_remote_code=engine_model_config.trust_remote_code, + truncation_side="left") async def show_available_models(self) -> ModelList: """Show available models. Right now we only have one model.""" @@ -164,15 +167,26 @@ class OpenAIServing: self, request: Union[ChatCompletionRequest, CompletionRequest], prompt: Optional[str] = None, - prompt_ids: Optional[List[int]] = None) -> List[int]: + prompt_ids: Optional[List[int]] = None, + truncate_prompt_tokens: Optional[conint(ge=1)] = None + ) -> List[int]: if not (prompt or prompt_ids): raise ValueError("Either prompt or prompt_ids should be provided.") if (prompt and prompt_ids): raise ValueError( "Only one of prompt or prompt_ids should be provided.") - input_ids = prompt_ids if prompt_ids is not None else self.tokenizer( - prompt).input_ids + if prompt_ids is None: + tokenizer_kwargs = {} if truncate_prompt_tokens is None else { + "truncation": True, + "max_length": truncate_prompt_tokens, + } + input_ids = self.tokenizer(prompt, **tokenizer_kwargs).input_ids + elif truncate_prompt_tokens is not None: + input_ids = prompt_ids[-truncate_prompt_tokens:] + else: + input_ids = prompt_ids + token_num = len(input_ids) if request.max_tokens is None: diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index bbba02a833fc..4fdc3c6dedae 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -5,6 +5,7 @@ from functools import cached_property from typing import Callable, List, Optional, Union import torch +from pydantic import conint _SAMPLING_EPS = 1e-5 @@ -94,6 +95,9 @@ class SamplingParams: tokens in the output. Defaults to True. logits_processors: List of functions that modify logits based on previously generated tokens. + truncate_prompt_tokens: If set to an integer k, will use only the last k + tokens from the prompt (i.e., left truncation). Defaults to None + (i.e., no truncation). """ def __init__( @@ -123,6 +127,7 @@ class SamplingParams: skip_special_tokens: bool = True, spaces_between_special_tokens: bool = True, logits_processors: Optional[List[LogitsProcessor]] = None, + truncate_prompt_tokens: Optional[conint(ge=1)] = None, ) -> None: self.n = n self.best_of = best_of if best_of is not None else n @@ -160,6 +165,7 @@ class SamplingParams: self.spaces_between_special_tokens = spaces_between_special_tokens self.logits_processors = logits_processors self.include_stop_str_in_output = include_stop_str_in_output + self.truncate_prompt_tokens = truncate_prompt_tokens self._verify_args() if self.use_beam_search: self._verify_beam_search() @@ -216,6 +222,10 @@ class SamplingParams: if self.prompt_logprobs is not None and self.prompt_logprobs < 0: raise ValueError(f"prompt_logprobs must be non-negative, got " f"{self.prompt_logprobs}.") + if (self.truncate_prompt_tokens is not None + and self.truncate_prompt_tokens < 1): + raise ValueError(f"truncate_prompt_tokens must be >= 1, " + f"got {self.truncate_prompt_tokens}") if self.stop and not self.detokenize: raise ValueError( "stop strings are only supported when detokenize is True. " @@ -300,4 +310,5 @@ class SamplingParams: f"prompt_logprobs={self.prompt_logprobs}, " f"skip_special_tokens={self.skip_special_tokens}, " "spaces_between_special_tokens=" - f"{self.spaces_between_special_tokens})") + f"{self.spaces_between_special_tokens}, " + f"truncate_prompt_tokens={self.truncate_prompt_tokens})")