mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-12 15:45:02 +08:00
Add skip_special_tokens sampling params (#1186)
This commit is contained in:
parent
649aa730c5
commit
20f7cc4cde
@ -387,7 +387,7 @@ class LLMEngine:
|
|||||||
child_seqs.append((parent, parent))
|
child_seqs.append((parent, parent))
|
||||||
|
|
||||||
for seq, _ in child_seqs:
|
for seq, _ in child_seqs:
|
||||||
self._decode_sequence(seq)
|
self._decode_sequence(seq, seq_group.sampling_params)
|
||||||
self._check_stop(seq, seq_group.sampling_params)
|
self._check_stop(seq, seq_group.sampling_params)
|
||||||
|
|
||||||
# Non-beam search case
|
# Non-beam search case
|
||||||
@ -621,7 +621,8 @@ class LLMEngine:
|
|||||||
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
|
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
|
||||||
self.last_logging_time = now
|
self.last_logging_time = now
|
||||||
|
|
||||||
def _decode_sequence(self, seq: Sequence) -> None:
|
def _decode_sequence(self, seq: Sequence,
|
||||||
|
sampling_params: SamplingParams) -> None:
|
||||||
"""Decodes the new token for a sequence."""
|
"""Decodes the new token for a sequence."""
|
||||||
(new_tokens, new_output_text, prefix_offset,
|
(new_tokens, new_output_text, prefix_offset,
|
||||||
read_offset) = detokenize_incrementally(
|
read_offset) = detokenize_incrementally(
|
||||||
@ -630,7 +631,7 @@ class LLMEngine:
|
|||||||
prev_tokens=seq.tokens,
|
prev_tokens=seq.tokens,
|
||||||
prefix_offset=seq.prefix_offset,
|
prefix_offset=seq.prefix_offset,
|
||||||
read_offset=seq.read_offset,
|
read_offset=seq.read_offset,
|
||||||
skip_special_tokens=True,
|
skip_special_tokens=sampling_params.skip_special_tokens,
|
||||||
)
|
)
|
||||||
if seq.tokens is None:
|
if seq.tokens is None:
|
||||||
seq.tokens = new_tokens
|
seq.tokens = new_tokens
|
||||||
|
|||||||
@ -225,6 +225,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
|||||||
top_k=request.top_k,
|
top_k=request.top_k,
|
||||||
ignore_eos=request.ignore_eos,
|
ignore_eos=request.ignore_eos,
|
||||||
use_beam_search=request.use_beam_search,
|
use_beam_search=request.use_beam_search,
|
||||||
|
skip_special_tokens=request.skip_special_tokens,
|
||||||
)
|
)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
||||||
@ -426,6 +427,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
|||||||
max_tokens=request.max_tokens,
|
max_tokens=request.max_tokens,
|
||||||
logprobs=request.logprobs,
|
logprobs=request.logprobs,
|
||||||
use_beam_search=request.use_beam_search,
|
use_beam_search=request.use_beam_search,
|
||||||
|
skip_special_tokens=request.skip_special_tokens,
|
||||||
)
|
)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
||||||
|
|||||||
@ -71,6 +71,7 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
ignore_eos: Optional[bool] = False
|
ignore_eos: Optional[bool] = False
|
||||||
use_beam_search: Optional[bool] = False
|
use_beam_search: Optional[bool] = False
|
||||||
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||||
|
skip_special_tokens: Optional[bool] = True
|
||||||
|
|
||||||
|
|
||||||
class CompletionRequest(BaseModel):
|
class CompletionRequest(BaseModel):
|
||||||
@ -96,6 +97,7 @@ class CompletionRequest(BaseModel):
|
|||||||
ignore_eos: Optional[bool] = False
|
ignore_eos: Optional[bool] = False
|
||||||
use_beam_search: Optional[bool] = False
|
use_beam_search: Optional[bool] = False
|
||||||
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||||
|
skip_special_tokens: Optional[bool] = True
|
||||||
|
|
||||||
|
|
||||||
class LogProbs(BaseModel):
|
class LogProbs(BaseModel):
|
||||||
|
|||||||
@ -60,6 +60,8 @@ class SamplingParams:
|
|||||||
tokens after the EOS token is generated.
|
tokens after the EOS token is generated.
|
||||||
max_tokens: Maximum number of tokens to generate per output sequence.
|
max_tokens: Maximum number of tokens to generate per output sequence.
|
||||||
logprobs: Number of log probabilities to return per output token.
|
logprobs: Number of log probabilities to return per output token.
|
||||||
|
skip_special_tokens: Whether to skip special tokens in the output.
|
||||||
|
Defaults to true.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -79,6 +81,7 @@ class SamplingParams:
|
|||||||
ignore_eos: bool = False,
|
ignore_eos: bool = False,
|
||||||
max_tokens: int = 16,
|
max_tokens: int = 16,
|
||||||
logprobs: Optional[int] = None,
|
logprobs: Optional[int] = None,
|
||||||
|
skip_special_tokens: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.n = n
|
self.n = n
|
||||||
self.best_of = best_of if best_of is not None else n
|
self.best_of = best_of if best_of is not None else n
|
||||||
@ -103,6 +106,7 @@ class SamplingParams:
|
|||||||
self.ignore_eos = ignore_eos
|
self.ignore_eos = ignore_eos
|
||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens
|
||||||
self.logprobs = logprobs
|
self.logprobs = logprobs
|
||||||
|
self.skip_special_tokens = skip_special_tokens
|
||||||
|
|
||||||
self._verify_args()
|
self._verify_args()
|
||||||
if self.use_beam_search:
|
if self.use_beam_search:
|
||||||
@ -196,4 +200,5 @@ class SamplingParams:
|
|||||||
f"stop={self.stop}, "
|
f"stop={self.stop}, "
|
||||||
f"ignore_eos={self.ignore_eos}, "
|
f"ignore_eos={self.ignore_eos}, "
|
||||||
f"max_tokens={self.max_tokens}, "
|
f"max_tokens={self.max_tokens}, "
|
||||||
f"logprobs={self.logprobs})")
|
f"logprobs={self.logprobs}, "
|
||||||
|
f"skip_special_tokens={self.skip_special_tokens})")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user