mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:15:01 +08:00
Add a flag to include stop string in output text (#1976)
This commit is contained in:
parent
614856da25
commit
c06170cc8e
@ -682,9 +682,10 @@ class LLMEngine:
|
||||
"""Stop the finished sequences."""
|
||||
for stop_str in sampling_params.stop:
|
||||
if seq.output_text.endswith(stop_str):
|
||||
# Truncate the output text so that the stop string is
|
||||
# not included in the output.
|
||||
seq.output_text = seq.output_text[:-len(stop_str)]
|
||||
if not sampling_params.include_stop_str_in_output:
|
||||
# Truncate the output text so that the stop string is
|
||||
# not included in the output.
|
||||
seq.output_text = seq.output_text[:-len(stop_str)]
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
return
|
||||
if seq.get_last_token_id() in sampling_params.stop_token_ids:
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
from enum import IntEnum
|
||||
from functools import cached_property
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
@ -70,6 +71,8 @@ class SamplingParams:
|
||||
stop_token_ids: List of tokens that stop the generation when they are
|
||||
generated. The returned output will contain the stop tokens unless
|
||||
the stop tokens are special tokens.
|
||||
include_stop_str_in_output: Whether to include the stop strings in output
|
||||
text. Defaults to False.
|
||||
ignore_eos: Whether to ignore the EOS token and continue generating
|
||||
tokens after the EOS token is generated.
|
||||
max_tokens: Maximum number of tokens to generate per output sequence.
|
||||
@ -103,6 +106,7 @@ class SamplingParams:
|
||||
early_stopping: Union[bool, str] = False,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
include_stop_str_in_output: bool = False,
|
||||
ignore_eos: bool = False,
|
||||
max_tokens: int = 16,
|
||||
logprobs: Optional[int] = None,
|
||||
@ -140,6 +144,7 @@ class SamplingParams:
|
||||
self.skip_special_tokens = skip_special_tokens
|
||||
self.spaces_between_special_tokens = spaces_between_special_tokens
|
||||
self.logits_processors = logits_processors
|
||||
self.include_stop_str_in_output = include_stop_str_in_output
|
||||
self._verify_args()
|
||||
if self.use_beam_search:
|
||||
self._verify_beam_search()
|
||||
@ -227,24 +232,26 @@ class SamplingParams:
|
||||
return SamplingType.RANDOM
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"SamplingParams(n={self.n}, "
|
||||
f"best_of={self.best_of}, "
|
||||
f"presence_penalty={self.presence_penalty}, "
|
||||
f"frequency_penalty={self.frequency_penalty}, "
|
||||
f"repetition_penalty={self.repetition_penalty}, "
|
||||
f"temperature={self.temperature}, "
|
||||
f"top_p={self.top_p}, "
|
||||
f"top_k={self.top_k}, "
|
||||
f"min_p={self.min_p}, "
|
||||
f"use_beam_search={self.use_beam_search}, "
|
||||
f"length_penalty={self.length_penalty}, "
|
||||
f"early_stopping={self.early_stopping}, "
|
||||
f"stop={self.stop}, "
|
||||
f"stop_token_ids={self.stop_token_ids}, "
|
||||
f"ignore_eos={self.ignore_eos}, "
|
||||
f"max_tokens={self.max_tokens}, "
|
||||
f"logprobs={self.logprobs}, "
|
||||
f"prompt_logprobs={self.prompt_logprobs}, "
|
||||
f"skip_special_tokens={self.skip_special_tokens}, "
|
||||
"spaces_between_special_tokens="
|
||||
f"{self.spaces_between_special_tokens})")
|
||||
return (
|
||||
f"SamplingParams(n={self.n}, "
|
||||
f"best_of={self.best_of}, "
|
||||
f"presence_penalty={self.presence_penalty}, "
|
||||
f"frequency_penalty={self.frequency_penalty}, "
|
||||
f"repetition_penalty={self.repetition_penalty}, "
|
||||
f"temperature={self.temperature}, "
|
||||
f"top_p={self.top_p}, "
|
||||
f"top_k={self.top_k}, "
|
||||
f"min_p={self.min_p}, "
|
||||
f"use_beam_search={self.use_beam_search}, "
|
||||
f"length_penalty={self.length_penalty}, "
|
||||
f"early_stopping={self.early_stopping}, "
|
||||
f"stop={self.stop}, "
|
||||
f"stop_token_ids={self.stop_token_ids}, "
|
||||
f"include_stop_str_in_output={self.include_stop_str_in_output}, "
|
||||
f"ignore_eos={self.ignore_eos}, "
|
||||
f"max_tokens={self.max_tokens}, "
|
||||
f"logprobs={self.logprobs}, "
|
||||
f"prompt_logprobs={self.prompt_logprobs}, "
|
||||
f"skip_special_tokens={self.skip_special_tokens}, "
|
||||
"spaces_between_special_tokens="
|
||||
f"{self.spaces_between_special_tokens})")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user