[Misc] Make cached tokenizer pickle-compatible (#17048)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-04-27 13:05:00 +08:00 committed by GitHub
parent 8e4b351a0c
commit 93a126fbc7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 81 additions and 57 deletions

View File

@ -63,14 +63,16 @@ class Request:
output_len: int output_len: int
def sample_tokens(tokenizer: PreTrainedTokenizerBase, length: int) -> str: def sample_tokens(tokenizer: PreTrainedTokenizerBase,
length: int) -> list[int]:
vocab = tokenizer.get_vocab() vocab = tokenizer.get_vocab()
all_special_ids = set(tokenizer.all_special_ids)
# Remove the special tokens. # Remove the special tokens.
vocab = { return random.choices(
k: v [v for k, v in vocab.items() if k not in all_special_ids],
for k, v in vocab.items() if k not in tokenizer.all_special_ids k=length,
} )
return random.choices(list(vocab.values()), k=length)
def sample_requests_from_dataset( def sample_requests_from_dataset(

View File

@ -1,24 +1,43 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import pickle
from copy import deepcopy from copy import deepcopy
import pytest
from transformers import AutoTokenizer from transformers import AutoTokenizer
from vllm.transformers_utils.tokenizer import get_cached_tokenizer from vllm.transformers_utils.tokenizer import (AnyTokenizer,
get_cached_tokenizer)
def test_cached_tokenizer(): @pytest.mark.parametrize("model_id", ["gpt2", "THUDM/chatglm3-6b"])
reference_tokenizer = AutoTokenizer.from_pretrained("gpt2") def test_cached_tokenizer(model_id: str):
reference_tokenizer = AutoTokenizer.from_pretrained(model_id,
trust_remote_code=True)
reference_tokenizer.add_special_tokens({"cls_token": "<CLS>"}) reference_tokenizer.add_special_tokens({"cls_token": "<CLS>"})
reference_tokenizer.add_special_tokens( reference_tokenizer.add_special_tokens(
{"additional_special_tokens": ["<SEP>"]}) {"additional_special_tokens": ["<SEP>"]})
cached_tokenizer = get_cached_tokenizer(deepcopy(reference_tokenizer))
assert reference_tokenizer.encode("prompt") == cached_tokenizer.encode( cached_tokenizer = get_cached_tokenizer(deepcopy(reference_tokenizer))
"prompt") _check_consistency(cached_tokenizer, reference_tokenizer)
assert set(reference_tokenizer.all_special_ids) == set(
cached_tokenizer.all_special_ids) pickled_tokenizer = pickle.dumps(cached_tokenizer)
assert set(reference_tokenizer.all_special_tokens) == set( unpickled_tokenizer = pickle.loads(pickled_tokenizer)
cached_tokenizer.all_special_tokens) _check_consistency(unpickled_tokenizer, reference_tokenizer)
assert set(reference_tokenizer.all_special_tokens_extended) == set(
cached_tokenizer.all_special_tokens_extended)
def _check_consistency(target: AnyTokenizer, expected: AnyTokenizer):
assert isinstance(target, type(expected))
# Cached attributes
assert target.all_special_ids == expected.all_special_ids
assert target.all_special_tokens == expected.all_special_tokens
assert (target.all_special_tokens_extended ==
expected.all_special_tokens_extended)
assert target.get_vocab() == expected.get_vocab()
assert len(target) == len(expected)
# Other attributes
assert getattr(target, "padding_side",
None) == getattr(expected, "padding_side", None)
assert target.encode("prompt") == expected.encode("prompt")

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import contextlib import contextlib
import copy
import os import os
import warnings import warnings
from functools import lru_cache from functools import lru_cache
@ -70,18 +71,17 @@ def encode_tokens(
def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer: def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
"""Get tokenizer with cached properties. """
This will patch the tokenizer object in place.
By default, transformers will recompute multiple tokenizer properties By default, transformers will recompute multiple tokenizer properties
each time they are called, leading to a significant slowdown. This each time they are called, leading to a significant slowdown.
function caches these properties for faster access.""" This proxy caches these properties for faster access.
"""
cached_tokenizer = copy.copy(tokenizer)
tokenizer_all_special_ids = set(tokenizer.all_special_ids) tokenizer_all_special_ids = tokenizer.all_special_ids
tokenizer_all_special_tokens = tokenizer.all_special_tokens
tokenizer_all_special_tokens_extended = ( tokenizer_all_special_tokens_extended = (
tokenizer.all_special_tokens_extended) tokenizer.all_special_tokens_extended)
tokenizer_all_special_tokens = set(tokenizer.all_special_tokens)
tokenizer_vocab = tokenizer.get_vocab() tokenizer_vocab = tokenizer.get_vocab()
tokenizer_len = len(tokenizer) tokenizer_len = len(tokenizer)
@ -97,31 +97,34 @@ def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
class CachedTokenizer(tokenizer.__class__): # type: ignore class CachedTokenizer(tokenizer.__class__): # type: ignore
@property @property
def all_special_ids(self): def all_special_ids(self) -> list[int]:
return tokenizer_all_special_ids return tokenizer_all_special_ids
@property @property
def all_special_tokens(self): def all_special_tokens(self) -> list[str]:
return tokenizer_all_special_tokens return tokenizer_all_special_tokens
@property @property
def all_special_tokens_extended(self): def all_special_tokens_extended(self) -> list[str]:
return tokenizer_all_special_tokens_extended return tokenizer_all_special_tokens_extended
@property @property
def max_token_id(self): def max_token_id(self) -> int:
return max_token_id return max_token_id
def get_vocab(self): def get_vocab(self) -> dict[str, int]:
return tokenizer_vocab return tokenizer_vocab
def __len__(self): def __len__(self) -> int:
return tokenizer_len return tokenizer_len
def __reduce__(self):
return get_cached_tokenizer, (tokenizer, )
CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}" CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}"
tokenizer.__class__ = CachedTokenizer cached_tokenizer.__class__ = CachedTokenizer
return tokenizer return cached_tokenizer
def patch_padding_side(tokenizer: PreTrainedTokenizer) -> None: def patch_padding_side(tokenizer: PreTrainedTokenizer) -> None:

View File

@ -2,7 +2,7 @@
import importlib import importlib
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Optional, Union
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
@ -12,17 +12,17 @@ class TokenizerBase(ABC):
@property @property
@abstractmethod @abstractmethod
def all_special_tokens_extended(self) -> List[str]: def all_special_tokens_extended(self) -> list[str]:
raise NotImplementedError() raise NotImplementedError()
@property @property
@abstractmethod @abstractmethod
def all_special_tokens(self) -> List[str]: def all_special_tokens(self) -> list[str]:
raise NotImplementedError() raise NotImplementedError()
@property @property
@abstractmethod @abstractmethod
def all_special_ids(self) -> List[int]: def all_special_ids(self) -> list[int]:
raise NotImplementedError() raise NotImplementedError()
@property @property
@ -66,7 +66,7 @@ class TokenizerBase(ABC):
@abstractmethod @abstractmethod
def __call__( def __call__(
self, self,
text: Union[str, List[str], List[int]], text: Union[str, list[str], list[int]],
text_pair: Optional[str] = None, text_pair: Optional[str] = None,
add_special_tokens: bool = False, add_special_tokens: bool = False,
truncation: bool = False, truncation: bool = False,
@ -75,11 +75,11 @@ class TokenizerBase(ABC):
raise NotImplementedError() raise NotImplementedError()
@abstractmethod @abstractmethod
def get_vocab(self) -> Dict[str, int]: def get_vocab(self) -> dict[str, int]:
raise NotImplementedError() raise NotImplementedError()
@abstractmethod @abstractmethod
def get_added_vocab(self) -> Dict[str, int]: def get_added_vocab(self) -> dict[str, int]:
raise NotImplementedError() raise NotImplementedError()
@abstractmethod @abstractmethod
@ -88,44 +88,44 @@ class TokenizerBase(ABC):
text: str, text: str,
truncation: bool = False, truncation: bool = False,
max_length: Optional[int] = None, max_length: Optional[int] = None,
) -> List[int]: ) -> list[int]:
raise NotImplementedError() raise NotImplementedError()
@abstractmethod @abstractmethod
def encode(self, def encode(self,
text: str, text: str,
add_special_tokens: Optional[bool] = None) -> List[int]: add_special_tokens: Optional[bool] = None) -> list[int]:
raise NotImplementedError() raise NotImplementedError()
@abstractmethod @abstractmethod
def apply_chat_template(self, def apply_chat_template(self,
messages: List["ChatCompletionMessageParam"], messages: list["ChatCompletionMessageParam"],
tools: Optional[List[Dict[str, Any]]] = None, tools: Optional[list[dict[str, Any]]] = None,
**kwargs) -> List[int]: **kwargs) -> list[int]:
raise NotImplementedError() raise NotImplementedError()
@abstractmethod @abstractmethod
def convert_tokens_to_string(self, tokens: List[str]) -> str: def convert_tokens_to_string(self, tokens: list[str]) -> str:
raise NotImplementedError() raise NotImplementedError()
@abstractmethod @abstractmethod
def decode(self, def decode(self,
ids: Union[List[int], int], ids: Union[list[int], int],
skip_special_tokens: bool = True) -> str: skip_special_tokens: bool = True) -> str:
raise NotImplementedError() raise NotImplementedError()
@abstractmethod @abstractmethod
def convert_ids_to_tokens( def convert_ids_to_tokens(
self, self,
ids: List[int], ids: list[int],
skip_special_tokens: bool = True, skip_special_tokens: bool = True,
) -> List[str]: ) -> list[str]:
raise NotImplementedError() raise NotImplementedError()
class TokenizerRegistry: class TokenizerRegistry:
# Tokenizer name -> (tokenizer module, tokenizer class) # Tokenizer name -> (tokenizer module, tokenizer class)
REGISTRY: Dict[str, Tuple[str, str]] = {} REGISTRY: dict[str, tuple[str, str]] = {}
@staticmethod @staticmethod
def register(name: str, module: str, class_name: str) -> None: def register(name: str, module: str, class_name: str) -> None:

View File

@ -257,7 +257,7 @@ class MistralTokenizer(TokenizerBase):
# the following attributes are set to fit vLLM's design and are used # the following attributes are set to fit vLLM's design and are used
# by the guided structured output backends. # by the guided structured output backends.
@property @property
def all_special_tokens_extended(self) -> List[str]: def all_special_tokens_extended(self) -> list[str]:
from mistral_common.tokens.tokenizers.base import SpecialTokens from mistral_common.tokens.tokenizers.base import SpecialTokens
# tekken defines its own extended special tokens list # tekken defines its own extended special tokens list
@ -271,11 +271,11 @@ class MistralTokenizer(TokenizerBase):
] ]
@property @property
def all_special_tokens(self) -> List[str]: def all_special_tokens(self) -> list[str]:
return self.all_special_tokens_extended return self.all_special_tokens_extended
@property @property
def all_special_ids(self) -> List[int]: def all_special_ids(self) -> list[int]:
return [ return [
self.all_special_tokens.index(t) for t in self.all_special_tokens self.all_special_tokens.index(t) for t in self.all_special_tokens
] ]
@ -335,12 +335,12 @@ class MistralTokenizer(TokenizerBase):
input_ids = self.encode_one(text, truncation, max_length) input_ids = self.encode_one(text, truncation, max_length)
return Encoding(input_ids=input_ids) return Encoding(input_ids=input_ids)
def get_vocab(self) -> Dict[str, int]: def get_vocab(self) -> dict[str, int]:
# NB: the dictionary form of the vocabulary collapses token ids that map # NB: the dictionary form of the vocabulary collapses token ids that map
# to the same string but have different bytes # to the same string but have different bytes
return self._vocab_dict return self._vocab_dict
def get_added_vocab(self) -> Dict[str, int]: def get_added_vocab(self) -> dict[str, int]:
# Mistral tokenizers have no added vocabulary # Mistral tokenizers have no added vocabulary
return {} return {}