mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-22 15:35:41 +08:00
[IMPROVEMENT] Change MistralReasoningParser behavior (#30391)
Signed-off-by: juliendenize <julien.denize@mistral.ai> Signed-off-by: Julien Denize <40604584+juliendenize@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
305b168a9f
commit
aa3c250c48
@ -18,47 +18,53 @@ def mistral_tokenizer():
|
||||
return mistral_tokenizer
|
||||
|
||||
|
||||
SIMPLE_REASONING = {
|
||||
INVALID_SIMPLE_REASONING = {
|
||||
"output": "This is a reasoning section[/THINK]This is the rest",
|
||||
"reasoning": "This is a reasoning section",
|
||||
"content": "This is the rest",
|
||||
"is_reasoning_end": True,
|
||||
"reasoning": None,
|
||||
"content": "This is a reasoning sectionThis is the rest",
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
COMPLETE_REASONING = {
|
||||
INVALID_COMPLETE_REASONING = {
|
||||
"output": "This is a reasoning section[/THINK]",
|
||||
"reasoning": "This is a reasoning section",
|
||||
"content": None,
|
||||
"is_reasoning_end": True,
|
||||
"reasoning": None,
|
||||
"content": "This is a reasoning section",
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
NO_CONTENT = {
|
||||
"output": "This is content",
|
||||
"reasoning": "This is content",
|
||||
"output": "[THINK]This is reasoning",
|
||||
"reasoning": "This is reasoning",
|
||||
"content": None,
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
NO_REASONING = {
|
||||
"output": "This is content",
|
||||
"reasoning": None,
|
||||
"content": "This is content",
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
NO_REASONING_STREAMING = {
|
||||
"output": "This is a reasoning section",
|
||||
"reasoning": "This is a reasoning section",
|
||||
"content": None,
|
||||
"reasoning": None,
|
||||
"content": "This is a reasoning section",
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
MULTIPLE_LINES = {
|
||||
INVALID_MULTIPLE_LINES = {
|
||||
"output": "This\nThat[/THINK]This is the rest\nThat",
|
||||
"reasoning": "This\nThat",
|
||||
"content": "This is the rest\nThat",
|
||||
"is_reasoning_end": True,
|
||||
"reasoning": None,
|
||||
"content": "This\nThatThis is the rest\nThat",
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
SHORTEST_REASONING_NO_STREAMING = {
|
||||
"output": "[/THINK]This is the rest",
|
||||
"reasoning": "",
|
||||
"content": "This is the rest",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
SHORTEST_REASONING = {
|
||||
INVALID_SHORTEST_REASONING_NO_STREAMING = {
|
||||
"output": "[/THINK]This is the rest",
|
||||
"reasoning": None,
|
||||
"content": "This is the rest",
|
||||
"is_reasoning_end": True,
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
INVALID_SHORTEST_REASONING = {
|
||||
"output": "[/THINK]This is the rest",
|
||||
"reasoning": None,
|
||||
"content": "This is the rest",
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
REASONING_WITH_THINK = {
|
||||
"output": "[THINK]This is a reasoning section[/THINK]This is the rest",
|
||||
@ -78,17 +84,17 @@ MULTIPLE_LINES_WITH_THINK = {
|
||||
"content": "This is the rest\nThat",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
SHORTEST_REASONING_NO_STREAMING_WITH_THINK = {
|
||||
"output": "[/THINK]This is the rest",
|
||||
"reasoning": "",
|
||||
"content": "This is the rest",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
SHORTEST_REASONING_WITH_THINK = {
|
||||
INVALID_SHORTEST_REASONING_NO_STREAMING_WITH_THINK = {
|
||||
"output": "[/THINK]This is the rest",
|
||||
"reasoning": None,
|
||||
"content": "This is the rest",
|
||||
"is_reasoning_end": True,
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
INVALID_SHORTEST_REASONING_WITH_THINK = {
|
||||
"output": "[/THINK]This is the rest",
|
||||
"reasoning": None,
|
||||
"content": "This is the rest",
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
THINK_NO_END = {
|
||||
"output": "[THINK]This is a reasoning section",
|
||||
@ -98,8 +104,8 @@ THINK_NO_END = {
|
||||
}
|
||||
EMPTY = {
|
||||
"output": "",
|
||||
"reasoning": "",
|
||||
"content": None,
|
||||
"reasoning": None,
|
||||
"content": "",
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
EMPTY_STREAMING = {
|
||||
@ -109,47 +115,48 @@ EMPTY_STREAMING = {
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
NEW_LINE = {
|
||||
"output": "\n[THINK]This is a reasoning section[/THINK]\nThis is the rest",
|
||||
"output": "Before\n[THINK]This is a reasoning section[/THINK]\nThis is the rest",
|
||||
"reasoning": "This is a reasoning section",
|
||||
"content": "\nThis is the rest",
|
||||
"content": "Before\n\nThis is the rest",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
# Streaming cannot handle new lines at the beginning of the output
|
||||
# because we need to support [THINK]...[/THINK] and [/THINK]...
|
||||
# We cannot know if the text before [THINK] is reasoning content
|
||||
# or not.
|
||||
NEW_LINE_STREAMING = {
|
||||
"output": "\n[THINK]This is a reasoning section[/THINK]\nThis is the rest",
|
||||
"reasoning": "\nThis is a reasoning section",
|
||||
"content": "\nThis is the rest",
|
||||
"output": "Before\n[THINK]This is a reasoning section[/THINK]\nThis is the rest",
|
||||
"reasoning": "This is a reasoning section",
|
||||
"content": "Before\n\nThis is the rest",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
|
||||
TEST_CASES = [
|
||||
pytest.param(
|
||||
False,
|
||||
SIMPLE_REASONING,
|
||||
id="simple_reasoning",
|
||||
INVALID_SIMPLE_REASONING,
|
||||
id="invalid_simple_reasoning",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
SIMPLE_REASONING,
|
||||
id="simple_reasoning_streaming",
|
||||
INVALID_SIMPLE_REASONING,
|
||||
id="invalid_simple_reasoning_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
COMPLETE_REASONING,
|
||||
id="complete_reasoning",
|
||||
INVALID_COMPLETE_REASONING,
|
||||
id="invalid_complete_reasoning",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
COMPLETE_REASONING,
|
||||
id="complete_reasoning_streaming",
|
||||
INVALID_COMPLETE_REASONING,
|
||||
id="invalid_complete_reasoning_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
NO_CONTENT,
|
||||
id="no_content_token",
|
||||
id="no_content",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
NO_REASONING,
|
||||
id="no_reasoning",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
@ -158,23 +165,23 @@ TEST_CASES = [
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
MULTIPLE_LINES,
|
||||
id="multiple_lines",
|
||||
INVALID_MULTIPLE_LINES,
|
||||
id="invalid_multiple_lines",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
MULTIPLE_LINES,
|
||||
id="multiple_lines_streaming",
|
||||
INVALID_MULTIPLE_LINES,
|
||||
id="invalid_multiple_lines_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
SHORTEST_REASONING,
|
||||
id="shortest",
|
||||
INVALID_SHORTEST_REASONING,
|
||||
id="invalid_shortest",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
SHORTEST_REASONING_NO_STREAMING,
|
||||
id="shortest_streaming",
|
||||
INVALID_SHORTEST_REASONING_NO_STREAMING,
|
||||
id="invalid_shortest_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
@ -208,13 +215,13 @@ TEST_CASES = [
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
SHORTEST_REASONING_NO_STREAMING_WITH_THINK,
|
||||
id="shortest_with_think",
|
||||
INVALID_SHORTEST_REASONING_NO_STREAMING_WITH_THINK,
|
||||
id="invalid_shortest_with_think",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
SHORTEST_REASONING_WITH_THINK,
|
||||
id="shortest_with_think_streaming",
|
||||
INVALID_SHORTEST_REASONING_WITH_THINK,
|
||||
id="invalid_shortest_with_think_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
@ -316,10 +323,26 @@ def test_mistral_reasoning(
|
||||
|
||||
# Test extract_content
|
||||
if param_dict["content"] is not None:
|
||||
content = parser.extract_content_ids(output_tokens)
|
||||
assert content == mistral_tokenizer.tokenizer.encode(
|
||||
param_dict["content"], bos=False, eos=False
|
||||
# Handle the case where there are tokens outputted before Thinking.
|
||||
# This should not occur if the model is well trained and prompted.
|
||||
if "[THINK]" in param_dict["output"] and not param_dict["output"].startswith(
|
||||
"[THINK]"
|
||||
):
|
||||
before_content = param_dict["output"].split("[THINK]")[0]
|
||||
before_token_ids = mistral_tokenizer.tokenizer.encode(
|
||||
before_content, bos=False, eos=False
|
||||
)
|
||||
left_to_encode = param_dict["content"][len(before_content) :]
|
||||
# Normal situation.
|
||||
else:
|
||||
before_token_ids = []
|
||||
left_to_encode = param_dict["content"]
|
||||
|
||||
content_tokens = parser.extract_content_ids(output_tokens)
|
||||
expected_token_ids = before_token_ids + mistral_tokenizer.tokenizer.encode(
|
||||
left_to_encode, bos=False, eos=False
|
||||
)
|
||||
assert content_tokens == expected_token_ids
|
||||
else:
|
||||
content = parser.extract_content_ids(output_tokens)
|
||||
assert content == []
|
||||
|
||||
@ -3,20 +3,29 @@
|
||||
|
||||
from functools import cached_property
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
ResponsesRequest,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.reasoning import ReasoningParser
|
||||
from vllm.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
|
||||
from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser
|
||||
from vllm.tokenizers import MistralTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MistralReasoningParser(DeepSeekR1ReasoningParser):
|
||||
class MistralReasoningParser(BaseThinkingReasoningParser):
|
||||
"""
|
||||
Reasoning parser for Mistral models.
|
||||
|
||||
The Mistral models uses [THINK]...[/THINK] tokens to denote reasoning
|
||||
The Mistral models uses `[THINK]`...`[/THINK]` tokens to denote reasoning
|
||||
text. This parser extracts the reasoning content from the model output.
|
||||
|
||||
A valid reasoning trace should always start with a `[THINK]` token and end with
|
||||
a `[/THINK]` token.
|
||||
|
||||
If `[THINK]` token is not generated, then this parser only returns content.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: MistralTokenizer, *args, **kwargs):
|
||||
@ -53,3 +62,93 @@ class MistralReasoningParser(DeepSeekR1ReasoningParser):
|
||||
from mistral_common.tokens.tokenizers.base import SpecialTokens
|
||||
|
||||
return SpecialTokens.end_think
|
||||
|
||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||
has_eot_token = False
|
||||
|
||||
for id in input_ids[::-1]:
|
||||
if id == self.start_token_id:
|
||||
# Reasoning ends only if a BOT token is found before a EOT token.
|
||||
return has_eot_token
|
||||
elif id == self.end_token_id:
|
||||
has_eot_token = True
|
||||
return False
|
||||
|
||||
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
|
||||
"""
|
||||
Extract the content
|
||||
"""
|
||||
has_bot_token = False
|
||||
has_eot_token = False
|
||||
bot_token_index = -1
|
||||
eot_token_index = -1
|
||||
# One for loop instead of multiple lookups
|
||||
for i, token_id in enumerate(input_ids):
|
||||
# We filter that we have multiple BOT tokens which should not
|
||||
# happen for a well prompted trained model
|
||||
if token_id == self.start_token_id and not has_bot_token:
|
||||
has_bot_token = True
|
||||
bot_token_index = i
|
||||
elif token_id == self.end_token_id:
|
||||
has_eot_token = True
|
||||
eot_token_index = i
|
||||
break
|
||||
|
||||
# 1. Only BOT has been outputted
|
||||
if has_bot_token and not has_eot_token:
|
||||
# Should be = [] if model is well prompted and trained.
|
||||
return input_ids[:bot_token_index]
|
||||
# 2. Neither BOT or EOT have been outputted
|
||||
elif not has_bot_token and not has_eot_token:
|
||||
return input_ids
|
||||
# 3. Both BOT and EOT have been outputted.
|
||||
elif has_bot_token and has_eot_token:
|
||||
return input_ids[:bot_token_index] + input_ids[eot_token_index + 1 :]
|
||||
# 4. Only EOT has been outputted => this should not have occured for a model
|
||||
# well prompted and trained.
|
||||
else:
|
||||
return input_ids[:eot_token_index] + input_ids[eot_token_index + 1 :]
|
||||
|
||||
def extract_reasoning(
|
||||
self, model_output: str, request: ChatCompletionRequest | ResponsesRequest
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""
|
||||
Extract reasoning content from the model output.
|
||||
"""
|
||||
if not model_output:
|
||||
return (None, "")
|
||||
|
||||
# Check if the start token is present in the model output, remove it
|
||||
# if it is present.
|
||||
prev_bot_token, bot_token, post_bot_token = model_output.partition(
|
||||
self.start_token
|
||||
)
|
||||
|
||||
has_bot_token = bool(bot_token)
|
||||
# Valid EOT tokens should follow BOT token
|
||||
has_valid_eot_token = has_bot_token and self.end_token in post_bot_token
|
||||
|
||||
# 1. If there is BOT token followed by EOT token
|
||||
if has_bot_token and has_valid_eot_token:
|
||||
prev_eot_token, _, post_eot_token = post_bot_token.partition(self.end_token)
|
||||
# If model is well prompted and trained prev_bot_token should be ""
|
||||
content = prev_bot_token + post_eot_token
|
||||
return prev_eot_token, content if content else None
|
||||
# 2. Only BOT token
|
||||
elif has_bot_token:
|
||||
# If model is well prompted and trained prev_bot_token should be ""
|
||||
return post_bot_token, prev_bot_token if prev_bot_token else None
|
||||
# 3. EOT token has been outputted without BOT or neither has been outputted
|
||||
else:
|
||||
has_non_valid_eot_token = self.end_token in prev_bot_token
|
||||
# 3.a EOT token has been outputted without BOT
|
||||
# If model is well prompted and trained `has_non_valid_eot_token` should
|
||||
# be `False` and the parser outputs all tokens as 'content'
|
||||
if has_non_valid_eot_token:
|
||||
prev_eot_token, _, post_eot_token = prev_bot_token.partition(
|
||||
self.end_token
|
||||
)
|
||||
return None, prev_eot_token + post_eot_token
|
||||
# 3.b neither BOT or EOT have been outputted
|
||||
else:
|
||||
return None, prev_bot_token
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user