# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import _get_protocol_attrs # type: ignore import pytest from transformers import ( PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast, ) from vllm.tokenizers import TokenizerLike, get_tokenizer from vllm.tokenizers.mistral import MistralTokenizer def _get_missing_attrs(obj: object, target: type): return [k for k in _get_protocol_attrs(target) if not hasattr(obj, k)] def _assert_tokenizer_like(tokenizer: object): missing_attrs = _get_missing_attrs(tokenizer, TokenizerLike) assert not missing_attrs, f"Missing attrs: {missing_attrs}" def test_tokenizer_like_protocol(): tokenizer = get_tokenizer("gpt2", use_fast=False) assert isinstance(tokenizer, PreTrainedTokenizer) _assert_tokenizer_like(tokenizer) tokenizer = get_tokenizer("gpt2", use_fast=True) assert isinstance(tokenizer, PreTrainedTokenizerFast) _assert_tokenizer_like(tokenizer) tokenizer = get_tokenizer( "mistralai/Mistral-7B-Instruct-v0.3", tokenizer_mode="mistral" ) assert isinstance(tokenizer, MistralTokenizer) _assert_tokenizer_like(tokenizer) @pytest.mark.parametrize("tokenizer_name", ["facebook/opt-125m", "gpt2"]) def test_tokenizer_revision(tokenizer_name: str): # Assume that "main" branch always exists tokenizer = get_tokenizer(tokenizer_name, revision="main") assert isinstance(tokenizer, PreTrainedTokenizerBase) # Assume that "never" branch always does not exist with pytest.raises(OSError, match="not a valid git identifier"): get_tokenizer(tokenizer_name, revision="never") @pytest.mark.parametrize("tokenizer_name", ["BAAI/bge-base-en"]) @pytest.mark.parametrize("n_tokens", [510]) def test_special_tokens(tokenizer_name: str, n_tokens: int): tokenizer = get_tokenizer(tokenizer_name, revision="main") prompts = "[UNK]" * n_tokens prompt_token_ids = tokenizer.encode(prompts) assert len(prompt_token_ids) == n_tokens + 2