Refactor MistralTokenizer (#26358)

Signed-off-by: Julien Denize <julien.denize@mistral.ai>
This commit is contained in:
Julien Denize 2025-10-10 00:48:58 +02:00 committed by GitHub
parent 8983e0216f
commit c6187f55f7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 2351 additions and 463 deletions

View File

@ -145,7 +145,7 @@ Supported models:
Known issues: Known issues:
1. Mistral 7B struggles to generate parallel tool calls correctly. 1. Mistral 7B struggles to generate parallel tool calls correctly.
2. Mistral's `tokenizer_config.json` chat template requires tool call IDs that are exactly 9 digits, which is 2. **For Transformers tokenization backend only**: Mistral's `tokenizer_config.json` chat template requires tool call IDs that are exactly 9 digits, which is
much shorter than what vLLM generates. Since an exception is thrown when this condition much shorter than what vLLM generates. Since an exception is thrown when this condition
is not met, the following additional chat templates are provided: is not met, the following additional chat templates are provided:
@ -154,7 +154,14 @@ Known issues:
* <gh-file:examples/tool_chat_template_mistral_parallel.jinja> - this is a "better" version that adds a tool-use system prompt * <gh-file:examples/tool_chat_template_mistral_parallel.jinja> - this is a "better" version that adds a tool-use system prompt
when tools are provided, that results in much better reliability when working with parallel tool calling. when tools are provided, that results in much better reliability when working with parallel tool calling.
Recommended flags: `--tool-call-parser mistral --chat-template examples/tool_chat_template_mistral_parallel.jinja` Recommended flags:
1. To use [mistral-common](https://github.com/mistralai/mistral-common) the official Mistral tokenization backend:
`--tokenizer_mode mistral --config_format mistral --load_format mistral --tool-call-parser mistral`
2. To use the default Transformers tokenization backend:
`--tool-call-parser mistral --chat-template examples/tool_chat_template_mistral_parallel.jinja`
### Llama Models (`llama3_json`) ### Llama Models (`llama3_json`)

View File

@ -45,10 +45,12 @@ class ModelRequestData(NamedTuple):
# Voxtral # Voxtral
def run_voxtral(question: str, audio_count: int) -> ModelRequestData: def run_voxtral(question: str, audio_count: int) -> ModelRequestData:
from mistral_common.audio import Audio from mistral_common.audio import Audio
from mistral_common.protocol.instruct.messages import ( from mistral_common.protocol.instruct.chunk import (
AudioChunk, AudioChunk,
RawAudio, RawAudio,
TextChunk, TextChunk,
)
from mistral_common.protocol.instruct.messages import (
UserMessage, UserMessage,
) )
from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.protocol.instruct.request import ChatCompletionRequest

View File

@ -32,7 +32,7 @@ pyzmq >= 25.0.0
msgspec msgspec
gguf >= 0.13.0 gguf >= 0.13.0
importlib_metadata; python_version < '3.10' importlib_metadata; python_version < '3.10'
mistral_common[image,audio] >= 1.8.2 mistral_common[image,audio] >= 1.8.5
opencv-python-headless >= 4.11.0 # required for video IO opencv-python-headless >= 4.11.0 # required for video IO
pyyaml pyyaml
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12 six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12

View File

@ -23,7 +23,7 @@ jiwer # required for audio tests
timm # required for internvl test timm # required for internvl test
transformers_stream_generator # required for qwen-vl test transformers_stream_generator # required for qwen-vl test
matplotlib # required for qwen-vl test matplotlib # required for qwen-vl test
mistral_common[image,audio] >= 1.8.2 # required for voxtral test mistral_common[image,audio] >= 1.8.5 # required for voxtral test
num2words # required for smolvlm test num2words # required for smolvlm test
opencv-python-headless >= 4.11.0 # required for video test opencv-python-headless >= 4.11.0 # required for video test
datamodel_code_generator # required for minicpm3 test datamodel_code_generator # required for minicpm3 test

View File

@ -29,7 +29,7 @@ torchaudio==2.8.0
torchvision==0.23.0 torchvision==0.23.0
transformers_stream_generator # required for qwen-vl test transformers_stream_generator # required for qwen-vl test
matplotlib # required for qwen-vl test matplotlib # required for qwen-vl test
mistral_common[image,audio] >= 1.8.2 # required for voxtral test mistral_common[image,audio] >= 1.8.5 # required for voxtral test
num2words # required for smolvlm test num2words # required for smolvlm test
open_clip_torch==2.32.0 # Required for nemotron_vl test open_clip_torch==2.32.0 # Required for nemotron_vl test
opencv-python-headless >= 4.11.0 # required for video test opencv-python-headless >= 4.11.0 # required for video test

View File

@ -474,7 +474,7 @@ mbstrdecoder==1.1.3
# typepy # typepy
mdurl==0.1.2 mdurl==0.1.2
# via markdown-it-py # via markdown-it-py
mistral-common==1.8.2 mistral-common==1.8.5
# via -r requirements/test.in # via -r requirements/test.in
mlflow==2.22.0 mlflow==2.22.0
# via terratorch # via terratorch
@ -1012,8 +1012,6 @@ sentence-transformers==3.2.1
# via # via
# -r requirements/test.in # -r requirements/test.in
# mteb # mteb
sentencepiece==0.2.0
# via mistral-common
setuptools==77.0.3 setuptools==77.0.3
# via # via
# lightning-utilities # lightning-utilities

View File

@ -6,8 +6,7 @@ from collections.abc import Mapping
from typing import Literal, Optional from typing import Literal, Optional
import pytest import pytest
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy, SpecialTokens from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
from mistral_common.tokens.tokenizers.tekken import SpecialTokenInfo, Tekkenizer
from vllm.assets.audio import AudioAsset from vllm.assets.audio import AudioAsset
from vllm.assets.image import ImageAsset from vllm.assets.image import ImageAsset
@ -2119,34 +2118,9 @@ def test_apply_mistral_chat_template_thinking_chunk():
}, },
{"role": "user", "content": "Thanks, what is 3+3?"}, {"role": "user", "content": "Thanks, what is 3+3?"},
] ]
# TODO(Julien): upon model release change to a tokenizer already configured.
# =================================================================
mistral_tokenizer = MistralTokenizer.from_pretrained( mistral_tokenizer = MistralTokenizer.from_pretrained(
"mistralai/Devstral-Small-2507" "mistralai/Magistral-Small-2509"
) )
assert isinstance(mistral_tokenizer.tokenizer, Tekkenizer)
# Add think special tokens to the tokenizer
mistral_tokenizer.tokenizer._all_special_tokens[35] = SpecialTokenInfo(
rank=35, is_control=True, token_str=SpecialTokens.begin_think.value
)
mistral_tokenizer.tokenizer._all_special_tokens[36] = SpecialTokenInfo(
rank=36, is_control=True, token_str=SpecialTokens.end_think.value
)
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab = {
k: v
for k, v in mistral_tokenizer.tokenizer._special_tokens_reverse_vocab.items()
if v not in {35, 36}
}
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[
SpecialTokens.begin_think.value
] = 35
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[
SpecialTokens.end_think.value
] = 36
mistral_tokenizer.instruct.BEGIN_THINK = 35
mistral_tokenizer.instruct.END_THINK = 36
# =================================================================
tokens_ids = apply_mistral_chat_template( tokens_ids = apply_mistral_chat_template(
mistral_tokenizer, messages, chat_template=None, tools=None mistral_tokenizer, messages, chat_template=None, tools=None

View File

@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any, Optional
import pytest import pytest
from mistral_common.multimodal import download_image from mistral_common.multimodal import download_image
from mistral_common.protocol.instruct.messages import ImageURLChunk from mistral_common.protocol.instruct.chunk import ImageURLChunk
from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.tokens.tokenizers.multimodal import image_from_chunk from mistral_common.tokens.tokenizers.multimodal import image_from_chunk

View File

@ -6,12 +6,8 @@ import json
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from mistral_common.audio import Audio from mistral_common.audio import Audio
from mistral_common.protocol.instruct.messages import ( from mistral_common.protocol.instruct.chunk import AudioChunk, RawAudio, TextChunk
AudioChunk, from mistral_common.protocol.instruct.messages import UserMessage
RawAudio,
TextChunk,
UserMessage,
)
from vllm.transformers_utils.tokenizer import MistralTokenizer from vllm.transformers_utils.tokenizer import MistralTokenizer

View File

@ -6,7 +6,8 @@ from typing import Optional, Union
import numpy as np import numpy as np
import pytest import pytest
from mistral_common.protocol.instruct.messages import ImageChunk, TextChunk, UserMessage from mistral_common.protocol.instruct.chunk import ImageChunk, TextChunk
from mistral_common.protocol.instruct.messages import UserMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.protocol.instruct.request import ChatCompletionRequest
from PIL import Image from PIL import Image

View File

@ -9,7 +9,8 @@ from typing import Any, Union
import numpy as np import numpy as np
import pytest import pytest
import torch.nn as nn import torch.nn as nn
from mistral_common.protocol.instruct.messages import ImageChunk, TextChunk, UserMessage from mistral_common.protocol.instruct.chunk import ImageChunk, TextChunk
from mistral_common.protocol.instruct.messages import UserMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.protocol.instruct.request import ChatCompletionRequest
from PIL import Image from PIL import Image

View File

@ -2,8 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest import pytest
from mistral_common.tokens.tokenizers.base import SpecialTokens
from mistral_common.tokens.tokenizers.tekken import SpecialTokenInfo, Tekkenizer
from tests.reasoning.utils import run_reasoning_extraction_mistral from tests.reasoning.utils import run_reasoning_extraction_mistral
from vllm.reasoning import ReasoningParser, ReasoningParserManager from vllm.reasoning import ReasoningParser, ReasoningParserManager
@ -14,33 +12,9 @@ parser_name = "mistral"
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def mistral_tokenizer(): def mistral_tokenizer():
# TODO(Julien): upon model release change to a tokenizer already configured.
# =================================================================
mistral_tokenizer = MistralTokenizer.from_pretrained( mistral_tokenizer = MistralTokenizer.from_pretrained(
"mistralai/Devstral-Small-2507" "mistralai/Magistral-Small-2509"
) )
assert isinstance(mistral_tokenizer.tokenizer, Tekkenizer)
# Add think special tokens to the tokenizer
mistral_tokenizer.tokenizer._all_special_tokens[35] = SpecialTokenInfo(
rank=35, is_control=True, token_str=SpecialTokens.begin_think.value
)
mistral_tokenizer.tokenizer._all_special_tokens[36] = SpecialTokenInfo(
rank=36, is_control=True, token_str=SpecialTokens.end_think.value
)
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab = {
k: v
for k, v in mistral_tokenizer.tokenizer._special_tokens_reverse_vocab.items()
if v not in {35, 36}
}
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[
SpecialTokens.begin_think.value
] = 35
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[
SpecialTokens.end_think.value
] = 36
mistral_tokenizer.instruct.BEGIN_THINK = 35
mistral_tokenizer.instruct.END_THINK = 36
# =================================================================
return mistral_tokenizer return mistral_tokenizer

File diff suppressed because it is too large Load Diff

View File

@ -403,20 +403,12 @@ def resolve_mistral_chat_template(
chat_template: Optional[str], chat_template: Optional[str],
**kwargs: Any, **kwargs: Any,
) -> Optional[str]: ) -> Optional[str]:
if chat_template is not None: if chat_template is not None or kwargs.get("chat_template_kwargs") is not None:
logger.warning_once( raise ValueError(
"'chat_template' cannot be overridden for mistral tokenizer." "'chat_template' or 'chat_template_kwargs' cannot be overridden "
) "for mistral tokenizer."
if "add_generation_prompt" in kwargs:
logger.warning_once(
"'add_generation_prompt' is not supported for mistral tokenizer, "
"so it will be ignored."
)
if "continue_final_message" in kwargs:
logger.warning_once(
"'continue_final_message' is not supported for mistral tokenizer, "
"so it will be ignored."
) )
return None return None

View File

@ -10,7 +10,8 @@ from typing import Annotated, Literal, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mistral_common.protocol.instruct.messages import ImageChunk, TextChunk, UserMessage from mistral_common.protocol.instruct.chunk import ImageChunk, TextChunk
from mistral_common.protocol.instruct.messages import UserMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.tokens.tokenizers.multimodal import ImageEncoder from mistral_common.tokens.tokenizers.multimodal import ImageEncoder
from PIL import Image from PIL import Image

View File

@ -12,12 +12,8 @@ import regex as re
import torch import torch
import torch.nn as nn import torch.nn as nn
from mistral_common.audio import mel_filter_bank from mistral_common.audio import mel_filter_bank
from mistral_common.protocol.instruct.messages import ( from mistral_common.protocol.instruct.chunk import AudioChunk, RawAudio, TextChunk
AudioChunk, from mistral_common.protocol.instruct.messages import UserMessage
RawAudio,
TextChunk,
UserMessage,
)
from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.protocol.transcription.request import TranscriptionRequest from mistral_common.protocol.transcription.request import TranscriptionRequest
from mistral_common.tokens.tokenizers.audio import Audio, AudioEncoder from mistral_common.tokens.tokenizers.audio import Audio, AudioEncoder

View File

@ -1,34 +1,27 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional, Union, cast from typing import TYPE_CHECKING, Any, Optional, Union, cast
import huggingface_hub
import regex as re
from huggingface_hub import HfApi, hf_hub_download
from transformers.tokenization_utils_base import BatchEncoding
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer_base import TokenizerBase from vllm.transformers_utils.tokenizer_base import TokenizerBase
from vllm.utils import is_list_of
if TYPE_CHECKING: if TYPE_CHECKING:
# make sure `mistral_common` is lazy imported, from mistral_common.protocol.instruct.request import (
# so that users who only use non-mistral models ChatCompletionRequest as MistralChatCompletionRequest,
# will not be bothered by the dependency. )
from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.tokens.tokenizers.tekken import Tekkenizer
from mistral_common.tokens.tokenizers.mistral import ( from transformers.tokenization_mistral_common import (
MistralTokenizer as PublicMistralTokenizer, MistralCommonTokenizer as TransformersMistralTokenizer,
) )
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
logger = init_logger(__name__) logger = init_logger(__name__)
def maybe_serialize_tool_calls(request: "ChatCompletionRequest"): def maybe_serialize_tool_calls(request: "MistralChatCompletionRequest"):
# SEE: https://github.com/vllm-project/vllm/pull/9951 # SEE: https://github.com/vllm-project/vllm/pull/9951
# Credits go to: @gcalmettes # Credits go to: @gcalmettes
# NOTE: There is currently a bug in pydantic where attributes # NOTE: There is currently a bug in pydantic where attributes
@ -65,7 +58,7 @@ def maybe_serialize_tool_calls(request: "ChatCompletionRequest"):
request.messages[i]["tool_calls"] = validated_tool_calls request.messages[i]["tool_calls"] = validated_tool_calls
def truncate_tool_call_ids(request: "ChatCompletionRequest"): def truncate_tool_call_ids(request: "MistralChatCompletionRequest"):
"""Truncates tool call IDs for Mistral's ID requirements.""" """Truncates tool call IDs for Mistral's ID requirements."""
for i, message in enumerate(request.messages): for i, message in enumerate(request.messages):
if message.get("role") == "assistant": if message.get("role") == "assistant":
@ -95,84 +88,34 @@ def truncate_tool_call_ids(request: "ChatCompletionRequest"):
request.messages[i]["tool_call_id"] = tool_call_id request.messages[i]["tool_call_id"] = tool_call_id
def validate_request_params(request: "ChatCompletionRequest"): def _prepare_apply_chat_template_tools_and_messages(
if request.skip_special_tokens is not None and not request.skip_special_tokens:
raise ValueError(
"skip_special_tokens=False is not supported for Mistral tokenizers."
)
def list_local_repo_files(repo_id: str, revision: Optional[str]) -> list[str]:
repo_cache = os.path.join(
huggingface_hub.constants.HF_HUB_CACHE,
huggingface_hub.constants.REPO_ID_SEPARATOR.join(
["models", *repo_id.split("/")]
),
)
if revision is None:
revision_file = os.path.join(repo_cache, "refs", "main")
if os.path.isfile(revision_file):
with open(revision_file) as file:
revision = file.read()
if revision:
revision_dir = os.path.join(repo_cache, "snapshots", revision)
if os.path.isdir(revision_dir):
return os.listdir(revision_dir)
return []
def find_tokenizer_file(files: list[str]):
# Accept both versioned (tokenizer.model.v3) and unversioned
# (tokenizer.model) forms, plus tekken.json and tokenizer.mm.model
# variants. Previous pattern only matched the versioned variants.
file_pattern = re.compile(
r"^tokenizer\.model(\.v.*)?|tekken\.json|tokenizer\.mm\.model(\.v.*)?$"
)
matched_files = [file for file in files if file_pattern.match(file)]
if len(matched_files) > 1:
logger.warning(
"Multiple files matched pattern `%s`: %s. Using %s.",
file_pattern.pattern,
matched_files,
matched_files[0],
)
elif len(matched_files) == 0:
raise OSError(
f"Found {len(matched_files)} files matching the "
f"pattern: `{file_pattern.pattern}`. Make sure that a Mistral "
f"tokenizer is present in {files}."
)
return matched_files[0]
def _aggregate_content(content: list) -> list[dict[str, Any]]:
aggregated_content: list[dict[str, Any]] = []
for chunk in content:
if (
chunk.get("type") == "text"
and aggregated_content
and aggregated_content[-1].get("type") == "text"
):
aggregated_content[-1]["text"] += "\n\n" + chunk.get("text")
else:
aggregated_content.append(chunk)
if len(aggregated_content) == 1 and aggregated_content[0].get("type") == "text":
content = aggregated_content[0]["text"]
return content
def make_mistral_chat_completion_request(
messages: list["ChatCompletionMessageParam"], messages: list["ChatCompletionMessageParam"],
tools: Optional[list[dict[str, Any]]] = None, tools: Optional[list[dict[str, Any]]] = None,
) -> "ChatCompletionRequest": continue_final_message: bool = False,
add_generation_prompt: bool = False,
) -> tuple[list["ChatCompletionMessageParam"], Optional[list[dict[str, Any]]]]:
if add_generation_prompt and continue_final_message:
raise ValueError(
"Cannot set both `add_generation_prompt` and "
"`continue_final_message` to True."
)
last_message = cast(dict[str, Any], messages[-1]) last_message = cast(dict[str, Any], messages[-1])
if last_message["role"] == "assistant": # add_generation_prompt is directly handled by the tokenizer but we
last_message["prefix"] = True # check if the user is trying to use it with a final assistant message
# which is probably not what they want.
# If add_generation_prompt is False, we don't need to check anything.
if add_generation_prompt and last_message["role"] == "assistant":
raise ValueError(
"Cannot set `add_generation_prompt` to True when "
"the last message is from the assistant. Consider "
"using `continue_final_message` instead."
)
if continue_final_message and last_message["role"] != "assistant":
raise ValueError(
"Cannot set `continue_final_message` to True when "
"the last message is not from the assistant."
)
# mistral-common requires AssistantMessage content to be string [1]. # mistral-common requires AssistantMessage content to be string [1].
# #
@ -181,13 +124,6 @@ def make_mistral_chat_completion_request(
# Remove reasoning_content as unsupported by Mistral # Remove reasoning_content as unsupported by Mistral
_ = message.pop("reasoning_content", None) # type: ignore _ = message.pop("reasoning_content", None) # type: ignore
# Convert list text content to string
if message.get("role") in ("assistant", "tool"):
content: Any = message.get("content")
if isinstance(content, list):
content = _aggregate_content(content)
message["content"] = content
# The Mistral client, in comparison to the OpenAI client, requires the # The Mistral client, in comparison to the OpenAI client, requires the
# "parameters" dict and the "description" string to be present # "parameters" dict and the "description" string to be present
# even if they are empty. # even if they are empty.
@ -200,108 +136,113 @@ def make_mistral_chat_completion_request(
if function.get("description") is None: if function.get("description") is None:
function["description"] = "" function["description"] = ""
from mistral_common.protocol.instruct.request import ChatCompletionRequest return messages, tools
return ChatCompletionRequest(messages=messages, tools=tools) # type: ignore[type-var]
def validate_request_params(request: "ChatCompletionRequest"):
if request.chat_template is not None or request.chat_template_kwargs is not None:
raise ValueError("chat_template is not supported for Mistral tokenizers.")
def _tekken_token_to_id(tokenizer: "Tekkenizer", t: Union[str, bytes]) -> int:
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
assert isinstance(tokenizer, Tekkenizer), type(tokenizer)
t_bytes = t.encode("utf-8") if not isinstance(t, bytes) else t
shift = tokenizer.num_special_tokens
try:
return shift + tokenizer._tekken_token2id_nospecial[t_bytes]
except KeyError:
t_str = t_bytes.decode("utf-8")
if t_str in tokenizer._special_tokens_reverse_vocab:
return tokenizer._special_tokens_reverse_vocab[t_str]
logger.warning(
"Failed to convert token %s to id, replacing with <unk>", t_bytes
)
return tokenizer.unk_id
class MistralTokenizer(TokenizerBase): class MistralTokenizer(TokenizerBase):
def __init__(self, tokenizer: "PublicMistralTokenizer") -> None: def __init__(self, tokenizer: "TransformersMistralTokenizer") -> None:
self.mistral = tokenizer
self.instruct = tokenizer.instruct_tokenizer
_mistral_version_str = self.instruct.tokenizer.version.value
self.version: int = int(_mistral_version_str.split("v")[-1])
tokenizer_ = tokenizer.instruct_tokenizer.tokenizer
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
self.is_tekken = isinstance(tokenizer_, Tekkenizer)
from mistral_common.tokens.tokenizers.sentencepiece import ( from mistral_common.tokens.tokenizers.sentencepiece import (
SentencePieceTokenizer, SentencePieceTokenizer,
) )
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
self.is_spm = isinstance(tokenizer_, SentencePieceTokenizer) self.transformers_tokenizer = tokenizer
self._special_token_policy = ( self.mistral = tokenizer.tokenizer
SpecialTokenPolicy.IGNORE if self.is_tekken else None self.instruct = self.mistral.instruct_tokenizer
) self.tokenizer = self.instruct.tokenizer
_mistral_version_str = str(self.tokenizer.version.value)
self.version: int = int(_mistral_version_str.split("v")[-1])
self.is_tekken = isinstance(self.tokenizer, Tekkenizer)
self.is_spm = isinstance(self.tokenizer, SentencePieceTokenizer)
if not (self.is_tekken or self.is_spm): if not (self.is_tekken or self.is_spm):
raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}") raise TypeError(f"Unsupported tokenizer: {type(self.tokenizer)}")
self._vocab = tokenizer_.vocab() # Reverse order to ensure that the lowest token id is kept.
# Convert to a dict[str, int] to match protocol, but this is a lossy self._vocab_dict = {
# conversion. There may be multiple token ids that decode to the same self.convert_ids_to_tokens([i], skip_special_tokens=False)[0]: i
# string due to partial UTF-8 byte sequences being converted to <20> for i in range(self.vocab_size - 1, -1, -1)
self._vocab_dict = {token: idx for idx, token in enumerate(self._vocab)} }
self.tokenizer = tokenizer_ # Sort the dict for convenience
self._vocab_dict = dict(sorted(self._vocab_dict.items(), key=lambda x: x[1]))
# Vocab sorted by token id.
self._vocab = self.tokenizer._vocab
self._max_token_id = self.vocab_size - 1 self._max_token_id = self.vocab_size - 1
@classmethod @classmethod
def from_pretrained( def from_pretrained(
cls, path_or_repo_id: str, *, revision: Optional[str] = None cls, path_or_repo_id: str, *, revision: Optional[str] = None
) -> "MistralTokenizer": ) -> "MistralTokenizer":
if not Path(path_or_repo_id).exists(): from transformers.tokenization_mistral_common import (
assert len(path_or_repo_id.split("/")) == 2, ( MistralCommonTokenizer as TransformersMistralTokenizer,
"You have either provided a non-existent path: "
"{path_or_repo_id} or an invalid HF Hub repo id."
)
tokenizer_file = cls._download_mistral_tokenizer_from_hf(
path_or_repo_id, revision
)
elif Path(path_or_repo_id).is_dir():
tokenizer_file_name = find_tokenizer_file(os.listdir(path_or_repo_id))
tokenizer_file = str(Path(path_or_repo_id) / tokenizer_file_name)
else:
assert Path(path_or_repo_id).is_file(), f"Invalid path: {path_or_repo_id}"
tokenizer_file = str(Path(path_or_repo_id))
from mistral_common.tokens.tokenizers.mistral import (
MistralTokenizer as PublicMistralTokenizer,
) )
mistral_tokenizer = PublicMistralTokenizer.from_file(tokenizer_file) str_revision = "main" if revision is None else revision
return cls(mistral_tokenizer) return cls(
TransformersMistralTokenizer.from_pretrained(
@staticmethod path_or_repo_id, revision=str_revision
def _download_mistral_tokenizer_from_hf( )
tokenizer_name: str, revision: Optional[str]
) -> str:
try:
hf_api = HfApi()
files = hf_api.list_repo_files(repo_id=tokenizer_name, revision=revision)
except ConnectionError as exc:
files = list_local_repo_files(repo_id=tokenizer_name, revision=revision)
if len(files) == 0:
raise exc
filename = find_tokenizer_file(files)
tokenizer_file = hf_hub_download(
tokenizer_name, filename=filename, revision=revision
) )
return tokenizer_file
# the following attributes are set to fit vLLM's design and are used # the following attributes are set to fit vLLM's design and are used
# by the structured output backends. # by the structured output backends.
@property @property
def all_special_tokens_extended(self) -> list[str]: def all_special_tokens_extended(self) -> list[str]:
from mistral_common.tokens.tokenizers.base import SpecialTokens return self.all_special_tokens
# tekken defines its own extended special tokens list
if hasattr(self.tokenizer, "SPECIAL_TOKENS"):
special_tokens = self.tokenizer.SPECIAL_TOKENS
else:
special_tokens = list(SpecialTokens)
return [s.value if isinstance(s, SpecialTokens) else s for s in special_tokens]
@property @property
def all_special_tokens(self) -> list[str]: def all_special_tokens(self) -> list[str]:
return self.all_special_tokens_extended from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
return [
self.tokenizer.decode([i], special_token_policy=SpecialTokenPolicy.KEEP)
for i in self.all_special_ids
]
@property @property
def all_special_ids(self) -> list[int]: def all_special_ids(self) -> list[int]:
return [self.all_special_tokens.index(t) for t in self.all_special_tokens] from mistral_common.tokens.tokenizers.sentencepiece import (
SentencePieceTokenizer,
)
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
if self.is_tekken:
assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer)
special_ids = {t["rank"] for t in self.tokenizer._all_special_tokens}
elif self.is_spm:
assert isinstance(self.tokenizer, SentencePieceTokenizer), type(
self.tokenizer
)
special_ids = self.tokenizer._control_tokens
else:
raise ValueError(f"Unknown tokenizer type: {type(self.tokenizer)}")
return sorted(special_ids)
@property @property
def bos_token_id(self) -> int: def bos_token_id(self) -> int:
@ -317,7 +258,7 @@ class MistralTokenizer(TokenizerBase):
@property @property
def pad_token(self) -> str: def pad_token(self) -> str:
raise NotImplementedError() return self.transformers_tokenizer.pad_token
@property @property
def is_fast(self) -> bool: def is_fast(self) -> bool:
@ -325,7 +266,7 @@ class MistralTokenizer(TokenizerBase):
@property @property
def vocab_size(self) -> int: def vocab_size(self) -> int:
return len(self._vocab) return self.transformers_tokenizer.vocab_size
@property @property
def max_token_id(self) -> int: def max_token_id(self) -> int:
@ -335,6 +276,23 @@ class MistralTokenizer(TokenizerBase):
def truncation_side(self) -> str: def truncation_side(self) -> str:
raise NotImplementedError() raise NotImplementedError()
def _is_special_token_id(self, token_id: int) -> bool:
from mistral_common.tokens.tokenizers.sentencepiece import (
SentencePieceTokenizer,
)
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
if self.is_spm:
assert isinstance(self.tokenizer, SentencePieceTokenizer), type(
self.tokenizer
)
return token_id in self.tokenizer._control_tokens
if self.is_tekken:
assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer)
return token_id < self.tokenizer.num_special_tokens
else:
raise ValueError(f"Unknown tokenizer type: {type(self.tokenizer)}")
def __len__(self) -> int: def __len__(self) -> int:
return self.vocab_size return self.vocab_size
@ -346,25 +304,19 @@ class MistralTokenizer(TokenizerBase):
truncation: bool = False, truncation: bool = False,
max_length: Optional[int] = None, max_length: Optional[int] = None,
): ):
input_ids: Union[list[int], list[list[int]]] return self.transformers_tokenizer(
# For list[str], original prompt text text=text,
if is_list_of(text, str): text_pair=text_pair,
input_ids_: list[list[int]] = [] add_special_tokens=add_special_tokens,
for p in text: truncation=truncation,
each_input_ids = self.encode_one(p, truncation, max_length) max_length=max_length,
input_ids_.append(each_input_ids) )
input_ids = input_ids_
# For list[int], apply chat template output, already tokens. @property
elif is_list_of(text, int): def vocab(self) -> list[str]:
input_ids = text return self._vocab
# For str, single prompt text
else:
input_ids = self.encode_one(text, truncation, max_length)
return BatchEncoding({"input_ids": input_ids})
def get_vocab(self) -> dict[str, int]: def get_vocab(self) -> dict[str, int]:
# NB: the dictionary form of the vocabulary collapses token ids that map
# to the same string but have different bytes
return self._vocab_dict return self._vocab_dict
def get_added_vocab(self) -> dict[str, int]: def get_added_vocab(self) -> dict[str, int]:
@ -378,11 +330,9 @@ class MistralTokenizer(TokenizerBase):
max_length: Optional[int] = None, max_length: Optional[int] = None,
) -> list[int]: ) -> list[int]:
# Mistral Tokenizers should not add special tokens # Mistral Tokenizers should not add special tokens
input_ids = self.encode(text) return self.transformers_tokenizer.encode(
text, add_special_tokens=False, truncation=truncation, max_length=max_length
if truncation: )
input_ids = input_ids[:max_length]
return input_ids
def encode( def encode(
self, self,
@ -391,15 +341,20 @@ class MistralTokenizer(TokenizerBase):
max_length: Optional[int] = None, max_length: Optional[int] = None,
add_special_tokens: Optional[bool] = None, add_special_tokens: Optional[bool] = None,
) -> list[int]: ) -> list[int]:
# `encode` should only be used for prompt completion
# it should never be used for chat_completion.
# For chat completion use `apply_chat_template`
if add_special_tokens is not None: if add_special_tokens is not None:
return self.tokenizer.encode( return self.transformers_tokenizer.encode(
text, bos=add_special_tokens, eos=add_special_tokens text,
truncation=truncation,
max_length=max_length,
add_special_tokens=add_special_tokens,
) )
else: else:
return self.tokenizer.encode(text, bos=True, eos=False) encoded = self.tokenizer.encode(text, bos=True, eos=False)
if truncation is not False and max_length is not None:
return encoded[:max_length]
else:
return encoded
def apply_chat_template( def apply_chat_template(
self, self,
@ -407,59 +362,79 @@ class MistralTokenizer(TokenizerBase):
tools: Optional[list[dict[str, Any]]] = None, tools: Optional[list[dict[str, Any]]] = None,
**kwargs, **kwargs,
) -> list[int]: ) -> list[int]:
request = make_mistral_chat_completion_request(messages, tools) add_generation_prompt = kwargs.pop("add_generation_prompt", False)
encoded = self.mistral.encode_chat_completion(request) continue_final_message = kwargs.get("continue_final_message", False)
padding = kwargs.get("padding", False)
truncation = kwargs.get("truncation", False)
max_length = kwargs.get("max_length")
# encode-decode to get clean prompt messages, tools = _prepare_apply_chat_template_tools_and_messages(
return encoded.tokens messages, tools, continue_final_message, add_generation_prompt
)
return self.transformers_tokenizer.apply_chat_template(
conversation=messages,
tools=tools,
continue_final_message=continue_final_message,
tokenize=True,
padding=padding,
truncation=truncation,
max_length=max_length,
return_tensors=None,
return_dict=False,
)
def decode(
self, ids: Union[list[int], int], skip_special_tokens: bool = True
) -> str:
return self.transformers_tokenizer.decode(
ids, skip_special_tokens=skip_special_tokens
)
def convert_tokens_to_string(self, tokens: list[str]) -> str: def convert_tokens_to_string(self, tokens: list[str]) -> str:
from mistral_common.tokens.tokenizers.base import SpecialTokens from mistral_common.tokens.tokenizers.base import (
SpecialTokenPolicy,
SpecialTokens,
)
from mistral_common.tokens.tokenizers.sentencepiece import (
SentencePieceTokenizer,
)
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
to_decode_special_tokens = {SpecialTokens.tool_calls}
if self.is_tekken: if self.is_tekken:
assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer)
tokens = [ tokens = [
t t
for t in tokens for t in tokens
if ( if (t in to_decode_special_tokens or t not in self.all_special_tokens)
t is SpecialTokens.tool_calls
or t not in self.tokenizer._all_special_tokens
)
] ]
if any(isinstance(t, bytes) for t in tokens): if any(isinstance(t, bytes) for t in tokens):
# we need to encode and decode all tokens again # we need to encode and decode all tokens again
shift = self.tokenizer.num_special_tokens ids = [_tekken_token_to_id(self.tokenizer, t) for t in tokens]
# We filtered unwanted special tokens before
def _token_to_id(t: str): # so we can decode the rest.
t_bytes = t.encode("utf-8") if not isinstance(t, bytes) else t decoded = self.tokenizer.decode(ids, SpecialTokenPolicy.KEEP)
try:
return (
shift + self.tokenizer._tekken_token2id_nospecial[t_bytes]
)
except KeyError:
logger.warning(
"Failed to convert token %s to id, replacing with <unk>",
t_bytes,
)
return self.tokenizer.unk_id
ids = [_token_to_id(t) for t in tokens]
decoded = self.tokenizer.decode(ids, self._special_token_policy)
else: else:
decoded = "".join(tokens) decoded = "".join(tokens)
else: else:
# make sure certain special tokens like Tool calls are # make sure certain special tokens like Tool calls are
# not decoded # not decoded
special_tokens = {SpecialTokens.tool_calls} assert isinstance(self.tokenizer, SentencePieceTokenizer), type(
self.tokenizer
)
regular_tokens: list[str] = [] regular_tokens: list[str] = []
decoded_list = [] decoded_list: list[str] = []
decoded = ""
for token in tokens: for token in tokens:
if token in special_tokens: if token in to_decode_special_tokens:
if regular_tokens: if regular_tokens:
decoded_list.append( decoded_list.append(
self.tokenizer.decode( self.tokenizer.decode(
regular_tokens, self._special_token_policy regular_tokens, SpecialTokenPolicy.IGNORE
) )
) )
regular_tokens = [] regular_tokens = []
@ -469,66 +444,56 @@ class MistralTokenizer(TokenizerBase):
if regular_tokens: if regular_tokens:
decoded_list.append( decoded_list.append(
self.tokenizer.decode(regular_tokens, self._special_token_policy) self.tokenizer.decode(regular_tokens, SpecialTokenPolicy.IGNORE)
) )
decoded = "".join(decoded_list) decoded = "".join(decoded_list)
return decoded return decoded
def decode(
self, ids: Union[list[int], int], skip_special_tokens: bool = True
) -> str:
assert skip_special_tokens, (
"skip_special_tokens=False is not supported for Mistral tokenizers."
)
if isinstance(ids, int):
ids = [ids]
return self.tokenizer.decode(ids, self._special_token_policy)
def convert_ids_to_tokens( def convert_ids_to_tokens(
self, self,
ids: list[int], ids: list[int],
skip_special_tokens: bool = True, skip_special_tokens: bool = True,
) -> list[str]: ) -> list[str]:
from mistral_common.tokens.tokenizers.base import SpecialTokens from mistral_common.tokens.tokenizers.base import (
SpecialTokenPolicy,
SpecialTokens,
)
from mistral_common.tokens.tokenizers.instruct import InstructTokenizerV13 from mistral_common.tokens.tokenizers.instruct import InstructTokenizerV13
# TODO(Patrick) - potentially allow special tokens to not be skipped if not skip_special_tokens:
assert skip_special_tokens, ( return [self.tokenizer.id_to_piece(token_id) for token_id in ids]
"skip_special_tokens=False is not supported for Mistral tokenizers."
)
assert self.is_tekken or self.is_spm, type(self.tokenizer) non_skip_special_tokens_ids = {
self.tokenizer.get_control_token(SpecialTokens.tool_calls),
}
if isinstance(self.instruct, InstructTokenizerV13):
if self.instruct.BEGIN_THINK:
non_skip_special_tokens_ids.add(self.instruct.BEGIN_THINK)
if self.instruct.END_THINK:
non_skip_special_tokens_ids.add(self.instruct.END_THINK)
if self.is_tekken: ids_kept = [
# skip special tokens except tool call and think tokens i
non_skip_special_tokens = { for i in ids
self.tokenizer.get_control_token(SpecialTokens.tool_calls) if i in non_skip_special_tokens_ids or not self._is_special_token_id(i)
} ]
if isinstance(self.instruct, InstructTokenizerV13):
if self.instruct.BEGIN_THINK:
non_skip_special_tokens.add(self.instruct.BEGIN_THINK)
if self.instruct.END_THINK:
non_skip_special_tokens.add(self.instruct.END_THINK)
ids = [
i
for i in ids
if i > self.tokenizer.num_special_tokens or i in non_skip_special_tokens
]
tokens = [self.tokenizer.id_to_piece(id) for id in ids] # We filtered unwanted special tokens so we can decode the rest.
tokens = [self.tokenizer.id_to_piece(token_id) for token_id in ids_kept]
if any("<EFBFBD>" in t for t in tokens) and self.is_tekken: if any("<EFBFBD>" in t for t in tokens) and self.is_tekken:
# if a decoded token contains the replacement character, then the # if a decoded token contains the replacement character, then the
# token has an incomplete UTF-8 character so we must use bytes # token has an incomplete UTF-8 character so we must use bytes
# See: https://github.com/vllm-project/vllm/pull/8640 # See: https://github.com/vllm-project/vllm/pull/8640
# https://github.com/vllm-project/vllm/pull/9625 # https://github.com/vllm-project/vllm/pull/9625
# if underlying tokenizeir is sentencepiece, we just add "<22>" # if underlying tokenizer is sentencepiece, we just add "<22>".
# We filtered unwanted special tokens so we can decode the rest.
tokens = [ tokens = [
self.tokenizer.id_to_byte_piece(id, self._special_token_policy) self.tokenizer.id_to_byte_piece(token_id, SpecialTokenPolicy.KEEP)
for id in ids if token_id not in self.all_special_ids
else self.tokenizer.decode([token_id], SpecialTokenPolicy.KEEP)
for token_id in ids_kept
] ]
return tokens return tokens

