mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 22:05:44 +08:00
Use TGI-like incremental detokenization (#984)
This commit is contained in:
parent
3272d7a0b7
commit
9841d48a10
55
tests/engine/test_detokenize.py
Normal file
55
tests/engine/test_detokenize.py
Normal file
@ -0,0 +1,55 @@
|
||||
import pytest
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.transformers_utils.tokenizer import detokenize_incrementally
|
||||
|
||||
TRUTH = [
|
||||
"Hello here, this is a simple test",
|
||||
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be used in production environments, where inference and serving",
|
||||
"我很感谢你的热情"
|
||||
]
|
||||
TOKENIZERS = [
|
||||
"facebook/opt-125m",
|
||||
"gpt2",
|
||||
"bigcode/tiny_starcoder_py",
|
||||
"EleutherAI/gpt-j-6b",
|
||||
"EleutherAI/pythia-70m",
|
||||
"bigscience/bloom-560m",
|
||||
"mosaicml/mpt-7b",
|
||||
"tiiuae/falcon-7b",
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
"codellama/CodeLlama-7b-hf",
|
||||
]
|
||||
|
||||
|
||||
def _run_incremental_decode(tokenizer, all_input_ids):
|
||||
decoded_text = ""
|
||||
offset = 0
|
||||
token_offset = 0
|
||||
prev_tokens = None
|
||||
for i in range(len(all_input_ids)):
|
||||
new_tokens, text, offset, token_offset = detokenize_incrementally(
|
||||
tokenizer,
|
||||
all_input_ids[:i + 1],
|
||||
prev_tokens,
|
||||
offset,
|
||||
token_offset,
|
||||
skip_special_tokens=False)
|
||||
decoded_text += text
|
||||
if prev_tokens is None:
|
||||
prev_tokens = new_tokens
|
||||
else:
|
||||
prev_tokens += new_tokens
|
||||
return decoded_text
|
||||
|
||||
|
||||
@pytest.mark.parametrize("truth", TRUTH)
|
||||
@pytest.mark.parametrize("tokenizer_id", TOKENIZERS)
|
||||
def test_decode_streaming(tokenizer_id, truth):
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
|
||||
all_input_ids = tokenizer(truth, add_special_tokens=False)["input_ids"]
|
||||
|
||||
decoded_text = _run_incremental_decode(tokenizer, all_input_ids)
|
||||
|
||||
assert decoded_text == truth
|
||||
@ -623,15 +623,22 @@ class LLMEngine:
|
||||
|
||||
def _decode_sequence(self, seq: Sequence) -> None:
|
||||
"""Decodes the new token for a sequence."""
|
||||
new_token, new_output_text = detokenize_incrementally(
|
||||
self.tokenizer,
|
||||
seq.output_tokens,
|
||||
seq.get_last_token_id(),
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
if new_token is not None:
|
||||
seq.output_tokens.append(new_token)
|
||||
seq.output_text = new_output_text
|
||||
(new_tokens, new_output_text, prefix_offset,
|
||||
read_offset) = detokenize_incrementally(
|
||||
self.tokenizer,
|
||||
all_input_ids=seq.get_token_ids(),
|
||||
prev_tokens=seq.tokens,
|
||||
prefix_offset=seq.prefix_offset,
|
||||
read_offset=seq.read_offset,
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
if seq.tokens is None:
|
||||
seq.tokens = new_tokens
|
||||
else:
|
||||
seq.tokens.extend(new_tokens)
|
||||
seq.prefix_offset = prefix_offset
|
||||
seq.read_offset = read_offset
|
||||
seq.output_text += new_output_text
|
||||
|
||||
def _check_stop(self, seq: Sequence,
|
||||
sampling_params: SamplingParams) -> None:
|
||||
|
||||
@ -114,7 +114,6 @@ class Sequence:
|
||||
|
||||
self.data = SequenceData(prompt_token_ids)
|
||||
self.output_logprobs: List[Dict[int, float]] = []
|
||||
self.output_tokens: List[str] = []
|
||||
self.output_text = ""
|
||||
|
||||
self.logical_token_blocks: List[LogicalTokenBlock] = []
|
||||
@ -122,6 +121,12 @@ class Sequence:
|
||||
self._append_tokens_to_blocks(prompt_token_ids)
|
||||
self.status = SequenceStatus.WAITING
|
||||
|
||||
# Used for incremental detokenization
|
||||
self.prefix_offset = 0
|
||||
self.read_offset = 0
|
||||
# Input + output tokens
|
||||
self.tokens: Optional[List[str]] = None
|
||||
|
||||
def _append_logical_block(self) -> None:
|
||||
block = LogicalTokenBlock(
|
||||
block_number=len(self.logical_token_blocks),
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import List, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from transformers import (AutoTokenizer, PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast)
|
||||
@ -67,33 +67,11 @@ def get_tokenizer(
|
||||
return tokenizer
|
||||
|
||||
|
||||
def detokenize_incrementally(
|
||||
def _convert_tokens_to_string_with_added_encoders(
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
prev_output_tokens: List[str],
|
||||
new_token_id: int,
|
||||
output_tokens: List[str],
|
||||
skip_special_tokens: bool,
|
||||
) -> Tuple[str, str]:
|
||||
"""Detokenizes the new token in conjunction with the previous output tokens.
|
||||
|
||||
NOTE: This function does not update prev_output_tokens.
|
||||
|
||||
Returns:
|
||||
new_token: The new token as a string.
|
||||
output_text: The new output text as a string.
|
||||
"""
|
||||
if skip_special_tokens and (new_token_id in tokenizer.all_special_ids):
|
||||
return None, prev_output_tokens
|
||||
new_token = tokenizer.convert_ids_to_tokens(
|
||||
new_token_id, skip_special_tokens=skip_special_tokens)
|
||||
output_tokens = prev_output_tokens + [new_token]
|
||||
|
||||
# Convert the tokens to a string.
|
||||
# Optimization: If the tokenizer does not have `added_tokens_encoder`,
|
||||
# then we can directly use `convert_tokens_to_string`.
|
||||
if not getattr(tokenizer, "added_tokens_encoder", {}):
|
||||
output_text = tokenizer.convert_tokens_to_string(output_tokens)
|
||||
return new_token, output_text
|
||||
|
||||
) -> str:
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
|
||||
# NOTE(woosuk): The following code is slow because it runs a for loop over
|
||||
@ -115,5 +93,61 @@ def detokenize_incrementally(
|
||||
if current_sub_text:
|
||||
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
|
||||
sub_texts.append(sub_text)
|
||||
output_text = " ".join(sub_texts)
|
||||
return new_token, output_text
|
||||
return " ".join(sub_texts)
|
||||
|
||||
|
||||
# Based on
|
||||
# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15
|
||||
# under Apache 2.0 license
|
||||
def detokenize_incrementally(
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
all_input_ids: List[int],
|
||||
prev_tokens: Optional[List[str]],
|
||||
prefix_offset: int = 0,
|
||||
read_offset: int = 0,
|
||||
skip_special_tokens: bool = False,
|
||||
) -> Tuple[List[str], str, int, int]:
|
||||
new_token_id = all_input_ids[-1]
|
||||
# This is the first iteration for this sequence
|
||||
if prev_tokens is None:
|
||||
new_tokens = tokenizer.convert_ids_to_tokens(
|
||||
all_input_ids, skip_special_tokens=skip_special_tokens)
|
||||
output_tokens = new_tokens
|
||||
# 5 is an arbitrary value that should work for all
|
||||
# tokenizers (bigger = more conservative).
|
||||
# Subtract 1 extra to account for the generated token.
|
||||
prefix_offset = max(len(output_tokens) - 6, 0)
|
||||
read_offset = max(len(output_tokens) - 1, 0)
|
||||
else:
|
||||
new_token = tokenizer.convert_ids_to_tokens(
|
||||
new_token_id, skip_special_tokens=skip_special_tokens)
|
||||
new_tokens = [new_token]
|
||||
output_tokens = prev_tokens + new_tokens
|
||||
|
||||
# The prefix text is necessary only to defeat cleanup algorithms in
|
||||
# the decode which decide to add a space or not depending on the
|
||||
# surrounding ids.
|
||||
if not getattr(tokenizer, "added_tokens_encoder", {}):
|
||||
prefix_text = tokenizer.convert_tokens_to_string(
|
||||
output_tokens[prefix_offset:read_offset])
|
||||
new_text = tokenizer.convert_tokens_to_string(
|
||||
output_tokens[prefix_offset:])
|
||||
else:
|
||||
prefix_text = _convert_tokens_to_string_with_added_encoders(
|
||||
tokenizer,
|
||||
output_tokens[prefix_offset:read_offset],
|
||||
skip_special_tokens=skip_special_tokens)
|
||||
new_text = _convert_tokens_to_string_with_added_encoders(
|
||||
tokenizer,
|
||||
output_tokens[prefix_offset:],
|
||||
skip_special_tokens=skip_special_tokens)
|
||||
|
||||
if len(new_text) > len(prefix_text) and not new_text.endswith("<EFBFBD>"):
|
||||
# utf-8 char at the end means it's a potential unfinished byte sequence
|
||||
# from byte fallback tokenization.
|
||||
# If it's in the middle, it's probably a real invalid id generated
|
||||
# by the model
|
||||
new_text = new_text[len(prefix_text):]
|
||||
return new_tokens, new_text, read_offset, len(output_tokens)
|
||||
else:
|
||||
return new_tokens, "", prefix_offset, read_offset
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user