Add a flag to include stop string in output text (#1976)

This commit is contained in:
Yunfeng Bai 2023-12-15 00:45:58 -08:00 committed by GitHub
parent 614856da25
commit c06170cc8e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 32 additions and 24 deletions

View File

@ -682,9 +682,10 @@ class LLMEngine:
"""Stop the finished sequences.""" """Stop the finished sequences."""
for stop_str in sampling_params.stop: for stop_str in sampling_params.stop:
if seq.output_text.endswith(stop_str): if seq.output_text.endswith(stop_str):
# Truncate the output text so that the stop string is if not sampling_params.include_stop_str_in_output:
# not included in the output. # Truncate the output text so that the stop string is
seq.output_text = seq.output_text[:-len(stop_str)] # not included in the output.
seq.output_text = seq.output_text[:-len(stop_str)]
seq.status = SequenceStatus.FINISHED_STOPPED seq.status = SequenceStatus.FINISHED_STOPPED
return return
if seq.get_last_token_id() in sampling_params.stop_token_ids: if seq.get_last_token_id() in sampling_params.stop_token_ids:

View File

@ -2,6 +2,7 @@
from enum import IntEnum from enum import IntEnum
from functools import cached_property from functools import cached_property
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
import torch import torch
_SAMPLING_EPS = 1e-5 _SAMPLING_EPS = 1e-5
@ -70,6 +71,8 @@ class SamplingParams:
stop_token_ids: List of tokens that stop the generation when they are stop_token_ids: List of tokens that stop the generation when they are
generated. The returned output will contain the stop tokens unless generated. The returned output will contain the stop tokens unless
the stop tokens are special tokens. the stop tokens are special tokens.
include_stop_str_in_output: Whether to include the stop strings in output
text. Defaults to False.
ignore_eos: Whether to ignore the EOS token and continue generating ignore_eos: Whether to ignore the EOS token and continue generating
tokens after the EOS token is generated. tokens after the EOS token is generated.
max_tokens: Maximum number of tokens to generate per output sequence. max_tokens: Maximum number of tokens to generate per output sequence.
@ -103,6 +106,7 @@ class SamplingParams:
early_stopping: Union[bool, str] = False, early_stopping: Union[bool, str] = False,
stop: Optional[Union[str, List[str]]] = None, stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None, stop_token_ids: Optional[List[int]] = None,
include_stop_str_in_output: bool = False,
ignore_eos: bool = False, ignore_eos: bool = False,
max_tokens: int = 16, max_tokens: int = 16,
logprobs: Optional[int] = None, logprobs: Optional[int] = None,
@ -140,6 +144,7 @@ class SamplingParams:
self.skip_special_tokens = skip_special_tokens self.skip_special_tokens = skip_special_tokens
self.spaces_between_special_tokens = spaces_between_special_tokens self.spaces_between_special_tokens = spaces_between_special_tokens
self.logits_processors = logits_processors self.logits_processors = logits_processors
self.include_stop_str_in_output = include_stop_str_in_output
self._verify_args() self._verify_args()
if self.use_beam_search: if self.use_beam_search:
self._verify_beam_search() self._verify_beam_search()
@ -227,24 +232,26 @@ class SamplingParams:
return SamplingType.RANDOM return SamplingType.RANDOM
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"SamplingParams(n={self.n}, " return (
f"best_of={self.best_of}, " f"SamplingParams(n={self.n}, "
f"presence_penalty={self.presence_penalty}, " f"best_of={self.best_of}, "
f"frequency_penalty={self.frequency_penalty}, " f"presence_penalty={self.presence_penalty}, "
f"repetition_penalty={self.repetition_penalty}, " f"frequency_penalty={self.frequency_penalty}, "
f"temperature={self.temperature}, " f"repetition_penalty={self.repetition_penalty}, "
f"top_p={self.top_p}, " f"temperature={self.temperature}, "
f"top_k={self.top_k}, " f"top_p={self.top_p}, "
f"min_p={self.min_p}, " f"top_k={self.top_k}, "
f"use_beam_search={self.use_beam_search}, " f"min_p={self.min_p}, "
f"length_penalty={self.length_penalty}, " f"use_beam_search={self.use_beam_search}, "
f"early_stopping={self.early_stopping}, " f"length_penalty={self.length_penalty}, "
f"stop={self.stop}, " f"early_stopping={self.early_stopping}, "
f"stop_token_ids={self.stop_token_ids}, " f"stop={self.stop}, "
f"ignore_eos={self.ignore_eos}, " f"stop_token_ids={self.stop_token_ids}, "
f"max_tokens={self.max_tokens}, " f"include_stop_str_in_output={self.include_stop_str_in_output}, "
f"logprobs={self.logprobs}, " f"ignore_eos={self.ignore_eos}, "
f"prompt_logprobs={self.prompt_logprobs}, " f"max_tokens={self.max_tokens}, "
f"skip_special_tokens={self.skip_special_tokens}, " f"logprobs={self.logprobs}, "
"spaces_between_special_tokens=" f"prompt_logprobs={self.prompt_logprobs}, "
f"{self.spaces_between_special_tokens})") f"skip_special_tokens={self.skip_special_tokens}, "
"spaces_between_special_tokens="
f"{self.spaces_between_special_tokens})")