diff --git a/requirements/common.txt b/requirements/common.txt index 1876a7e9af08..96ab646bb50a 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -33,7 +33,7 @@ pyzmq >= 25.0.0 msgspec gguf >= 0.13.0 importlib_metadata; python_version < '3.10' -mistral_common[opencv] >= 1.8.0 +mistral_common[image,audio] >= 1.8.2 opencv-python-headless >= 4.11.0 # required for video IO pyyaml six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12 diff --git a/requirements/nightly_torch_test.txt b/requirements/nightly_torch_test.txt index 9c378dcf68fb..0a72ddefda79 100644 --- a/requirements/nightly_torch_test.txt +++ b/requirements/nightly_torch_test.txt @@ -23,7 +23,7 @@ jiwer # required for audio tests timm # required for internvl test transformers_stream_generator # required for qwen-vl test matplotlib # required for qwen-vl test -mistral_common[opencv] >= 1.8.0 # required for voxtral test +mistral_common[image,audio] >= 1.8.2 # required for voxtral test num2words # required for smolvlm test opencv-python-headless >= 4.11.0 # required for video test datamodel_code_generator # required for minicpm3 test diff --git a/requirements/test.in b/requirements/test.in index 9f66e2d6919a..429d1a50422f 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -28,7 +28,7 @@ torchvision==0.22.1 transformers_stream_generator # required for qwen-vl test mamba_ssm # required for plamo2 test matplotlib # required for qwen-vl test -mistral_common[opencv] >= 1.8.0 # required for voxtral test +mistral_common[image,audio] >= 1.8.2 # required for voxtral test num2words # required for smolvlm test open_clip_torch==2.32.0 # Required for nemotron_vl test opencv-python-headless >= 4.11.0 # required for video test diff --git a/requirements/test.txt b/requirements/test.txt index a2b230102d4e..8e5af8d74bad 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -447,7 +447,7 @@ mbstrdecoder==1.1.3 # typepy mdurl==0.1.2 # via markdown-it-py -mistral-common==1.8.0 +mistral-common==1.8.2 # via -r requirements/test.in mlflow==2.22.0 # via terratorch @@ -999,8 +999,11 @@ soundfile==0.12.1 # via # -r requirements/test.in # librosa + # mistral-common soxr==0.5.0.post1 - # via librosa + # via + # librosa + # mistral-common sqlalchemy==2.0.41 # via # alembic diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py index e321ca70001d..ed57fe39df64 100644 --- a/tests/entrypoints/test_chat_utils.py +++ b/tests/entrypoints/test_chat_utils.py @@ -6,6 +6,10 @@ from collections.abc import Mapping from typing import Literal, Optional import pytest +from mistral_common.tokens.tokenizers.base import (SpecialTokenPolicy, + SpecialTokens) +from mistral_common.tokens.tokenizers.tekken import (SpecialTokenInfo, + Tekkenizer) from vllm.assets.audio import AudioAsset from vllm.assets.image import ImageAsset @@ -21,6 +25,7 @@ from vllm.multimodal import MultiModalDataDict from vllm.multimodal.utils import (encode_audio_base64, encode_image_base64, encode_video_base64) from vllm.transformers_utils.tokenizer_group import TokenizerGroup +from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer from ..models.registry import HF_EXAMPLE_MODELS from ..utils import VLLM_PATH @@ -1374,3 +1379,165 @@ def test_resolve_content_format_examples(template_path, expected_format): ) assert resolved_format == expected_format + + +def test_parse_chat_messages_include_thinking_chunk(mistral_model_config, + mistral_tokenizer): + messages = [{ + "role": + "system", + "content": [{ + "type": "text", + "text": "You are a helpful assistant." + }, { + "type": + "thinking", + "closed": + True, + "thinking": + "Only return the answer when you are confident." + }] + }, { + "role": "user", + "content": "What is 2+2?" + }, { + "role": + "assistant", + "content": [{ + "type": "text", + "text": "Let me think about it." + }, { + "type": "thinking", + "closed": True, + "thinking": "2+2 = 4" + }, { + "type": "text", + "text": "The answer is 4.", + }], + }] + + conversation_with_thinking, _ = parse_chat_messages( + messages, + mistral_model_config, + mistral_tokenizer, + content_format="openai", + ) + + expected_conversation = [{ + "role": + "system", + "content": [{ + "type": "text", + "text": "You are a helpful assistant." + }, { + "type": "text", + "text": "Only return the answer when you are confident." + }], + }, { + "role": + "user", + "content": [{ + "type": "text", + "text": "What is 2+2?" + }], + }, { + "role": + "assistant", + "content": [ + { + "type": "text", + "text": "Let me think about it." + }, + { + "type": "text", + "text": "2+2 = 4" + }, + { + "type": "text", + "text": "The answer is 4." + }, + ] + }] + + assert conversation_with_thinking == expected_conversation + + +def test_apply_mistral_chat_template_thinking_chunk(): + # Moved import here to avoid yapf and isort conflicts + from vllm.entrypoints.chat_utils import apply_mistral_chat_template + messages = [{ + "role": + "system", + "content": [{ + "type": "text", + "text": "You are a helpful assistant." + }, { + "type": + "thinking", + "closed": + True, + "thinking": + "Only return the answer when you are confident." + }] + }, { + "role": "user", + "content": "What is 2+2?" + }, { + "role": + "assistant", + "content": [{ + "type": "text", + "text": "Let me think about it." + }, { + "type": "thinking", + "closed": True, + "thinking": "2+2 = 4" + }, { + "type": "text", + "text": "The answer is 4.", + }], + }, { + "role": "user", + "content": "Thanks, what is 3+3?" + }] + + # TODO(Julien): upon model release change to a tokenizer already configured. + # ================================================================= + mistral_tokenizer = MistralTokenizer.from_pretrained( + "mistralai/Devstral-Small-2507") + assert isinstance(mistral_tokenizer.tokenizer, Tekkenizer) + # Add think special tokens to the tokenizer + mistral_tokenizer.tokenizer._all_special_tokens[35] = SpecialTokenInfo( + rank=35, is_control=True, token_str=SpecialTokens.begin_think.value) + mistral_tokenizer.tokenizer._all_special_tokens[36] = SpecialTokenInfo( + rank=36, is_control=True, token_str=SpecialTokens.end_think.value) + mistral_tokenizer.tokenizer._special_tokens_reverse_vocab = { + k: v + for k, v in + mistral_tokenizer.tokenizer._special_tokens_reverse_vocab.items() + if v not in {35, 36} + } + mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[ + SpecialTokens.begin_think.value] = 35 + mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[ + SpecialTokens.end_think.value] = 36 + mistral_tokenizer.instruct.BEGIN_THINK = 35 + mistral_tokenizer.instruct.END_THINK = 36 + # ================================================================= + + tokens_ids = apply_mistral_chat_template(mistral_tokenizer, + messages, + chat_template=None, + tools=None) + + string_tokens = mistral_tokenizer.mistral.decode( + tokens_ids, special_token_policy=SpecialTokenPolicy.KEEP) + + expected_tokens = ( + r"[SYSTEM_PROMPT]You are a helpful assistant.[THINK]Only return the" + r" answer when you are confident.[/THINK][/SYSTEM_PROMPT]" + r"[INST]What is 2+2?[/INST]" + r"Let me think about it.[THINK]2+2 = 4[/THINK]The answer is 4." + r"[INST]Thanks, what is 3+3?[/INST]") + + assert string_tokens == expected_tokens diff --git a/tests/reasoning/test_mistral_reasoning_parser.py b/tests/reasoning/test_mistral_reasoning_parser.py new file mode 100644 index 000000000000..91a22f6f5d72 --- /dev/null +++ b/tests/reasoning/test_mistral_reasoning_parser.py @@ -0,0 +1,341 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +from mistral_common.tokens.tokenizers.base import SpecialTokens +from mistral_common.tokens.tokenizers.tekken import (SpecialTokenInfo, + Tekkenizer) + +from tests.reasoning.utils import run_reasoning_extraction_mistral +from vllm.reasoning import ReasoningParser, ReasoningParserManager +from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer + +parser_name = "mistral" + + +@pytest.fixture(scope="module") +def mistral_tokenizer(): + # TODO(Julien): upon model release change to a tokenizer already configured. + # ================================================================= + mistral_tokenizer = MistralTokenizer.from_pretrained( + "mistralai/Devstral-Small-2507") + assert isinstance(mistral_tokenizer.tokenizer, Tekkenizer) + # Add think special tokens to the tokenizer + mistral_tokenizer.tokenizer._all_special_tokens[35] = SpecialTokenInfo( + rank=35, is_control=True, token_str=SpecialTokens.begin_think.value) + mistral_tokenizer.tokenizer._all_special_tokens[36] = SpecialTokenInfo( + rank=36, is_control=True, token_str=SpecialTokens.end_think.value) + mistral_tokenizer.tokenizer._special_tokens_reverse_vocab = { + k: v + for k, v in + mistral_tokenizer.tokenizer._special_tokens_reverse_vocab.items() + if v not in {35, 36} + } + mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[ + SpecialTokens.begin_think.value] = 35 + mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[ + SpecialTokens.end_think.value] = 36 + mistral_tokenizer.instruct.BEGIN_THINK = 35 + mistral_tokenizer.instruct.END_THINK = 36 + # ================================================================= + return mistral_tokenizer + + +SIMPLE_REASONING = { + "output": "This is a reasoning section[/THINK]This is the rest", + "reasoning_content": "This is a reasoning section", + "content": "This is the rest", + "is_reasoning_end": True, +} +COMPLETE_REASONING = { + "output": "This is a reasoning section[/THINK]", + "reasoning_content": "This is a reasoning section", + "content": None, + "is_reasoning_end": True, +} +NO_CONTENT = { + "output": "This is content", + "reasoning_content": "This is content", + "content": None, + "is_reasoning_end": False, +} +NO_REASONING_STREAMING = { + "output": "This is a reasoning section", + "reasoning_content": "This is a reasoning section", + "content": None, + "is_reasoning_end": False, +} +MULTIPLE_LINES = { + "output": "This\nThat[/THINK]This is the rest\nThat", + "reasoning_content": "This\nThat", + "content": "This is the rest\nThat", + "is_reasoning_end": True, +} +SHORTEST_REASONING_NO_STREAMING = { + "output": "[/THINK]This is the rest", + "reasoning_content": "", + "content": "This is the rest", + "is_reasoning_end": True, +} +SHORTEST_REASONING = { + "output": "[/THINK]This is the rest", + "reasoning_content": None, + "content": "This is the rest", + "is_reasoning_end": True, +} +REASONING_WITH_THINK = { + "output": "[THINK]This is a reasoning section[/THINK]This is the rest", + "reasoning_content": "This is a reasoning section", + "content": "This is the rest", + "is_reasoning_end": True, +} +COMPLETE_REASONING_WITH_THINK = { + "output": "[THINK]This is a reasoning section[/THINK]", + "reasoning_content": "This is a reasoning section", + "content": None, + "is_reasoning_end": True, +} +MULTIPLE_LINES_WITH_THINK = { + "output": "[THINK]This\nThat[/THINK]This is the rest\nThat", + "reasoning_content": "This\nThat", + "content": "This is the rest\nThat", + "is_reasoning_end": True, +} +SHORTEST_REASONING_NO_STREAMING_WITH_THINK = { + "output": "[/THINK]This is the rest", + "reasoning_content": "", + "content": "This is the rest", + "is_reasoning_end": True, +} +SHORTEST_REASONING_WITH_THINK = { + "output": "[/THINK]This is the rest", + "reasoning_content": None, + "content": "This is the rest", + "is_reasoning_end": True, +} +THINK_NO_END = { + "output": "[THINK]This is a reasoning section", + "reasoning_content": "This is a reasoning section", + "content": None, + "is_reasoning_end": False, +} +EMPTY = { + "output": "", + "reasoning_content": "", + "content": None, + "is_reasoning_end": False, +} +EMPTY_STREAMING = { + "output": "", + "reasoning_content": None, + "content": None, + "is_reasoning_end": False, +} +NEW_LINE = { + "output": "\n[THINK]This is a reasoning section[/THINK]\nThis is the rest", + "reasoning_content": "This is a reasoning section", + "content": "\nThis is the rest", + "is_reasoning_end": True, +} +# Streaming cannot handle new lines at the beginning of the output +# because we need to support [THINK]...[/THINK] and [/THINK]... +# We cannot know if the text before [THINK] is reasoning content +# or not. +NEW_LINE_STREAMING = { + "output": "\n[THINK]This is a reasoning section[/THINK]\nThis is the rest", + "reasoning_content": "\nThis is a reasoning section", + "content": "\nThis is the rest", + "is_reasoning_end": True, +} + +TEST_CASES = [ + pytest.param( + False, + SIMPLE_REASONING, + id="simple_reasoning", + ), + pytest.param( + True, + SIMPLE_REASONING, + id="simple_reasoning_streaming", + ), + pytest.param( + False, + COMPLETE_REASONING, + id="complete_reasoning", + ), + pytest.param( + True, + COMPLETE_REASONING, + id="complete_reasoning_streaming", + ), + pytest.param( + False, + NO_CONTENT, + id="no_content_token", + ), + pytest.param( + True, + NO_REASONING_STREAMING, + id="no_reasoning_token_streaming", + ), + pytest.param( + False, + MULTIPLE_LINES, + id="multiple_lines", + ), + pytest.param( + True, + MULTIPLE_LINES, + id="multiple_lines_streaming", + ), + pytest.param( + True, + SHORTEST_REASONING, + id="shortest", + ), + pytest.param( + False, + SHORTEST_REASONING_NO_STREAMING, + id="shortest_streaming", + ), + pytest.param( + False, + REASONING_WITH_THINK, + id="reasoning_with_think", + ), + pytest.param( + True, + REASONING_WITH_THINK, + id="reasoning_with_think_streaming", + ), + pytest.param( + False, + COMPLETE_REASONING_WITH_THINK, + id="complete_reasoning_with_think", + ), + pytest.param( + True, + COMPLETE_REASONING_WITH_THINK, + id="complete_reasoning_with_think_streaming", + ), + pytest.param( + False, + MULTIPLE_LINES_WITH_THINK, + id="multiple_lines_with_think", + ), + pytest.param( + True, + MULTIPLE_LINES_WITH_THINK, + id="multiple_lines_with_think_streaming", + ), + pytest.param( + False, + SHORTEST_REASONING_NO_STREAMING_WITH_THINK, + id="shortest_with_think", + ), + pytest.param( + True, + SHORTEST_REASONING_WITH_THINK, + id="shortest_with_think_streaming", + ), + pytest.param( + False, + THINK_NO_END, + id="think_no_end", + ), + pytest.param( + True, + THINK_NO_END, + id="think_no_end_streaming", + ), + pytest.param( + False, + EMPTY, + id="empty", + ), + pytest.param( + True, + EMPTY_STREAMING, + id="empty_streaming", + ), + pytest.param( + False, + NEW_LINE, + id="new_line", + ), + pytest.param( + True, + NEW_LINE_STREAMING, + id="new_line_streaming", + ), +] + + +@pytest.mark.parametrize("streaming, param_dict", TEST_CASES) +def test_mistral_reasoning( + streaming: bool, + param_dict: dict, + mistral_tokenizer: MistralTokenizer, +): + output = param_dict["output"] + + index_think = output.find("[THINK]") + len_think = len("[THINK]") + index_end_think = output.find("[/THINK]") + len_end_think = len("[/THINK]") + + # encode everything to tokens ids + output_tokens = [] + if index_think != -1: + output_before_think = output[:index_think] + output_tokens += mistral_tokenizer.tokenizer.encode( + output_before_think, False, False) + output_tokens += [mistral_tokenizer.instruct.BEGIN_THINK] + + if index_end_think != -1: + output_middle = output[index_think + len_think:index_end_think] + output_after_think = output[index_end_think + len_end_think:] + output_tokens += mistral_tokenizer.tokenizer.encode( + output_middle, False, False) + output_tokens += [mistral_tokenizer.instruct.END_THINK] + output_tokens += mistral_tokenizer.tokenizer.encode( + output_after_think, False, False) + else: + output_middle = output[index_think + len_think:] + output_tokens += mistral_tokenizer.tokenizer.encode( + output_middle, False, False) + elif index_end_think != -1: + output_before_think = output[:index_end_think] + output_after_think = output[index_end_think + len_end_think:] + output_tokens += mistral_tokenizer.tokenizer.encode( + output_before_think, False, False) + output_tokens += [mistral_tokenizer.instruct.END_THINK] + output_tokens += mistral_tokenizer.tokenizer.encode( + output_after_think, False, False) + else: + output_tokens += mistral_tokenizer.tokenizer.encode( + output, False, False) + + parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser( + parser_name)(mistral_tokenizer) + + reasoning, content = run_reasoning_extraction_mistral(parser, + output_tokens, + streaming=streaming) + + assert reasoning == param_dict["reasoning_content"] + assert content == param_dict["content"] + + # Test is_reasoning_end + is_reasoning_end = parser.is_reasoning_end(output_tokens) + assert is_reasoning_end == param_dict["is_reasoning_end"] + + # Test extract_content + if param_dict["content"] is not None: + content = parser.extract_content_ids(output_tokens) + assert content == mistral_tokenizer.tokenizer.encode( + param_dict["content"], bos=False, eos=False) + else: + content = parser.extract_content_ids(output_tokens) + assert content == [] diff --git a/tests/reasoning/utils.py b/tests/reasoning/utils.py index ddcf89796fb5..9af5fa5addbc 100644 --- a/tests/reasoning/utils.py +++ b/tests/reasoning/utils.py @@ -6,6 +6,7 @@ from typing import Optional, Union from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaMessage) from vllm.reasoning import ReasoningParser +from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer class StreamingReasoningReconstructor: @@ -54,6 +55,32 @@ def run_reasoning_extraction( return reasoning, content +def run_reasoning_extraction_mistral( + reasoning_parser: ReasoningParser, + model_output: list[int], + request: Union[ChatCompletionRequest, None] = None, + streaming: bool = False, +) -> tuple[Optional[str], Optional[str]]: + assert isinstance(reasoning_parser.model_tokenizer, + MistralTokenizer), type(reasoning_parser.model_tokenizer) + if streaming: + reconstructor = run_reasoning_extraction_streaming_mistral( + reasoning_parser, + model_output, + request, + ) + return ( + reconstructor.reasoning_content, + reconstructor.other_content or None, + ) + else: + str_output = reasoning_parser.model_tokenizer.convert_ids_to_tokens( + model_output) + reasoning, content = run_reasoning_extraction_nonstreaming( + reasoning_parser, str_output, request) + return reasoning, content + + def run_reasoning_extraction_nonstreaming( reasoning_parser: ReasoningParser, model_output: list[str], @@ -94,3 +121,35 @@ def run_reasoning_extraction_streaming( previous_text = current_text previous_tokens = current_tokens return reconstructor + + +def run_reasoning_extraction_streaming_mistral( + reasoning_parser: ReasoningParser, + model_deltas: list[int], + request: Union[ChatCompletionRequest, None] = None, +) -> StreamingReasoningReconstructor: + assert isinstance(reasoning_parser.model_tokenizer, + MistralTokenizer), type(reasoning_parser.model_tokenizer) + request = request or ChatCompletionRequest(messages=[], model="test-model") + reconstructor = StreamingReasoningReconstructor() + previous_text = "" + previous_tokens: list[int] = [] + for model_delta in model_deltas: + token_delta = [model_delta] + delta = reasoning_parser.model_tokenizer.convert_ids_to_tokens( + [model_delta])[0] + current_text = previous_text + delta + current_tokens = previous_tokens + token_delta + delta_message = reasoning_parser.extract_reasoning_content_streaming( + previous_text, + current_text, + delta, + previous_tokens, + current_tokens, + token_delta, + ) + if delta_message is not None: + reconstructor.append_delta(delta_message) + previous_text = current_text + previous_tokens = current_tokens + return reconstructor diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 496caef4256d..a6602391d408 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -151,6 +151,27 @@ class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False): video_url: Required[str] +class CustomThinkCompletionContentParam(TypedDict, total=False): + """A Think Completion Content Param that accepts a plain text and a boolean. + + Example: + { + "thinking": "I am thinking about the answer", + "closed": True, + "type": "thinking" + } + """ + + thinking: Required[str] + """The thinking content.""" + + closed: bool + """Whether the thinking is closed.""" + + type: Required[Literal["thinking"]] + """The thinking type.""" + + ChatCompletionContentPartParam: TypeAlias = Union[ OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam, ChatCompletionContentPartInputAudioParam, @@ -159,7 +180,8 @@ ChatCompletionContentPartParam: TypeAlias = Union[ CustomChatCompletionContentSimpleImageParam, ChatCompletionContentPartImageEmbedsParam, CustomChatCompletionContentSimpleAudioParam, - CustomChatCompletionContentSimpleVideoParam, str] + CustomChatCompletionContentSimpleVideoParam, str, + CustomThinkCompletionContentParam] class CustomChatCompletionMessageParam(TypedDict, total=False): @@ -938,6 +960,7 @@ _ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam) _InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam) _RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam) _PILImageParser = partial(cast, CustomChatCompletionContentPILImageParam) +_ThinkParser = partial(cast, CustomThinkCompletionContentParam) # Need to validate url objects _ImageParser = TypeAdapter(ChatCompletionContentPartImageParam).validate_python _AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python @@ -954,6 +977,8 @@ MM_PARSER_MAP: dict[ ] = { "text": lambda part: _TextParser(part).get("text", None), + "thinking": + lambda part: _ThinkParser(part).get("thinking", None), "input_text": lambda part: _TextParser(part).get("text", None), "input_image": @@ -1100,7 +1125,7 @@ def _parse_chat_message_content_part( "with empty / unparsable content.", part, part_type) return None - if part_type in ("text", "input_text", "refusal"): + if part_type in ("text", "input_text", "refusal", "thinking"): str_content = cast(str, content) if wrap_dicts: return {'type': 'text', 'text': str_content} diff --git a/vllm/reasoning/__init__.py b/vllm/reasoning/__init__.py index bae593c1dff0..d61e4f11dfa2 100644 --- a/vllm/reasoning/__init__.py +++ b/vllm/reasoning/__init__.py @@ -6,6 +6,7 @@ from .deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser from .glm4_moe_reasoning_parser import Glm4MoeModelReasoningParser from .granite_reasoning_parser import GraniteReasoningParser from .hunyuan_a13b_reasoning_parser import HunyuanA13BReasoningParser +from .mistral_reasoning_parser import MistralReasoningParser from .qwen3_reasoning_parser import Qwen3ReasoningParser __all__ = [ @@ -16,4 +17,5 @@ __all__ = [ "HunyuanA13BReasoningParser", "Qwen3ReasoningParser", "Glm4MoeModelReasoningParser", + "MistralReasoningParser", ] diff --git a/vllm/reasoning/mistral_reasoning_parser.py b/vllm/reasoning/mistral_reasoning_parser.py new file mode 100644 index 000000000000..6c707a4079fa --- /dev/null +++ b/vllm/reasoning/mistral_reasoning_parser.py @@ -0,0 +1,47 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.logger import init_logger +from vllm.reasoning import ReasoningParser, ReasoningParserManager +from vllm.reasoning.deepseek_r1_reasoning_parser import ( + DeepSeekR1ReasoningParser) +from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer + +logger = init_logger(__name__) + + +@ReasoningParserManager.register_module("mistral") +class MistralReasoningParser(DeepSeekR1ReasoningParser): + """ + Reasoning parser for Mistral models. + + The Mistral models uses [THINK]...[/THINK] tokens to denote reasoning + text. This parser extracts the reasoning content from the model output. + """ + + def __init__(self, tokenizer: MistralTokenizer): + if not isinstance(tokenizer, MistralTokenizer): + raise ValueError( + "The tokenizer must be an instance of MistralTokenizer.") + + ReasoningParser.__init__(self, tokenizer) + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ReasoningParser " + "constructor during construction.") + + from mistral_common.tokens.tokenizers.base import SpecialTokens + + self.start_token = SpecialTokens.begin_think + self.end_token = SpecialTokens.end_think + + self.start_token_id = tokenizer.tokenizer.get_control_token( + self.start_token) + self.end_token_id = tokenizer.tokenizer.get_control_token( + self.end_token) + + if self.start_token_id is None or self.end_token_id is None: + raise RuntimeError( + "Mistral reasoning parser could not locate think start/end " + "tokens in the tokenizer!") diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index 24ac4580d670..f83405cfc016 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -145,6 +145,21 @@ def find_tokenizer_file(files: list[str]): return matched_files[0] +def _aggregate_content(content: list) -> list[dict[str, Any]]: + aggregated_content: list[dict[str, Any]] = [] + for chunk in content: + if chunk.get("type" + ) == "text" and aggregated_content and aggregated_content[ + -1].get("type") == "text": + aggregated_content[-1]["text"] += "\n\n" + chunk.get("text") + else: + aggregated_content.append(chunk) + if len(aggregated_content) == 1 and aggregated_content[0].get( + "type") == "text": + content = aggregated_content[0]["text"] + return content + + def make_mistral_chat_completion_request( messages: list["ChatCompletionMessageParam"], tools: Optional[list[dict[str, @@ -162,10 +177,10 @@ def make_mistral_chat_completion_request( # Convert list text content to string if message.get("role") in ("assistant", "tool"): - content = message.get("content") + content: Any = message.get("content") if isinstance(content, list): - content = "\n".join(chunk.get("text") for chunk in content) - message["content"] = content + content = _aggregate_content(content) + message["content"] = content # The Mistral client, in comparison to the OpenAI client, requires the # "parameters" dict to be present, even if it's empty. @@ -465,6 +480,8 @@ class MistralTokenizer(TokenizerBase): skip_special_tokens: bool = True, ) -> list[str]: from mistral_common.tokens.tokenizers.base import SpecialTokens + from mistral_common.tokens.tokenizers.instruct import ( + InstructTokenizerV13) # TODO(Patrick) - potentially allow special tokens to not be skipped assert ( @@ -474,10 +491,18 @@ class MistralTokenizer(TokenizerBase): assert self.is_tekken or self.is_spm, type(self.tokenizer) if self.is_tekken: - # skip special tokens except tool call - ids = [ - i for i in ids if i > self.tokenizer.num_special_tokens or i == + # skip special tokens except tool call and think tokens + non_skip_special_tokens = { self.tokenizer.get_control_token(SpecialTokens.tool_calls) + } + if isinstance(self.instruct, InstructTokenizerV13): + if self.instruct.BEGIN_THINK: + non_skip_special_tokens.add(self.instruct.BEGIN_THINK) + if self.instruct.END_THINK: + non_skip_special_tokens.add(self.instruct.END_THINK) + ids = [ + i for i in ids if i > self.tokenizer.num_special_tokens + or i in non_skip_special_tokens ] tokens = [self.tokenizer.id_to_piece(id) for id in ids]