From c06170cc8e324f4fe6a0c26b57d09e8c958e11bc Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Fri, 15 Dec 2023 00:45:58 -0800 Subject: [PATCH] Add a flag to include stop string in output text (#1976) --- vllm/engine/llm_engine.py | 7 +++--- vllm/sampling_params.py | 49 ++++++++++++++++++++++----------------- 2 files changed, 32 insertions(+), 24 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index a1acdfde449a..e5a126705ed2 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -682,9 +682,10 @@ class LLMEngine: """Stop the finished sequences.""" for stop_str in sampling_params.stop: if seq.output_text.endswith(stop_str): - # Truncate the output text so that the stop string is - # not included in the output. - seq.output_text = seq.output_text[:-len(stop_str)] + if not sampling_params.include_stop_str_in_output: + # Truncate the output text so that the stop string is + # not included in the output. + seq.output_text = seq.output_text[:-len(stop_str)] seq.status = SequenceStatus.FINISHED_STOPPED return if seq.get_last_token_id() in sampling_params.stop_token_ids: diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 38b7c0b531bd..30a8036a63fc 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -2,6 +2,7 @@ from enum import IntEnum from functools import cached_property from typing import Callable, List, Optional, Union + import torch _SAMPLING_EPS = 1e-5 @@ -70,6 +71,8 @@ class SamplingParams: stop_token_ids: List of tokens that stop the generation when they are generated. The returned output will contain the stop tokens unless the stop tokens are special tokens. + include_stop_str_in_output: Whether to include the stop strings in output + text. Defaults to False. ignore_eos: Whether to ignore the EOS token and continue generating tokens after the EOS token is generated. max_tokens: Maximum number of tokens to generate per output sequence. @@ -103,6 +106,7 @@ class SamplingParams: early_stopping: Union[bool, str] = False, stop: Optional[Union[str, List[str]]] = None, stop_token_ids: Optional[List[int]] = None, + include_stop_str_in_output: bool = False, ignore_eos: bool = False, max_tokens: int = 16, logprobs: Optional[int] = None, @@ -140,6 +144,7 @@ class SamplingParams: self.skip_special_tokens = skip_special_tokens self.spaces_between_special_tokens = spaces_between_special_tokens self.logits_processors = logits_processors + self.include_stop_str_in_output = include_stop_str_in_output self._verify_args() if self.use_beam_search: self._verify_beam_search() @@ -227,24 +232,26 @@ class SamplingParams: return SamplingType.RANDOM def __repr__(self) -> str: - return (f"SamplingParams(n={self.n}, " - f"best_of={self.best_of}, " - f"presence_penalty={self.presence_penalty}, " - f"frequency_penalty={self.frequency_penalty}, " - f"repetition_penalty={self.repetition_penalty}, " - f"temperature={self.temperature}, " - f"top_p={self.top_p}, " - f"top_k={self.top_k}, " - f"min_p={self.min_p}, " - f"use_beam_search={self.use_beam_search}, " - f"length_penalty={self.length_penalty}, " - f"early_stopping={self.early_stopping}, " - f"stop={self.stop}, " - f"stop_token_ids={self.stop_token_ids}, " - f"ignore_eos={self.ignore_eos}, " - f"max_tokens={self.max_tokens}, " - f"logprobs={self.logprobs}, " - f"prompt_logprobs={self.prompt_logprobs}, " - f"skip_special_tokens={self.skip_special_tokens}, " - "spaces_between_special_tokens=" - f"{self.spaces_between_special_tokens})") + return ( + f"SamplingParams(n={self.n}, " + f"best_of={self.best_of}, " + f"presence_penalty={self.presence_penalty}, " + f"frequency_penalty={self.frequency_penalty}, " + f"repetition_penalty={self.repetition_penalty}, " + f"temperature={self.temperature}, " + f"top_p={self.top_p}, " + f"top_k={self.top_k}, " + f"min_p={self.min_p}, " + f"use_beam_search={self.use_beam_search}, " + f"length_penalty={self.length_penalty}, " + f"early_stopping={self.early_stopping}, " + f"stop={self.stop}, " + f"stop_token_ids={self.stop_token_ids}, " + f"include_stop_str_in_output={self.include_stop_str_in_output}, " + f"ignore_eos={self.ignore_eos}, " + f"max_tokens={self.max_tokens}, " + f"logprobs={self.logprobs}, " + f"prompt_logprobs={self.prompt_logprobs}, " + f"skip_special_tokens={self.skip_special_tokens}, " + "spaces_between_special_tokens=" + f"{self.spaces_between_special_tokens})")