mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:26:00 +08:00
feat: support stop_token_ids parameter. (#1097)
This commit is contained in:
parent
2d1e86f1b1
commit
f98b745a81
@ -650,6 +650,9 @@ class LLMEngine:
|
||||
seq.output_text = seq.output_text[:-len(stop_str)]
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
return
|
||||
if seq.get_last_token_id() in sampling_params.stop_token_ids:
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
return
|
||||
|
||||
# Check if the sequence has reached max_model_len.
|
||||
if seq.get_len() > self.scheduler_config.max_model_len:
|
||||
|
||||
@ -217,6 +217,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
stop=request.stop,
|
||||
stop_token_ids=request.stop_token_ids,
|
||||
max_tokens=request.max_tokens,
|
||||
best_of=request.best_of,
|
||||
top_k=request.top_k,
|
||||
@ -418,6 +419,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
top_p=request.top_p,
|
||||
top_k=request.top_k,
|
||||
stop=request.stop,
|
||||
stop_token_ids=request.stop_token_ids,
|
||||
ignore_eos=request.ignore_eos,
|
||||
max_tokens=request.max_tokens,
|
||||
logprobs=request.logprobs,
|
||||
|
||||
@ -70,6 +70,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
top_k: Optional[int] = -1
|
||||
ignore_eos: Optional[bool] = False
|
||||
use_beam_search: Optional[bool] = False
|
||||
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class CompletionRequest(BaseModel):
|
||||
@ -94,6 +95,7 @@ class CompletionRequest(BaseModel):
|
||||
top_k: Optional[int] = -1
|
||||
ignore_eos: Optional[bool] = False
|
||||
use_beam_search: Optional[bool] = False
|
||||
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class LogProbs(BaseModel):
|
||||
|
||||
@ -45,6 +45,9 @@ class SamplingParams:
|
||||
(canonical beam search algorithm).
|
||||
stop: List of strings that stop the generation when they are generated.
|
||||
The returned output will not contain the stop strings.
|
||||
stop_token_ids: List of tokens that stop the generation when they are
|
||||
generated. The returned output will contain the stop tokens unless
|
||||
the stop tokens are sepcial tokens.
|
||||
ignore_eos: Whether to ignore the EOS token and continue generating
|
||||
tokens after the EOS token is generated.
|
||||
max_tokens: Maximum number of tokens to generate per output sequence.
|
||||
@ -64,6 +67,7 @@ class SamplingParams:
|
||||
length_penalty: float = 1.0,
|
||||
early_stopping: Union[bool, str] = False,
|
||||
stop: Union[None, str, List[str]] = None,
|
||||
stop_token_ids: List[int] = None,
|
||||
ignore_eos: bool = False,
|
||||
max_tokens: int = 16,
|
||||
logprobs: Optional[int] = None,
|
||||
@ -84,6 +88,10 @@ class SamplingParams:
|
||||
self.stop = [stop]
|
||||
else:
|
||||
self.stop = list(stop)
|
||||
if stop_token_ids is None:
|
||||
self.stop_token_ids = []
|
||||
else:
|
||||
self.stop_token_ids = list(stop_token_ids)
|
||||
self.ignore_eos = ignore_eos
|
||||
self.max_tokens = max_tokens
|
||||
self.logprobs = logprobs
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user