Fix detokenization leaving special tokens (#1044)

Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
This commit is contained in:
Antoni Baum 2023-09-14 16:37:03 -07:00 committed by GitHub
parent eda1a7cad3
commit dd54a4b026
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 14 additions and 7 deletions

View File

@ -23,7 +23,8 @@ TOKENIZERS = [
]
def _run_incremental_decode(tokenizer, all_input_ids):
def _run_incremental_decode(tokenizer, all_input_ids,
skip_special_tokens: bool):
decoded_text = ""
offset = 0
token_offset = 0
@ -35,7 +36,7 @@ def _run_incremental_decode(tokenizer, all_input_ids):
prev_tokens,
offset,
token_offset,
skip_special_tokens=False)
skip_special_tokens=skip_special_tokens)
decoded_text += text
if prev_tokens is None:
prev_tokens = new_tokens
@ -46,10 +47,16 @@ def _run_incremental_decode(tokenizer, all_input_ids):
@pytest.mark.parametrize("truth", TRUTH)
@pytest.mark.parametrize("tokenizer_id", TOKENIZERS)
def test_decode_streaming(tokenizer_id, truth):
@pytest.mark.parametrize("skip_special_tokens", (True, False))
def test_decode_streaming(tokenizer_id, truth, skip_special_tokens):
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
all_input_ids = tokenizer(truth, add_special_tokens=False)["input_ids"]
if skip_special_tokens:
all_input_ids = ([tokenizer.bos_token_id]
if tokenizer.bos_token_id is not None else
[]) + all_input_ids + [tokenizer.eos_token_id]
decoded_text = _run_incremental_decode(tokenizer, all_input_ids)
decoded_text = _run_incremental_decode(
tokenizer, all_input_ids, skip_special_tokens=skip_special_tokens)
assert decoded_text == truth

View File

@ -119,9 +119,9 @@ def detokenize_incrementally(
prefix_offset = max(len(output_tokens) - 6, 0)
read_offset = max(len(output_tokens) - 1, 0)
else:
new_token = tokenizer.convert_ids_to_tokens(
new_token_id, skip_special_tokens=skip_special_tokens)
new_tokens = [new_token]
# Put new_token_id in a list so skip_special_tokens is respected
new_tokens = tokenizer.convert_ids_to_tokens(
[new_token_id], skip_special_tokens=skip_special_tokens)
output_tokens = prev_tokens + new_tokens
# The prefix text is necessary only to defeat cleanup algorithms in