mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:15:42 +08:00
Add a flag to include stop string in output text (#1976)
This commit is contained in:
parent
614856da25
commit
c06170cc8e
@ -682,9 +682,10 @@ class LLMEngine:
|
|||||||
"""Stop the finished sequences."""
|
"""Stop the finished sequences."""
|
||||||
for stop_str in sampling_params.stop:
|
for stop_str in sampling_params.stop:
|
||||||
if seq.output_text.endswith(stop_str):
|
if seq.output_text.endswith(stop_str):
|
||||||
# Truncate the output text so that the stop string is
|
if not sampling_params.include_stop_str_in_output:
|
||||||
# not included in the output.
|
# Truncate the output text so that the stop string is
|
||||||
seq.output_text = seq.output_text[:-len(stop_str)]
|
# not included in the output.
|
||||||
|
seq.output_text = seq.output_text[:-len(stop_str)]
|
||||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||||
return
|
return
|
||||||
if seq.get_last_token_id() in sampling_params.stop_token_ids:
|
if seq.get_last_token_id() in sampling_params.stop_token_ids:
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import Callable, List, Optional, Union
|
from typing import Callable, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
_SAMPLING_EPS = 1e-5
|
_SAMPLING_EPS = 1e-5
|
||||||
@ -70,6 +71,8 @@ class SamplingParams:
|
|||||||
stop_token_ids: List of tokens that stop the generation when they are
|
stop_token_ids: List of tokens that stop the generation when they are
|
||||||
generated. The returned output will contain the stop tokens unless
|
generated. The returned output will contain the stop tokens unless
|
||||||
the stop tokens are special tokens.
|
the stop tokens are special tokens.
|
||||||
|
include_stop_str_in_output: Whether to include the stop strings in output
|
||||||
|
text. Defaults to False.
|
||||||
ignore_eos: Whether to ignore the EOS token and continue generating
|
ignore_eos: Whether to ignore the EOS token and continue generating
|
||||||
tokens after the EOS token is generated.
|
tokens after the EOS token is generated.
|
||||||
max_tokens: Maximum number of tokens to generate per output sequence.
|
max_tokens: Maximum number of tokens to generate per output sequence.
|
||||||
@ -103,6 +106,7 @@ class SamplingParams:
|
|||||||
early_stopping: Union[bool, str] = False,
|
early_stopping: Union[bool, str] = False,
|
||||||
stop: Optional[Union[str, List[str]]] = None,
|
stop: Optional[Union[str, List[str]]] = None,
|
||||||
stop_token_ids: Optional[List[int]] = None,
|
stop_token_ids: Optional[List[int]] = None,
|
||||||
|
include_stop_str_in_output: bool = False,
|
||||||
ignore_eos: bool = False,
|
ignore_eos: bool = False,
|
||||||
max_tokens: int = 16,
|
max_tokens: int = 16,
|
||||||
logprobs: Optional[int] = None,
|
logprobs: Optional[int] = None,
|
||||||
@ -140,6 +144,7 @@ class SamplingParams:
|
|||||||
self.skip_special_tokens = skip_special_tokens
|
self.skip_special_tokens = skip_special_tokens
|
||||||
self.spaces_between_special_tokens = spaces_between_special_tokens
|
self.spaces_between_special_tokens = spaces_between_special_tokens
|
||||||
self.logits_processors = logits_processors
|
self.logits_processors = logits_processors
|
||||||
|
self.include_stop_str_in_output = include_stop_str_in_output
|
||||||
self._verify_args()
|
self._verify_args()
|
||||||
if self.use_beam_search:
|
if self.use_beam_search:
|
||||||
self._verify_beam_search()
|
self._verify_beam_search()
|
||||||
@ -227,24 +232,26 @@ class SamplingParams:
|
|||||||
return SamplingType.RANDOM
|
return SamplingType.RANDOM
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (f"SamplingParams(n={self.n}, "
|
return (
|
||||||
f"best_of={self.best_of}, "
|
f"SamplingParams(n={self.n}, "
|
||||||
f"presence_penalty={self.presence_penalty}, "
|
f"best_of={self.best_of}, "
|
||||||
f"frequency_penalty={self.frequency_penalty}, "
|
f"presence_penalty={self.presence_penalty}, "
|
||||||
f"repetition_penalty={self.repetition_penalty}, "
|
f"frequency_penalty={self.frequency_penalty}, "
|
||||||
f"temperature={self.temperature}, "
|
f"repetition_penalty={self.repetition_penalty}, "
|
||||||
f"top_p={self.top_p}, "
|
f"temperature={self.temperature}, "
|
||||||
f"top_k={self.top_k}, "
|
f"top_p={self.top_p}, "
|
||||||
f"min_p={self.min_p}, "
|
f"top_k={self.top_k}, "
|
||||||
f"use_beam_search={self.use_beam_search}, "
|
f"min_p={self.min_p}, "
|
||||||
f"length_penalty={self.length_penalty}, "
|
f"use_beam_search={self.use_beam_search}, "
|
||||||
f"early_stopping={self.early_stopping}, "
|
f"length_penalty={self.length_penalty}, "
|
||||||
f"stop={self.stop}, "
|
f"early_stopping={self.early_stopping}, "
|
||||||
f"stop_token_ids={self.stop_token_ids}, "
|
f"stop={self.stop}, "
|
||||||
f"ignore_eos={self.ignore_eos}, "
|
f"stop_token_ids={self.stop_token_ids}, "
|
||||||
f"max_tokens={self.max_tokens}, "
|
f"include_stop_str_in_output={self.include_stop_str_in_output}, "
|
||||||
f"logprobs={self.logprobs}, "
|
f"ignore_eos={self.ignore_eos}, "
|
||||||
f"prompt_logprobs={self.prompt_logprobs}, "
|
f"max_tokens={self.max_tokens}, "
|
||||||
f"skip_special_tokens={self.skip_special_tokens}, "
|
f"logprobs={self.logprobs}, "
|
||||||
"spaces_between_special_tokens="
|
f"prompt_logprobs={self.prompt_logprobs}, "
|
||||||
f"{self.spaces_between_special_tokens})")
|
f"skip_special_tokens={self.skip_special_tokens}, "
|
||||||
|
"spaces_between_special_tokens="
|
||||||
|
f"{self.spaces_between_special_tokens})")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user