mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:06:10 +08:00
[Misc] Make cached tokenizer pickle-compatible (#17048)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
8e4b351a0c
commit
93a126fbc7
@ -63,14 +63,16 @@ class Request:
|
||||
output_len: int
|
||||
|
||||
|
||||
def sample_tokens(tokenizer: PreTrainedTokenizerBase, length: int) -> str:
|
||||
def sample_tokens(tokenizer: PreTrainedTokenizerBase,
|
||||
length: int) -> list[int]:
|
||||
vocab = tokenizer.get_vocab()
|
||||
all_special_ids = set(tokenizer.all_special_ids)
|
||||
|
||||
# Remove the special tokens.
|
||||
vocab = {
|
||||
k: v
|
||||
for k, v in vocab.items() if k not in tokenizer.all_special_ids
|
||||
}
|
||||
return random.choices(list(vocab.values()), k=length)
|
||||
return random.choices(
|
||||
[v for k, v in vocab.items() if k not in all_special_ids],
|
||||
k=length,
|
||||
)
|
||||
|
||||
|
||||
def sample_requests_from_dataset(
|
||||
|
||||
@ -1,24 +1,43 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pickle
|
||||
from copy import deepcopy
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.transformers_utils.tokenizer import get_cached_tokenizer
|
||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
|
||||
get_cached_tokenizer)
|
||||
|
||||
|
||||
def test_cached_tokenizer():
|
||||
reference_tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
@pytest.mark.parametrize("model_id", ["gpt2", "THUDM/chatglm3-6b"])
|
||||
def test_cached_tokenizer(model_id: str):
|
||||
reference_tokenizer = AutoTokenizer.from_pretrained(model_id,
|
||||
trust_remote_code=True)
|
||||
reference_tokenizer.add_special_tokens({"cls_token": "<CLS>"})
|
||||
reference_tokenizer.add_special_tokens(
|
||||
{"additional_special_tokens": ["<SEP>"]})
|
||||
cached_tokenizer = get_cached_tokenizer(deepcopy(reference_tokenizer))
|
||||
|
||||
assert reference_tokenizer.encode("prompt") == cached_tokenizer.encode(
|
||||
"prompt")
|
||||
assert set(reference_tokenizer.all_special_ids) == set(
|
||||
cached_tokenizer.all_special_ids)
|
||||
assert set(reference_tokenizer.all_special_tokens) == set(
|
||||
cached_tokenizer.all_special_tokens)
|
||||
assert set(reference_tokenizer.all_special_tokens_extended) == set(
|
||||
cached_tokenizer.all_special_tokens_extended)
|
||||
cached_tokenizer = get_cached_tokenizer(deepcopy(reference_tokenizer))
|
||||
_check_consistency(cached_tokenizer, reference_tokenizer)
|
||||
|
||||
pickled_tokenizer = pickle.dumps(cached_tokenizer)
|
||||
unpickled_tokenizer = pickle.loads(pickled_tokenizer)
|
||||
_check_consistency(unpickled_tokenizer, reference_tokenizer)
|
||||
|
||||
|
||||
def _check_consistency(target: AnyTokenizer, expected: AnyTokenizer):
|
||||
assert isinstance(target, type(expected))
|
||||
|
||||
# Cached attributes
|
||||
assert target.all_special_ids == expected.all_special_ids
|
||||
assert target.all_special_tokens == expected.all_special_tokens
|
||||
assert (target.all_special_tokens_extended ==
|
||||
expected.all_special_tokens_extended)
|
||||
assert target.get_vocab() == expected.get_vocab()
|
||||
assert len(target) == len(expected)
|
||||
|
||||
# Other attributes
|
||||
assert getattr(target, "padding_side",
|
||||
None) == getattr(expected, "padding_side", None)
|
||||
|
||||
assert target.encode("prompt") == expected.encode("prompt")
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import contextlib
|
||||
import copy
|
||||
import os
|
||||
import warnings
|
||||
from functools import lru_cache
|
||||
@ -70,18 +71,17 @@ def encode_tokens(
|
||||
|
||||
|
||||
def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
|
||||
"""Get tokenizer with cached properties.
|
||||
|
||||
This will patch the tokenizer object in place.
|
||||
|
||||
"""
|
||||
By default, transformers will recompute multiple tokenizer properties
|
||||
each time they are called, leading to a significant slowdown. This
|
||||
function caches these properties for faster access."""
|
||||
each time they are called, leading to a significant slowdown.
|
||||
This proxy caches these properties for faster access.
|
||||
"""
|
||||
cached_tokenizer = copy.copy(tokenizer)
|
||||
|
||||
tokenizer_all_special_ids = set(tokenizer.all_special_ids)
|
||||
tokenizer_all_special_ids = tokenizer.all_special_ids
|
||||
tokenizer_all_special_tokens = tokenizer.all_special_tokens
|
||||
tokenizer_all_special_tokens_extended = (
|
||||
tokenizer.all_special_tokens_extended)
|
||||
tokenizer_all_special_tokens = set(tokenizer.all_special_tokens)
|
||||
tokenizer_vocab = tokenizer.get_vocab()
|
||||
tokenizer_len = len(tokenizer)
|
||||
|
||||
@ -97,31 +97,34 @@ def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
|
||||
class CachedTokenizer(tokenizer.__class__): # type: ignore
|
||||
|
||||
@property
|
||||
def all_special_ids(self):
|
||||
def all_special_ids(self) -> list[int]:
|
||||
return tokenizer_all_special_ids
|
||||
|
||||
@property
|
||||
def all_special_tokens(self):
|
||||
def all_special_tokens(self) -> list[str]:
|
||||
return tokenizer_all_special_tokens
|
||||
|
||||
@property
|
||||
def all_special_tokens_extended(self):
|
||||
def all_special_tokens_extended(self) -> list[str]:
|
||||
return tokenizer_all_special_tokens_extended
|
||||
|
||||
@property
|
||||
def max_token_id(self):
|
||||
def max_token_id(self) -> int:
|
||||
return max_token_id
|
||||
|
||||
def get_vocab(self):
|
||||
def get_vocab(self) -> dict[str, int]:
|
||||
return tokenizer_vocab
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return tokenizer_len
|
||||
|
||||
def __reduce__(self):
|
||||
return get_cached_tokenizer, (tokenizer, )
|
||||
|
||||
CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}"
|
||||
|
||||
tokenizer.__class__ = CachedTokenizer
|
||||
return tokenizer
|
||||
cached_tokenizer.__class__ = CachedTokenizer
|
||||
return cached_tokenizer
|
||||
|
||||
|
||||
def patch_padding_side(tokenizer: PreTrainedTokenizer) -> None:
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
import importlib
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||
@ -12,17 +12,17 @@ class TokenizerBase(ABC):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def all_special_tokens_extended(self) -> List[str]:
|
||||
def all_special_tokens_extended(self) -> list[str]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def all_special_tokens(self) -> List[str]:
|
||||
def all_special_tokens(self) -> list[str]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def all_special_ids(self) -> List[int]:
|
||||
def all_special_ids(self) -> list[int]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
@ -66,7 +66,7 @@ class TokenizerBase(ABC):
|
||||
@abstractmethod
|
||||
def __call__(
|
||||
self,
|
||||
text: Union[str, List[str], List[int]],
|
||||
text: Union[str, list[str], list[int]],
|
||||
text_pair: Optional[str] = None,
|
||||
add_special_tokens: bool = False,
|
||||
truncation: bool = False,
|
||||
@ -75,11 +75,11 @@ class TokenizerBase(ABC):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def get_vocab(self) -> Dict[str, int]:
|
||||
def get_vocab(self) -> dict[str, int]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def get_added_vocab(self) -> Dict[str, int]:
|
||||
def get_added_vocab(self) -> dict[str, int]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
@ -88,44 +88,44 @@ class TokenizerBase(ABC):
|
||||
text: str,
|
||||
truncation: bool = False,
|
||||
max_length: Optional[int] = None,
|
||||
) -> List[int]:
|
||||
) -> list[int]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def encode(self,
|
||||
text: str,
|
||||
add_special_tokens: Optional[bool] = None) -> List[int]:
|
||||
add_special_tokens: Optional[bool] = None) -> list[int]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def apply_chat_template(self,
|
||||
messages: List["ChatCompletionMessageParam"],
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
**kwargs) -> List[int]:
|
||||
messages: list["ChatCompletionMessageParam"],
|
||||
tools: Optional[list[dict[str, Any]]] = None,
|
||||
**kwargs) -> list[int]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
||||
def convert_tokens_to_string(self, tokens: list[str]) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def decode(self,
|
||||
ids: Union[List[int], int],
|
||||
ids: Union[list[int], int],
|
||||
skip_special_tokens: bool = True) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def convert_ids_to_tokens(
|
||||
self,
|
||||
ids: List[int],
|
||||
ids: list[int],
|
||||
skip_special_tokens: bool = True,
|
||||
) -> List[str]:
|
||||
) -> list[str]:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class TokenizerRegistry:
|
||||
# Tokenizer name -> (tokenizer module, tokenizer class)
|
||||
REGISTRY: Dict[str, Tuple[str, str]] = {}
|
||||
REGISTRY: dict[str, tuple[str, str]] = {}
|
||||
|
||||
@staticmethod
|
||||
def register(name: str, module: str, class_name: str) -> None:
|
||||
|
||||
@ -257,7 +257,7 @@ class MistralTokenizer(TokenizerBase):
|
||||
# the following attributes are set to fit vLLM's design and are used
|
||||
# by the guided structured output backends.
|
||||
@property
|
||||
def all_special_tokens_extended(self) -> List[str]:
|
||||
def all_special_tokens_extended(self) -> list[str]:
|
||||
from mistral_common.tokens.tokenizers.base import SpecialTokens
|
||||
|
||||
# tekken defines its own extended special tokens list
|
||||
@ -271,11 +271,11 @@ class MistralTokenizer(TokenizerBase):
|
||||
]
|
||||
|
||||
@property
|
||||
def all_special_tokens(self) -> List[str]:
|
||||
def all_special_tokens(self) -> list[str]:
|
||||
return self.all_special_tokens_extended
|
||||
|
||||
@property
|
||||
def all_special_ids(self) -> List[int]:
|
||||
def all_special_ids(self) -> list[int]:
|
||||
return [
|
||||
self.all_special_tokens.index(t) for t in self.all_special_tokens
|
||||
]
|
||||
@ -335,12 +335,12 @@ class MistralTokenizer(TokenizerBase):
|
||||
input_ids = self.encode_one(text, truncation, max_length)
|
||||
return Encoding(input_ids=input_ids)
|
||||
|
||||
def get_vocab(self) -> Dict[str, int]:
|
||||
def get_vocab(self) -> dict[str, int]:
|
||||
# NB: the dictionary form of the vocabulary collapses token ids that map
|
||||
# to the same string but have different bytes
|
||||
return self._vocab_dict
|
||||
|
||||
def get_added_vocab(self) -> Dict[str, int]:
|
||||
def get_added_vocab(self) -> dict[str, int]:
|
||||
# Mistral tokenizers have no added vocabulary
|
||||
return {}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user