mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 08:25:01 +08:00
Add skip_special_tokens sampling params (#1186)
This commit is contained in:
parent
649aa730c5
commit
20f7cc4cde
@ -387,7 +387,7 @@ class LLMEngine:
|
||||
child_seqs.append((parent, parent))
|
||||
|
||||
for seq, _ in child_seqs:
|
||||
self._decode_sequence(seq)
|
||||
self._decode_sequence(seq, seq_group.sampling_params)
|
||||
self._check_stop(seq, seq_group.sampling_params)
|
||||
|
||||
# Non-beam search case
|
||||
@ -621,7 +621,8 @@ class LLMEngine:
|
||||
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
|
||||
self.last_logging_time = now
|
||||
|
||||
def _decode_sequence(self, seq: Sequence) -> None:
|
||||
def _decode_sequence(self, seq: Sequence,
|
||||
sampling_params: SamplingParams) -> None:
|
||||
"""Decodes the new token for a sequence."""
|
||||
(new_tokens, new_output_text, prefix_offset,
|
||||
read_offset) = detokenize_incrementally(
|
||||
@ -630,7 +631,7 @@ class LLMEngine:
|
||||
prev_tokens=seq.tokens,
|
||||
prefix_offset=seq.prefix_offset,
|
||||
read_offset=seq.read_offset,
|
||||
skip_special_tokens=True,
|
||||
skip_special_tokens=sampling_params.skip_special_tokens,
|
||||
)
|
||||
if seq.tokens is None:
|
||||
seq.tokens = new_tokens
|
||||
|
||||
@ -225,6 +225,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
||||
top_k=request.top_k,
|
||||
ignore_eos=request.ignore_eos,
|
||||
use_beam_search=request.use_beam_search,
|
||||
skip_special_tokens=request.skip_special_tokens,
|
||||
)
|
||||
except ValueError as e:
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
||||
@ -426,6 +427,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
max_tokens=request.max_tokens,
|
||||
logprobs=request.logprobs,
|
||||
use_beam_search=request.use_beam_search,
|
||||
skip_special_tokens=request.skip_special_tokens,
|
||||
)
|
||||
except ValueError as e:
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
||||
|
||||
@ -71,6 +71,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
ignore_eos: Optional[bool] = False
|
||||
use_beam_search: Optional[bool] = False
|
||||
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||
skip_special_tokens: Optional[bool] = True
|
||||
|
||||
|
||||
class CompletionRequest(BaseModel):
|
||||
@ -96,6 +97,7 @@ class CompletionRequest(BaseModel):
|
||||
ignore_eos: Optional[bool] = False
|
||||
use_beam_search: Optional[bool] = False
|
||||
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||
skip_special_tokens: Optional[bool] = True
|
||||
|
||||
|
||||
class LogProbs(BaseModel):
|
||||
|
||||
@ -60,6 +60,8 @@ class SamplingParams:
|
||||
tokens after the EOS token is generated.
|
||||
max_tokens: Maximum number of tokens to generate per output sequence.
|
||||
logprobs: Number of log probabilities to return per output token.
|
||||
skip_special_tokens: Whether to skip special tokens in the output.
|
||||
Defaults to true.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -79,6 +81,7 @@ class SamplingParams:
|
||||
ignore_eos: bool = False,
|
||||
max_tokens: int = 16,
|
||||
logprobs: Optional[int] = None,
|
||||
skip_special_tokens: bool = True,
|
||||
) -> None:
|
||||
self.n = n
|
||||
self.best_of = best_of if best_of is not None else n
|
||||
@ -103,6 +106,7 @@ class SamplingParams:
|
||||
self.ignore_eos = ignore_eos
|
||||
self.max_tokens = max_tokens
|
||||
self.logprobs = logprobs
|
||||
self.skip_special_tokens = skip_special_tokens
|
||||
|
||||
self._verify_args()
|
||||
if self.use_beam_search:
|
||||
@ -196,4 +200,5 @@ class SamplingParams:
|
||||
f"stop={self.stop}, "
|
||||
f"ignore_eos={self.ignore_eos}, "
|
||||
f"max_tokens={self.max_tokens}, "
|
||||
f"logprobs={self.logprobs})")
|
||||
f"logprobs={self.logprobs}, "
|
||||
f"skip_special_tokens={self.skip_special_tokens})")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user