# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pickle from copy import deepcopy import pytest from transformers import AutoTokenizer from vllm.tokenizers import TokenizerLike from vllm.tokenizers.hf import get_cached_tokenizer @pytest.mark.parametrize("model_id", ["gpt2", "zai-org/chatglm3-6b"]) def test_cached_tokenizer(model_id: str): reference_tokenizer = AutoTokenizer.from_pretrained( model_id, trust_remote_code=True ) reference_tokenizer.add_special_tokens({"cls_token": ""}) reference_tokenizer.add_special_tokens({"additional_special_tokens": [""]}) cached_tokenizer = get_cached_tokenizer(deepcopy(reference_tokenizer)) _check_consistency(cached_tokenizer, reference_tokenizer) pickled_tokenizer = pickle.dumps(cached_tokenizer) unpickled_tokenizer = pickle.loads(pickled_tokenizer) _check_consistency(unpickled_tokenizer, reference_tokenizer) def _check_consistency(target: TokenizerLike, expected: TokenizerLike): assert isinstance(target, type(expected)) # Cached attributes assert target.all_special_ids == expected.all_special_ids assert target.all_special_tokens == expected.all_special_tokens assert target.get_vocab() == expected.get_vocab() assert len(target) == len(expected) # Other attributes assert getattr(target, "padding_side", None) == getattr( expected, "padding_side", None ) assert target.encode("prompt") == expected.encode("prompt") def test_cached_tokenizer_preserves_class_name(): """Test that cached tokenizer preserves original class name. This is important for compatibility with HuggingFace transformers processor type checking, which validates tokenizer class name. See: https://github.com/vllm-project/vllm/issues/31080 """ tokenizer = AutoTokenizer.from_pretrained("gpt2") original_class_name = tokenizer.__class__.__name__ cached_tokenizer = get_cached_tokenizer(tokenizer) # The cached tokenizer's class should have the same name as original assert cached_tokenizer.__class__.__name__ == original_class_name assert cached_tokenizer.__class__.__qualname__ == tokenizer.__class__.__qualname__