Add a flag to include stop string in output text (#1976)

This commit is contained in:
Yunfeng Bai 2023-12-15 00:45:58 -08:00 committed by GitHub
parent 614856da25
commit c06170cc8e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 32 additions and 24 deletions

View File

@ -682,9 +682,10 @@ class LLMEngine:
"""Stop the finished sequences."""
for stop_str in sampling_params.stop:
if seq.output_text.endswith(stop_str):
# Truncate the output text so that the stop string is
# not included in the output.
seq.output_text = seq.output_text[:-len(stop_str)]
if not sampling_params.include_stop_str_in_output:
# Truncate the output text so that the stop string is
# not included in the output.
seq.output_text = seq.output_text[:-len(stop_str)]
seq.status = SequenceStatus.FINISHED_STOPPED
return
if seq.get_last_token_id() in sampling_params.stop_token_ids:

View File

@ -2,6 +2,7 @@
from enum import IntEnum
from functools import cached_property
from typing import Callable, List, Optional, Union
import torch
_SAMPLING_EPS = 1e-5
@ -70,6 +71,8 @@ class SamplingParams:
stop_token_ids: List of tokens that stop the generation when they are
generated. The returned output will contain the stop tokens unless
the stop tokens are special tokens.
include_stop_str_in_output: Whether to include the stop strings in output
text. Defaults to False.
ignore_eos: Whether to ignore the EOS token and continue generating
tokens after the EOS token is generated.
max_tokens: Maximum number of tokens to generate per output sequence.
@ -103,6 +106,7 @@ class SamplingParams:
early_stopping: Union[bool, str] = False,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
include_stop_str_in_output: bool = False,
ignore_eos: bool = False,
max_tokens: int = 16,
logprobs: Optional[int] = None,
@ -140,6 +144,7 @@ class SamplingParams:
self.skip_special_tokens = skip_special_tokens
self.spaces_between_special_tokens = spaces_between_special_tokens
self.logits_processors = logits_processors
self.include_stop_str_in_output = include_stop_str_in_output
self._verify_args()
if self.use_beam_search:
self._verify_beam_search()
@ -227,24 +232,26 @@ class SamplingParams:
return SamplingType.RANDOM
def __repr__(self) -> str:
return (f"SamplingParams(n={self.n}, "
f"best_of={self.best_of}, "
f"presence_penalty={self.presence_penalty}, "
f"frequency_penalty={self.frequency_penalty}, "
f"repetition_penalty={self.repetition_penalty}, "
f"temperature={self.temperature}, "
f"top_p={self.top_p}, "
f"top_k={self.top_k}, "
f"min_p={self.min_p}, "
f"use_beam_search={self.use_beam_search}, "
f"length_penalty={self.length_penalty}, "
f"early_stopping={self.early_stopping}, "
f"stop={self.stop}, "
f"stop_token_ids={self.stop_token_ids}, "
f"ignore_eos={self.ignore_eos}, "
f"max_tokens={self.max_tokens}, "
f"logprobs={self.logprobs}, "
f"prompt_logprobs={self.prompt_logprobs}, "
f"skip_special_tokens={self.skip_special_tokens}, "
"spaces_between_special_tokens="
f"{self.spaces_between_special_tokens})")
return (
f"SamplingParams(n={self.n}, "
f"best_of={self.best_of}, "
f"presence_penalty={self.presence_penalty}, "
f"frequency_penalty={self.frequency_penalty}, "
f"repetition_penalty={self.repetition_penalty}, "
f"temperature={self.temperature}, "
f"top_p={self.top_p}, "
f"top_k={self.top_k}, "
f"min_p={self.min_p}, "
f"use_beam_search={self.use_beam_search}, "
f"length_penalty={self.length_penalty}, "
f"early_stopping={self.early_stopping}, "
f"stop={self.stop}, "
f"stop_token_ids={self.stop_token_ids}, "
f"include_stop_str_in_output={self.include_stop_str_in_output}, "
f"ignore_eos={self.ignore_eos}, "
f"max_tokens={self.max_tokens}, "
f"logprobs={self.logprobs}, "
f"prompt_logprobs={self.prompt_logprobs}, "
f"skip_special_tokens={self.skip_special_tokens}, "
"spaces_between_special_tokens="
f"{self.spaces_between_special_tokens})")