From dd54a4b026455f728f9d5945eca369b2be7b12f9 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 14 Sep 2023 16:37:03 -0700 Subject: [PATCH] Fix detokenization leaving special tokens (#1044) Signed-off-by: Antoni Baum --- tests/engine/test_detokenize.py | 15 +++++++++++---- vllm/transformers_utils/tokenizer.py | 6 +++--- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/tests/engine/test_detokenize.py b/tests/engine/test_detokenize.py index 405904707632..fc5936c7434e 100644 --- a/tests/engine/test_detokenize.py +++ b/tests/engine/test_detokenize.py @@ -23,7 +23,8 @@ TOKENIZERS = [ ] -def _run_incremental_decode(tokenizer, all_input_ids): +def _run_incremental_decode(tokenizer, all_input_ids, + skip_special_tokens: bool): decoded_text = "" offset = 0 token_offset = 0 @@ -35,7 +36,7 @@ def _run_incremental_decode(tokenizer, all_input_ids): prev_tokens, offset, token_offset, - skip_special_tokens=False) + skip_special_tokens=skip_special_tokens) decoded_text += text if prev_tokens is None: prev_tokens = new_tokens @@ -46,10 +47,16 @@ def _run_incremental_decode(tokenizer, all_input_ids): @pytest.mark.parametrize("truth", TRUTH) @pytest.mark.parametrize("tokenizer_id", TOKENIZERS) -def test_decode_streaming(tokenizer_id, truth): +@pytest.mark.parametrize("skip_special_tokens", (True, False)) +def test_decode_streaming(tokenizer_id, truth, skip_special_tokens): tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) all_input_ids = tokenizer(truth, add_special_tokens=False)["input_ids"] + if skip_special_tokens: + all_input_ids = ([tokenizer.bos_token_id] + if tokenizer.bos_token_id is not None else + []) + all_input_ids + [tokenizer.eos_token_id] - decoded_text = _run_incremental_decode(tokenizer, all_input_ids) + decoded_text = _run_incremental_decode( + tokenizer, all_input_ids, skip_special_tokens=skip_special_tokens) assert decoded_text == truth diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index d1275a1dd2b4..10f57b4082fa 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -119,9 +119,9 @@ def detokenize_incrementally( prefix_offset = max(len(output_tokens) - 6, 0) read_offset = max(len(output_tokens) - 1, 0) else: - new_token = tokenizer.convert_ids_to_tokens( - new_token_id, skip_special_tokens=skip_special_tokens) - new_tokens = [new_token] + # Put new_token_id in a list so skip_special_tokens is respected + new_tokens = tokenizer.convert_ids_to_tokens( + [new_token_id], skip_special_tokens=skip_special_tokens) output_tokens = prev_tokens + new_tokens # The prefix text is necessary only to defeat cleanup algorithms in