diff --git a/tests/test_utils.py b/tests/test_utils.py index a165d2d7213a..f90715fd7513 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -14,9 +14,12 @@ from unittest.mock import patch import pytest import torch import zmq +from transformers import AutoTokenizer from vllm_test_utils.monitor import monitor from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config +from vllm.transformers_utils.detokenizer_utils import ( + convert_ids_list_to_tokens) from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache, MemorySnapshot, PlaceholderModule, StoreBoolean, bind_kv_cache, common_broadcastable_dtype, @@ -918,3 +921,14 @@ def test_split_host_port(): def test_join_host_port(): assert join_host_port("127.0.0.1", 5555) == "127.0.0.1:5555" assert join_host_port("::1", 5555) == "[::1]:5555" + + +def test_convert_ids_list_to_tokens(): + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct") + token_ids = tokenizer.encode("Hello, world!") + # token_ids = [9707, 11, 1879, 0] + assert tokenizer.convert_ids_to_tokens(token_ids) == [ + 'Hello', ',', 'Ġworld', '!' + ] + tokens = convert_ids_list_to_tokens(tokenizer, token_ids) + assert tokens == ['Hello', ',', ' world', '!'] diff --git a/tests/v1/engine/test_output_processor.py b/tests/v1/engine/test_output_processor.py index 1c8c5f25e29b..949ab764e2e9 100644 --- a/tests/v1/engine/test_output_processor.py +++ b/tests/v1/engine/test_output_processor.py @@ -35,7 +35,7 @@ def _ref_convert_id_to_token( Returns: String representation of input token id """ - return tokenizer.convert_ids_to_tokens(token_id) or "" + return tokenizer.decode([token_id]) or "" @pytest.mark.parametrize( diff --git a/vllm/transformers_utils/detokenizer_utils.py b/vllm/transformers_utils/detokenizer_utils.py index 342632989d57..6812cda7110f 100644 --- a/vllm/transformers_utils/detokenizer_utils.py +++ b/vllm/transformers_utils/detokenizer_utils.py @@ -78,6 +78,7 @@ def convert_prompt_ids_to_tokens( def convert_ids_list_to_tokens( tokenizer: AnyTokenizer, token_ids: list[int], + skip_special_tokens: bool = False, ) -> list[str]: """Detokenize the input ids individually. @@ -89,8 +90,15 @@ def convert_ids_list_to_tokens( Python list of token string representations """ - token_str_lst = tokenizer.convert_ids_to_tokens(token_ids) - _replace_none_with_empty(token_str_lst) # type: ignore + token_str_lst = [] + for token_id in token_ids: + token_str = tokenizer.decode( + [token_id], + skip_special_tokens=skip_special_tokens, + ) + if token_str is None: + token_str = "" + token_str_lst.append(token_str) return token_str_lst