From 20f7cc4cdebcbcad788fbe1b06e5e07f8d145b77 Mon Sep 17 00:00:00 2001 From: Dan Lord Date: Wed, 27 Sep 2023 19:21:42 -0700 Subject: [PATCH] Add `skip_special_tokens` sampling params (#1186) --- vllm/engine/llm_engine.py | 7 ++++--- vllm/entrypoints/openai/api_server.py | 2 ++ vllm/entrypoints/openai/protocol.py | 2 ++ vllm/sampling_params.py | 7 ++++++- 4 files changed, 14 insertions(+), 4 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 65095be9c74c..9922bda7d14c 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -387,7 +387,7 @@ class LLMEngine: child_seqs.append((parent, parent)) for seq, _ in child_seqs: - self._decode_sequence(seq) + self._decode_sequence(seq, seq_group.sampling_params) self._check_stop(seq, seq_group.sampling_params) # Non-beam search case @@ -621,7 +621,8 @@ class LLMEngine: f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%") self.last_logging_time = now - def _decode_sequence(self, seq: Sequence) -> None: + def _decode_sequence(self, seq: Sequence, + sampling_params: SamplingParams) -> None: """Decodes the new token for a sequence.""" (new_tokens, new_output_text, prefix_offset, read_offset) = detokenize_incrementally( @@ -630,7 +631,7 @@ class LLMEngine: prev_tokens=seq.tokens, prefix_offset=seq.prefix_offset, read_offset=seq.read_offset, - skip_special_tokens=True, + skip_special_tokens=sampling_params.skip_special_tokens, ) if seq.tokens is None: seq.tokens = new_tokens diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index d260396e47c4..643dd06cb17d 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -225,6 +225,7 @@ async def create_chat_completion(request: ChatCompletionRequest, top_k=request.top_k, ignore_eos=request.ignore_eos, use_beam_search=request.use_beam_search, + skip_special_tokens=request.skip_special_tokens, ) except ValueError as e: return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) @@ -426,6 +427,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request): max_tokens=request.max_tokens, logprobs=request.logprobs, use_beam_search=request.use_beam_search, + skip_special_tokens=request.skip_special_tokens, ) except ValueError as e: return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 473400a7faf9..12b7453de819 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -71,6 +71,7 @@ class ChatCompletionRequest(BaseModel): ignore_eos: Optional[bool] = False use_beam_search: Optional[bool] = False stop_token_ids: Optional[List[int]] = Field(default_factory=list) + skip_special_tokens: Optional[bool] = True class CompletionRequest(BaseModel): @@ -96,6 +97,7 @@ class CompletionRequest(BaseModel): ignore_eos: Optional[bool] = False use_beam_search: Optional[bool] = False stop_token_ids: Optional[List[int]] = Field(default_factory=list) + skip_special_tokens: Optional[bool] = True class LogProbs(BaseModel): diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 53bd743fce9d..5206eb0b8c4d 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -60,6 +60,8 @@ class SamplingParams: tokens after the EOS token is generated. max_tokens: Maximum number of tokens to generate per output sequence. logprobs: Number of log probabilities to return per output token. + skip_special_tokens: Whether to skip special tokens in the output. + Defaults to true. """ def __init__( @@ -79,6 +81,7 @@ class SamplingParams: ignore_eos: bool = False, max_tokens: int = 16, logprobs: Optional[int] = None, + skip_special_tokens: bool = True, ) -> None: self.n = n self.best_of = best_of if best_of is not None else n @@ -103,6 +106,7 @@ class SamplingParams: self.ignore_eos = ignore_eos self.max_tokens = max_tokens self.logprobs = logprobs + self.skip_special_tokens = skip_special_tokens self._verify_args() if self.use_beam_search: @@ -196,4 +200,5 @@ class SamplingParams: f"stop={self.stop}, " f"ignore_eos={self.ignore_eos}, " f"max_tokens={self.max_tokens}, " - f"logprobs={self.logprobs})") + f"logprobs={self.logprobs}, " + f"skip_special_tokens={self.skip_special_tokens})")