Add a flag to include stop string in output text (#1976)

This commit is contained in:
Yunfeng Bai 2023-12-15 00:45:58 -08:00 committed by GitHub
parent 614856da25
commit c06170cc8e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 32 additions and 24 deletions

View File

@ -682,6 +682,7 @@ class LLMEngine:
"""Stop the finished sequences."""
for stop_str in sampling_params.stop:
if seq.output_text.endswith(stop_str):
if not sampling_params.include_stop_str_in_output:
# Truncate the output text so that the stop string is
# not included in the output.
seq.output_text = seq.output_text[:-len(stop_str)]

View File

@ -2,6 +2,7 @@
from enum import IntEnum
from functools import cached_property
from typing import Callable, List, Optional, Union
import torch
_SAMPLING_EPS = 1e-5
@ -70,6 +71,8 @@ class SamplingParams:
stop_token_ids: List of tokens that stop the generation when they are
generated. The returned output will contain the stop tokens unless
the stop tokens are special tokens.
include_stop_str_in_output: Whether to include the stop strings in output
text. Defaults to False.
ignore_eos: Whether to ignore the EOS token and continue generating
tokens after the EOS token is generated.
max_tokens: Maximum number of tokens to generate per output sequence.
@ -103,6 +106,7 @@ class SamplingParams:
early_stopping: Union[bool, str] = False,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
include_stop_str_in_output: bool = False,
ignore_eos: bool = False,
max_tokens: int = 16,
logprobs: Optional[int] = None,
@ -140,6 +144,7 @@ class SamplingParams:
self.skip_special_tokens = skip_special_tokens
self.spaces_between_special_tokens = spaces_between_special_tokens
self.logits_processors = logits_processors
self.include_stop_str_in_output = include_stop_str_in_output
self._verify_args()
if self.use_beam_search:
self._verify_beam_search()
@ -227,7 +232,8 @@ class SamplingParams:
return SamplingType.RANDOM
def __repr__(self) -> str:
return (f"SamplingParams(n={self.n}, "
return (
f"SamplingParams(n={self.n}, "
f"best_of={self.best_of}, "
f"presence_penalty={self.presence_penalty}, "
f"frequency_penalty={self.frequency_penalty}, "
@ -241,6 +247,7 @@ class SamplingParams:
f"early_stopping={self.early_stopping}, "
f"stop={self.stop}, "
f"stop_token_ids={self.stop_token_ids}, "
f"include_stop_str_in_output={self.include_stop_str_in_output}, "
f"ignore_eos={self.ignore_eos}, "
f"max_tokens={self.max_tokens}, "
f"logprobs={self.logprobs}, "