feat: support stop_token_ids parameter. (#1097)

This commit is contained in:
Ricardo Lu 2023-09-22 06:34:02 +08:00 committed by GitHub
parent 2d1e86f1b1
commit f98b745a81
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 15 additions and 0 deletions

View File

@ -650,6 +650,9 @@ class LLMEngine:
seq.output_text = seq.output_text[:-len(stop_str)]
seq.status = SequenceStatus.FINISHED_STOPPED
return
if seq.get_last_token_id() in sampling_params.stop_token_ids:
seq.status = SequenceStatus.FINISHED_STOPPED
return
# Check if the sequence has reached max_model_len.
if seq.get_len() > self.scheduler_config.max_model_len:

View File

@ -217,6 +217,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
temperature=request.temperature,
top_p=request.top_p,
stop=request.stop,
stop_token_ids=request.stop_token_ids,
max_tokens=request.max_tokens,
best_of=request.best_of,
top_k=request.top_k,
@ -418,6 +419,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
top_p=request.top_p,
top_k=request.top_k,
stop=request.stop,
stop_token_ids=request.stop_token_ids,
ignore_eos=request.ignore_eos,
max_tokens=request.max_tokens,
logprobs=request.logprobs,

View File

@ -70,6 +70,7 @@ class ChatCompletionRequest(BaseModel):
top_k: Optional[int] = -1
ignore_eos: Optional[bool] = False
use_beam_search: Optional[bool] = False
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
class CompletionRequest(BaseModel):
@ -94,6 +95,7 @@ class CompletionRequest(BaseModel):
top_k: Optional[int] = -1
ignore_eos: Optional[bool] = False
use_beam_search: Optional[bool] = False
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
class LogProbs(BaseModel):

View File

@ -45,6 +45,9 @@ class SamplingParams:
(canonical beam search algorithm).
stop: List of strings that stop the generation when they are generated.
The returned output will not contain the stop strings.
stop_token_ids: List of tokens that stop the generation when they are
generated. The returned output will contain the stop tokens unless
the stop tokens are sepcial tokens.
ignore_eos: Whether to ignore the EOS token and continue generating
tokens after the EOS token is generated.
max_tokens: Maximum number of tokens to generate per output sequence.
@ -64,6 +67,7 @@ class SamplingParams:
length_penalty: float = 1.0,
early_stopping: Union[bool, str] = False,
stop: Union[None, str, List[str]] = None,
stop_token_ids: List[int] = None,
ignore_eos: bool = False,
max_tokens: int = 16,
logprobs: Optional[int] = None,
@ -84,6 +88,10 @@ class SamplingParams:
self.stop = [stop]
else:
self.stop = list(stop)
if stop_token_ids is None:
self.stop_token_ids = []
else:
self.stop_token_ids = list(stop_token_ids)
self.ignore_eos = ignore_eos
self.max_tokens = max_tokens
self.logprobs = logprobs