View File

@ -43,34 +43,13 @@ class XgrammarBackend(StructuredOutputBackend):
if isinstance(self.tokenizer, MistralTokenizer): if isinstance(self.tokenizer, MistralTokenizer):
# NOTE: ideally, xgrammar should handle this accordingly. # NOTE: ideally, xgrammar should handle this accordingly.
# refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98 # refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98
try: stop_token_ids = [self.tokenizer.eos_token_id]
if self.tokenizer.is_tekken:
encoded_vocab = self.tokenizer._vocab # not self.tokenizer.vocab_size as self.tokenizer.vocab
else: # collapses all decoded errors into a single token.
encoded_vocab = [ self.vocab_size = len(self.tokenizer.vocab)
token
for token, _ in sorted(
self.tokenizer.get_vocab().items(),
key=lambda x: x[1],
)
]
stop_token_ids = None
if (
hasattr(
self.tokenizer,
"eos_token_id",
)
and self.tokenizer.eos_token_id is not None
):
stop_token_ids = [self.tokenizer.eos_token_id]
except AttributeError as e:
raise ValueError(
f"Cannot get the vocabulary of the tokenizer "
f"{type(self.tokenizer)}. The tokenizer should have a "
"get_vocab method."
) from e
tokenizer_info = xgr.TokenizerInfo( # type: ignore tokenizer_info = xgr.TokenizerInfo( # type: ignore
encoded_vocab=encoded_vocab, encoded_vocab=self.tokenizer.vocab,
# NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501 # NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501
vocab_type=xgr.VocabType.RAW vocab_type=xgr.VocabType.RAW
if self.tokenizer.is_tekken if self.tokenizer.is_tekken