mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 18:15:01 +08:00
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai> Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
350 lines
13 KiB
Python
350 lines
13 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||
from abc import ABC, abstractmethod
|
||
from typing import Optional
|
||
|
||
import tokenizers
|
||
from packaging import version
|
||
from tokenizers import Tokenizer
|
||
from tokenizers.decoders import DecodeStream
|
||
from transformers import PreTrainedTokenizerFast
|
||
|
||
from vllm.logger import init_logger
|
||
from vllm.transformers_utils.detokenizer_utils import (
|
||
AnyTokenizer, convert_prompt_ids_to_tokens, detokenize_incrementally)
|
||
from vllm.utils import length_from_prompt_token_ids_or_embeds
|
||
from vllm.v1.engine import EngineCoreRequest
|
||
|
||
logger = init_logger(__name__)
|
||
|
||
# Only tokenizers >= 0.21.1 supports DecodeStream used for
|
||
# FastIncrementalDetokenizer.
|
||
USE_FAST_DETOKENIZER = version.parse(
|
||
tokenizers.__version__) >= version.parse("0.21.1")
|
||
|
||
# Error string from https://github.com/huggingface/tokenizers/blob/909fdde2a4ffedd9295206f705eb612be2a91b12/tokenizers/src/tokenizer/mod.rs#L1042
|
||
INVALID_PREFIX_ERR_MSG = "Invalid prefix encountered"
|
||
|
||
|
||
class IncrementalDetokenizer:
|
||
|
||
def __init__(self):
|
||
self.token_ids: list[int] = []
|
||
|
||
@property
|
||
def output_token_ids(self) -> list[int]:
|
||
return self.token_ids
|
||
|
||
def update(self, new_token_ids: list[int],
|
||
stop_terminated: bool) -> Optional[str]:
|
||
self.token_ids.extend(new_token_ids)
|
||
return None
|
||
|
||
def get_next_output_text(self, finished: bool, delta: bool) -> str:
|
||
return ""
|
||
|
||
@classmethod
|
||
def from_new_request(
|
||
cls,
|
||
tokenizer: Optional[AnyTokenizer],
|
||
request: EngineCoreRequest,
|
||
) -> "IncrementalDetokenizer":
|
||
|
||
assert request.sampling_params is not None
|
||
|
||
if tokenizer is None:
|
||
# No tokenizer => skipping detokenization.
|
||
return IncrementalDetokenizer()
|
||
|
||
if USE_FAST_DETOKENIZER and isinstance(tokenizer,
|
||
PreTrainedTokenizerFast):
|
||
# Fast tokenizer => use tokenizers library DecodeStream.
|
||
return FastIncrementalDetokenizer(tokenizer, request)
|
||
|
||
# Fall back to slow python-based incremental detokenization.
|
||
return SlowIncrementalDetokenizer(tokenizer, request)
|
||
|
||
|
||
class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC):
|
||
|
||
def __init__(self, request: EngineCoreRequest):
|
||
super().__init__()
|
||
|
||
# Stop strings
|
||
params = request.sampling_params
|
||
assert params is not None
|
||
self.stop = stop = params.stop
|
||
self.min_tokens = params.min_tokens
|
||
self.include_stop_str_in_output = params.include_stop_str_in_output
|
||
|
||
# Number of chars to hold back when stop strings are to be excluded
|
||
# from streamed output.
|
||
if stop and not self.include_stop_str_in_output:
|
||
self.stop_buffer_length = max(len(s) for s in stop) - 1
|
||
else:
|
||
self.stop_buffer_length = 0
|
||
self._last_output_text_offset: int = 0
|
||
|
||
# Generation data
|
||
self.output_text = ""
|
||
|
||
def update(self, new_token_ids: list[int],
|
||
stop_terminated: bool) -> Optional[str]:
|
||
"""
|
||
Update RequestState for the request_id by:
|
||
1) Detokenize the new token ids incrementally.
|
||
2) Evaluate stop criteria.
|
||
|
||
Return matched stop string or None.
|
||
"""
|
||
if not new_token_ids:
|
||
# Skip detokenization if no new token ids.
|
||
return None
|
||
|
||
if stop_terminated and not self.include_stop_str_in_output:
|
||
# If stop-terminated, exclude last token from detokenization
|
||
# based on include_stop_str_in_output parameter.
|
||
skipped_stop_token_id = new_token_ids[-1]
|
||
new_token_ids = new_token_ids[:-1]
|
||
else:
|
||
skipped_stop_token_id = None
|
||
|
||
# 1) Detokenize the new token ids incrementally.
|
||
# TODO(woosuk): This method becomes very inefficient when the number of
|
||
# new_token_ids is more than 1. We need to optimize this.
|
||
stop_check_offset = len(self.output_text)
|
||
for new_token_id in new_token_ids:
|
||
self.token_ids.append(new_token_id)
|
||
self.output_text += self.decode_next(new_token_id)
|
||
# Support min_tokens, see https://github.com/vllm-project/vllm/pull/22014
|
||
if self.min_tokens and len(
|
||
self.output_token_ids) <= self.min_tokens:
|
||
stop_check_offset = len(self.output_text)
|
||
|
||
if skipped_stop_token_id is not None:
|
||
# Cleanup after skipping detokenization.
|
||
self.token_ids.append(skipped_stop_token_id)
|
||
|
||
# 2) Evaluate stop strings.
|
||
stop_string = None
|
||
if self.stop and len(self.output_token_ids) > self.min_tokens:
|
||
stop = check_stop_strings(
|
||
output_text=self.output_text,
|
||
new_char_count=len(self.output_text) - stop_check_offset,
|
||
stop=self.stop,
|
||
include_in_output=self.include_stop_str_in_output,
|
||
)
|
||
if stop is not None:
|
||
stop_string, truncate_to = stop
|
||
if truncate_to != -1:
|
||
self.output_text = self.output_text[:truncate_to]
|
||
|
||
return stop_string
|
||
|
||
@abstractmethod
|
||
def decode_next(self, next_token_id: int) -> str:
|
||
raise NotImplementedError
|
||
|
||
def get_next_output_text(self, finished: bool, delta: bool) -> str:
|
||
"""If delta is True, only new text since the last call to
|
||
this method is returned"""
|
||
|
||
# We return the full output text if the sequence is finished.
|
||
buffer_length = 0 if finished else self.stop_buffer_length
|
||
if not delta:
|
||
return self.output_text[:-buffer_length] if buffer_length else (
|
||
self.output_text)
|
||
length = len(self.output_text) - buffer_length
|
||
last_offset = self._last_output_text_offset
|
||
if last_offset < length:
|
||
self._last_output_text_offset = length
|
||
return self.output_text[last_offset:length]
|
||
return ""
|
||
|
||
|
||
class FastIncrementalDetokenizer(BaseIncrementalDetokenizer):
|
||
|
||
def __init__(self, tokenizer: PreTrainedTokenizerFast,
|
||
request: EngineCoreRequest):
|
||
super().__init__(request)
|
||
|
||
sampling_params = request.sampling_params
|
||
assert sampling_params is not None
|
||
|
||
self.request_id = request.request_id
|
||
self.skip_special_tokens = sampling_params.skip_special_tokens
|
||
self.stream = DecodeStream(
|
||
skip_special_tokens=self.skip_special_tokens)
|
||
|
||
self.tokenizer: Tokenizer = tokenizer._tokenizer
|
||
|
||
# Find a safe place to start.
|
||
prompt_token_ids = request.prompt_token_ids or []
|
||
prompt_suffix = prompt_token_ids
|
||
prompt_len = len(prompt_suffix)
|
||
if prompt_len > 4:
|
||
for i in range(4, min(prompt_len + 1, 24)):
|
||
suffix = prompt_token_ids[-i:]
|
||
if '<EFBFBD>' not in self.tokenizer.decode(suffix):
|
||
prompt_suffix = suffix
|
||
break
|
||
|
||
# Prime the stream.
|
||
for tid in prompt_suffix:
|
||
self._protected_step(tid)
|
||
|
||
self.spaces_between_special_tokens = (
|
||
sampling_params.skip_special_tokens
|
||
or sampling_params.spaces_between_special_tokens)
|
||
|
||
if not self.spaces_between_special_tokens:
|
||
# Store dict of added token ids so that we can suppress
|
||
# the spaces between them.
|
||
if (added_token_ids := getattr(self.tokenizer, "added_token_ids",
|
||
None)) is None:
|
||
self.tokenizer.added_token_ids = added_token_ids = {
|
||
tid: tok.content
|
||
for tid, tok in
|
||
self.tokenizer.get_added_tokens_decoder().items()
|
||
}
|
||
|
||
if added_token_ids:
|
||
self.last_special = False
|
||
self.added_token_ids = added_token_ids
|
||
else:
|
||
# No added tokens.
|
||
self.spaces_between_special_tokens = True
|
||
|
||
def decode_next(self, next_token_id: int) -> str:
|
||
token = self._protected_step(next_token_id)
|
||
|
||
if not self.spaces_between_special_tokens:
|
||
special_token = self.added_token_ids.get(next_token_id)
|
||
is_special = special_token is not None
|
||
if is_special and self.last_special:
|
||
# Return raw token string without any prefixed spaces.
|
||
token = special_token
|
||
self.last_special = is_special
|
||
|
||
return token or ""
|
||
|
||
def _protected_step(self, next_token_id: int) -> Optional[str]:
|
||
try:
|
||
token = self.stream.step(self.tokenizer, next_token_id)
|
||
except OverflowError:
|
||
# Handle rare observed overflow, still to be diagnosed.
|
||
# See https://github.com/vllm-project/vllm/issues/21951.
|
||
logger.exception("Encountered invalid token id: %d", next_token_id)
|
||
token = None
|
||
except Exception as e:
|
||
if not str(e).startswith(INVALID_PREFIX_ERR_MSG):
|
||
raise e
|
||
# Recover from edge case where tokenizer can produce non-monotonic,
|
||
# invalid UTF-8 output, which breaks the internal state of
|
||
# tokenizers' DecodeStream.
|
||
# See https://github.com/vllm-project/vllm/issues/17448.
|
||
logger.warning(
|
||
"Encountered invalid prefix detokenization error"
|
||
" for request %s, resetting decode stream.", self.request_id)
|
||
self.stream = DecodeStream(
|
||
skip_special_tokens=self.skip_special_tokens)
|
||
token = self.stream.step(self.tokenizer, next_token_id)
|
||
return token
|
||
|
||
|
||
class SlowIncrementalDetokenizer(BaseIncrementalDetokenizer):
|
||
|
||
def __init__(self, tokenizer: AnyTokenizer, request: EngineCoreRequest):
|
||
super().__init__(request)
|
||
|
||
self.tokenizer = tokenizer
|
||
params = request.sampling_params
|
||
assert params is not None
|
||
|
||
self.prompt_len = length_from_prompt_token_ids_or_embeds(
|
||
request.prompt_token_ids, request.prompt_embeds)
|
||
|
||
# Metadata for incremental detokenization.
|
||
if request.prompt_token_ids is not None:
|
||
self.tokens, self.prefix_offset, self.read_offset = (
|
||
convert_prompt_ids_to_tokens(
|
||
tokenizer=tokenizer,
|
||
prompt_ids=request.prompt_token_ids,
|
||
skip_special_tokens=params.skip_special_tokens,
|
||
))
|
||
else:
|
||
# Prompt embedding requests cannot be detokenized, in general.
|
||
self.tokens = [""] * self.prompt_len
|
||
self.prefix_offset = 0
|
||
self.read_offest = 0
|
||
|
||
self.token_ids.extend(request.prompt_token_ids
|
||
or [0] * self.prompt_len)
|
||
|
||
self.skip_special_tokens = params.skip_special_tokens
|
||
self.spaces_between_special_tokens = (
|
||
params.spaces_between_special_tokens)
|
||
|
||
@property
|
||
def output_token_ids(self) -> list[int]:
|
||
return self.token_ids if not self.prompt_len else (
|
||
self.token_ids[self.prompt_len:])
|
||
|
||
def decode_next(self, next_token_id: int) -> str:
|
||
new_tokens, decoded_text, prefix_offset, read_offset = (
|
||
detokenize_incrementally(
|
||
tokenizer=self.tokenizer,
|
||
all_input_ids=self.token_ids,
|
||
prev_tokens=self.tokens,
|
||
prefix_offset=self.prefix_offset,
|
||
read_offset=self.read_offset,
|
||
skip_special_tokens=self.skip_special_tokens,
|
||
spaces_between_special_tokens=self.
|
||
spaces_between_special_tokens,
|
||
))
|
||
|
||
self.tokens.extend(new_tokens)
|
||
self.prefix_offset = prefix_offset
|
||
self.read_offset = read_offset
|
||
|
||
return decoded_text
|
||
|
||
|
||
def check_stop_strings(
|
||
output_text: str,
|
||
new_char_count: int,
|
||
stop: list[str],
|
||
include_in_output: bool,
|
||
) -> Optional[tuple[str, int]]:
|
||
"""Check if any stop strings are matched and truncate sequence
|
||
output text accordingly.
|
||
|
||
Returns tuple (stop_string, offset) if matched or else None.
|
||
|
||
Where stop_string is the matched stop string and offset is the
|
||
length to which output_text should be truncated, or -1 for no
|
||
truncation.
|
||
"""
|
||
if not new_char_count or not stop:
|
||
return None
|
||
|
||
for stop_str in stop:
|
||
stop_string_len = len(stop_str)
|
||
# Avoid searching already-searched text.
|
||
stop_index = output_text.find(stop_str,
|
||
1 - new_char_count - stop_string_len)
|
||
if stop_index == -1:
|
||
continue
|
||
|
||
if include_in_output:
|
||
# Truncate to end of stop string.
|
||
stop_index += stop_string_len
|
||
if stop_index >= len(output_text):
|
||
# No truncation required.
|
||
return stop_str, -1
|
||
|
||
# Truncate the output text to either the beginning
|
||
# or end of the stop string.
|
||
return stop_str, stop_index
|
||
return None
|