mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 18:46:25 +08:00
[Bugfix] The special_tokens in tokenizer should also be controlled by do_lower_case in encoder_config. (#20750)
Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
parent
ca4eb82bcb
commit
5895afd780
18
tests/tokenization/test_do_lower_case.py
Normal file
18
tests/tokenization/test_do_lower_case.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
|
||||||
|
TOKENIZER_NAMES = ["BAAI/bge-base-en"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("tokenizer_name", TOKENIZER_NAMES)
|
||||||
|
@pytest.mark.parametrize("n_tokens", [510])
|
||||||
|
def test_special_tokens(tokenizer_name: str, n_tokens: int):
|
||||||
|
tokenizer = get_tokenizer(tokenizer_name, revision="main")
|
||||||
|
|
||||||
|
prompts = '[UNK]' * n_tokens
|
||||||
|
prompt_token_ids = tokenizer.encode(prompts)
|
||||||
|
assert len(prompt_token_ids) == n_tokens + 2
|
||||||
@ -16,6 +16,8 @@ from transformers import (AutoTokenizer, PreTrainedTokenizer,
|
|||||||
|
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.transformers_utils.config import (
|
||||||
|
get_sentence_transformer_tokenizer_config)
|
||||||
from vllm.transformers_utils.tokenizers import MistralTokenizer
|
from vllm.transformers_utils.tokenizers import MistralTokenizer
|
||||||
from vllm.transformers_utils.utils import check_gguf_file
|
from vllm.transformers_utils.utils import check_gguf_file
|
||||||
from vllm.utils import make_async
|
from vllm.utils import make_async
|
||||||
@ -256,6 +258,18 @@ def get_tokenizer(
|
|||||||
else:
|
else:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
# The special_tokens in tokenizer should also be
|
||||||
|
# controlled by do_lower_case in encoder_config
|
||||||
|
encoder_config = get_sentence_transformer_tokenizer_config(
|
||||||
|
tokenizer_name, revision)
|
||||||
|
if isinstance(encoder_config, dict) and encoder_config.get(
|
||||||
|
"do_lower_case", False):
|
||||||
|
special_tokens_map = {
|
||||||
|
k: v.lower()
|
||||||
|
for k, v in tokenizer.special_tokens_map.items()
|
||||||
|
}
|
||||||
|
tokenizer.add_special_tokens(special_tokens_map)
|
||||||
|
|
||||||
# NOTE: We can remove this after https://github.com/THUDM/ChatGLM3/issues/1324
|
# NOTE: We can remove this after https://github.com/THUDM/ChatGLM3/issues/1324
|
||||||
if type(tokenizer).__name__ in ("ChatGLMTokenizer",
|
if type(tokenizer).__name__ in ("ChatGLMTokenizer",
|
||||||
"ChatGLM4Tokenizer"):
|
"ChatGLM4Tokenizer"):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user