mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 18:45:35 +08:00
Add a flag to include stop string in output text (#1976)
This commit is contained in:
parent
614856da25
commit
c06170cc8e
@ -682,6 +682,7 @@ class LLMEngine:
|
|||||||
"""Stop the finished sequences."""
|
"""Stop the finished sequences."""
|
||||||
for stop_str in sampling_params.stop:
|
for stop_str in sampling_params.stop:
|
||||||
if seq.output_text.endswith(stop_str):
|
if seq.output_text.endswith(stop_str):
|
||||||
|
if not sampling_params.include_stop_str_in_output:
|
||||||
# Truncate the output text so that the stop string is
|
# Truncate the output text so that the stop string is
|
||||||
# not included in the output.
|
# not included in the output.
|
||||||
seq.output_text = seq.output_text[:-len(stop_str)]
|
seq.output_text = seq.output_text[:-len(stop_str)]
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import Callable, List, Optional, Union
|
from typing import Callable, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
_SAMPLING_EPS = 1e-5
|
_SAMPLING_EPS = 1e-5
|
||||||
@ -70,6 +71,8 @@ class SamplingParams:
|
|||||||
stop_token_ids: List of tokens that stop the generation when they are
|
stop_token_ids: List of tokens that stop the generation when they are
|
||||||
generated. The returned output will contain the stop tokens unless
|
generated. The returned output will contain the stop tokens unless
|
||||||
the stop tokens are special tokens.
|
the stop tokens are special tokens.
|
||||||
|
include_stop_str_in_output: Whether to include the stop strings in output
|
||||||
|
text. Defaults to False.
|
||||||
ignore_eos: Whether to ignore the EOS token and continue generating
|
ignore_eos: Whether to ignore the EOS token and continue generating
|
||||||
tokens after the EOS token is generated.
|
tokens after the EOS token is generated.
|
||||||
max_tokens: Maximum number of tokens to generate per output sequence.
|
max_tokens: Maximum number of tokens to generate per output sequence.
|
||||||
@ -103,6 +106,7 @@ class SamplingParams:
|
|||||||
early_stopping: Union[bool, str] = False,
|
early_stopping: Union[bool, str] = False,
|
||||||
stop: Optional[Union[str, List[str]]] = None,
|
stop: Optional[Union[str, List[str]]] = None,
|
||||||
stop_token_ids: Optional[List[int]] = None,
|
stop_token_ids: Optional[List[int]] = None,
|
||||||
|
include_stop_str_in_output: bool = False,
|
||||||
ignore_eos: bool = False,
|
ignore_eos: bool = False,
|
||||||
max_tokens: int = 16,
|
max_tokens: int = 16,
|
||||||
logprobs: Optional[int] = None,
|
logprobs: Optional[int] = None,
|
||||||
@ -140,6 +144,7 @@ class SamplingParams:
|
|||||||
self.skip_special_tokens = skip_special_tokens
|
self.skip_special_tokens = skip_special_tokens
|
||||||
self.spaces_between_special_tokens = spaces_between_special_tokens
|
self.spaces_between_special_tokens = spaces_between_special_tokens
|
||||||
self.logits_processors = logits_processors
|
self.logits_processors = logits_processors
|
||||||
|
self.include_stop_str_in_output = include_stop_str_in_output
|
||||||
self._verify_args()
|
self._verify_args()
|
||||||
if self.use_beam_search:
|
if self.use_beam_search:
|
||||||
self._verify_beam_search()
|
self._verify_beam_search()
|
||||||
@ -227,7 +232,8 @@ class SamplingParams:
|
|||||||
return SamplingType.RANDOM
|
return SamplingType.RANDOM
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (f"SamplingParams(n={self.n}, "
|
return (
|
||||||
|
f"SamplingParams(n={self.n}, "
|
||||||
f"best_of={self.best_of}, "
|
f"best_of={self.best_of}, "
|
||||||
f"presence_penalty={self.presence_penalty}, "
|
f"presence_penalty={self.presence_penalty}, "
|
||||||
f"frequency_penalty={self.frequency_penalty}, "
|
f"frequency_penalty={self.frequency_penalty}, "
|
||||||
@ -241,6 +247,7 @@ class SamplingParams:
|
|||||||
f"early_stopping={self.early_stopping}, "
|
f"early_stopping={self.early_stopping}, "
|
||||||
f"stop={self.stop}, "
|
f"stop={self.stop}, "
|
||||||
f"stop_token_ids={self.stop_token_ids}, "
|
f"stop_token_ids={self.stop_token_ids}, "
|
||||||
|
f"include_stop_str_in_output={self.include_stop_str_in_output}, "
|
||||||
f"ignore_eos={self.ignore_eos}, "
|
f"ignore_eos={self.ignore_eos}, "
|
||||||
f"max_tokens={self.max_tokens}, "
|
f"max_tokens={self.max_tokens}, "
|
||||||
f"logprobs={self.logprobs}, "
|
f"logprobs={self.logprobs}, "
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user