mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 09:16:06 +08:00
feat: support stop_token_ids parameter. (#1097)
This commit is contained in:
parent
2d1e86f1b1
commit
f98b745a81
@ -650,6 +650,9 @@ class LLMEngine:
|
|||||||
seq.output_text = seq.output_text[:-len(stop_str)]
|
seq.output_text = seq.output_text[:-len(stop_str)]
|
||||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||||
return
|
return
|
||||||
|
if seq.get_last_token_id() in sampling_params.stop_token_ids:
|
||||||
|
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||||
|
return
|
||||||
|
|
||||||
# Check if the sequence has reached max_model_len.
|
# Check if the sequence has reached max_model_len.
|
||||||
if seq.get_len() > self.scheduler_config.max_model_len:
|
if seq.get_len() > self.scheduler_config.max_model_len:
|
||||||
|
|||||||
@ -217,6 +217,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
|||||||
temperature=request.temperature,
|
temperature=request.temperature,
|
||||||
top_p=request.top_p,
|
top_p=request.top_p,
|
||||||
stop=request.stop,
|
stop=request.stop,
|
||||||
|
stop_token_ids=request.stop_token_ids,
|
||||||
max_tokens=request.max_tokens,
|
max_tokens=request.max_tokens,
|
||||||
best_of=request.best_of,
|
best_of=request.best_of,
|
||||||
top_k=request.top_k,
|
top_k=request.top_k,
|
||||||
@ -418,6 +419,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
|||||||
top_p=request.top_p,
|
top_p=request.top_p,
|
||||||
top_k=request.top_k,
|
top_k=request.top_k,
|
||||||
stop=request.stop,
|
stop=request.stop,
|
||||||
|
stop_token_ids=request.stop_token_ids,
|
||||||
ignore_eos=request.ignore_eos,
|
ignore_eos=request.ignore_eos,
|
||||||
max_tokens=request.max_tokens,
|
max_tokens=request.max_tokens,
|
||||||
logprobs=request.logprobs,
|
logprobs=request.logprobs,
|
||||||
|
|||||||
@ -70,6 +70,7 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
top_k: Optional[int] = -1
|
top_k: Optional[int] = -1
|
||||||
ignore_eos: Optional[bool] = False
|
ignore_eos: Optional[bool] = False
|
||||||
use_beam_search: Optional[bool] = False
|
use_beam_search: Optional[bool] = False
|
||||||
|
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class CompletionRequest(BaseModel):
|
class CompletionRequest(BaseModel):
|
||||||
@ -94,6 +95,7 @@ class CompletionRequest(BaseModel):
|
|||||||
top_k: Optional[int] = -1
|
top_k: Optional[int] = -1
|
||||||
ignore_eos: Optional[bool] = False
|
ignore_eos: Optional[bool] = False
|
||||||
use_beam_search: Optional[bool] = False
|
use_beam_search: Optional[bool] = False
|
||||||
|
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class LogProbs(BaseModel):
|
class LogProbs(BaseModel):
|
||||||
|
|||||||
@ -45,6 +45,9 @@ class SamplingParams:
|
|||||||
(canonical beam search algorithm).
|
(canonical beam search algorithm).
|
||||||
stop: List of strings that stop the generation when they are generated.
|
stop: List of strings that stop the generation when they are generated.
|
||||||
The returned output will not contain the stop strings.
|
The returned output will not contain the stop strings.
|
||||||
|
stop_token_ids: List of tokens that stop the generation when they are
|
||||||
|
generated. The returned output will contain the stop tokens unless
|
||||||
|
the stop tokens are sepcial tokens.
|
||||||
ignore_eos: Whether to ignore the EOS token and continue generating
|
ignore_eos: Whether to ignore the EOS token and continue generating
|
||||||
tokens after the EOS token is generated.
|
tokens after the EOS token is generated.
|
||||||
max_tokens: Maximum number of tokens to generate per output sequence.
|
max_tokens: Maximum number of tokens to generate per output sequence.
|
||||||
@ -64,6 +67,7 @@ class SamplingParams:
|
|||||||
length_penalty: float = 1.0,
|
length_penalty: float = 1.0,
|
||||||
early_stopping: Union[bool, str] = False,
|
early_stopping: Union[bool, str] = False,
|
||||||
stop: Union[None, str, List[str]] = None,
|
stop: Union[None, str, List[str]] = None,
|
||||||
|
stop_token_ids: List[int] = None,
|
||||||
ignore_eos: bool = False,
|
ignore_eos: bool = False,
|
||||||
max_tokens: int = 16,
|
max_tokens: int = 16,
|
||||||
logprobs: Optional[int] = None,
|
logprobs: Optional[int] = None,
|
||||||
@ -84,6 +88,10 @@ class SamplingParams:
|
|||||||
self.stop = [stop]
|
self.stop = [stop]
|
||||||
else:
|
else:
|
||||||
self.stop = list(stop)
|
self.stop = list(stop)
|
||||||
|
if stop_token_ids is None:
|
||||||
|
self.stop_token_ids = []
|
||||||
|
else:
|
||||||
|
self.stop_token_ids = list(stop_token_ids)
|
||||||
self.ignore_eos = ignore_eos
|
self.ignore_eos = ignore_eos
|
||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens
|
||||||
self.logprobs = logprobs
|
self.logprobs = logprobs
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user