mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 03:54:59 +08:00
Add think chunk (#21333)
Signed-off-by: Julien Denize <julien.denize@mistral.ai>
This commit is contained in:
parent
11ef7a611e
commit
6d8d0a24c0
@ -33,7 +33,7 @@ pyzmq >= 25.0.0
|
|||||||
msgspec
|
msgspec
|
||||||
gguf >= 0.13.0
|
gguf >= 0.13.0
|
||||||
importlib_metadata; python_version < '3.10'
|
importlib_metadata; python_version < '3.10'
|
||||||
mistral_common[opencv] >= 1.8.0
|
mistral_common[image,audio] >= 1.8.2
|
||||||
opencv-python-headless >= 4.11.0 # required for video IO
|
opencv-python-headless >= 4.11.0 # required for video IO
|
||||||
pyyaml
|
pyyaml
|
||||||
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
|
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
|
||||||
|
|||||||
@ -23,7 +23,7 @@ jiwer # required for audio tests
|
|||||||
timm # required for internvl test
|
timm # required for internvl test
|
||||||
transformers_stream_generator # required for qwen-vl test
|
transformers_stream_generator # required for qwen-vl test
|
||||||
matplotlib # required for qwen-vl test
|
matplotlib # required for qwen-vl test
|
||||||
mistral_common[opencv] >= 1.8.0 # required for voxtral test
|
mistral_common[image,audio] >= 1.8.2 # required for voxtral test
|
||||||
num2words # required for smolvlm test
|
num2words # required for smolvlm test
|
||||||
opencv-python-headless >= 4.11.0 # required for video test
|
opencv-python-headless >= 4.11.0 # required for video test
|
||||||
datamodel_code_generator # required for minicpm3 test
|
datamodel_code_generator # required for minicpm3 test
|
||||||
|
|||||||
@ -28,7 +28,7 @@ torchvision==0.22.1
|
|||||||
transformers_stream_generator # required for qwen-vl test
|
transformers_stream_generator # required for qwen-vl test
|
||||||
mamba_ssm # required for plamo2 test
|
mamba_ssm # required for plamo2 test
|
||||||
matplotlib # required for qwen-vl test
|
matplotlib # required for qwen-vl test
|
||||||
mistral_common[opencv] >= 1.8.0 # required for voxtral test
|
mistral_common[image,audio] >= 1.8.2 # required for voxtral test
|
||||||
num2words # required for smolvlm test
|
num2words # required for smolvlm test
|
||||||
open_clip_torch==2.32.0 # Required for nemotron_vl test
|
open_clip_torch==2.32.0 # Required for nemotron_vl test
|
||||||
opencv-python-headless >= 4.11.0 # required for video test
|
opencv-python-headless >= 4.11.0 # required for video test
|
||||||
|
|||||||
@ -447,7 +447,7 @@ mbstrdecoder==1.1.3
|
|||||||
# typepy
|
# typepy
|
||||||
mdurl==0.1.2
|
mdurl==0.1.2
|
||||||
# via markdown-it-py
|
# via markdown-it-py
|
||||||
mistral-common==1.8.0
|
mistral-common==1.8.2
|
||||||
# via -r requirements/test.in
|
# via -r requirements/test.in
|
||||||
mlflow==2.22.0
|
mlflow==2.22.0
|
||||||
# via terratorch
|
# via terratorch
|
||||||
@ -999,8 +999,11 @@ soundfile==0.12.1
|
|||||||
# via
|
# via
|
||||||
# -r requirements/test.in
|
# -r requirements/test.in
|
||||||
# librosa
|
# librosa
|
||||||
|
# mistral-common
|
||||||
soxr==0.5.0.post1
|
soxr==0.5.0.post1
|
||||||
# via librosa
|
# via
|
||||||
|
# librosa
|
||||||
|
# mistral-common
|
||||||
sqlalchemy==2.0.41
|
sqlalchemy==2.0.41
|
||||||
# via
|
# via
|
||||||
# alembic
|
# alembic
|
||||||
|
|||||||
@ -6,6 +6,10 @@ from collections.abc import Mapping
|
|||||||
from typing import Literal, Optional
|
from typing import Literal, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from mistral_common.tokens.tokenizers.base import (SpecialTokenPolicy,
|
||||||
|
SpecialTokens)
|
||||||
|
from mistral_common.tokens.tokenizers.tekken import (SpecialTokenInfo,
|
||||||
|
Tekkenizer)
|
||||||
|
|
||||||
from vllm.assets.audio import AudioAsset
|
from vllm.assets.audio import AudioAsset
|
||||||
from vllm.assets.image import ImageAsset
|
from vllm.assets.image import ImageAsset
|
||||||
@ -21,6 +25,7 @@ from vllm.multimodal import MultiModalDataDict
|
|||||||
from vllm.multimodal.utils import (encode_audio_base64, encode_image_base64,
|
from vllm.multimodal.utils import (encode_audio_base64, encode_image_base64,
|
||||||
encode_video_base64)
|
encode_video_base64)
|
||||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||||
|
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||||
|
|
||||||
from ..models.registry import HF_EXAMPLE_MODELS
|
from ..models.registry import HF_EXAMPLE_MODELS
|
||||||
from ..utils import VLLM_PATH
|
from ..utils import VLLM_PATH
|
||||||
@ -1374,3 +1379,165 @@ def test_resolve_content_format_examples(template_path, expected_format):
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert resolved_format == expected_format
|
assert resolved_format == expected_format
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_chat_messages_include_thinking_chunk(mistral_model_config,
|
||||||
|
mistral_tokenizer):
|
||||||
|
messages = [{
|
||||||
|
"role":
|
||||||
|
"system",
|
||||||
|
"content": [{
|
||||||
|
"type": "text",
|
||||||
|
"text": "You are a helpful assistant."
|
||||||
|
}, {
|
||||||
|
"type":
|
||||||
|
"thinking",
|
||||||
|
"closed":
|
||||||
|
True,
|
||||||
|
"thinking":
|
||||||
|
"Only return the answer when you are confident."
|
||||||
|
}]
|
||||||
|
}, {
|
||||||
|
"role": "user",
|
||||||
|
"content": "What is 2+2?"
|
||||||
|
}, {
|
||||||
|
"role":
|
||||||
|
"assistant",
|
||||||
|
"content": [{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Let me think about it."
|
||||||
|
}, {
|
||||||
|
"type": "thinking",
|
||||||
|
"closed": True,
|
||||||
|
"thinking": "2+2 = 4"
|
||||||
|
}, {
|
||||||
|
"type": "text",
|
||||||
|
"text": "The answer is 4.",
|
||||||
|
}],
|
||||||
|
}]
|
||||||
|
|
||||||
|
conversation_with_thinking, _ = parse_chat_messages(
|
||||||
|
messages,
|
||||||
|
mistral_model_config,
|
||||||
|
mistral_tokenizer,
|
||||||
|
content_format="openai",
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_conversation = [{
|
||||||
|
"role":
|
||||||
|
"system",
|
||||||
|
"content": [{
|
||||||
|
"type": "text",
|
||||||
|
"text": "You are a helpful assistant."
|
||||||
|
}, {
|
||||||
|
"type": "text",
|
||||||
|
"text": "Only return the answer when you are confident."
|
||||||
|
}],
|
||||||
|
}, {
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content": [{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What is 2+2?"
|
||||||
|
}],
|
||||||
|
}, {
|
||||||
|
"role":
|
||||||
|
"assistant",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Let me think about it."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "2+2 = 4"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "The answer is 4."
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}]
|
||||||
|
|
||||||
|
assert conversation_with_thinking == expected_conversation
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_mistral_chat_template_thinking_chunk():
|
||||||
|
# Moved import here to avoid yapf and isort conflicts
|
||||||
|
from vllm.entrypoints.chat_utils import apply_mistral_chat_template
|
||||||
|
messages = [{
|
||||||
|
"role":
|
||||||
|
"system",
|
||||||
|
"content": [{
|
||||||
|
"type": "text",
|
||||||
|
"text": "You are a helpful assistant."
|
||||||
|
}, {
|
||||||
|
"type":
|
||||||
|
"thinking",
|
||||||
|
"closed":
|
||||||
|
True,
|
||||||
|
"thinking":
|
||||||
|
"Only return the answer when you are confident."
|
||||||
|
}]
|
||||||
|
}, {
|
||||||
|
"role": "user",
|
||||||
|
"content": "What is 2+2?"
|
||||||
|
}, {
|
||||||
|
"role":
|
||||||
|
"assistant",
|
||||||
|
"content": [{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Let me think about it."
|
||||||
|
}, {
|
||||||
|
"type": "thinking",
|
||||||
|
"closed": True,
|
||||||
|
"thinking": "2+2 = 4"
|
||||||
|
}, {
|
||||||
|
"type": "text",
|
||||||
|
"text": "The answer is 4.",
|
||||||
|
}],
|
||||||
|
}, {
|
||||||
|
"role": "user",
|
||||||
|
"content": "Thanks, what is 3+3?"
|
||||||
|
}]
|
||||||
|
|
||||||
|
# TODO(Julien): upon model release change to a tokenizer already configured.
|
||||||
|
# =================================================================
|
||||||
|
mistral_tokenizer = MistralTokenizer.from_pretrained(
|
||||||
|
"mistralai/Devstral-Small-2507")
|
||||||
|
assert isinstance(mistral_tokenizer.tokenizer, Tekkenizer)
|
||||||
|
# Add think special tokens to the tokenizer
|
||||||
|
mistral_tokenizer.tokenizer._all_special_tokens[35] = SpecialTokenInfo(
|
||||||
|
rank=35, is_control=True, token_str=SpecialTokens.begin_think.value)
|
||||||
|
mistral_tokenizer.tokenizer._all_special_tokens[36] = SpecialTokenInfo(
|
||||||
|
rank=36, is_control=True, token_str=SpecialTokens.end_think.value)
|
||||||
|
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab = {
|
||||||
|
k: v
|
||||||
|
for k, v in
|
||||||
|
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab.items()
|
||||||
|
if v not in {35, 36}
|
||||||
|
}
|
||||||
|
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[
|
||||||
|
SpecialTokens.begin_think.value] = 35
|
||||||
|
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[
|
||||||
|
SpecialTokens.end_think.value] = 36
|
||||||
|
mistral_tokenizer.instruct.BEGIN_THINK = 35
|
||||||
|
mistral_tokenizer.instruct.END_THINK = 36
|
||||||
|
# =================================================================
|
||||||
|
|
||||||
|
tokens_ids = apply_mistral_chat_template(mistral_tokenizer,
|
||||||
|
messages,
|
||||||
|
chat_template=None,
|
||||||
|
tools=None)
|
||||||
|
|
||||||
|
string_tokens = mistral_tokenizer.mistral.decode(
|
||||||
|
tokens_ids, special_token_policy=SpecialTokenPolicy.KEEP)
|
||||||
|
|
||||||
|
expected_tokens = (
|
||||||
|
r"<s>[SYSTEM_PROMPT]You are a helpful assistant.[THINK]Only return the"
|
||||||
|
r" answer when you are confident.[/THINK][/SYSTEM_PROMPT]"
|
||||||
|
r"[INST]What is 2+2?[/INST]"
|
||||||
|
r"Let me think about it.[THINK]2+2 = 4[/THINK]The answer is 4.</s>"
|
||||||
|
r"[INST]Thanks, what is 3+3?[/INST]")
|
||||||
|
|
||||||
|
assert string_tokens == expected_tokens
|
||||||
|
|||||||
341
tests/reasoning/test_mistral_reasoning_parser.py
Normal file
341
tests/reasoning/test_mistral_reasoning_parser.py
Normal file
@ -0,0 +1,341 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from mistral_common.tokens.tokenizers.base import SpecialTokens
|
||||||
|
from mistral_common.tokens.tokenizers.tekken import (SpecialTokenInfo,
|
||||||
|
Tekkenizer)
|
||||||
|
|
||||||
|
from tests.reasoning.utils import run_reasoning_extraction_mistral
|
||||||
|
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||||
|
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||||
|
|
||||||
|
parser_name = "mistral"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def mistral_tokenizer():
|
||||||
|
# TODO(Julien): upon model release change to a tokenizer already configured.
|
||||||
|
# =================================================================
|
||||||
|
mistral_tokenizer = MistralTokenizer.from_pretrained(
|
||||||
|
"mistralai/Devstral-Small-2507")
|
||||||
|
assert isinstance(mistral_tokenizer.tokenizer, Tekkenizer)
|
||||||
|
# Add think special tokens to the tokenizer
|
||||||
|
mistral_tokenizer.tokenizer._all_special_tokens[35] = SpecialTokenInfo(
|
||||||
|
rank=35, is_control=True, token_str=SpecialTokens.begin_think.value)
|
||||||
|
mistral_tokenizer.tokenizer._all_special_tokens[36] = SpecialTokenInfo(
|
||||||
|
rank=36, is_control=True, token_str=SpecialTokens.end_think.value)
|
||||||
|
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab = {
|
||||||
|
k: v
|
||||||
|
for k, v in
|
||||||
|
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab.items()
|
||||||
|
if v not in {35, 36}
|
||||||
|
}
|
||||||
|
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[
|
||||||
|
SpecialTokens.begin_think.value] = 35
|
||||||
|
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[
|
||||||
|
SpecialTokens.end_think.value] = 36
|
||||||
|
mistral_tokenizer.instruct.BEGIN_THINK = 35
|
||||||
|
mistral_tokenizer.instruct.END_THINK = 36
|
||||||
|
# =================================================================
|
||||||
|
return mistral_tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
SIMPLE_REASONING = {
|
||||||
|
"output": "This is a reasoning section[/THINK]This is the rest",
|
||||||
|
"reasoning_content": "This is a reasoning section",
|
||||||
|
"content": "This is the rest",
|
||||||
|
"is_reasoning_end": True,
|
||||||
|
}
|
||||||
|
COMPLETE_REASONING = {
|
||||||
|
"output": "This is a reasoning section[/THINK]",
|
||||||
|
"reasoning_content": "This is a reasoning section",
|
||||||
|
"content": None,
|
||||||
|
"is_reasoning_end": True,
|
||||||
|
}
|
||||||
|
NO_CONTENT = {
|
||||||
|
"output": "This is content",
|
||||||
|
"reasoning_content": "This is content",
|
||||||
|
"content": None,
|
||||||
|
"is_reasoning_end": False,
|
||||||
|
}
|
||||||
|
NO_REASONING_STREAMING = {
|
||||||
|
"output": "This is a reasoning section",
|
||||||
|
"reasoning_content": "This is a reasoning section",
|
||||||
|
"content": None,
|
||||||
|
"is_reasoning_end": False,
|
||||||
|
}
|
||||||
|
MULTIPLE_LINES = {
|
||||||
|
"output": "This\nThat[/THINK]This is the rest\nThat",
|
||||||
|
"reasoning_content": "This\nThat",
|
||||||
|
"content": "This is the rest\nThat",
|
||||||
|
"is_reasoning_end": True,
|
||||||
|
}
|
||||||
|
SHORTEST_REASONING_NO_STREAMING = {
|
||||||
|
"output": "[/THINK]This is the rest",
|
||||||
|
"reasoning_content": "",
|
||||||
|
"content": "This is the rest",
|
||||||
|
"is_reasoning_end": True,
|
||||||
|
}
|
||||||
|
SHORTEST_REASONING = {
|
||||||
|
"output": "[/THINK]This is the rest",
|
||||||
|
"reasoning_content": None,
|
||||||
|
"content": "This is the rest",
|
||||||
|
"is_reasoning_end": True,
|
||||||
|
}
|
||||||
|
REASONING_WITH_THINK = {
|
||||||
|
"output": "[THINK]This is a reasoning section[/THINK]This is the rest",
|
||||||
|
"reasoning_content": "This is a reasoning section",
|
||||||
|
"content": "This is the rest",
|
||||||
|
"is_reasoning_end": True,
|
||||||
|
}
|
||||||
|
COMPLETE_REASONING_WITH_THINK = {
|
||||||
|
"output": "[THINK]This is a reasoning section[/THINK]",
|
||||||
|
"reasoning_content": "This is a reasoning section",
|
||||||
|
"content": None,
|
||||||
|
"is_reasoning_end": True,
|
||||||
|
}
|
||||||
|
MULTIPLE_LINES_WITH_THINK = {
|
||||||
|
"output": "[THINK]This\nThat[/THINK]This is the rest\nThat",
|
||||||
|
"reasoning_content": "This\nThat",
|
||||||
|
"content": "This is the rest\nThat",
|
||||||
|
"is_reasoning_end": True,
|
||||||
|
}
|
||||||
|
SHORTEST_REASONING_NO_STREAMING_WITH_THINK = {
|
||||||
|
"output": "[/THINK]This is the rest",
|
||||||
|
"reasoning_content": "",
|
||||||
|
"content": "This is the rest",
|
||||||
|
"is_reasoning_end": True,
|
||||||
|
}
|
||||||
|
SHORTEST_REASONING_WITH_THINK = {
|
||||||
|
"output": "[/THINK]This is the rest",
|
||||||
|
"reasoning_content": None,
|
||||||
|
"content": "This is the rest",
|
||||||
|
"is_reasoning_end": True,
|
||||||
|
}
|
||||||
|
THINK_NO_END = {
|
||||||
|
"output": "[THINK]This is a reasoning section",
|
||||||
|
"reasoning_content": "This is a reasoning section",
|
||||||
|
"content": None,
|
||||||
|
"is_reasoning_end": False,
|
||||||
|
}
|
||||||
|
EMPTY = {
|
||||||
|
"output": "",
|
||||||
|
"reasoning_content": "",
|
||||||
|
"content": None,
|
||||||
|
"is_reasoning_end": False,
|
||||||
|
}
|
||||||
|
EMPTY_STREAMING = {
|
||||||
|
"output": "",
|
||||||
|
"reasoning_content": None,
|
||||||
|
"content": None,
|
||||||
|
"is_reasoning_end": False,
|
||||||
|
}
|
||||||
|
NEW_LINE = {
|
||||||
|
"output": "\n[THINK]This is a reasoning section[/THINK]\nThis is the rest",
|
||||||
|
"reasoning_content": "This is a reasoning section",
|
||||||
|
"content": "\nThis is the rest",
|
||||||
|
"is_reasoning_end": True,
|
||||||
|
}
|
||||||
|
# Streaming cannot handle new lines at the beginning of the output
|
||||||
|
# because we need to support [THINK]...[/THINK] and [/THINK]...
|
||||||
|
# We cannot know if the text before [THINK] is reasoning content
|
||||||
|
# or not.
|
||||||
|
NEW_LINE_STREAMING = {
|
||||||
|
"output": "\n[THINK]This is a reasoning section[/THINK]\nThis is the rest",
|
||||||
|
"reasoning_content": "\nThis is a reasoning section",
|
||||||
|
"content": "\nThis is the rest",
|
||||||
|
"is_reasoning_end": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASES = [
|
||||||
|
pytest.param(
|
||||||
|
False,
|
||||||
|
SIMPLE_REASONING,
|
||||||
|
id="simple_reasoning",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
True,
|
||||||
|
SIMPLE_REASONING,
|
||||||
|
id="simple_reasoning_streaming",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
False,
|
||||||
|
COMPLETE_REASONING,
|
||||||
|
id="complete_reasoning",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
True,
|
||||||
|
COMPLETE_REASONING,
|
||||||
|
id="complete_reasoning_streaming",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
False,
|
||||||
|
NO_CONTENT,
|
||||||
|
id="no_content_token",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
True,
|
||||||
|
NO_REASONING_STREAMING,
|
||||||
|
id="no_reasoning_token_streaming",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
False,
|
||||||
|
MULTIPLE_LINES,
|
||||||
|
id="multiple_lines",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
True,
|
||||||
|
MULTIPLE_LINES,
|
||||||
|
id="multiple_lines_streaming",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
True,
|
||||||
|
SHORTEST_REASONING,
|
||||||
|
id="shortest",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
False,
|
||||||
|
SHORTEST_REASONING_NO_STREAMING,
|
||||||
|
id="shortest_streaming",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
False,
|
||||||
|
REASONING_WITH_THINK,
|
||||||
|
id="reasoning_with_think",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
True,
|
||||||
|
REASONING_WITH_THINK,
|
||||||
|
id="reasoning_with_think_streaming",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
False,
|
||||||
|
COMPLETE_REASONING_WITH_THINK,
|
||||||
|
id="complete_reasoning_with_think",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
True,
|
||||||
|
COMPLETE_REASONING_WITH_THINK,
|
||||||
|
id="complete_reasoning_with_think_streaming",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
False,
|
||||||
|
MULTIPLE_LINES_WITH_THINK,
|
||||||
|
id="multiple_lines_with_think",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
True,
|
||||||
|
MULTIPLE_LINES_WITH_THINK,
|
||||||
|
id="multiple_lines_with_think_streaming",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
False,
|
||||||
|
SHORTEST_REASONING_NO_STREAMING_WITH_THINK,
|
||||||
|
id="shortest_with_think",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
True,
|
||||||
|
SHORTEST_REASONING_WITH_THINK,
|
||||||
|
id="shortest_with_think_streaming",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
False,
|
||||||
|
THINK_NO_END,
|
||||||
|
id="think_no_end",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
True,
|
||||||
|
THINK_NO_END,
|
||||||
|
id="think_no_end_streaming",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
False,
|
||||||
|
EMPTY,
|
||||||
|
id="empty",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
True,
|
||||||
|
EMPTY_STREAMING,
|
||||||
|
id="empty_streaming",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
False,
|
||||||
|
NEW_LINE,
|
||||||
|
id="new_line",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
True,
|
||||||
|
NEW_LINE_STREAMING,
|
||||||
|
id="new_line_streaming",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("streaming, param_dict", TEST_CASES)
|
||||||
|
def test_mistral_reasoning(
|
||||||
|
streaming: bool,
|
||||||
|
param_dict: dict,
|
||||||
|
mistral_tokenizer: MistralTokenizer,
|
||||||
|
):
|
||||||
|
output = param_dict["output"]
|
||||||
|
|
||||||
|
index_think = output.find("[THINK]")
|
||||||
|
len_think = len("[THINK]")
|
||||||
|
index_end_think = output.find("[/THINK]")
|
||||||
|
len_end_think = len("[/THINK]")
|
||||||
|
|
||||||
|
# encode everything to tokens ids
|
||||||
|
output_tokens = []
|
||||||
|
if index_think != -1:
|
||||||
|
output_before_think = output[:index_think]
|
||||||
|
output_tokens += mistral_tokenizer.tokenizer.encode(
|
||||||
|
output_before_think, False, False)
|
||||||
|
output_tokens += [mistral_tokenizer.instruct.BEGIN_THINK]
|
||||||
|
|
||||||
|
if index_end_think != -1:
|
||||||
|
output_middle = output[index_think + len_think:index_end_think]
|
||||||
|
output_after_think = output[index_end_think + len_end_think:]
|
||||||
|
output_tokens += mistral_tokenizer.tokenizer.encode(
|
||||||
|
output_middle, False, False)
|
||||||
|
output_tokens += [mistral_tokenizer.instruct.END_THINK]
|
||||||
|
output_tokens += mistral_tokenizer.tokenizer.encode(
|
||||||
|
output_after_think, False, False)
|
||||||
|
else:
|
||||||
|
output_middle = output[index_think + len_think:]
|
||||||
|
output_tokens += mistral_tokenizer.tokenizer.encode(
|
||||||
|
output_middle, False, False)
|
||||||
|
elif index_end_think != -1:
|
||||||
|
output_before_think = output[:index_end_think]
|
||||||
|
output_after_think = output[index_end_think + len_end_think:]
|
||||||
|
output_tokens += mistral_tokenizer.tokenizer.encode(
|
||||||
|
output_before_think, False, False)
|
||||||
|
output_tokens += [mistral_tokenizer.instruct.END_THINK]
|
||||||
|
output_tokens += mistral_tokenizer.tokenizer.encode(
|
||||||
|
output_after_think, False, False)
|
||||||
|
else:
|
||||||
|
output_tokens += mistral_tokenizer.tokenizer.encode(
|
||||||
|
output, False, False)
|
||||||
|
|
||||||
|
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(
|
||||||
|
parser_name)(mistral_tokenizer)
|
||||||
|
|
||||||
|
reasoning, content = run_reasoning_extraction_mistral(parser,
|
||||||
|
output_tokens,
|
||||||
|
streaming=streaming)
|
||||||
|
|
||||||
|
assert reasoning == param_dict["reasoning_content"]
|
||||||
|
assert content == param_dict["content"]
|
||||||
|
|
||||||
|
# Test is_reasoning_end
|
||||||
|
is_reasoning_end = parser.is_reasoning_end(output_tokens)
|
||||||
|
assert is_reasoning_end == param_dict["is_reasoning_end"]
|
||||||
|
|
||||||
|
# Test extract_content
|
||||||
|
if param_dict["content"] is not None:
|
||||||
|
content = parser.extract_content_ids(output_tokens)
|
||||||
|
assert content == mistral_tokenizer.tokenizer.encode(
|
||||||
|
param_dict["content"], bos=False, eos=False)
|
||||||
|
else:
|
||||||
|
content = parser.extract_content_ids(output_tokens)
|
||||||
|
assert content == []
|
||||||
@ -6,6 +6,7 @@ from typing import Optional, Union
|
|||||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||||
DeltaMessage)
|
DeltaMessage)
|
||||||
from vllm.reasoning import ReasoningParser
|
from vllm.reasoning import ReasoningParser
|
||||||
|
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||||
|
|
||||||
|
|
||||||
class StreamingReasoningReconstructor:
|
class StreamingReasoningReconstructor:
|
||||||
@ -54,6 +55,32 @@ def run_reasoning_extraction(
|
|||||||
return reasoning, content
|
return reasoning, content
|
||||||
|
|
||||||
|
|
||||||
|
def run_reasoning_extraction_mistral(
|
||||||
|
reasoning_parser: ReasoningParser,
|
||||||
|
model_output: list[int],
|
||||||
|
request: Union[ChatCompletionRequest, None] = None,
|
||||||
|
streaming: bool = False,
|
||||||
|
) -> tuple[Optional[str], Optional[str]]:
|
||||||
|
assert isinstance(reasoning_parser.model_tokenizer,
|
||||||
|
MistralTokenizer), type(reasoning_parser.model_tokenizer)
|
||||||
|
if streaming:
|
||||||
|
reconstructor = run_reasoning_extraction_streaming_mistral(
|
||||||
|
reasoning_parser,
|
||||||
|
model_output,
|
||||||
|
request,
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
reconstructor.reasoning_content,
|
||||||
|
reconstructor.other_content or None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
str_output = reasoning_parser.model_tokenizer.convert_ids_to_tokens(
|
||||||
|
model_output)
|
||||||
|
reasoning, content = run_reasoning_extraction_nonstreaming(
|
||||||
|
reasoning_parser, str_output, request)
|
||||||
|
return reasoning, content
|
||||||
|
|
||||||
|
|
||||||
def run_reasoning_extraction_nonstreaming(
|
def run_reasoning_extraction_nonstreaming(
|
||||||
reasoning_parser: ReasoningParser,
|
reasoning_parser: ReasoningParser,
|
||||||
model_output: list[str],
|
model_output: list[str],
|
||||||
@ -94,3 +121,35 @@ def run_reasoning_extraction_streaming(
|
|||||||
previous_text = current_text
|
previous_text = current_text
|
||||||
previous_tokens = current_tokens
|
previous_tokens = current_tokens
|
||||||
return reconstructor
|
return reconstructor
|
||||||
|
|
||||||
|
|
||||||
|
def run_reasoning_extraction_streaming_mistral(
|
||||||
|
reasoning_parser: ReasoningParser,
|
||||||
|
model_deltas: list[int],
|
||||||
|
request: Union[ChatCompletionRequest, None] = None,
|
||||||
|
) -> StreamingReasoningReconstructor:
|
||||||
|
assert isinstance(reasoning_parser.model_tokenizer,
|
||||||
|
MistralTokenizer), type(reasoning_parser.model_tokenizer)
|
||||||
|
request = request or ChatCompletionRequest(messages=[], model="test-model")
|
||||||
|
reconstructor = StreamingReasoningReconstructor()
|
||||||
|
previous_text = ""
|
||||||
|
previous_tokens: list[int] = []
|
||||||
|
for model_delta in model_deltas:
|
||||||
|
token_delta = [model_delta]
|
||||||
|
delta = reasoning_parser.model_tokenizer.convert_ids_to_tokens(
|
||||||
|
[model_delta])[0]
|
||||||
|
current_text = previous_text + delta
|
||||||
|
current_tokens = previous_tokens + token_delta
|
||||||
|
delta_message = reasoning_parser.extract_reasoning_content_streaming(
|
||||||
|
previous_text,
|
||||||
|
current_text,
|
||||||
|
delta,
|
||||||
|
previous_tokens,
|
||||||
|
current_tokens,
|
||||||
|
token_delta,
|
||||||
|
)
|
||||||
|
if delta_message is not None:
|
||||||
|
reconstructor.append_delta(delta_message)
|
||||||
|
previous_text = current_text
|
||||||
|
previous_tokens = current_tokens
|
||||||
|
return reconstructor
|
||||||
|
|||||||
@ -151,6 +151,27 @@ class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
|
|||||||
video_url: Required[str]
|
video_url: Required[str]
|
||||||
|
|
||||||
|
|
||||||
|
class CustomThinkCompletionContentParam(TypedDict, total=False):
|
||||||
|
"""A Think Completion Content Param that accepts a plain text and a boolean.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
{
|
||||||
|
"thinking": "I am thinking about the answer",
|
||||||
|
"closed": True,
|
||||||
|
"type": "thinking"
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
thinking: Required[str]
|
||||||
|
"""The thinking content."""
|
||||||
|
|
||||||
|
closed: bool
|
||||||
|
"""Whether the thinking is closed."""
|
||||||
|
|
||||||
|
type: Required[Literal["thinking"]]
|
||||||
|
"""The thinking type."""
|
||||||
|
|
||||||
|
|
||||||
ChatCompletionContentPartParam: TypeAlias = Union[
|
ChatCompletionContentPartParam: TypeAlias = Union[
|
||||||
OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam,
|
OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam,
|
||||||
ChatCompletionContentPartInputAudioParam,
|
ChatCompletionContentPartInputAudioParam,
|
||||||
@ -159,7 +180,8 @@ ChatCompletionContentPartParam: TypeAlias = Union[
|
|||||||
CustomChatCompletionContentSimpleImageParam,
|
CustomChatCompletionContentSimpleImageParam,
|
||||||
ChatCompletionContentPartImageEmbedsParam,
|
ChatCompletionContentPartImageEmbedsParam,
|
||||||
CustomChatCompletionContentSimpleAudioParam,
|
CustomChatCompletionContentSimpleAudioParam,
|
||||||
CustomChatCompletionContentSimpleVideoParam, str]
|
CustomChatCompletionContentSimpleVideoParam, str,
|
||||||
|
CustomThinkCompletionContentParam]
|
||||||
|
|
||||||
|
|
||||||
class CustomChatCompletionMessageParam(TypedDict, total=False):
|
class CustomChatCompletionMessageParam(TypedDict, total=False):
|
||||||
@ -938,6 +960,7 @@ _ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam)
|
|||||||
_InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
|
_InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
|
||||||
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
|
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
|
||||||
_PILImageParser = partial(cast, CustomChatCompletionContentPILImageParam)
|
_PILImageParser = partial(cast, CustomChatCompletionContentPILImageParam)
|
||||||
|
_ThinkParser = partial(cast, CustomThinkCompletionContentParam)
|
||||||
# Need to validate url objects
|
# Need to validate url objects
|
||||||
_ImageParser = TypeAdapter(ChatCompletionContentPartImageParam).validate_python
|
_ImageParser = TypeAdapter(ChatCompletionContentPartImageParam).validate_python
|
||||||
_AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python
|
_AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python
|
||||||
@ -954,6 +977,8 @@ MM_PARSER_MAP: dict[
|
|||||||
] = {
|
] = {
|
||||||
"text":
|
"text":
|
||||||
lambda part: _TextParser(part).get("text", None),
|
lambda part: _TextParser(part).get("text", None),
|
||||||
|
"thinking":
|
||||||
|
lambda part: _ThinkParser(part).get("thinking", None),
|
||||||
"input_text":
|
"input_text":
|
||||||
lambda part: _TextParser(part).get("text", None),
|
lambda part: _TextParser(part).get("text", None),
|
||||||
"input_image":
|
"input_image":
|
||||||
@ -1100,7 +1125,7 @@ def _parse_chat_message_content_part(
|
|||||||
"with empty / unparsable content.", part, part_type)
|
"with empty / unparsable content.", part, part_type)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if part_type in ("text", "input_text", "refusal"):
|
if part_type in ("text", "input_text", "refusal", "thinking"):
|
||||||
str_content = cast(str, content)
|
str_content = cast(str, content)
|
||||||
if wrap_dicts:
|
if wrap_dicts:
|
||||||
return {'type': 'text', 'text': str_content}
|
return {'type': 'text', 'text': str_content}
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from .deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
|
|||||||
from .glm4_moe_reasoning_parser import Glm4MoeModelReasoningParser
|
from .glm4_moe_reasoning_parser import Glm4MoeModelReasoningParser
|
||||||
from .granite_reasoning_parser import GraniteReasoningParser
|
from .granite_reasoning_parser import GraniteReasoningParser
|
||||||
from .hunyuan_a13b_reasoning_parser import HunyuanA13BReasoningParser
|
from .hunyuan_a13b_reasoning_parser import HunyuanA13BReasoningParser
|
||||||
|
from .mistral_reasoning_parser import MistralReasoningParser
|
||||||
from .qwen3_reasoning_parser import Qwen3ReasoningParser
|
from .qwen3_reasoning_parser import Qwen3ReasoningParser
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -16,4 +17,5 @@ __all__ = [
|
|||||||
"HunyuanA13BReasoningParser",
|
"HunyuanA13BReasoningParser",
|
||||||
"Qwen3ReasoningParser",
|
"Qwen3ReasoningParser",
|
||||||
"Glm4MoeModelReasoningParser",
|
"Glm4MoeModelReasoningParser",
|
||||||
|
"MistralReasoningParser",
|
||||||
]
|
]
|
||||||
|
|||||||
47
vllm/reasoning/mistral_reasoning_parser.py
Normal file
47
vllm/reasoning/mistral_reasoning_parser.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||||
|
from vllm.reasoning.deepseek_r1_reasoning_parser import (
|
||||||
|
DeepSeekR1ReasoningParser)
|
||||||
|
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ReasoningParserManager.register_module("mistral")
|
||||||
|
class MistralReasoningParser(DeepSeekR1ReasoningParser):
|
||||||
|
"""
|
||||||
|
Reasoning parser for Mistral models.
|
||||||
|
|
||||||
|
The Mistral models uses [THINK]...[/THINK] tokens to denote reasoning
|
||||||
|
text. This parser extracts the reasoning content from the model output.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, tokenizer: MistralTokenizer):
|
||||||
|
if not isinstance(tokenizer, MistralTokenizer):
|
||||||
|
raise ValueError(
|
||||||
|
"The tokenizer must be an instance of MistralTokenizer.")
|
||||||
|
|
||||||
|
ReasoningParser.__init__(self, tokenizer)
|
||||||
|
|
||||||
|
if not self.model_tokenizer:
|
||||||
|
raise ValueError(
|
||||||
|
"The model tokenizer must be passed to the ReasoningParser "
|
||||||
|
"constructor during construction.")
|
||||||
|
|
||||||
|
from mistral_common.tokens.tokenizers.base import SpecialTokens
|
||||||
|
|
||||||
|
self.start_token = SpecialTokens.begin_think
|
||||||
|
self.end_token = SpecialTokens.end_think
|
||||||
|
|
||||||
|
self.start_token_id = tokenizer.tokenizer.get_control_token(
|
||||||
|
self.start_token)
|
||||||
|
self.end_token_id = tokenizer.tokenizer.get_control_token(
|
||||||
|
self.end_token)
|
||||||
|
|
||||||
|
if self.start_token_id is None or self.end_token_id is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Mistral reasoning parser could not locate think start/end "
|
||||||
|
"tokens in the tokenizer!")
|
||||||
@ -145,6 +145,21 @@ def find_tokenizer_file(files: list[str]):
|
|||||||
return matched_files[0]
|
return matched_files[0]
|
||||||
|
|
||||||
|
|
||||||
|
def _aggregate_content(content: list) -> list[dict[str, Any]]:
|
||||||
|
aggregated_content: list[dict[str, Any]] = []
|
||||||
|
for chunk in content:
|
||||||
|
if chunk.get("type"
|
||||||
|
) == "text" and aggregated_content and aggregated_content[
|
||||||
|
-1].get("type") == "text":
|
||||||
|
aggregated_content[-1]["text"] += "\n\n" + chunk.get("text")
|
||||||
|
else:
|
||||||
|
aggregated_content.append(chunk)
|
||||||
|
if len(aggregated_content) == 1 and aggregated_content[0].get(
|
||||||
|
"type") == "text":
|
||||||
|
content = aggregated_content[0]["text"]
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
def make_mistral_chat_completion_request(
|
def make_mistral_chat_completion_request(
|
||||||
messages: list["ChatCompletionMessageParam"],
|
messages: list["ChatCompletionMessageParam"],
|
||||||
tools: Optional[list[dict[str,
|
tools: Optional[list[dict[str,
|
||||||
@ -162,10 +177,10 @@ def make_mistral_chat_completion_request(
|
|||||||
|
|
||||||
# Convert list text content to string
|
# Convert list text content to string
|
||||||
if message.get("role") in ("assistant", "tool"):
|
if message.get("role") in ("assistant", "tool"):
|
||||||
content = message.get("content")
|
content: Any = message.get("content")
|
||||||
if isinstance(content, list):
|
if isinstance(content, list):
|
||||||
content = "\n".join(chunk.get("text") for chunk in content)
|
content = _aggregate_content(content)
|
||||||
message["content"] = content
|
message["content"] = content
|
||||||
|
|
||||||
# The Mistral client, in comparison to the OpenAI client, requires the
|
# The Mistral client, in comparison to the OpenAI client, requires the
|
||||||
# "parameters" dict to be present, even if it's empty.
|
# "parameters" dict to be present, even if it's empty.
|
||||||
@ -465,6 +480,8 @@ class MistralTokenizer(TokenizerBase):
|
|||||||
skip_special_tokens: bool = True,
|
skip_special_tokens: bool = True,
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
from mistral_common.tokens.tokenizers.base import SpecialTokens
|
from mistral_common.tokens.tokenizers.base import SpecialTokens
|
||||||
|
from mistral_common.tokens.tokenizers.instruct import (
|
||||||
|
InstructTokenizerV13)
|
||||||
|
|
||||||
# TODO(Patrick) - potentially allow special tokens to not be skipped
|
# TODO(Patrick) - potentially allow special tokens to not be skipped
|
||||||
assert (
|
assert (
|
||||||
@ -474,10 +491,18 @@ class MistralTokenizer(TokenizerBase):
|
|||||||
assert self.is_tekken or self.is_spm, type(self.tokenizer)
|
assert self.is_tekken or self.is_spm, type(self.tokenizer)
|
||||||
|
|
||||||
if self.is_tekken:
|
if self.is_tekken:
|
||||||
# skip special tokens except tool call
|
# skip special tokens except tool call and think tokens
|
||||||
ids = [
|
non_skip_special_tokens = {
|
||||||
i for i in ids if i > self.tokenizer.num_special_tokens or i ==
|
|
||||||
self.tokenizer.get_control_token(SpecialTokens.tool_calls)
|
self.tokenizer.get_control_token(SpecialTokens.tool_calls)
|
||||||
|
}
|
||||||
|
if isinstance(self.instruct, InstructTokenizerV13):
|
||||||
|
if self.instruct.BEGIN_THINK:
|
||||||
|
non_skip_special_tokens.add(self.instruct.BEGIN_THINK)
|
||||||
|
if self.instruct.END_THINK:
|
||||||
|
non_skip_special_tokens.add(self.instruct.END_THINK)
|
||||||
|
ids = [
|
||||||
|
i for i in ids if i > self.tokenizer.num_special_tokens
|
||||||
|
or i in non_skip_special_tokens
|
||||||
]
|
]
|
||||||
|
|
||||||
tokens = [self.tokenizer.id_to_piece(id) for id in ids]
|
tokens = [self.tokenizer.id_to_piece(id) for id in ids]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user