mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:55:40 +08:00
Fix detokenization leaving special tokens (#1044)
Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
This commit is contained in:
parent
eda1a7cad3
commit
dd54a4b026
@ -23,7 +23,8 @@ TOKENIZERS = [
|
||||
]
|
||||
|
||||
|
||||
def _run_incremental_decode(tokenizer, all_input_ids):
|
||||
def _run_incremental_decode(tokenizer, all_input_ids,
|
||||
skip_special_tokens: bool):
|
||||
decoded_text = ""
|
||||
offset = 0
|
||||
token_offset = 0
|
||||
@ -35,7 +36,7 @@ def _run_incremental_decode(tokenizer, all_input_ids):
|
||||
prev_tokens,
|
||||
offset,
|
||||
token_offset,
|
||||
skip_special_tokens=False)
|
||||
skip_special_tokens=skip_special_tokens)
|
||||
decoded_text += text
|
||||
if prev_tokens is None:
|
||||
prev_tokens = new_tokens
|
||||
@ -46,10 +47,16 @@ def _run_incremental_decode(tokenizer, all_input_ids):
|
||||
|
||||
@pytest.mark.parametrize("truth", TRUTH)
|
||||
@pytest.mark.parametrize("tokenizer_id", TOKENIZERS)
|
||||
def test_decode_streaming(tokenizer_id, truth):
|
||||
@pytest.mark.parametrize("skip_special_tokens", (True, False))
|
||||
def test_decode_streaming(tokenizer_id, truth, skip_special_tokens):
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
|
||||
all_input_ids = tokenizer(truth, add_special_tokens=False)["input_ids"]
|
||||
if skip_special_tokens:
|
||||
all_input_ids = ([tokenizer.bos_token_id]
|
||||
if tokenizer.bos_token_id is not None else
|
||||
[]) + all_input_ids + [tokenizer.eos_token_id]
|
||||
|
||||
decoded_text = _run_incremental_decode(tokenizer, all_input_ids)
|
||||
decoded_text = _run_incremental_decode(
|
||||
tokenizer, all_input_ids, skip_special_tokens=skip_special_tokens)
|
||||
|
||||
assert decoded_text == truth
|
||||
|
||||
@ -119,9 +119,9 @@ def detokenize_incrementally(
|
||||
prefix_offset = max(len(output_tokens) - 6, 0)
|
||||
read_offset = max(len(output_tokens) - 1, 0)
|
||||
else:
|
||||
new_token = tokenizer.convert_ids_to_tokens(
|
||||
new_token_id, skip_special_tokens=skip_special_tokens)
|
||||
new_tokens = [new_token]
|
||||
# Put new_token_id in a list so skip_special_tokens is respected
|
||||
new_tokens = tokenizer.convert_ids_to_tokens(
|
||||
[new_token_id], skip_special_tokens=skip_special_tokens)
|
||||
output_tokens = prev_tokens + new_tokens
|
||||
|
||||
# The prefix text is necessary only to defeat cleanup algorithms in
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user