[Misc] Make cached tokenizer pickle-compatible (#17048)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-04-27 13:05:00 +08:00 committed by GitHub
parent 8e4b351a0c
commit 93a126fbc7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 81 additions and 57 deletions

View File

@ -63,14 +63,16 @@ class Request:
output_len: int
def sample_tokens(tokenizer: PreTrainedTokenizerBase, length: int) -> str:
def sample_tokens(tokenizer: PreTrainedTokenizerBase,
length: int) -> list[int]:
vocab = tokenizer.get_vocab()
all_special_ids = set(tokenizer.all_special_ids)
# Remove the special tokens.
vocab = {
k: v
for k, v in vocab.items() if k not in tokenizer.all_special_ids
}
return random.choices(list(vocab.values()), k=length)
return random.choices(
[v for k, v in vocab.items() if k not in all_special_ids],
k=length,
)
def sample_requests_from_dataset(

View File

@ -1,24 +1,43 @@
# SPDX-License-Identifier: Apache-2.0
import pickle
from copy import deepcopy
import pytest
from transformers import AutoTokenizer
from vllm.transformers_utils.tokenizer import get_cached_tokenizer
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
get_cached_tokenizer)
def test_cached_tokenizer():
reference_tokenizer = AutoTokenizer.from_pretrained("gpt2")
@pytest.mark.parametrize("model_id", ["gpt2", "THUDM/chatglm3-6b"])
def test_cached_tokenizer(model_id: str):
reference_tokenizer = AutoTokenizer.from_pretrained(model_id,
trust_remote_code=True)
reference_tokenizer.add_special_tokens({"cls_token": "<CLS>"})
reference_tokenizer.add_special_tokens(
{"additional_special_tokens": ["<SEP>"]})
cached_tokenizer = get_cached_tokenizer(deepcopy(reference_tokenizer))
assert reference_tokenizer.encode("prompt") == cached_tokenizer.encode(
"prompt")
assert set(reference_tokenizer.all_special_ids) == set(
cached_tokenizer.all_special_ids)
assert set(reference_tokenizer.all_special_tokens) == set(
cached_tokenizer.all_special_tokens)
assert set(reference_tokenizer.all_special_tokens_extended) == set(
cached_tokenizer.all_special_tokens_extended)
cached_tokenizer = get_cached_tokenizer(deepcopy(reference_tokenizer))
_check_consistency(cached_tokenizer, reference_tokenizer)
pickled_tokenizer = pickle.dumps(cached_tokenizer)
unpickled_tokenizer = pickle.loads(pickled_tokenizer)
_check_consistency(unpickled_tokenizer, reference_tokenizer)
def _check_consistency(target: AnyTokenizer, expected: AnyTokenizer):
assert isinstance(target, type(expected))
# Cached attributes
assert target.all_special_ids == expected.all_special_ids
assert target.all_special_tokens == expected.all_special_tokens
assert (target.all_special_tokens_extended ==
expected.all_special_tokens_extended)
assert target.get_vocab() == expected.get_vocab()
assert len(target) == len(expected)
# Other attributes
assert getattr(target, "padding_side",
None) == getattr(expected, "padding_side", None)
assert target.encode("prompt") == expected.encode("prompt")

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
import contextlib
import copy
import os
import warnings
from functools import lru_cache
@ -70,18 +71,17 @@ def encode_tokens(
def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
"""Get tokenizer with cached properties.
This will patch the tokenizer object in place.
"""
By default, transformers will recompute multiple tokenizer properties
each time they are called, leading to a significant slowdown. This
function caches these properties for faster access."""
each time they are called, leading to a significant slowdown.
This proxy caches these properties for faster access.
"""
cached_tokenizer = copy.copy(tokenizer)
tokenizer_all_special_ids = set(tokenizer.all_special_ids)
tokenizer_all_special_ids = tokenizer.all_special_ids
tokenizer_all_special_tokens = tokenizer.all_special_tokens
tokenizer_all_special_tokens_extended = (
tokenizer.all_special_tokens_extended)
tokenizer_all_special_tokens = set(tokenizer.all_special_tokens)
tokenizer_vocab = tokenizer.get_vocab()
tokenizer_len = len(tokenizer)
@ -97,31 +97,34 @@ def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
class CachedTokenizer(tokenizer.__class__): # type: ignore
@property
def all_special_ids(self):
def all_special_ids(self) -> list[int]:
return tokenizer_all_special_ids
@property
def all_special_tokens(self):
def all_special_tokens(self) -> list[str]:
return tokenizer_all_special_tokens
@property
def all_special_tokens_extended(self):
def all_special_tokens_extended(self) -> list[str]:
return tokenizer_all_special_tokens_extended
@property
def max_token_id(self):
def max_token_id(self) -> int:
return max_token_id
def get_vocab(self):
def get_vocab(self) -> dict[str, int]:
return tokenizer_vocab
def __len__(self):
def __len__(self) -> int:
return tokenizer_len
def __reduce__(self):
return get_cached_tokenizer, (tokenizer, )
CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}"
tokenizer.__class__ = CachedTokenizer
return tokenizer
cached_tokenizer.__class__ = CachedTokenizer
return cached_tokenizer
def patch_padding_side(tokenizer: PreTrainedTokenizer) -> None:

View File

@ -2,7 +2,7 @@
import importlib
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Optional, Union
if TYPE_CHECKING:
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
@ -12,17 +12,17 @@ class TokenizerBase(ABC):
@property
@abstractmethod
def all_special_tokens_extended(self) -> List[str]:
def all_special_tokens_extended(self) -> list[str]:
raise NotImplementedError()
@property
@abstractmethod
def all_special_tokens(self) -> List[str]:
def all_special_tokens(self) -> list[str]:
raise NotImplementedError()
@property
@abstractmethod
def all_special_ids(self) -> List[int]:
def all_special_ids(self) -> list[int]:
raise NotImplementedError()
@property
@ -66,7 +66,7 @@ class TokenizerBase(ABC):
@abstractmethod
def __call__(
self,
text: Union[str, List[str], List[int]],
text: Union[str, list[str], list[int]],
text_pair: Optional[str] = None,
add_special_tokens: bool = False,
truncation: bool = False,
@ -75,11 +75,11 @@ class TokenizerBase(ABC):
raise NotImplementedError()
@abstractmethod
def get_vocab(self) -> Dict[str, int]:
def get_vocab(self) -> dict[str, int]:
raise NotImplementedError()
@abstractmethod
def get_added_vocab(self) -> Dict[str, int]:
def get_added_vocab(self) -> dict[str, int]:
raise NotImplementedError()
@abstractmethod
@ -88,44 +88,44 @@ class TokenizerBase(ABC):
text: str,
truncation: bool = False,
max_length: Optional[int] = None,
) -> List[int]:
) -> list[int]:
raise NotImplementedError()
@abstractmethod
def encode(self,
text: str,
add_special_tokens: Optional[bool] = None) -> List[int]:
add_special_tokens: Optional[bool] = None) -> list[int]:
raise NotImplementedError()
@abstractmethod
def apply_chat_template(self,
messages: List["ChatCompletionMessageParam"],
tools: Optional[List[Dict[str, Any]]] = None,
**kwargs) -> List[int]:
messages: list["ChatCompletionMessageParam"],
tools: Optional[list[dict[str, Any]]] = None,
**kwargs) -> list[int]:
raise NotImplementedError()
@abstractmethod
def convert_tokens_to_string(self, tokens: List[str]) -> str:
def convert_tokens_to_string(self, tokens: list[str]) -> str:
raise NotImplementedError()
@abstractmethod
def decode(self,
ids: Union[List[int], int],
ids: Union[list[int], int],
skip_special_tokens: bool = True) -> str:
raise NotImplementedError()
@abstractmethod
def convert_ids_to_tokens(
self,
ids: List[int],
ids: list[int],
skip_special_tokens: bool = True,
) -> List[str]:
) -> list[str]:
raise NotImplementedError()
class TokenizerRegistry:
# Tokenizer name -> (tokenizer module, tokenizer class)
REGISTRY: Dict[str, Tuple[str, str]] = {}
REGISTRY: dict[str, tuple[str, str]] = {}
@staticmethod
def register(name: str, module: str, class_name: str) -> None:

