mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 07:34:59 +08:00
Fix detokenization leaving special tokens (#1044)
Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
This commit is contained in:
parent
eda1a7cad3
commit
dd54a4b026
@ -23,7 +23,8 @@ TOKENIZERS = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def _run_incremental_decode(tokenizer, all_input_ids):
|
def _run_incremental_decode(tokenizer, all_input_ids,
|
||||||
|
skip_special_tokens: bool):
|
||||||
decoded_text = ""
|
decoded_text = ""
|
||||||
offset = 0
|
offset = 0
|
||||||
token_offset = 0
|
token_offset = 0
|
||||||
@ -35,7 +36,7 @@ def _run_incremental_decode(tokenizer, all_input_ids):
|
|||||||
prev_tokens,
|
prev_tokens,
|
||||||
offset,
|
offset,
|
||||||
token_offset,
|
token_offset,
|
||||||
skip_special_tokens=False)
|
skip_special_tokens=skip_special_tokens)
|
||||||
decoded_text += text
|
decoded_text += text
|
||||||
if prev_tokens is None:
|
if prev_tokens is None:
|
||||||
prev_tokens = new_tokens
|
prev_tokens = new_tokens
|
||||||
@ -46,10 +47,16 @@ def _run_incremental_decode(tokenizer, all_input_ids):
|
|||||||
|
|
||||||
@pytest.mark.parametrize("truth", TRUTH)
|
@pytest.mark.parametrize("truth", TRUTH)
|
||||||
@pytest.mark.parametrize("tokenizer_id", TOKENIZERS)
|
@pytest.mark.parametrize("tokenizer_id", TOKENIZERS)
|
||||||
def test_decode_streaming(tokenizer_id, truth):
|
@pytest.mark.parametrize("skip_special_tokens", (True, False))
|
||||||
|
def test_decode_streaming(tokenizer_id, truth, skip_special_tokens):
|
||||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
|
||||||
all_input_ids = tokenizer(truth, add_special_tokens=False)["input_ids"]
|
all_input_ids = tokenizer(truth, add_special_tokens=False)["input_ids"]
|
||||||
|
if skip_special_tokens:
|
||||||
|
all_input_ids = ([tokenizer.bos_token_id]
|
||||||
|
if tokenizer.bos_token_id is not None else
|
||||||
|
[]) + all_input_ids + [tokenizer.eos_token_id]
|
||||||
|
|
||||||
decoded_text = _run_incremental_decode(tokenizer, all_input_ids)
|
decoded_text = _run_incremental_decode(
|
||||||
|
tokenizer, all_input_ids, skip_special_tokens=skip_special_tokens)
|
||||||
|
|
||||||
assert decoded_text == truth
|
assert decoded_text == truth
|
||||||
|
|||||||
@ -119,9 +119,9 @@ def detokenize_incrementally(
|
|||||||
prefix_offset = max(len(output_tokens) - 6, 0)
|
prefix_offset = max(len(output_tokens) - 6, 0)
|
||||||
read_offset = max(len(output_tokens) - 1, 0)
|
read_offset = max(len(output_tokens) - 1, 0)
|
||||||
else:
|
else:
|
||||||
new_token = tokenizer.convert_ids_to_tokens(
|
# Put new_token_id in a list so skip_special_tokens is respected
|
||||||
new_token_id, skip_special_tokens=skip_special_tokens)
|
new_tokens = tokenizer.convert_ids_to_tokens(
|
||||||
new_tokens = [new_token]
|
[new_token_id], skip_special_tokens=skip_special_tokens)
|
||||||
output_tokens = prev_tokens + new_tokens
|
output_tokens = prev_tokens + new_tokens
|
||||||
|
|
||||||
# The prefix text is necessary only to defeat cleanup algorithms in
|
# The prefix text is necessary only to defeat cleanup algorithms in
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user