From b2c1d294faca96643dbc2413d604ca160f458f0d Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Fri, 28 Nov 2025 09:44:47 +0100 Subject: [PATCH] [BUGFIX] MistralTokenizer._call__ adds an invalid EOS token (#29607) Signed-off-by: Julien Denize Signed-off-by: Julien Denize <40604584+juliendenize@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Cyrus Leung --- tests/tokenization/test_mistral_tokenizer.py | 68 +++++++++++++++++++ vllm/transformers_utils/tokenizers/mistral.py | 20 +++++- 2 files changed, 87 insertions(+), 1 deletion(-) diff --git a/tests/tokenization/test_mistral_tokenizer.py b/tests/tokenization/test_mistral_tokenizer.py index 1ada8ee187c38..c80b698ba3848 100644 --- a/tests/tokenization/test_mistral_tokenizer.py +++ b/tests/tokenization/test_mistral_tokenizer.py @@ -331,6 +331,7 @@ class TestMistralTokenizer: ) == token_ids ) + assert mistral_tokenizer.encode_one("") == [] def test_encode(self, mistral_tokenizer: MistralTokenizer): token_ids = ( @@ -370,6 +371,51 @@ class TestMistralTokenizer: mistral_tokenizer.encode("Hello world !", add_special_tokens=False) == token_ids[1:] ) + assert mistral_tokenizer.encode("", add_special_tokens=False) == [] + + def test_call(self, mistral_tokenizer: MistralTokenizer): + token_ids = ( + [1, 22177, 4304, 2662] + if mistral_tokenizer.is_tekken + else [1, 23325, 2294, 1686] + ) + attn_mask = [1 for _ in range(len(token_ids))] + + # Test 1: default + assert mistral_tokenizer("Hello world !") == { + "attention_mask": attn_mask[1:], + "input_ids": token_ids[1:], + } + # Test 2: special tokens + assert mistral_tokenizer("Hello world !", add_special_tokens=True) == { + "attention_mask": attn_mask, + "input_ids": token_ids, + } + # Test 3: special tokens + truncation + assert mistral_tokenizer( + "Hello world !", add_special_tokens=True, truncation=True, max_length=3 + ) == { + "attention_mask": attn_mask[:-1], + "input_ids": token_ids[:-1], + } + # Test 4: special tokens + no truncation + max length + assert mistral_tokenizer( + "Hello world !", add_special_tokens=True, max_length=3 + ) == { + "attention_mask": attn_mask, + "input_ids": token_ids, + } + # Test 5: empty string + assert mistral_tokenizer("") == { + "attention_mask": [], + "input_ids": [], + } + + with pytest.raises( + ValueError, + match=(r"`text_pair` is not supported by `MistralTokenizer.__call__`."), + ): + mistral_tokenizer("Hello world !", "invalid pair") @pytest.mark.parametrize( "openai_request,add_generation_prompt,continue_final_message,expected_output,decoded_expected_output", @@ -1087,6 +1133,24 @@ class TestMistralTokenizer: ) == expected_tokens[mistral_tokenizer.is_tekken] ) + assert ( + mistral_tokenizer.decode( + ids[mistral_tokenizer.is_tekken], + skip_special_tokens=skip_special_tokens, + ) + == expected_tokens[mistral_tokenizer.is_tekken] + ) + + def test_decode_empty( + self, + mistral_tokenizer: MistralTokenizer, + ): + assert ( + mistral_tokenizer.decode( + [], + ) + == "" + ) def test_decode_int( self, @@ -1390,6 +1454,8 @@ class TestMistralTokenizer: == expected_strings[mistral_tokenizer.is_tekken] ) + assert mistral_tokenizer.convert_tokens_to_string([]) == "" + @pytest.mark.parametrize( "skip_special_tokens,tuple_expected_tokens", ( @@ -2220,3 +2286,5 @@ class TestMistralTokenizer: ids, skip_special_tokens=skip_special_tokens ) assert actual_tokens == expected_tokens + + assert mistral_tokenizer.convert_ids_to_tokens([]) == [] diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index 39198a1f3d815..caff43c55ce85 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -312,13 +312,27 @@ class MistralTokenizer(TokenizerBase): truncation: bool = False, max_length: int | None = None, ): - return self.transformers_tokenizer( + if text_pair is not None: + raise ValueError( + "`text_pair` is not supported by `MistralTokenizer.__call__`." + ) + + encoded = self.transformers_tokenizer( text=text, text_pair=text_pair, add_special_tokens=add_special_tokens, truncation=truncation, max_length=max_length, ) + # TODO(juliendenize): once https://github.com/huggingface/transformers/pull/41962 + # is in, revert to only call self.transformers_tokenizer(...). + # Hack to fix wrongly added eos token, when fix will be supported the condition + # below will be False even before the revert is done. + if encoded["input_ids"] and encoded["input_ids"][-1] == self.eos_token_id: + encoded["input_ids"].pop(-1) + if attention_mask := encoded.get("attention_mask"): + attention_mask.pop(-1) + return encoded @property def vocab(self) -> list[str]: @@ -349,6 +363,8 @@ class MistralTokenizer(TokenizerBase): max_length: int | None = None, add_special_tokens: bool | None = None, ) -> list[int]: + # TODO(juliendenize): once https://github.com/huggingface/transformers/pull/41962 + # is in, directly call self.transformers_tokenizer.encode(...). encoded = self.tokenizer.encode( text, bos=add_special_tokens is not False, eos=False ) @@ -387,6 +403,8 @@ class MistralTokenizer(TokenizerBase): ) def decode(self, ids: list[int] | int, skip_special_tokens: bool = True) -> str: + # TODO(juliendenize): once https://github.com/huggingface/transformers/pull/41962 + # is in, directly call self.transformers_tokenizer.decode(...). if isinstance(ids, int): ids = [ids]