View File

@ -257,7 +257,7 @@ class MistralTokenizer(TokenizerBase):
# the following attributes are set to fit vLLM's design and are used
# by the guided structured output backends.
@property
def all_special_tokens_extended(self) -> List[str]:
def all_special_tokens_extended(self) -> list[str]:
from mistral_common.tokens.tokenizers.base import SpecialTokens
# tekken defines its own extended special tokens list
@ -271,11 +271,11 @@ class MistralTokenizer(TokenizerBase):
]
@property
def all_special_tokens(self) -> List[str]:
def all_special_tokens(self) -> list[str]:
return self.all_special_tokens_extended
@property
def all_special_ids(self) -> List[int]:
def all_special_ids(self) -> list[int]:
return [
self.all_special_tokens.index(t) for t in self.all_special_tokens
]
@ -335,12 +335,12 @@ class MistralTokenizer(TokenizerBase):
input_ids = self.encode_one(text, truncation, max_length)
return Encoding(input_ids=input_ids)
def get_vocab(self) -> Dict[str, int]:
def get_vocab(self) -> dict[str, int]:
# NB: the dictionary form of the vocabulary collapses token ids that map
# to the same string but have different bytes
return self._vocab_dict
def get_added_vocab(self) -> Dict[str, int]:
def get_added_vocab(self) -> dict[str, int]:
# Mistral tokenizers have no added vocabulary
return {}