diff --git a/tests/tokenization/test_mistral_tokenizer.py b/tests/tokenization/test_mistral_tokenizer.py index ebf107217c3cb..926ad2503398c 100644 --- a/tests/tokenization/test_mistral_tokenizer.py +++ b/tests/tokenization/test_mistral_tokenizer.py @@ -334,20 +334,20 @@ class TestMistralTokenizer: def test_encode(self, mistral_tokenizer: MistralTokenizer): token_ids = ( - [1, 22177, 4304, 2662, 2] + [1, 22177, 4304, 2662] if mistral_tokenizer.is_tekken - else [1, 23325, 2294, 1686, 2] + else [1, 23325, 2294, 1686] ) - assert mistral_tokenizer.encode("Hello world !") == token_ids[:-1] - assert mistral_tokenizer.encode("Hello world !", max_length=3) == token_ids[:-2] + assert mistral_tokenizer.encode("Hello world !") == token_ids + assert mistral_tokenizer.encode("Hello world !", max_length=3) == token_ids[:-1] assert ( mistral_tokenizer.encode("Hello world !", truncation=True, max_length=3) - == token_ids[:-2] + == token_ids[:-1] ) assert ( mistral_tokenizer.encode("Hello world !", truncation=False, max_length=3) - == token_ids[:-1] + == token_ids ) assert ( @@ -358,7 +358,7 @@ class TestMistralTokenizer: mistral_tokenizer.encode( "Hello world !", add_special_tokens=True, max_length=3 ) - == token_ids[:-2] + == token_ids[:-1] ) assert ( mistral_tokenizer.encode( @@ -368,7 +368,7 @@ class TestMistralTokenizer: ) assert ( mistral_tokenizer.encode("Hello world !", add_special_tokens=False) - == token_ids[1:-1] + == token_ids[1:] ) @pytest.mark.parametrize( @@ -1088,6 +1088,19 @@ class TestMistralTokenizer: == expected_tokens[mistral_tokenizer.is_tekken] ) + def test_decode_int( + self, + mistral_tokenizer: MistralTokenizer, + ): + ids = 1 + assert ( + mistral_tokenizer.decode( + ids, + skip_special_tokens=False, + ) + == "" + ) + def test_convert_tokens_to_string(self, mistral_tokenizer: MistralTokenizer): tokens = ( [ diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index 7033523224c51..34433484fc14e 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -165,6 +165,7 @@ def _tekken_token_to_id(tokenizer: "Tekkenizer", t: str | bytes) -> int: class MistralTokenizer(TokenizerBase): def __init__(self, tokenizer: "TransformersMistralTokenizer") -> None: + from mistral_common.protocol.instruct.validator import ValidationMode from mistral_common.tokens.tokenizers.sentencepiece import ( SentencePieceTokenizer, ) @@ -175,6 +176,14 @@ class MistralTokenizer(TokenizerBase): self.instruct = self.mistral.instruct_tokenizer self.tokenizer = self.instruct.tokenizer + mode = self.mistral._chat_completion_request_validator._mode + if mode != ValidationMode.test: + raise ValueError( + "Mistral tokenizer must be in test mode. Make sure to " + "set `mode='ValidationMode.test'` when creating the " + "Mistral tokenizer." + ) + _mistral_version_str = str(self.tokenizer.version.value) self.version: int = int(_mistral_version_str.split("v")[-1]) @@ -205,6 +214,7 @@ class MistralTokenizer(TokenizerBase): def from_pretrained( cls, path_or_repo_id: str, *, revision: str | None = None ) -> "MistralTokenizer": + from mistral_common.protocol.instruct.validator import ValidationMode from transformers.tokenization_mistral_common import ( MistralCommonTokenizer as TransformersMistralTokenizer, ) @@ -212,7 +222,7 @@ class MistralTokenizer(TokenizerBase): str_revision = "main" if revision is None else revision return cls( TransformersMistralTokenizer.from_pretrained( - path_or_repo_id, revision=str_revision + path_or_repo_id, revision=str_revision, mode=ValidationMode.test ) ) @@ -339,20 +349,14 @@ class MistralTokenizer(TokenizerBase): max_length: int | None = None, add_special_tokens: bool | None = None, ) -> list[int]: - if add_special_tokens is not None: - return self.transformers_tokenizer.encode( - text, - truncation=truncation, - max_length=max_length, - add_special_tokens=add_special_tokens, - ) - else: - encoded = self.tokenizer.encode(text, bos=True, eos=False) + encoded = self.tokenizer.encode( + text, bos=add_special_tokens is not False, eos=False + ) - if truncation is not False and max_length is not None: - return encoded[:max_length] - else: - return encoded + if truncation is not False and max_length is not None: + return encoded[:max_length] + else: + return encoded def apply_chat_template( self, @@ -383,6 +387,9 @@ class MistralTokenizer(TokenizerBase): ) def decode(self, ids: list[int] | int, skip_special_tokens: bool = True) -> str: + if isinstance(ids, int): + ids = [ids] + return self.transformers_tokenizer.decode( ids, skip_special_tokens=skip_special_tokens )