diff --git a/tests/tokenization/test_do_lower_case.py b/tests/tokenization/test_do_lower_case.py new file mode 100644 index 000000000000..7aa655e1c3b4 --- /dev/null +++ b/tests/tokenization/test_do_lower_case.py @@ -0,0 +1,18 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + +from vllm.transformers_utils.tokenizer import get_tokenizer + +TOKENIZER_NAMES = ["BAAI/bge-base-en"] + + +@pytest.mark.parametrize("tokenizer_name", TOKENIZER_NAMES) +@pytest.mark.parametrize("n_tokens", [510]) +def test_special_tokens(tokenizer_name: str, n_tokens: int): + tokenizer = get_tokenizer(tokenizer_name, revision="main") + + prompts = '[UNK]' * n_tokens + prompt_token_ids = tokenizer.encode(prompts) + assert len(prompt_token_ids) == n_tokens + 2 diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 01d1769f0e5e..25dd71d877fb 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -16,6 +16,8 @@ from transformers import (AutoTokenizer, PreTrainedTokenizer, from vllm import envs from vllm.logger import init_logger +from vllm.transformers_utils.config import ( + get_sentence_transformer_tokenizer_config) from vllm.transformers_utils.tokenizers import MistralTokenizer from vllm.transformers_utils.utils import check_gguf_file from vllm.utils import make_async @@ -256,6 +258,18 @@ def get_tokenizer( else: raise e + # The special_tokens in tokenizer should also be + # controlled by do_lower_case in encoder_config + encoder_config = get_sentence_transformer_tokenizer_config( + tokenizer_name, revision) + if isinstance(encoder_config, dict) and encoder_config.get( + "do_lower_case", False): + special_tokens_map = { + k: v.lower() + for k, v in tokenizer.special_tokens_map.items() + } + tokenizer.add_special_tokens(special_tokens_map) + # NOTE: We can remove this after https://github.com/THUDM/ChatGLM3/issues/1324 if type(tokenizer).__name__ in ("ChatGLMTokenizer", "ChatGLM4Tokenizer"):