From c6187f55f7c4844ed9ff5630d41114cbe6fccb6b Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Fri, 10 Oct 2025 00:48:58 +0200 Subject: [PATCH] Refactor MistralTokenizer (#26358) Signed-off-by: Julien Denize --- docs/features/tool_calling.md | 11 +- examples/offline_inference/audio_language.py | 4 +- requirements/common.txt | 2 +- requirements/nightly_torch_test.txt | 2 +- requirements/test.in | 2 +- requirements/test.txt | 4 +- tests/entrypoints/test_chat_utils.py | 30 +- .../multimodal/generation/test_pixtral.py | 2 +- .../multimodal/generation/test_voxtral.py | 8 +- .../multimodal/processing/test_common.py | 3 +- .../processing/test_tensor_schema.py | 3 +- .../test_mistral_reasoning_parser.py | 28 +- tests/tokenization/test_mistral_tokenizer.py | 2152 ++++++++++++++++- vllm/entrypoints/chat_utils.py | 18 +- vllm/model_executor/models/pixtral.py | 3 +- vllm/model_executor/models/voxtral.py | 8 +- vllm/transformers_utils/tokenizers/mistral.py | 501 ++-- vllm/v1/structured_output/backend_xgrammar.py | 33 +- 18 files changed, 2351 insertions(+), 463 deletions(-) diff --git a/docs/features/tool_calling.md b/docs/features/tool_calling.md index 6a0bcfac66d0a..e57a8945971f5 100644 --- a/docs/features/tool_calling.md +++ b/docs/features/tool_calling.md @@ -145,7 +145,7 @@ Supported models: Known issues: 1. Mistral 7B struggles to generate parallel tool calls correctly. -2. Mistral's `tokenizer_config.json` chat template requires tool call IDs that are exactly 9 digits, which is +2. **For Transformers tokenization backend only**: Mistral's `tokenizer_config.json` chat template requires tool call IDs that are exactly 9 digits, which is much shorter than what vLLM generates. Since an exception is thrown when this condition is not met, the following additional chat templates are provided: @@ -154,7 +154,14 @@ Known issues: * - this is a "better" version that adds a tool-use system prompt when tools are provided, that results in much better reliability when working with parallel tool calling. -Recommended flags: `--tool-call-parser mistral --chat-template examples/tool_chat_template_mistral_parallel.jinja` +Recommended flags: + +1. To use [mistral-common](https://github.com/mistralai/mistral-common) the official Mistral tokenization backend: + + `--tokenizer_mode mistral --config_format mistral --load_format mistral --tool-call-parser mistral` + +2. To use the default Transformers tokenization backend: + `--tool-call-parser mistral --chat-template examples/tool_chat_template_mistral_parallel.jinja` ### Llama Models (`llama3_json`) diff --git a/examples/offline_inference/audio_language.py b/examples/offline_inference/audio_language.py index 65a87d2dd9e8e..a36664e470450 100644 --- a/examples/offline_inference/audio_language.py +++ b/examples/offline_inference/audio_language.py @@ -45,10 +45,12 @@ class ModelRequestData(NamedTuple): # Voxtral def run_voxtral(question: str, audio_count: int) -> ModelRequestData: from mistral_common.audio import Audio - from mistral_common.protocol.instruct.messages import ( + from mistral_common.protocol.instruct.chunk import ( AudioChunk, RawAudio, TextChunk, + ) + from mistral_common.protocol.instruct.messages import ( UserMessage, ) from mistral_common.protocol.instruct.request import ChatCompletionRequest diff --git a/requirements/common.txt b/requirements/common.txt index a87e77dc9901d..d5fa1e92bd7eb 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -32,7 +32,7 @@ pyzmq >= 25.0.0 msgspec gguf >= 0.13.0 importlib_metadata; python_version < '3.10' -mistral_common[image,audio] >= 1.8.2 +mistral_common[image,audio] >= 1.8.5 opencv-python-headless >= 4.11.0 # required for video IO pyyaml six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12 diff --git a/requirements/nightly_torch_test.txt b/requirements/nightly_torch_test.txt index 52b5d269db30c..dea1926bbd695 100644 --- a/requirements/nightly_torch_test.txt +++ b/requirements/nightly_torch_test.txt @@ -23,7 +23,7 @@ jiwer # required for audio tests timm # required for internvl test transformers_stream_generator # required for qwen-vl test matplotlib # required for qwen-vl test -mistral_common[image,audio] >= 1.8.2 # required for voxtral test +mistral_common[image,audio] >= 1.8.5 # required for voxtral test num2words # required for smolvlm test opencv-python-headless >= 4.11.0 # required for video test datamodel_code_generator # required for minicpm3 test diff --git a/requirements/test.in b/requirements/test.in index bf69628e67b2d..f0941d3c59183 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -29,7 +29,7 @@ torchaudio==2.8.0 torchvision==0.23.0 transformers_stream_generator # required for qwen-vl test matplotlib # required for qwen-vl test -mistral_common[image,audio] >= 1.8.2 # required for voxtral test +mistral_common[image,audio] >= 1.8.5 # required for voxtral test num2words # required for smolvlm test open_clip_torch==2.32.0 # Required for nemotron_vl test opencv-python-headless >= 4.11.0 # required for video test diff --git a/requirements/test.txt b/requirements/test.txt index 75d3d40f61346..03fbdcc8d453b 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -474,7 +474,7 @@ mbstrdecoder==1.1.3 # typepy mdurl==0.1.2 # via markdown-it-py -mistral-common==1.8.2 +mistral-common==1.8.5 # via -r requirements/test.in mlflow==2.22.0 # via terratorch @@ -1012,8 +1012,6 @@ sentence-transformers==3.2.1 # via # -r requirements/test.in # mteb -sentencepiece==0.2.0 - # via mistral-common setuptools==77.0.3 # via # lightning-utilities diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py index 6e92419c4f67d..dcd196ebdd772 100644 --- a/tests/entrypoints/test_chat_utils.py +++ b/tests/entrypoints/test_chat_utils.py @@ -6,8 +6,7 @@ from collections.abc import Mapping from typing import Literal, Optional import pytest -from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy, SpecialTokens -from mistral_common.tokens.tokenizers.tekken import SpecialTokenInfo, Tekkenizer +from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy from vllm.assets.audio import AudioAsset from vllm.assets.image import ImageAsset @@ -2119,34 +2118,9 @@ def test_apply_mistral_chat_template_thinking_chunk(): }, {"role": "user", "content": "Thanks, what is 3+3?"}, ] - - # TODO(Julien): upon model release change to a tokenizer already configured. - # ================================================================= mistral_tokenizer = MistralTokenizer.from_pretrained( - "mistralai/Devstral-Small-2507" + "mistralai/Magistral-Small-2509" ) - assert isinstance(mistral_tokenizer.tokenizer, Tekkenizer) - # Add think special tokens to the tokenizer - mistral_tokenizer.tokenizer._all_special_tokens[35] = SpecialTokenInfo( - rank=35, is_control=True, token_str=SpecialTokens.begin_think.value - ) - mistral_tokenizer.tokenizer._all_special_tokens[36] = SpecialTokenInfo( - rank=36, is_control=True, token_str=SpecialTokens.end_think.value - ) - mistral_tokenizer.tokenizer._special_tokens_reverse_vocab = { - k: v - for k, v in mistral_tokenizer.tokenizer._special_tokens_reverse_vocab.items() - if v not in {35, 36} - } - mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[ - SpecialTokens.begin_think.value - ] = 35 - mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[ - SpecialTokens.end_think.value - ] = 36 - mistral_tokenizer.instruct.BEGIN_THINK = 35 - mistral_tokenizer.instruct.END_THINK = 36 - # ================================================================= tokens_ids = apply_mistral_chat_template( mistral_tokenizer, messages, chat_template=None, tools=None diff --git a/tests/models/multimodal/generation/test_pixtral.py b/tests/models/multimodal/generation/test_pixtral.py index db0effdaf6664..bde07da9101ac 100644 --- a/tests/models/multimodal/generation/test_pixtral.py +++ b/tests/models/multimodal/generation/test_pixtral.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any, Optional import pytest from mistral_common.multimodal import download_image -from mistral_common.protocol.instruct.messages import ImageURLChunk +from mistral_common.protocol.instruct.chunk import ImageURLChunk from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.tokens.tokenizers.mistral import MistralTokenizer from mistral_common.tokens.tokenizers.multimodal import image_from_chunk diff --git a/tests/models/multimodal/generation/test_voxtral.py b/tests/models/multimodal/generation/test_voxtral.py index d27b3ab5ff475..18a50c3a555da 100644 --- a/tests/models/multimodal/generation/test_voxtral.py +++ b/tests/models/multimodal/generation/test_voxtral.py @@ -6,12 +6,8 @@ import json import pytest import pytest_asyncio from mistral_common.audio import Audio -from mistral_common.protocol.instruct.messages import ( - AudioChunk, - RawAudio, - TextChunk, - UserMessage, -) +from mistral_common.protocol.instruct.chunk import AudioChunk, RawAudio, TextChunk +from mistral_common.protocol.instruct.messages import UserMessage from vllm.transformers_utils.tokenizer import MistralTokenizer diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index d9d85f7e0c007..5c872143a07e6 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -6,7 +6,8 @@ from typing import Optional, Union import numpy as np import pytest -from mistral_common.protocol.instruct.messages import ImageChunk, TextChunk, UserMessage +from mistral_common.protocol.instruct.chunk import ImageChunk, TextChunk +from mistral_common.protocol.instruct.messages import UserMessage from mistral_common.protocol.instruct.request import ChatCompletionRequest from PIL import Image diff --git a/tests/models/multimodal/processing/test_tensor_schema.py b/tests/models/multimodal/processing/test_tensor_schema.py index 2c4d109c36875..6b6c53a50397b 100644 --- a/tests/models/multimodal/processing/test_tensor_schema.py +++ b/tests/models/multimodal/processing/test_tensor_schema.py @@ -9,7 +9,8 @@ from typing import Any, Union import numpy as np import pytest import torch.nn as nn -from mistral_common.protocol.instruct.messages import ImageChunk, TextChunk, UserMessage +from mistral_common.protocol.instruct.chunk import ImageChunk, TextChunk +from mistral_common.protocol.instruct.messages import UserMessage from mistral_common.protocol.instruct.request import ChatCompletionRequest from PIL import Image diff --git a/tests/reasoning/test_mistral_reasoning_parser.py b/tests/reasoning/test_mistral_reasoning_parser.py index 96107c0c1193b..ff7f94b40ee11 100644 --- a/tests/reasoning/test_mistral_reasoning_parser.py +++ b/tests/reasoning/test_mistral_reasoning_parser.py @@ -2,8 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest -from mistral_common.tokens.tokenizers.base import SpecialTokens -from mistral_common.tokens.tokenizers.tekken import SpecialTokenInfo, Tekkenizer from tests.reasoning.utils import run_reasoning_extraction_mistral from vllm.reasoning import ReasoningParser, ReasoningParserManager @@ -14,33 +12,9 @@ parser_name = "mistral" @pytest.fixture(scope="module") def mistral_tokenizer(): - # TODO(Julien): upon model release change to a tokenizer already configured. - # ================================================================= mistral_tokenizer = MistralTokenizer.from_pretrained( - "mistralai/Devstral-Small-2507" + "mistralai/Magistral-Small-2509" ) - assert isinstance(mistral_tokenizer.tokenizer, Tekkenizer) - # Add think special tokens to the tokenizer - mistral_tokenizer.tokenizer._all_special_tokens[35] = SpecialTokenInfo( - rank=35, is_control=True, token_str=SpecialTokens.begin_think.value - ) - mistral_tokenizer.tokenizer._all_special_tokens[36] = SpecialTokenInfo( - rank=36, is_control=True, token_str=SpecialTokens.end_think.value - ) - mistral_tokenizer.tokenizer._special_tokens_reverse_vocab = { - k: v - for k, v in mistral_tokenizer.tokenizer._special_tokens_reverse_vocab.items() - if v not in {35, 36} - } - mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[ - SpecialTokens.begin_think.value - ] = 35 - mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[ - SpecialTokens.end_think.value - ] = 36 - mistral_tokenizer.instruct.BEGIN_THINK = 35 - mistral_tokenizer.instruct.END_THINK = 36 - # ================================================================= return mistral_tokenizer diff --git a/tests/tokenization/test_mistral_tokenizer.py b/tests/tokenization/test_mistral_tokenizer.py index a034188387d01..ebf107217c3cb 100644 --- a/tests/tokenization/test_mistral_tokenizer.py +++ b/tests/tokenization/test_mistral_tokenizer.py @@ -1,27 +1,20 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any + import pytest -from mistral_common.protocol.instruct.messages import ( - AssistantMessage, - ToolMessage, - UserMessage, -) -from mistral_common.protocol.instruct.request import ChatCompletionRequest -from mistral_common.protocol.instruct.tool_calls import ( - Function, - FunctionCall, - Tool, - ToolCall, -) +from mistral_common.exceptions import InvalidMessageStructureException +from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy from vllm.transformers_utils.tokenizers.mistral import ( - make_mistral_chat_completion_request, + MistralTokenizer, + _prepare_apply_chat_template_tools_and_messages, ) @pytest.mark.parametrize( - "openai_request,expected_mistral_request", + "openai_request,expected_mistral_output", [ ( { @@ -41,19 +34,22 @@ from vllm.transformers_utils.tokenizers.mistral import ( } ], }, - ChatCompletionRequest( - messages=[ - UserMessage(content="What is the current local date and time?") + ( + [ + { + "role": "user", + "content": "What is the current local date and time?", + } ], - tools=[ - Tool( - type="function", - function=Function( - name="get_current_time", - description="Fetch the current local date and time.", - parameters={}, - ), - ) + [ + { + "type": "function", + "function": { + "description": "Fetch the current local date and time.", + "name": "get_current_time", + "parameters": {}, + }, + } ], ), ), @@ -71,39 +67,44 @@ from vllm.transformers_utils.tokenizers.mistral import ( "function": { "description": "Fetch the current local date and time.", "name": "get_current_time", - "parameters": None, + "parameters": {}, }, } ], }, - ChatCompletionRequest( - messages=[ - UserMessage(content="What is the current local date and time?") + ( + [ + { + "role": "user", + "content": "What is the current local date and time?", + } ], - tools=[ - Tool( - type="function", - function=Function( - name="get_current_time", - description="Fetch the current local date and time.", - parameters={}, - ), - ) + [ + { + "type": "function", + "function": { + "description": "Fetch the current local date and time.", + "name": "get_current_time", + "parameters": {}, + }, + } ], ), ), ], ) -def test_make_mistral_chat_completion_request(openai_request, expected_mistral_request): - actual_request = make_mistral_chat_completion_request( +def test_prepare_apply_chat_template_tools_and_messages( + openai_request, expected_mistral_output +): + actual_request = _prepare_apply_chat_template_tools_and_messages( openai_request["messages"], openai_request["tools"] ) - assert actual_request == expected_mistral_request + assert actual_request == expected_mistral_output # Tool use with list content and reasoning_content @pytest.mark.parametrize( - "openai_request,expected_mistral_request", + "openai_request,expected_mistral_output", [ ( { @@ -154,34 +155,40 @@ def test_make_mistral_chat_completion_request(openai_request, expected_mistral_r } ], }, - ChatCompletionRequest( - messages=[ - UserMessage(content="What's the weather in Paris?"), - AssistantMessage( - content=None, - tool_calls=[ - ToolCall( - id="call123", - function=FunctionCall( - name="get_weather", - arguments='{"city": "Paris"}', - ), - ) + ( + [ + { + "role": "user", + "content": "What's the weather in Paris?", + }, + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call123", + "type": "function", + "function": { + "name": "get_weather", + "arguments": '{"city": "Paris"}', + }, + } ], - ), - ToolMessage( - content="Rainy", - tool_call_id="call123", - name="get_weather", - ), + }, + { + "role": "tool", + "content": [{"type": "text", "text": "Rainy"}], + "name": "get_weather", + "tool_call_id": "call123", + }, ], - tools=[ - Tool( - type="function", - function=Function( - name="get_weather", - description="Gets the current weather in a city.", - parameters={ + [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Gets the current weather in a city.", + "parameters": { "type": "object", "properties": { "city": { @@ -191,17 +198,2012 @@ def test_make_mistral_chat_completion_request(openai_request, expected_mistral_r }, "required": ["city"], }, - ), - ) + }, + } ], ), ) ], ) -def test_make_mistral_chat_completion_request_list_content( - openai_request, expected_mistral_request +def test_prepare_apply_chat_template_tools_and_messages_list_content( + openai_request, expected_mistral_output ): - actual_request = make_mistral_chat_completion_request( + actual_request = _prepare_apply_chat_template_tools_and_messages( openai_request["messages"], openai_request["tools"] ) - assert actual_request == expected_mistral_request + assert actual_request == expected_mistral_output + + +def test_prepare_apply_chat_template_generation_prompt_and_continue(): + messages = [{"role": "assistant", "content": "Hello"}] + tools: list[dict[str, Any]] = [] + with pytest.raises(ValueError): + _prepare_apply_chat_template_tools_and_messages( + messages, tools, add_generation_prompt=True + ) + + messages = [{"role": "user", "content": "Hello"}] + out_messages, _ = _prepare_apply_chat_template_tools_and_messages( + messages, tools, add_generation_prompt=True + ) + assert out_messages == [{"role": "user", "content": "Hello"}] + + with pytest.raises(ValueError): + _prepare_apply_chat_template_tools_and_messages( + messages, tools, add_generation_prompt=True, continue_final_message=True + ) + + messages = [{"role": "assistant", "content": "Hello"}] + out_messages, _ = _prepare_apply_chat_template_tools_and_messages( + messages, tools, add_generation_prompt=False, continue_final_message=True + ) + assert out_messages == [{"role": "assistant", "content": "Hello"}] + + messages = [{"role": "user", "content": "Hello"}] + with pytest.raises(ValueError): + _prepare_apply_chat_template_tools_and_messages( + messages, tools, add_generation_prompt=False, continue_final_message=True + ) + + +@pytest.fixture(scope="module") +def mistral_tokenizer(request) -> MistralTokenizer: + return MistralTokenizer.from_pretrained(request.param) + + +@pytest.mark.parametrize( + "mistral_tokenizer", + ["mistralai/Mistral-7B-Instruct-v0.3", "mistralai/Magistral-Small-2509"], + indirect=True, +) +class TestMistralTokenizer: + def test_all_special_tokens(self, mistral_tokenizer: MistralTokenizer): + attributes = [ + mistral_tokenizer.all_special_tokens, + mistral_tokenizer.all_special_tokens_extended, + ] + + for attribute in attributes: + if mistral_tokenizer.is_tekken: + assert attribute == [ + "", + "", + "", + "[INST]", + "[/INST]", + "[AVAILABLE_TOOLS]", + "[/AVAILABLE_TOOLS]", + "[TOOL_RESULTS]", + "[/TOOL_RESULTS]", + "[TOOL_CALLS]", + "[IMG]", + "", + "[IMG_BREAK]", + "[IMG_END]", + "[PREFIX]", + "[MIDDLE]", + "[SUFFIX]", + "[SYSTEM_PROMPT]", + "[/SYSTEM_PROMPT]", + "[TOOL_CONTENT]", + ] + [f"" for i in range(20, 32)] + [ + "[ARGS]", + "[CALL_ID]", + "[THINK]", + "[/THINK]", + ] + [f"" for i in range(36, 1000)] + else: + assert attribute == [ + "", + "", + "[INST]", + "[/INST]", + "[TOOL_CALLS]", + "[AVAILABLE_TOOLS]", + "[/AVAILABLE_TOOLS]", + "[TOOL_RESULTS]", + "[/TOOL_RESULTS]", + ] + [f"[control_{i}]" for i in range(8, 769)] + + def get_vocab(self, mistral_tokenizer: MistralTokenizer): + assert ( + mistral_tokenizer.get_vocab() + == mistral_tokenizer.transformers_tokenizer.get_vocab() + ) + + def test_get_added_vocab(self, mistral_tokenizer: MistralTokenizer): + assert mistral_tokenizer.get_added_vocab() == {} + + def test_encode_one(self, mistral_tokenizer: MistralTokenizer): + token_ids = ( + [22177, 4304, 2662] if mistral_tokenizer.is_tekken else [23325, 2294, 1686] + ) + + assert mistral_tokenizer.encode_one("Hello world !") == token_ids + assert mistral_tokenizer.encode_one("Hello world !", max_length=1) == token_ids + assert ( + mistral_tokenizer.encode_one("Hello world !", truncation=True, max_length=1) + == token_ids[:-2] + ) + assert ( + mistral_tokenizer.encode_one( + "Hello world !", truncation=False, max_length=1 + ) + == token_ids + ) + + def test_encode(self, mistral_tokenizer: MistralTokenizer): + token_ids = ( + [1, 22177, 4304, 2662, 2] + if mistral_tokenizer.is_tekken + else [1, 23325, 2294, 1686, 2] + ) + + assert mistral_tokenizer.encode("Hello world !") == token_ids[:-1] + assert mistral_tokenizer.encode("Hello world !", max_length=3) == token_ids[:-2] + assert ( + mistral_tokenizer.encode("Hello world !", truncation=True, max_length=3) + == token_ids[:-2] + ) + assert ( + mistral_tokenizer.encode("Hello world !", truncation=False, max_length=3) + == token_ids[:-1] + ) + + assert ( + mistral_tokenizer.encode("Hello world !", add_special_tokens=True) + == token_ids + ) + assert ( + mistral_tokenizer.encode( + "Hello world !", add_special_tokens=True, max_length=3 + ) + == token_ids[:-2] + ) + assert ( + mistral_tokenizer.encode( + "Hello world !", add_special_tokens=True, truncation=False, max_length=3 + ) + == token_ids + ) + assert ( + mistral_tokenizer.encode("Hello world !", add_special_tokens=False) + == token_ids[1:-1] + ) + + @pytest.mark.parametrize( + "openai_request,add_generation_prompt,continue_final_message,expected_output,decoded_expected_output", + [ + ( + { + "messages": [ + { + "role": "user", + "content": "Hello world !", + } + ], + }, + True, + False, + ([1, 3, 23325, 2294, 1686, 4], [1, 3, 22177, 4304, 2662, 4]), + ("[INST]▁Hello▁world▁![/INST]", ("[INST]Hello world ![/INST]")), + ), + ( + { + "messages": [ + { + "role": "system", + "content": "I am an AI", + }, + { + "role": "user", + "content": "Hello world !", + }, + ], + }, + True, + False, + ( + [1, 3, 1083, 1605, 1164, 16875, 781, 781, 16998, 2294, 1686, 4], + [1, 17, 1073, 1855, 1420, 26554, 18, 3, 22177, 4304, 2662, 4], + ), + ( + "[INST]▁I▁am▁an▁AI<0x0A><0x0A>Hello▁world▁![/INST]", + ( + "[SYSTEM_PROMPT]I am an AI[/SYSTEM_PROMPT][INST]Hello world ![/INST]" # noqa: E501 + ), + ), + ), + ( + { + "messages": [ + { + "role": "system", + "content": "I am an AI", + }, + { + "role": "user", + "content": "Hello world !", + }, + ], + "tools": [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Gets the current weather in a city.", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city name", + } + }, + "required": ["city"], + }, + }, + } + ], + }, + True, + False, + ( + [ + 1, + 6, + 1501, + 7567, + 1891, + 2032, + 1113, + 3396, + 1316, + 1113, + 3396, + 2032, + 10598, + 1629, + 2032, + 1113, + 1295, + 29498, + 1537, + 1991, + 1316, + 1113, + 7286, + 2032, + 1113, + 2226, + 29481, + 1040, + 2636, + 8854, + 1065, + 1032, + 3758, + 9959, + 1113, + 12206, + 2032, + 10598, + 1891, + 2032, + 1113, + 3582, + 1316, + 1113, + 11491, + 2032, + 10598, + 19141, + 2032, + 10598, + 1891, + 2032, + 1113, + 2195, + 1316, + 1113, + 7286, + 2032, + 1113, + 1782, + 3758, + 1909, + 29507, + 11549, + 1113, + 11661, + 2032, + 8135, + 19141, + 3010, + 1743, + 10925, + 7, + 3, + 1083, + 1605, + 1164, + 16875, + 781, + 781, + 16998, + 2294, + 1686, + 4, + ], + [ + 1, + 17, + 1073, + 1855, + 1420, + 26554, + 18, + 5, + 1091, + 19227, + 4994, + 2811, + 1429, + 5165, + 1897, + 1429, + 5165, + 2811, + 16753, + 2391, + 2811, + 1429, + 1689, + 1095, + 45629, + 1897, + 1429, + 14653, + 2811, + 1429, + 1071, + 3083, + 1278, + 3519, + 17253, + 1294, + 1261, + 5970, + 39249, + 1429, + 26204, + 2811, + 16753, + 4994, + 2811, + 1429, + 6371, + 1897, + 1429, + 48649, + 2811, + 16753, + 29363, + 2811, + 16753, + 4994, + 2811, + 1429, + 3607, + 1897, + 1429, + 14653, + 2811, + 1429, + 1784, + 5970, + 2564, + 1034, + 47579, + 1429, + 15760, + 2811, + 12161, + 29363, + 4964, + 2821, + 27028, + 6, + 3, + 22177, + 4304, + 2662, + 4, + ], + ), + ( + '[AVAILABLE_TOOLS]▁[{"type":▁"function",▁"function":▁{"name":▁"get_weather",▁"description":▁"Gets▁the▁current▁weather▁in▁a▁city.",▁"parameters":▁{"type":▁"object",▁"properties":▁{"city":▁{"type":▁"string",▁"description":▁"The▁city▁name"}},▁"required":▁["city"]}}}][/AVAILABLE_TOOLS][INST]▁I▁am▁an▁AI<0x0A><0x0A>Hello▁world▁![/INST]', + ( + '[SYSTEM_PROMPT]I am an AI[/SYSTEM_PROMPT][AVAILABLE_TOOLS][{"type": "function", "function": {"name": "get_weather", "description": "Gets the current weather in a city.", "parameters": {"type": "object", "properties": {"city": {"type": "string", "description": "The city name"}}, "required": ["city"]}}}][/AVAILABLE_TOOLS][INST]Hello world ![/INST]' # noqa: E501 + ), + ), + ), + ( + { + "messages": [ + { + "role": "system", + "content": "I am an AI", + }, + { + "role": "user", + "content": "Hello world !", + }, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "123456789", + "type": "function", + "function": { + "name": "get_weather", + "arguments": '{"city": "Paris"}', + }, + } + ], + }, + { + "role": "tool", + "tool_call_id": "123456789", + "content": '{"temperature": 20, "unit": "celsius"}', + }, + ], + "tools": [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Gets the current weather in a city.", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city name", + } + }, + "required": ["city"], + }, + }, + } + ], + }, + True, + False, + ( + [ + 1, + 6, + 1501, + 7567, + 1891, + 2032, + 1113, + 3396, + 1316, + 1113, + 3396, + 2032, + 10598, + 1629, + 2032, + 1113, + 1295, + 29498, + 1537, + 1991, + 1316, + 1113, + 7286, + 2032, + 1113, + 2226, + 29481, + 1040, + 2636, + 8854, + 1065, + 1032, + 3758, + 9959, + 1113, + 12206, + 2032, + 10598, + 1891, + 2032, + 1113, + 3582, + 1316, + 1113, + 11491, + 2032, + 10598, + 19141, + 2032, + 10598, + 1891, + 2032, + 1113, + 2195, + 1316, + 1113, + 7286, + 2032, + 1113, + 1782, + 3758, + 1909, + 29507, + 11549, + 1113, + 11661, + 2032, + 8135, + 19141, + 3010, + 1743, + 10925, + 7, + 3, + 1083, + 1605, + 1164, + 16875, + 781, + 781, + 16998, + 2294, + 1686, + 4, + 5, + 1501, + 7567, + 1629, + 2032, + 1113, + 1295, + 29498, + 1537, + 1991, + 1316, + 1113, + 17452, + 2032, + 10598, + 19141, + 2032, + 1113, + 4684, + 1046, + 8474, + 1113, + 1081, + 2032, + 1113, + 29508, + 29518, + 29538, + 29549, + 29550, + 29552, + 29555, + 29551, + 29542, + 29507, + 10925, + 2, + 8, + 10598, + 4557, + 2032, + 10598, + 29475, + 17329, + 2032, + 29473, + 29518, + 29502, + 29493, + 1113, + 6074, + 2032, + 1113, + 29485, + 1958, + 3938, + 8474, + 1113, + 3613, + 29498, + 1081, + 2032, + 1113, + 29508, + 29518, + 29538, + 29549, + 29550, + 29552, + 29555, + 29551, + 29542, + 18163, + 9, + ], + [ + 1, + 17, + 1073, + 1855, + 1420, + 26554, + 18, + 5, + 1091, + 19227, + 4994, + 2811, + 1429, + 5165, + 1897, + 1429, + 5165, + 2811, + 16753, + 2391, + 2811, + 1429, + 1689, + 1095, + 45629, + 1897, + 1429, + 14653, + 2811, + 1429, + 1071, + 3083, + 1278, + 3519, + 17253, + 1294, + 1261, + 5970, + 39249, + 1429, + 26204, + 2811, + 16753, + 4994, + 2811, + 1429, + 6371, + 1897, + 1429, + 48649, + 2811, + 16753, + 29363, + 2811, + 16753, + 4994, + 2811, + 1429, + 3607, + 1897, + 1429, + 14653, + 2811, + 1429, + 1784, + 5970, + 2564, + 1034, + 47579, + 1429, + 15760, + 2811, + 12161, + 29363, + 4964, + 2821, + 27028, + 6, + 3, + 22177, + 4304, + 2662, + 4, + 9, + 1689, + 1095, + 45629, + 32, + 19227, + 29363, + 2811, + 1429, + 42572, + 46005, + 2, + 7, + 19227, + 113824, + 2811, + 1032, + 1050, + 1048, + 1044, + 1429, + 8979, + 2811, + 1429, + 1099, + 79092, + 46005, + 8, + ], + ), + ( + '[AVAILABLE_TOOLS]▁[{"type":▁"function",▁"function":▁{"name":▁"get_weather",▁"description":▁"Gets▁the▁current▁weather▁in▁a▁city.",▁"parameters":▁{"type":▁"object",▁"properties":▁{"city":▁{"type":▁"string",▁"description":▁"The▁city▁name"}},▁"required":▁["city"]}}}][/AVAILABLE_TOOLS][INST]▁I▁am▁an▁AI<0x0A><0x0A>Hello▁world▁![/INST][TOOL_CALLS]▁[{"name":▁"get_weather",▁"arguments":▁{"city":▁"Paris"},▁"id":▁"123456789"}][TOOL_RESULTS]▁{"content":▁{"temperature":▁20,▁"unit":▁"celsius"},▁"call_id":▁"123456789"}[/TOOL_RESULTS]', + ( + '[SYSTEM_PROMPT]I am an AI[/SYSTEM_PROMPT][AVAILABLE_TOOLS][{"type": "function", "function": {"name": "get_weather", "description": "Gets the current weather in a city.", "parameters": {"type": "object", "properties": {"city": {"type": "string", "description": "The city name"}}, "required": ["city"]}}}][/AVAILABLE_TOOLS][INST]Hello world ![/INST][TOOL_CALLS]get_weather[ARGS]{"city": "Paris"}[TOOL_RESULTS]{"temperature": 20, "unit": "celsius"}[/TOOL_RESULTS]' # noqa: E501 + ), + ), + ), + ( + { + "messages": [ + { + "role": "user", + "content": "Hello world !", + }, + { + "role": "assistant", + "content": "Hello ", + }, + ], + }, + False, + True, + ( + [1, 3, 23325, 2294, 1686, 4, 23325], + [1, 3, 22177, 4304, 2662, 4, 22177, 2], + ), + ( + "[INST]▁Hello▁world▁![/INST]▁Hello", + ("[INST]Hello world ![/INST]Hello"), + ), + ), + ], + ) + def test_apply_chat_template( + self, + mistral_tokenizer: MistralTokenizer, + openai_request: dict[str, Any], + add_generation_prompt: bool, + continue_final_message: bool, + expected_output: tuple[list[int], list[int]], + decoded_expected_output: tuple[str, str], + ): + actual_output = mistral_tokenizer.apply_chat_template( + openai_request["messages"], + tools=openai_request.get("tools", []), + add_generation_prompt=add_generation_prompt, + continue_final_message=continue_final_message, + ) + decoded_actual_output = mistral_tokenizer.tokenizer.decode( + actual_output, SpecialTokenPolicy.KEEP + ) + + assert actual_output == expected_output[mistral_tokenizer.is_tekken] + assert ( + decoded_actual_output + == decoded_expected_output[mistral_tokenizer.is_tekken] + ) + + def test_apply_chat_template_error(self, mistral_tokenizer: MistralTokenizer): + messages = [{"role": "user", "content": "Hello world !"}] + + with pytest.raises(ValueError): + mistral_tokenizer.apply_chat_template( + messages, + tools=[], + add_generation_prompt=True, + continue_final_message=True, + ) + + with pytest.raises(ValueError): + mistral_tokenizer.apply_chat_template( + messages, + tools=[], + add_generation_prompt=False, + continue_final_message=True, + ) + + messages = [ + {"role": "user", "content": "Hello world !"}, + {"role": "assistant", "content": "Hello "}, + ] + with pytest.raises(ValueError): + mistral_tokenizer.apply_chat_template( + messages, + tools=[], + add_generation_prompt=True, + continue_final_message=False, + ) + + messages = [ + {"role": "user", "content": "Hello world !"}, + {"role": "assistant", "content": "Hello "}, + ] + with pytest.raises(InvalidMessageStructureException): + mistral_tokenizer.apply_chat_template( + messages, + tools=[], + add_generation_prompt=False, + continue_final_message=False, + ) + + @pytest.mark.parametrize( + "skip_special_tokens,expected_tokens", + ( + ( + False, + ( + "[INST]▁Hello▁world▁![/INST]▁Hello", + "[INST]Hello world ![/INST]Hello", + ), + ), + (True, ("Hello world ! Hello", "Hello world !Hello")), + ), + ) + def test_decode( + self, + mistral_tokenizer: MistralTokenizer, + skip_special_tokens: bool, + expected_tokens: tuple[str, str], + ): + ids = ( + [1, 3, 23325, 2294, 1686, 4, 23325, 2], + [1, 3, 22177, 4304, 2662, 4, 22177, 2], + ) + assert ( + mistral_tokenizer.decode( + ids[mistral_tokenizer.is_tekken], + skip_special_tokens=skip_special_tokens, + ) + == expected_tokens[mistral_tokenizer.is_tekken] + ) + + def test_convert_tokens_to_string(self, mistral_tokenizer: MistralTokenizer): + tokens = ( + [ + "", + "[AVAILABLE_TOOLS]", + "▁[", + '{"', + "type", + '":', + '▁"', + "function", + '",', + '▁"', + "function", + '":', + '▁{"', + "name", + '":', + '▁"', + "get", + "_", + "we", + "ather", + '",', + '▁"', + "description", + '":', + '▁"', + "Get", + "s", + "▁the", + "▁current", + "▁weather", + "▁in", + "▁a", + "▁city", + '.",', + '▁"', + "parameters", + '":', + '▁{"', + "type", + '":', + '▁"', + "object", + '",', + '▁"', + "properties", + '":', + '▁{"', + "city", + '":', + '▁{"', + "type", + '":', + '▁"', + "string", + '",', + '▁"', + "description", + '":', + '▁"', + "The", + "▁city", + "▁name", + '"', + "}},", + '▁"', + "required", + '":', + '▁["', + "city", + '"]', + "}}", + "}]", + "[/AVAILABLE_TOOLS]", + "[INST]", + "▁I", + "▁am", + "▁an", + "▁AI", + "<0x0A>", + "<0x0A>", + "Hello", + "▁world", + "▁!", + "[/INST]", + "[TOOL_CALLS]", + "▁[", + '{"', + "name", + '":', + '▁"', + "get", + "_", + "we", + "ather", + '",', + '▁"', + "arguments", + '":', + '▁{"', + "city", + '":', + '▁"', + "Par", + "is", + '"},', + '▁"', + "id", + '":', + '▁"', + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + '"', + "}]", + "", + "[TOOL_RESULTS]", + '▁{"', + "content", + '":', + '▁{"', + "t", + "emperature", + '":', + "▁", + "2", + "0", + ",", + '▁"', + "unit", + '":', + '▁"', + "c", + "els", + "ius", + '"},', + '▁"', + "call", + "_", + "id", + '":', + '▁"', + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + '"}', + "[/TOOL_RESULTS]", + ], + [ + "", + "[SYSTEM_PROMPT]", + "I", + " am", + " an", + " AI", + "[/SYSTEM_PROMPT]", + "[AVAILABLE_TOOLS]", + "[", + '{"', + "type", + '":', + ' "', + "function", + '",', + ' "', + "function", + '":', + ' {"', + "name", + '":', + ' "', + "get", + "_", + "weather", + '",', + ' "', + "description", + '":', + ' "', + "G", + "ets", + " the", + " current", + " weather", + " in", + " a", + " city", + '.",', + ' "', + "parameters", + '":', + ' {"', + "type", + '":', + ' "', + "object", + '",', + ' "', + "properties", + '":', + ' {"', + "city", + '":', + ' {"', + "type", + '":', + ' "', + "string", + '",', + ' "', + "description", + '":', + ' "', + "The", + " city", + " name", + '"', + "}},", + ' "', + "required", + '":', + ' ["', + "city", + '"]', + "}}", + "}]", + "[/AVAILABLE_TOOLS]", + "[INST]", + "Hello", + " world", + " !", + "[/INST]", + "[TOOL_CALLS]", + "get", + "_", + "weather", + "[ARGS]", + '{"', + "city", + '":', + ' "', + "Paris", + '"}', + "", + "[TOOL_RESULTS]", + '{"', + "temperature", + '":', + " ", + "2", + "0", + ",", + ' "', + "unit", + '":', + ' "', + "c", + "elsius", + '"}', + "[/TOOL_RESULTS]", + ], + ) + + expected_strings = ( + '[{"type": "function", "function": {"name": "get_weather", "description": "Gets the current weather in a city.", "parameters": {"type": "object", "properties": {"city": {"type": "string", "description": "The city name"}}, "required": ["city"]}}}] I am an AI\n\nHello world ![TOOL_CALLS][{"name": "get_weather", "arguments": {"city": "Paris"}, "id": "123456789"}] {"content": {"temperature": 20, "unit": "celsius"}, "call_id": "123456789"}', # noqa: E501 + 'I am an AI[{"type": "function", "function": {"name": "get_weather", "description": "Gets the current weather in a city.", "parameters": {"type": "object", "properties": {"city": {"type": "string", "description": "The city name"}}, "required": ["city"]}}}]Hello world ![TOOL_CALLS]get_weather{"city": "Paris"}{"temperature": 20, "unit": "celsius"}', # noqa: E501 + ) + + assert ( + mistral_tokenizer.convert_tokens_to_string( + tokens[mistral_tokenizer.is_tekken] + ) + == expected_strings[mistral_tokenizer.is_tekken] + ) + + @pytest.mark.parametrize( + "skip_special_tokens,tuple_expected_tokens", + ( + ( + True, + ( + [ + "▁[", + '{"', + "type", + '":', + '▁"', + "function", + '",', + '▁"', + "function", + '":', + '▁{"', + "name", + '":', + '▁"', + "get", + "_", + "we", + "ather", + '",', + '▁"', + "description", + '":', + '▁"', + "Get", + "s", + "▁the", + "▁current", + "▁weather", + "▁in", + "▁a", + "▁city", + '.",', + '▁"', + "parameters", + '":', + '▁{"', + "type", + '":', + '▁"', + "object", + '",', + '▁"', + "properties", + '":', + '▁{"', + "city", + '":', + '▁{"', + "type", + '":', + '▁"', + "string", + '",', + '▁"', + "description", + '":', + '▁"', + "The", + "▁city", + "▁name", + '"', + "}},", + '▁"', + "required", + '":', + '▁["', + "city", + '"]', + "}}", + "}]", + "▁I", + "▁am", + "▁an", + "▁AI", + "<0x0A>", + "<0x0A>", + "Hello", + "▁world", + "▁!", + "[TOOL_CALLS]", + "▁[", + '{"', + "name", + '":', + '▁"', + "get", + "_", + "we", + "ather", + '",', + '▁"', + "arguments", + '":', + '▁{"', + "city", + '":', + '▁"', + "Par", + "is", + '"},', + '▁"', + "id", + '":', + '▁"', + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + '"', + "}]", + '▁{"', + "content", + '":', + '▁{"', + "t", + "emperature", + '":', + "▁", + "2", + "0", + ",", + '▁"', + "unit", + '":', + '▁"', + "c", + "els", + "ius", + '"},', + '▁"', + "call", + "_", + "id", + '":', + '▁"', + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + '"}', + ], + [ + "I", + " am", + " an", + " AI", + "[", + '{"', + "type", + '":', + ' "', + "function", + '",', + ' "', + "function", + '":', + ' {"', + "name", + '":', + ' "', + "get", + "_", + "weather", + '",', + ' "', + "description", + '":', + ' "', + "G", + "ets", + " the", + " current", + " weather", + " in", + " a", + " city", + '.",', + ' "', + "parameters", + '":', + ' {"', + "type", + '":', + ' "', + "object", + '",', + ' "', + "properties", + '":', + ' {"', + "city", + '":', + ' {"', + "type", + '":', + ' "', + "string", + '",', + ' "', + "description", + '":', + ' "', + "The", + " city", + " name", + '"', + "}},", + ' "', + "required", + '":', + ' ["', + "city", + '"]', + "}}", + "}]", + "Hello", + " world", + " !", + "[TOOL_CALLS]", + "get", + "_", + "weather", + '{"', + "city", + '":', + ' "', + "Paris", + '"}', + '{"', + "temperature", + '":', + " ", + "2", + "0", + ",", + ' "', + "unit", + '":', + ' "', + "c", + "elsius", + '"}', + ], + ), + ), + ( + False, + ( + [ + "", + "[AVAILABLE_TOOLS]", + "▁[", + '{"', + "type", + '":', + '▁"', + "function", + '",', + '▁"', + "function", + '":', + '▁{"', + "name", + '":', + '▁"', + "get", + "_", + "we", + "ather", + '",', + '▁"', + "description", + '":', + '▁"', + "Get", + "s", + "▁the", + "▁current", + "▁weather", + "▁in", + "▁a", + "▁city", + '.",', + '▁"', + "parameters", + '":', + '▁{"', + "type", + '":', + '▁"', + "object", + '",', + '▁"', + "properties", + '":', + '▁{"', + "city", + '":', + '▁{"', + "type", + '":', + '▁"', + "string", + '",', + '▁"', + "description", + '":', + '▁"', + "The", + "▁city", + "▁name", + '"', + "}},", + '▁"', + "required", + '":', + '▁["', + "city", + '"]', + "}}", + "}]", + "[/AVAILABLE_TOOLS]", + "[INST]", + "▁I", + "▁am", + "▁an", + "▁AI", + "<0x0A>", + "<0x0A>", + "Hello", + "▁world", + "▁!", + "[/INST]", + "[TOOL_CALLS]", + "▁[", + '{"', + "name", + '":', + '▁"', + "get", + "_", + "we", + "ather", + '",', + '▁"', + "arguments", + '":', + '▁{"', + "city", + '":', + '▁"', + "Par", + "is", + '"},', + '▁"', + "id", + '":', + '▁"', + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + '"', + "}]", + "", + "[TOOL_RESULTS]", + '▁{"', + "content", + '":', + '▁{"', + "t", + "emperature", + '":', + "▁", + "2", + "0", + ",", + '▁"', + "unit", + '":', + '▁"', + "c", + "els", + "ius", + '"},', + '▁"', + "call", + "_", + "id", + '":', + '▁"', + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + '"}', + "[/TOOL_RESULTS]", + ], + [ + "", + "[SYSTEM_PROMPT]", + "I", + " am", + " an", + " AI", + "[/SYSTEM_PROMPT]", + "[AVAILABLE_TOOLS]", + "[", + '{"', + "type", + '":', + ' "', + "function", + '",', + ' "', + "function", + '":', + ' {"', + "name", + '":', + ' "', + "get", + "_", + "weather", + '",', + ' "', + "description", + '":', + ' "', + "G", + "ets", + " the", + " current", + " weather", + " in", + " a", + " city", + '.",', + ' "', + "parameters", + '":', + ' {"', + "type", + '":', + ' "', + "object", + '",', + ' "', + "properties", + '":', + ' {"', + "city", + '":', + ' {"', + "type", + '":', + ' "', + "string", + '",', + ' "', + "description", + '":', + ' "', + "The", + " city", + " name", + '"', + "}},", + ' "', + "required", + '":', + ' ["', + "city", + '"]', + "}}", + "}]", + "[/AVAILABLE_TOOLS]", + "[INST]", + "Hello", + " world", + " !", + "[/INST]", + "[TOOL_CALLS]", + "get", + "_", + "weather", + "[ARGS]", + '{"', + "city", + '":', + ' "', + "Paris", + '"}', + "", + "[TOOL_RESULTS]", + '{"', + "temperature", + '":', + " ", + "2", + "0", + ",", + ' "', + "unit", + '":', + ' "', + "c", + "elsius", + '"}', + "[/TOOL_RESULTS]", + ], + ), + ), + ), + ) + def test_convert_ids_to_tokens( + self, + mistral_tokenizer: MistralTokenizer, + skip_special_tokens: bool, + tuple_expected_tokens: tuple[list[str], list[str]], + ): + tuple_ids = ( + [ + 1, + 6, + 1501, + 7567, + 1891, + 2032, + 1113, + 3396, + 1316, + 1113, + 3396, + 2032, + 10598, + 1629, + 2032, + 1113, + 1295, + 29498, + 1537, + 1991, + 1316, + 1113, + 7286, + 2032, + 1113, + 2226, + 29481, + 1040, + 2636, + 8854, + 1065, + 1032, + 3758, + 9959, + 1113, + 12206, + 2032, + 10598, + 1891, + 2032, + 1113, + 3582, + 1316, + 1113, + 11491, + 2032, + 10598, + 19141, + 2032, + 10598, + 1891, + 2032, + 1113, + 2195, + 1316, + 1113, + 7286, + 2032, + 1113, + 1782, + 3758, + 1909, + 29507, + 11549, + 1113, + 11661, + 2032, + 8135, + 19141, + 3010, + 1743, + 10925, + 7, + 3, + 1083, + 1605, + 1164, + 16875, + 781, + 781, + 16998, + 2294, + 1686, + 4, + 5, + 1501, + 7567, + 1629, + 2032, + 1113, + 1295, + 29498, + 1537, + 1991, + 1316, + 1113, + 17452, + 2032, + 10598, + 19141, + 2032, + 1113, + 4684, + 1046, + 8474, + 1113, + 1081, + 2032, + 1113, + 29508, + 29518, + 29538, + 29549, + 29550, + 29552, + 29555, + 29551, + 29542, + 29507, + 10925, + 2, + 8, + 10598, + 4557, + 2032, + 10598, + 29475, + 17329, + 2032, + 29473, + 29518, + 29502, + 29493, + 1113, + 6074, + 2032, + 1113, + 29485, + 1958, + 3938, + 8474, + 1113, + 3613, + 29498, + 1081, + 2032, + 1113, + 29508, + 29518, + 29538, + 29549, + 29550, + 29552, + 29555, + 29551, + 29542, + 18163, + 9, + ], + [ + 1, + 17, + 1073, + 1855, + 1420, + 26554, + 18, + 5, + 1091, + 19227, + 4994, + 2811, + 1429, + 5165, + 1897, + 1429, + 5165, + 2811, + 16753, + 2391, + 2811, + 1429, + 1689, + 1095, + 45629, + 1897, + 1429, + 14653, + 2811, + 1429, + 1071, + 3083, + 1278, + 3519, + 17253, + 1294, + 1261, + 5970, + 39249, + 1429, + 26204, + 2811, + 16753, + 4994, + 2811, + 1429, + 6371, + 1897, + 1429, + 48649, + 2811, + 16753, + 29363, + 2811, + 16753, + 4994, + 2811, + 1429, + 3607, + 1897, + 1429, + 14653, + 2811, + 1429, + 1784, + 5970, + 2564, + 1034, + 47579, + 1429, + 15760, + 2811, + 12161, + 29363, + 4964, + 2821, + 27028, + 6, + 3, + 22177, + 4304, + 2662, + 4, + 9, + 1689, + 1095, + 45629, + 32, + 19227, + 29363, + 2811, + 1429, + 42572, + 46005, + 2, + 7, + 19227, + 113824, + 2811, + 1032, + 1050, + 1048, + 1044, + 1429, + 8979, + 2811, + 1429, + 1099, + 79092, + 46005, + 8, + ], + ) + + ids = tuple_ids[mistral_tokenizer.is_tekken] + expected_tokens = tuple_expected_tokens[mistral_tokenizer.is_tekken] + actual_tokens = mistral_tokenizer.convert_ids_to_tokens( + ids, skip_special_tokens=skip_special_tokens + ) + assert actual_tokens == expected_tokens diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 930b3bc69c3db..e548554dca734 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -403,20 +403,12 @@ def resolve_mistral_chat_template( chat_template: Optional[str], **kwargs: Any, ) -> Optional[str]: - if chat_template is not None: - logger.warning_once( - "'chat_template' cannot be overridden for mistral tokenizer." - ) - if "add_generation_prompt" in kwargs: - logger.warning_once( - "'add_generation_prompt' is not supported for mistral tokenizer, " - "so it will be ignored." - ) - if "continue_final_message" in kwargs: - logger.warning_once( - "'continue_final_message' is not supported for mistral tokenizer, " - "so it will be ignored." + if chat_template is not None or kwargs.get("chat_template_kwargs") is not None: + raise ValueError( + "'chat_template' or 'chat_template_kwargs' cannot be overridden " + "for mistral tokenizer." ) + return None diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 65abebcf37de9..62f642eae4b52 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -10,7 +10,8 @@ from typing import Annotated, Literal, Optional, Union import torch import torch.nn as nn import torch.nn.functional as F -from mistral_common.protocol.instruct.messages import ImageChunk, TextChunk, UserMessage +from mistral_common.protocol.instruct.chunk import ImageChunk, TextChunk +from mistral_common.protocol.instruct.messages import UserMessage from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.tokens.tokenizers.multimodal import ImageEncoder from PIL import Image diff --git a/vllm/model_executor/models/voxtral.py b/vllm/model_executor/models/voxtral.py index f929ba9913ecf..f4bfbd26756e1 100644 --- a/vllm/model_executor/models/voxtral.py +++ b/vllm/model_executor/models/voxtral.py @@ -12,12 +12,8 @@ import regex as re import torch import torch.nn as nn from mistral_common.audio import mel_filter_bank -from mistral_common.protocol.instruct.messages import ( - AudioChunk, - RawAudio, - TextChunk, - UserMessage, -) +from mistral_common.protocol.instruct.chunk import AudioChunk, RawAudio, TextChunk +from mistral_common.protocol.instruct.messages import UserMessage from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.protocol.transcription.request import TranscriptionRequest from mistral_common.tokens.tokenizers.audio import Audio, AudioEncoder diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index 5633a31455e9b..eae067fcfa344 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -1,34 +1,27 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import os -from pathlib import Path from typing import TYPE_CHECKING, Any, Optional, Union, cast -import huggingface_hub -import regex as re -from huggingface_hub import HfApi, hf_hub_download -from transformers.tokenization_utils_base import BatchEncoding - from vllm.logger import init_logger from vllm.transformers_utils.tokenizer_base import TokenizerBase -from vllm.utils import is_list_of if TYPE_CHECKING: - # make sure `mistral_common` is lazy imported, - # so that users who only use non-mistral models - # will not be bothered by the dependency. - from mistral_common.protocol.instruct.request import ChatCompletionRequest - from mistral_common.tokens.tokenizers.mistral import ( - MistralTokenizer as PublicMistralTokenizer, + from mistral_common.protocol.instruct.request import ( + ChatCompletionRequest as MistralChatCompletionRequest, + ) + from mistral_common.tokens.tokenizers.tekken import Tekkenizer + from transformers.tokenization_mistral_common import ( + MistralCommonTokenizer as TransformersMistralTokenizer, ) from vllm.entrypoints.chat_utils import ChatCompletionMessageParam + from vllm.entrypoints.openai.protocol import ChatCompletionRequest logger = init_logger(__name__) -def maybe_serialize_tool_calls(request: "ChatCompletionRequest"): +def maybe_serialize_tool_calls(request: "MistralChatCompletionRequest"): # SEE: https://github.com/vllm-project/vllm/pull/9951 # Credits go to: @gcalmettes # NOTE: There is currently a bug in pydantic where attributes @@ -65,7 +58,7 @@ def maybe_serialize_tool_calls(request: "ChatCompletionRequest"): request.messages[i]["tool_calls"] = validated_tool_calls -def truncate_tool_call_ids(request: "ChatCompletionRequest"): +def truncate_tool_call_ids(request: "MistralChatCompletionRequest"): """Truncates tool call IDs for Mistral's ID requirements.""" for i, message in enumerate(request.messages): if message.get("role") == "assistant": @@ -95,84 +88,34 @@ def truncate_tool_call_ids(request: "ChatCompletionRequest"): request.messages[i]["tool_call_id"] = tool_call_id -def validate_request_params(request: "ChatCompletionRequest"): - if request.skip_special_tokens is not None and not request.skip_special_tokens: - raise ValueError( - "skip_special_tokens=False is not supported for Mistral tokenizers." - ) - - -def list_local_repo_files(repo_id: str, revision: Optional[str]) -> list[str]: - repo_cache = os.path.join( - huggingface_hub.constants.HF_HUB_CACHE, - huggingface_hub.constants.REPO_ID_SEPARATOR.join( - ["models", *repo_id.split("/")] - ), - ) - - if revision is None: - revision_file = os.path.join(repo_cache, "refs", "main") - if os.path.isfile(revision_file): - with open(revision_file) as file: - revision = file.read() - - if revision: - revision_dir = os.path.join(repo_cache, "snapshots", revision) - if os.path.isdir(revision_dir): - return os.listdir(revision_dir) - - return [] - - -def find_tokenizer_file(files: list[str]): - # Accept both versioned (tokenizer.model.v3) and unversioned - # (tokenizer.model) forms, plus tekken.json and tokenizer.mm.model - # variants. Previous pattern only matched the versioned variants. - file_pattern = re.compile( - r"^tokenizer\.model(\.v.*)?|tekken\.json|tokenizer\.mm\.model(\.v.*)?$" - ) - - matched_files = [file for file in files if file_pattern.match(file)] - if len(matched_files) > 1: - logger.warning( - "Multiple files matched pattern `%s`: %s. Using %s.", - file_pattern.pattern, - matched_files, - matched_files[0], - ) - elif len(matched_files) == 0: - raise OSError( - f"Found {len(matched_files)} files matching the " - f"pattern: `{file_pattern.pattern}`. Make sure that a Mistral " - f"tokenizer is present in {files}." - ) - - return matched_files[0] - - -def _aggregate_content(content: list) -> list[dict[str, Any]]: - aggregated_content: list[dict[str, Any]] = [] - for chunk in content: - if ( - chunk.get("type") == "text" - and aggregated_content - and aggregated_content[-1].get("type") == "text" - ): - aggregated_content[-1]["text"] += "\n\n" + chunk.get("text") - else: - aggregated_content.append(chunk) - if len(aggregated_content) == 1 and aggregated_content[0].get("type") == "text": - content = aggregated_content[0]["text"] - return content - - -def make_mistral_chat_completion_request( +def _prepare_apply_chat_template_tools_and_messages( messages: list["ChatCompletionMessageParam"], tools: Optional[list[dict[str, Any]]] = None, -) -> "ChatCompletionRequest": + continue_final_message: bool = False, + add_generation_prompt: bool = False, +) -> tuple[list["ChatCompletionMessageParam"], Optional[list[dict[str, Any]]]]: + if add_generation_prompt and continue_final_message: + raise ValueError( + "Cannot set both `add_generation_prompt` and " + "`continue_final_message` to True." + ) + last_message = cast(dict[str, Any], messages[-1]) - if last_message["role"] == "assistant": - last_message["prefix"] = True + # add_generation_prompt is directly handled by the tokenizer but we + # check if the user is trying to use it with a final assistant message + # which is probably not what they want. + # If add_generation_prompt is False, we don't need to check anything. + if add_generation_prompt and last_message["role"] == "assistant": + raise ValueError( + "Cannot set `add_generation_prompt` to True when " + "the last message is from the assistant. Consider " + "using `continue_final_message` instead." + ) + if continue_final_message and last_message["role"] != "assistant": + raise ValueError( + "Cannot set `continue_final_message` to True when " + "the last message is not from the assistant." + ) # mistral-common requires AssistantMessage content to be string [1]. # @@ -181,13 +124,6 @@ def make_mistral_chat_completion_request( # Remove reasoning_content as unsupported by Mistral _ = message.pop("reasoning_content", None) # type: ignore - # Convert list text content to string - if message.get("role") in ("assistant", "tool"): - content: Any = message.get("content") - if isinstance(content, list): - content = _aggregate_content(content) - message["content"] = content - # The Mistral client, in comparison to the OpenAI client, requires the # "parameters" dict and the "description" string to be present # even if they are empty. @@ -200,108 +136,113 @@ def make_mistral_chat_completion_request( if function.get("description") is None: function["description"] = "" - from mistral_common.protocol.instruct.request import ChatCompletionRequest + return messages, tools - return ChatCompletionRequest(messages=messages, tools=tools) # type: ignore[type-var] + +def validate_request_params(request: "ChatCompletionRequest"): + if request.chat_template is not None or request.chat_template_kwargs is not None: + raise ValueError("chat_template is not supported for Mistral tokenizers.") + + +def _tekken_token_to_id(tokenizer: "Tekkenizer", t: Union[str, bytes]) -> int: + from mistral_common.tokens.tokenizers.tekken import Tekkenizer + + assert isinstance(tokenizer, Tekkenizer), type(tokenizer) + + t_bytes = t.encode("utf-8") if not isinstance(t, bytes) else t + shift = tokenizer.num_special_tokens + try: + return shift + tokenizer._tekken_token2id_nospecial[t_bytes] + except KeyError: + t_str = t_bytes.decode("utf-8") + if t_str in tokenizer._special_tokens_reverse_vocab: + return tokenizer._special_tokens_reverse_vocab[t_str] + logger.warning( + "Failed to convert token %s to id, replacing with ", t_bytes + ) + return tokenizer.unk_id class MistralTokenizer(TokenizerBase): - def __init__(self, tokenizer: "PublicMistralTokenizer") -> None: - self.mistral = tokenizer - self.instruct = tokenizer.instruct_tokenizer - _mistral_version_str = self.instruct.tokenizer.version.value - self.version: int = int(_mistral_version_str.split("v")[-1]) - - tokenizer_ = tokenizer.instruct_tokenizer.tokenizer - from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy - from mistral_common.tokens.tokenizers.tekken import Tekkenizer - - self.is_tekken = isinstance(tokenizer_, Tekkenizer) + def __init__(self, tokenizer: "TransformersMistralTokenizer") -> None: from mistral_common.tokens.tokenizers.sentencepiece import ( SentencePieceTokenizer, ) + from mistral_common.tokens.tokenizers.tekken import Tekkenizer - self.is_spm = isinstance(tokenizer_, SentencePieceTokenizer) - self._special_token_policy = ( - SpecialTokenPolicy.IGNORE if self.is_tekken else None - ) + self.transformers_tokenizer = tokenizer + self.mistral = tokenizer.tokenizer + self.instruct = self.mistral.instruct_tokenizer + self.tokenizer = self.instruct.tokenizer + + _mistral_version_str = str(self.tokenizer.version.value) + self.version: int = int(_mistral_version_str.split("v")[-1]) + + self.is_tekken = isinstance(self.tokenizer, Tekkenizer) + self.is_spm = isinstance(self.tokenizer, SentencePieceTokenizer) if not (self.is_tekken or self.is_spm): - raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}") + raise TypeError(f"Unsupported tokenizer: {type(self.tokenizer)}") - self._vocab = tokenizer_.vocab() - # Convert to a dict[str, int] to match protocol, but this is a lossy - # conversion. There may be multiple token ids that decode to the same - # string due to partial UTF-8 byte sequences being converted to � - self._vocab_dict = {token: idx for idx, token in enumerate(self._vocab)} - self.tokenizer = tokenizer_ + # Reverse order to ensure that the lowest token id is kept. + self._vocab_dict = { + self.convert_ids_to_tokens([i], skip_special_tokens=False)[0]: i + for i in range(self.vocab_size - 1, -1, -1) + } + # Sort the dict for convenience + self._vocab_dict = dict(sorted(self._vocab_dict.items(), key=lambda x: x[1])) + + # Vocab sorted by token id. + self._vocab = self.tokenizer._vocab self._max_token_id = self.vocab_size - 1 @classmethod def from_pretrained( cls, path_or_repo_id: str, *, revision: Optional[str] = None ) -> "MistralTokenizer": - if not Path(path_or_repo_id).exists(): - assert len(path_or_repo_id.split("/")) == 2, ( - "You have either provided a non-existent path: " - "{path_or_repo_id} or an invalid HF Hub repo id." - ) - tokenizer_file = cls._download_mistral_tokenizer_from_hf( - path_or_repo_id, revision - ) - elif Path(path_or_repo_id).is_dir(): - tokenizer_file_name = find_tokenizer_file(os.listdir(path_or_repo_id)) - tokenizer_file = str(Path(path_or_repo_id) / tokenizer_file_name) - else: - assert Path(path_or_repo_id).is_file(), f"Invalid path: {path_or_repo_id}" - tokenizer_file = str(Path(path_or_repo_id)) - - from mistral_common.tokens.tokenizers.mistral import ( - MistralTokenizer as PublicMistralTokenizer, + from transformers.tokenization_mistral_common import ( + MistralCommonTokenizer as TransformersMistralTokenizer, ) - mistral_tokenizer = PublicMistralTokenizer.from_file(tokenizer_file) - return cls(mistral_tokenizer) - - @staticmethod - def _download_mistral_tokenizer_from_hf( - tokenizer_name: str, revision: Optional[str] - ) -> str: - try: - hf_api = HfApi() - files = hf_api.list_repo_files(repo_id=tokenizer_name, revision=revision) - except ConnectionError as exc: - files = list_local_repo_files(repo_id=tokenizer_name, revision=revision) - - if len(files) == 0: - raise exc - - filename = find_tokenizer_file(files) - - tokenizer_file = hf_hub_download( - tokenizer_name, filename=filename, revision=revision + str_revision = "main" if revision is None else revision + return cls( + TransformersMistralTokenizer.from_pretrained( + path_or_repo_id, revision=str_revision + ) ) - return tokenizer_file # the following attributes are set to fit vLLM's design and are used # by the structured output backends. @property def all_special_tokens_extended(self) -> list[str]: - from mistral_common.tokens.tokenizers.base import SpecialTokens - - # tekken defines its own extended special tokens list - if hasattr(self.tokenizer, "SPECIAL_TOKENS"): - special_tokens = self.tokenizer.SPECIAL_TOKENS - else: - special_tokens = list(SpecialTokens) - return [s.value if isinstance(s, SpecialTokens) else s for s in special_tokens] + return self.all_special_tokens @property def all_special_tokens(self) -> list[str]: - return self.all_special_tokens_extended + from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy + + return [ + self.tokenizer.decode([i], special_token_policy=SpecialTokenPolicy.KEEP) + for i in self.all_special_ids + ] @property def all_special_ids(self) -> list[int]: - return [self.all_special_tokens.index(t) for t in self.all_special_tokens] + from mistral_common.tokens.tokenizers.sentencepiece import ( + SentencePieceTokenizer, + ) + from mistral_common.tokens.tokenizers.tekken import Tekkenizer + + if self.is_tekken: + assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer) + special_ids = {t["rank"] for t in self.tokenizer._all_special_tokens} + elif self.is_spm: + assert isinstance(self.tokenizer, SentencePieceTokenizer), type( + self.tokenizer + ) + special_ids = self.tokenizer._control_tokens + else: + raise ValueError(f"Unknown tokenizer type: {type(self.tokenizer)}") + return sorted(special_ids) @property def bos_token_id(self) -> int: @@ -317,7 +258,7 @@ class MistralTokenizer(TokenizerBase): @property def pad_token(self) -> str: - raise NotImplementedError() + return self.transformers_tokenizer.pad_token @property def is_fast(self) -> bool: @@ -325,7 +266,7 @@ class MistralTokenizer(TokenizerBase): @property def vocab_size(self) -> int: - return len(self._vocab) + return self.transformers_tokenizer.vocab_size @property def max_token_id(self) -> int: @@ -335,6 +276,23 @@ class MistralTokenizer(TokenizerBase): def truncation_side(self) -> str: raise NotImplementedError() + def _is_special_token_id(self, token_id: int) -> bool: + from mistral_common.tokens.tokenizers.sentencepiece import ( + SentencePieceTokenizer, + ) + from mistral_common.tokens.tokenizers.tekken import Tekkenizer + + if self.is_spm: + assert isinstance(self.tokenizer, SentencePieceTokenizer), type( + self.tokenizer + ) + return token_id in self.tokenizer._control_tokens + if self.is_tekken: + assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer) + return token_id < self.tokenizer.num_special_tokens + else: + raise ValueError(f"Unknown tokenizer type: {type(self.tokenizer)}") + def __len__(self) -> int: return self.vocab_size @@ -346,25 +304,19 @@ class MistralTokenizer(TokenizerBase): truncation: bool = False, max_length: Optional[int] = None, ): - input_ids: Union[list[int], list[list[int]]] - # For list[str], original prompt text - if is_list_of(text, str): - input_ids_: list[list[int]] = [] - for p in text: - each_input_ids = self.encode_one(p, truncation, max_length) - input_ids_.append(each_input_ids) - input_ids = input_ids_ - # For list[int], apply chat template output, already tokens. - elif is_list_of(text, int): - input_ids = text - # For str, single prompt text - else: - input_ids = self.encode_one(text, truncation, max_length) - return BatchEncoding({"input_ids": input_ids}) + return self.transformers_tokenizer( + text=text, + text_pair=text_pair, + add_special_tokens=add_special_tokens, + truncation=truncation, + max_length=max_length, + ) + + @property + def vocab(self) -> list[str]: + return self._vocab def get_vocab(self) -> dict[str, int]: - # NB: the dictionary form of the vocabulary collapses token ids that map - # to the same string but have different bytes return self._vocab_dict def get_added_vocab(self) -> dict[str, int]: @@ -378,11 +330,9 @@ class MistralTokenizer(TokenizerBase): max_length: Optional[int] = None, ) -> list[int]: # Mistral Tokenizers should not add special tokens - input_ids = self.encode(text) - - if truncation: - input_ids = input_ids[:max_length] - return input_ids + return self.transformers_tokenizer.encode( + text, add_special_tokens=False, truncation=truncation, max_length=max_length + ) def encode( self, @@ -391,15 +341,20 @@ class MistralTokenizer(TokenizerBase): max_length: Optional[int] = None, add_special_tokens: Optional[bool] = None, ) -> list[int]: - # `encode` should only be used for prompt completion - # it should never be used for chat_completion. - # For chat completion use `apply_chat_template` if add_special_tokens is not None: - return self.tokenizer.encode( - text, bos=add_special_tokens, eos=add_special_tokens + return self.transformers_tokenizer.encode( + text, + truncation=truncation, + max_length=max_length, + add_special_tokens=add_special_tokens, ) else: - return self.tokenizer.encode(text, bos=True, eos=False) + encoded = self.tokenizer.encode(text, bos=True, eos=False) + + if truncation is not False and max_length is not None: + return encoded[:max_length] + else: + return encoded def apply_chat_template( self, @@ -407,59 +362,79 @@ class MistralTokenizer(TokenizerBase): tools: Optional[list[dict[str, Any]]] = None, **kwargs, ) -> list[int]: - request = make_mistral_chat_completion_request(messages, tools) - encoded = self.mistral.encode_chat_completion(request) + add_generation_prompt = kwargs.pop("add_generation_prompt", False) + continue_final_message = kwargs.get("continue_final_message", False) + padding = kwargs.get("padding", False) + truncation = kwargs.get("truncation", False) + max_length = kwargs.get("max_length") - # encode-decode to get clean prompt - return encoded.tokens + messages, tools = _prepare_apply_chat_template_tools_and_messages( + messages, tools, continue_final_message, add_generation_prompt + ) + + return self.transformers_tokenizer.apply_chat_template( + conversation=messages, + tools=tools, + continue_final_message=continue_final_message, + tokenize=True, + padding=padding, + truncation=truncation, + max_length=max_length, + return_tensors=None, + return_dict=False, + ) + + def decode( + self, ids: Union[list[int], int], skip_special_tokens: bool = True + ) -> str: + return self.transformers_tokenizer.decode( + ids, skip_special_tokens=skip_special_tokens + ) def convert_tokens_to_string(self, tokens: list[str]) -> str: - from mistral_common.tokens.tokenizers.base import SpecialTokens + from mistral_common.tokens.tokenizers.base import ( + SpecialTokenPolicy, + SpecialTokens, + ) + from mistral_common.tokens.tokenizers.sentencepiece import ( + SentencePieceTokenizer, + ) + from mistral_common.tokens.tokenizers.tekken import Tekkenizer + to_decode_special_tokens = {SpecialTokens.tool_calls} if self.is_tekken: + assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer) tokens = [ t for t in tokens - if ( - t is SpecialTokens.tool_calls - or t not in self.tokenizer._all_special_tokens - ) + if (t in to_decode_special_tokens or t not in self.all_special_tokens) ] if any(isinstance(t, bytes) for t in tokens): # we need to encode and decode all tokens again - shift = self.tokenizer.num_special_tokens - - def _token_to_id(t: str): - t_bytes = t.encode("utf-8") if not isinstance(t, bytes) else t - try: - return ( - shift + self.tokenizer._tekken_token2id_nospecial[t_bytes] - ) - except KeyError: - logger.warning( - "Failed to convert token %s to id, replacing with ", - t_bytes, - ) - return self.tokenizer.unk_id - - ids = [_token_to_id(t) for t in tokens] - decoded = self.tokenizer.decode(ids, self._special_token_policy) + ids = [_tekken_token_to_id(self.tokenizer, t) for t in tokens] + # We filtered unwanted special tokens before + # so we can decode the rest. + decoded = self.tokenizer.decode(ids, SpecialTokenPolicy.KEEP) else: decoded = "".join(tokens) else: # make sure certain special tokens like Tool calls are # not decoded - special_tokens = {SpecialTokens.tool_calls} + assert isinstance(self.tokenizer, SentencePieceTokenizer), type( + self.tokenizer + ) + regular_tokens: list[str] = [] - decoded_list = [] + decoded_list: list[str] = [] + decoded = "" for token in tokens: - if token in special_tokens: + if token in to_decode_special_tokens: if regular_tokens: decoded_list.append( self.tokenizer.decode( - regular_tokens, self._special_token_policy + regular_tokens, SpecialTokenPolicy.IGNORE ) ) regular_tokens = [] @@ -469,66 +444,56 @@ class MistralTokenizer(TokenizerBase): if regular_tokens: decoded_list.append( - self.tokenizer.decode(regular_tokens, self._special_token_policy) + self.tokenizer.decode(regular_tokens, SpecialTokenPolicy.IGNORE) ) - decoded = "".join(decoded_list) return decoded - def decode( - self, ids: Union[list[int], int], skip_special_tokens: bool = True - ) -> str: - assert skip_special_tokens, ( - "skip_special_tokens=False is not supported for Mistral tokenizers." - ) - - if isinstance(ids, int): - ids = [ids] - return self.tokenizer.decode(ids, self._special_token_policy) - def convert_ids_to_tokens( self, ids: list[int], skip_special_tokens: bool = True, ) -> list[str]: - from mistral_common.tokens.tokenizers.base import SpecialTokens + from mistral_common.tokens.tokenizers.base import ( + SpecialTokenPolicy, + SpecialTokens, + ) from mistral_common.tokens.tokenizers.instruct import InstructTokenizerV13 - # TODO(Patrick) - potentially allow special tokens to not be skipped - assert skip_special_tokens, ( - "skip_special_tokens=False is not supported for Mistral tokenizers." - ) + if not skip_special_tokens: + return [self.tokenizer.id_to_piece(token_id) for token_id in ids] - assert self.is_tekken or self.is_spm, type(self.tokenizer) + non_skip_special_tokens_ids = { + self.tokenizer.get_control_token(SpecialTokens.tool_calls), + } + if isinstance(self.instruct, InstructTokenizerV13): + if self.instruct.BEGIN_THINK: + non_skip_special_tokens_ids.add(self.instruct.BEGIN_THINK) + if self.instruct.END_THINK: + non_skip_special_tokens_ids.add(self.instruct.END_THINK) - if self.is_tekken: - # skip special tokens except tool call and think tokens - non_skip_special_tokens = { - self.tokenizer.get_control_token(SpecialTokens.tool_calls) - } - if isinstance(self.instruct, InstructTokenizerV13): - if self.instruct.BEGIN_THINK: - non_skip_special_tokens.add(self.instruct.BEGIN_THINK) - if self.instruct.END_THINK: - non_skip_special_tokens.add(self.instruct.END_THINK) - ids = [ - i - for i in ids - if i > self.tokenizer.num_special_tokens or i in non_skip_special_tokens - ] + ids_kept = [ + i + for i in ids + if i in non_skip_special_tokens_ids or not self._is_special_token_id(i) + ] - tokens = [self.tokenizer.id_to_piece(id) for id in ids] + # We filtered unwanted special tokens so we can decode the rest. + tokens = [self.tokenizer.id_to_piece(token_id) for token_id in ids_kept] if any("�" in t for t in tokens) and self.is_tekken: # if a decoded token contains the replacement character, then the # token has an incomplete UTF-8 character so we must use bytes # See: https://github.com/vllm-project/vllm/pull/8640 # https://github.com/vllm-project/vllm/pull/9625 - # if underlying tokenizeir is sentencepiece, we just add "�" + # if underlying tokenizer is sentencepiece, we just add "�". + # We filtered unwanted special tokens so we can decode the rest. tokens = [ - self.tokenizer.id_to_byte_piece(id, self._special_token_policy) - for id in ids + self.tokenizer.id_to_byte_piece(token_id, SpecialTokenPolicy.KEEP) + if token_id not in self.all_special_ids + else self.tokenizer.decode([token_id], SpecialTokenPolicy.KEEP) + for token_id in ids_kept ] return tokens diff --git a/vllm/v1/structured_output/backend_xgrammar.py b/vllm/v1/structured_output/backend_xgrammar.py index 9f81d09633d7b..4b21b2591c589 100644 --- a/vllm/v1/structured_output/backend_xgrammar.py +++ b/vllm/v1/structured_output/backend_xgrammar.py @@ -43,34 +43,13 @@ class XgrammarBackend(StructuredOutputBackend): if isinstance(self.tokenizer, MistralTokenizer): # NOTE: ideally, xgrammar should handle this accordingly. # refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98 - try: - if self.tokenizer.is_tekken: - encoded_vocab = self.tokenizer._vocab - else: - encoded_vocab = [ - token - for token, _ in sorted( - self.tokenizer.get_vocab().items(), - key=lambda x: x[1], - ) - ] - stop_token_ids = None - if ( - hasattr( - self.tokenizer, - "eos_token_id", - ) - and self.tokenizer.eos_token_id is not None - ): - stop_token_ids = [self.tokenizer.eos_token_id] - except AttributeError as e: - raise ValueError( - f"Cannot get the vocabulary of the tokenizer " - f"{type(self.tokenizer)}. The tokenizer should have a " - "get_vocab method." - ) from e + stop_token_ids = [self.tokenizer.eos_token_id] + + # not self.tokenizer.vocab_size as self.tokenizer.vocab + # collapses all decoded errors into a single token. + self.vocab_size = len(self.tokenizer.vocab) tokenizer_info = xgr.TokenizerInfo( # type: ignore - encoded_vocab=encoded_vocab, + encoded_vocab=self.tokenizer.vocab, # NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501 vocab_type=xgr.VocabType.RAW if self.tokenizer.is_tekken