mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 03:15:00 +08:00
[Feature] Support MiniMax-M1 function calls features (#20297)
Signed-off-by: QscQ <qscqesze@gmail.com> Signed-off-by: qingjun <qingjun@minimaxi.com>
This commit is contained in:
parent
4ff61ababa
commit
363528de27
@ -264,6 +264,15 @@ For Qwen2.5, the chat template in tokenizer_config.json has already included sup
|
|||||||
|
|
||||||
Flags: `--tool-call-parser hermes`
|
Flags: `--tool-call-parser hermes`
|
||||||
|
|
||||||
|
### MiniMax Models (`minimax_m1`)
|
||||||
|
|
||||||
|
Supported models:
|
||||||
|
|
||||||
|
* `MiniMaxAi/MiniMax-M1-40k` (use with <gh-file:examples/tool_chat_template_minimax.jinja>)
|
||||||
|
* `MiniMaxAi/MiniMax-M1-80k` (use with <gh-file:examples/tool_chat_template_minimax.jinja>)
|
||||||
|
|
||||||
|
Flags: `--tool-call-parser minimax --chat-template examples/tool_chat_template_minimax.jinja`
|
||||||
|
|
||||||
### DeepSeek-V3 Models (`deepseek_v3`)
|
### DeepSeek-V3 Models (`deepseek_v3`)
|
||||||
|
|
||||||
Supported models:
|
Supported models:
|
||||||
|
|||||||
91
examples/tool_chat_template_minimax_m1.jinja
Normal file
91
examples/tool_chat_template_minimax_m1.jinja
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
{{ '<begin_of_document>' -}}
|
||||||
|
{%- if custom_tools is defined %}
|
||||||
|
{%- set tools = custom_tools %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- if not tools is defined %}
|
||||||
|
{%- set tools = none %}
|
||||||
|
{%- endif %}
|
||||||
|
|
||||||
|
{#- Extract system message #}
|
||||||
|
{% set ns = namespace(system_prompt='') -%}
|
||||||
|
{%- if messages[0]['role'] == 'system' %}
|
||||||
|
{%- if messages[0]['content'] is string %}
|
||||||
|
{%- set ns.system_prompt = messages[0]['content']|trim %}
|
||||||
|
{%- else %}
|
||||||
|
{%- set ns.system_prompt = messages[0]['content'][0]['text']|trim %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- set messages = messages[1:] %}
|
||||||
|
{%- else %}
|
||||||
|
{%- if tools is not none %}
|
||||||
|
{%- set ns.system_prompt = "You are a helpful assistant created by Minimax based on MiniMax-M1 model." %}
|
||||||
|
{%- else %}
|
||||||
|
{%- set ns.system_prompt = "You are a helpful assistant created by Minimax based on MiniMax-M1 model." %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endif %}
|
||||||
|
|
||||||
|
{#- System message #}
|
||||||
|
{%- if ns.system_prompt != '' %}
|
||||||
|
{{ '<beginning_of_sentence>system ai_setting=assistant\n' + ns.system_prompt + '<end_of_sentence>\n' -}}
|
||||||
|
{%- endif %}
|
||||||
|
|
||||||
|
{#- Tools configuration #}
|
||||||
|
{%- if tools is not none %}
|
||||||
|
{{ '<beginning_of_sentence>system tool_setting=tools\nYou are provided with these tools:\n<tools>\n' -}}
|
||||||
|
{%- for tool in tools %}
|
||||||
|
{{ tool | tojson ~ '\n' -}}
|
||||||
|
{%- endfor %}
|
||||||
|
{{ '</tools>\n\nIf you need to call tools, please respond with <tool_calls></tool_calls> XML tags, and provide tool-name and json-object of arguments, following the format below:\n<tool_calls>\n{"name": <tool-name>, "arguments": <args-json-object>}\n...\n</tool_calls><end_of_sentence>\n' -}}
|
||||||
|
{%- endif %}
|
||||||
|
|
||||||
|
{#- Process messages #}
|
||||||
|
{%- for message in messages %}
|
||||||
|
{%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}
|
||||||
|
{%- if message['role'] == 'user' %}
|
||||||
|
{{ '<beginning_of_sentence>user name=user\n' -}}
|
||||||
|
{%- if message['content'] is string %}
|
||||||
|
{{ message['content']|trim -}}
|
||||||
|
{%- else %}
|
||||||
|
{%- for content in message['content'] %}
|
||||||
|
{%- if content['type'] == 'text' %}
|
||||||
|
{{ content['text']|trim -}}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endfor %}
|
||||||
|
{%- endif %}
|
||||||
|
{{ '<end_of_sentence>\n' -}}
|
||||||
|
{%- elif message['role'] == 'assistant' %}
|
||||||
|
{{ '<beginning_of_sentence>ai name=assistant\n' -}}
|
||||||
|
{%- if message['content'] is string %}
|
||||||
|
{{ message['content']|trim -}}
|
||||||
|
{%- else %}
|
||||||
|
{%- for content in message['content'] | selectattr('type', 'equalto', 'text') %}
|
||||||
|
{{ content['text']|trim -}}
|
||||||
|
{%- endfor %}
|
||||||
|
{%- endif %}
|
||||||
|
{{ '<end_of_sentence>\n' -}}
|
||||||
|
{%- endif %}
|
||||||
|
{%- elif 'tool_calls' in message %}
|
||||||
|
{{ '<beginning_of_sentence>ai name=assistant\n<tool_calls>\n' -}}
|
||||||
|
{%- for tool_call in message.tool_calls %}
|
||||||
|
{{ '{"name": "' + tool_call.function.name + '", "arguments": ' + tool_call.function.arguments | tojson + '}\n' -}}
|
||||||
|
{%- endfor %}
|
||||||
|
{{ '</tool_calls><end_of_sentence>\n' -}}
|
||||||
|
{%- elif message.role == "tool" or message.role == "ipython" %}
|
||||||
|
{{ '<beginning_of_sentence>tool name=tools\n' -}}
|
||||||
|
{%- if message.content is string %}
|
||||||
|
{{ 'tool result: ' + message.content + '\n\n' -}}
|
||||||
|
{%- else %}
|
||||||
|
{%- for content in message['content'] %}
|
||||||
|
{%- if content['type'] == 'text' %}
|
||||||
|
{{ 'tool result: ' + content['text'] + '\n\n' -}}
|
||||||
|
{%- elif content.get('name') %}
|
||||||
|
{{ 'tool name: ' + content['name'] + '\ntool result: ' + content['text'] + '\n\n' -}}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endfor %}
|
||||||
|
{%- endif %}
|
||||||
|
{{ '<end_of_sentence>\n' -}}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endfor %}
|
||||||
|
|
||||||
|
{%- if add_generation_prompt %}
|
||||||
|
{{ '<beginning_of_sentence>ai name=assistant\n' -}}
|
||||||
|
{%- endif %}
|
||||||
371
tests/tool_use/test_minimax_tool_parser.py
Normal file
371
tests/tool_use/test_minimax_tool_parser.py
Normal file
@ -0,0 +1,371 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# ruff: noqa: E501
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall
|
||||||
|
from vllm.entrypoints.openai.tool_parsers import MinimaxToolParser
|
||||||
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
|
||||||
|
# Use a common model that is likely to be available
|
||||||
|
MODEL = "MiniMaxAi/MiniMax-M1-40k"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def minimax_tokenizer():
|
||||||
|
return get_tokenizer(tokenizer_name=MODEL)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def minimax_tool_parser(minimax_tokenizer):
|
||||||
|
return MinimaxToolParser(minimax_tokenizer)
|
||||||
|
|
||||||
|
|
||||||
|
def assert_tool_calls(actual_tool_calls: list[ToolCall],
|
||||||
|
expected_tool_calls: list[ToolCall]):
|
||||||
|
assert len(actual_tool_calls) == len(expected_tool_calls)
|
||||||
|
|
||||||
|
for actual_tool_call, expected_tool_call in zip(actual_tool_calls,
|
||||||
|
expected_tool_calls):
|
||||||
|
assert isinstance(actual_tool_call.id, str)
|
||||||
|
assert len(actual_tool_call.id) > 16
|
||||||
|
|
||||||
|
assert actual_tool_call.type == "function"
|
||||||
|
assert actual_tool_call.function == expected_tool_call.function
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_tool_calls_no_tools(minimax_tool_parser):
|
||||||
|
model_output = "This is a test"
|
||||||
|
extracted_tool_calls = minimax_tool_parser.extract_tool_calls(
|
||||||
|
model_output, request=None) # type: ignore[arg-type]
|
||||||
|
assert not extracted_tool_calls.tools_called
|
||||||
|
assert extracted_tool_calls.tool_calls == []
|
||||||
|
assert extracted_tool_calls.content == model_output
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
ids=[
|
||||||
|
"single_tool_call",
|
||||||
|
"multiple_tool_calls",
|
||||||
|
"tool_call_with_content_before",
|
||||||
|
"tool_call_with_single_line_json",
|
||||||
|
"tool_call_incomplete_tag",
|
||||||
|
],
|
||||||
|
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||||
|
argvalues=[
|
||||||
|
(
|
||||||
|
"""<tool_calls>
|
||||||
|
{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}
|
||||||
|
</tool_calls>""",
|
||||||
|
[
|
||||||
|
ToolCall(function=FunctionCall(
|
||||||
|
name="get_current_weather",
|
||||||
|
arguments=json.dumps({
|
||||||
|
"city": "Dallas",
|
||||||
|
"state": "TX",
|
||||||
|
"unit": "fahrenheit",
|
||||||
|
}),
|
||||||
|
))
|
||||||
|
],
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"""<tool_calls>
|
||||||
|
{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}
|
||||||
|
{"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}
|
||||||
|
</tool_calls>""",
|
||||||
|
[
|
||||||
|
ToolCall(function=FunctionCall(
|
||||||
|
name="get_current_weather",
|
||||||
|
arguments=json.dumps({
|
||||||
|
"city": "Dallas",
|
||||||
|
"state": "TX",
|
||||||
|
"unit": "fahrenheit",
|
||||||
|
}),
|
||||||
|
)),
|
||||||
|
ToolCall(function=FunctionCall(
|
||||||
|
name="get_current_weather",
|
||||||
|
arguments=json.dumps({
|
||||||
|
"city": "Orlando",
|
||||||
|
"state": "FL",
|
||||||
|
"unit": "fahrenheit",
|
||||||
|
}),
|
||||||
|
)),
|
||||||
|
],
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"""I'll help you check the weather. <tool_calls>
|
||||||
|
{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}
|
||||||
|
</tool_calls>""",
|
||||||
|
[
|
||||||
|
ToolCall(function=FunctionCall(
|
||||||
|
name="get_current_weather",
|
||||||
|
arguments=json.dumps({
|
||||||
|
"city": "Seattle",
|
||||||
|
"state": "WA",
|
||||||
|
"unit": "celsius",
|
||||||
|
}),
|
||||||
|
))
|
||||||
|
],
|
||||||
|
"I'll help you check the weather.",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"""<tool_calls>
|
||||||
|
{"name": "get_current_weather", "arguments": {"city": "New York", "state": "NY", "unit": "celsius"}}
|
||||||
|
</tool_calls>""",
|
||||||
|
[
|
||||||
|
ToolCall(function=FunctionCall(
|
||||||
|
name="get_current_weather",
|
||||||
|
arguments=json.dumps({
|
||||||
|
"city": "New York",
|
||||||
|
"state": "NY",
|
||||||
|
"unit": "celsius",
|
||||||
|
}),
|
||||||
|
))
|
||||||
|
],
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"""<tool_calls>
|
||||||
|
{"name": "get_current_weather", "arguments": {"city": "Boston", "state": "MA"}}""",
|
||||||
|
[
|
||||||
|
ToolCall(function=FunctionCall(
|
||||||
|
name="get_current_weather",
|
||||||
|
arguments=json.dumps({
|
||||||
|
"city": "Boston",
|
||||||
|
"state": "MA",
|
||||||
|
}),
|
||||||
|
))
|
||||||
|
],
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_extract_tool_calls(minimax_tool_parser, model_output,
|
||||||
|
expected_tool_calls, expected_content):
|
||||||
|
extracted_tool_calls = minimax_tool_parser.extract_tool_calls(
|
||||||
|
model_output, request=None) # type: ignore[arg-type]
|
||||||
|
assert extracted_tool_calls.tools_called
|
||||||
|
|
||||||
|
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
|
||||||
|
|
||||||
|
assert extracted_tool_calls.content == expected_content
|
||||||
|
|
||||||
|
|
||||||
|
def test_preprocess_model_output_with_thinking_tags(minimax_tool_parser):
|
||||||
|
"""Test that tool calls within thinking tags are removed during preprocessing."""
|
||||||
|
model_output = """<think>Let me think about this. <tool_calls>
|
||||||
|
{"name": "fake_tool", "arguments": {"param": "value"}}
|
||||||
|
</tool_calls> This should be removed.</think>
|
||||||
|
|
||||||
|
I'll help you with that. <tool_calls>
|
||||||
|
{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA"}}
|
||||||
|
</tool_calls>"""
|
||||||
|
|
||||||
|
processed_output = minimax_tool_parser.preprocess_model_output(
|
||||||
|
model_output)
|
||||||
|
|
||||||
|
# The tool call within thinking tags should be removed
|
||||||
|
assert "fake_tool" not in processed_output
|
||||||
|
# But the thinking tag itself should remain
|
||||||
|
assert "<think>" in processed_output
|
||||||
|
assert "</think>" in processed_output
|
||||||
|
# The actual tool call outside thinking tags should remain
|
||||||
|
assert "get_current_weather" in processed_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_tool_calls_with_thinking_tags(minimax_tool_parser):
|
||||||
|
"""Test tool extraction when thinking tags contain tool calls that should be ignored."""
|
||||||
|
model_output = """<think>I should use a tool. <tool_calls>
|
||||||
|
{"name": "ignored_tool", "arguments": {"should": "ignore"}}
|
||||||
|
</tool_calls></think>
|
||||||
|
|
||||||
|
Let me help you with the weather. <tool_calls>
|
||||||
|
{"name": "get_current_weather", "arguments": {"city": "Miami", "state": "FL", "unit": "fahrenheit"}}
|
||||||
|
</tool_calls>"""
|
||||||
|
|
||||||
|
extracted_tool_calls = minimax_tool_parser.extract_tool_calls(
|
||||||
|
model_output, request=None) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
assert extracted_tool_calls.tools_called
|
||||||
|
assert len(extracted_tool_calls.tool_calls) == 1
|
||||||
|
assert extracted_tool_calls.tool_calls[
|
||||||
|
0].function.name == "get_current_weather"
|
||||||
|
|
||||||
|
# Content extraction is based on the position of the first <tool_calls> in the original model_output
|
||||||
|
# Since preprocessing removes tool calls within thinking tags, the actual first <tool_calls> is the external one
|
||||||
|
expected_content = """<think>I should use a tool. <tool_calls>
|
||||||
|
{"name": "ignored_tool", "arguments": {"should": "ignore"}}
|
||||||
|
</tool_calls></think>
|
||||||
|
|
||||||
|
Let me help you with the weather."""
|
||||||
|
assert extracted_tool_calls.content == expected_content
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_tool_calls_invalid_json(minimax_tool_parser):
|
||||||
|
"""Test that invalid JSON in tool calls is handled gracefully."""
|
||||||
|
model_output = """<tool_calls>
|
||||||
|
{"name": "valid_tool", "arguments": {"city": "Seattle"}}
|
||||||
|
{invalid json here}
|
||||||
|
{"name": "another_valid_tool", "arguments": {"param": "value"}}
|
||||||
|
</tool_calls>"""
|
||||||
|
|
||||||
|
extracted_tool_calls = minimax_tool_parser.extract_tool_calls(
|
||||||
|
model_output, request=None) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
assert extracted_tool_calls.tools_called
|
||||||
|
# Should extract only the valid JSON tool calls
|
||||||
|
assert len(extracted_tool_calls.tool_calls) == 2
|
||||||
|
assert extracted_tool_calls.tool_calls[0].function.name == "valid_tool"
|
||||||
|
assert extracted_tool_calls.tool_calls[
|
||||||
|
1].function.name == "another_valid_tool"
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_tool_calls_missing_name_or_arguments(minimax_tool_parser):
|
||||||
|
"""Test that tool calls missing name or arguments are filtered out."""
|
||||||
|
model_output = """<tool_calls>
|
||||||
|
{"name": "valid_tool", "arguments": {"city": "Seattle"}}
|
||||||
|
{"name": "missing_args"}
|
||||||
|
{"arguments": {"city": "Portland"}}
|
||||||
|
{"name": "another_valid_tool", "arguments": {"param": "value"}}
|
||||||
|
</tool_calls>"""
|
||||||
|
|
||||||
|
extracted_tool_calls = minimax_tool_parser.extract_tool_calls(
|
||||||
|
model_output, request=None) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
assert extracted_tool_calls.tools_called
|
||||||
|
# Should extract only the valid tool calls with both name and arguments
|
||||||
|
assert len(extracted_tool_calls.tool_calls) == 2
|
||||||
|
assert extracted_tool_calls.tool_calls[0].function.name == "valid_tool"
|
||||||
|
assert extracted_tool_calls.tool_calls[
|
||||||
|
1].function.name == "another_valid_tool"
|
||||||
|
|
||||||
|
|
||||||
|
def test_streaming_basic_functionality(minimax_tool_parser):
|
||||||
|
"""Test basic streaming functionality."""
|
||||||
|
# Reset streaming state
|
||||||
|
minimax_tool_parser.current_tool_name_sent = False
|
||||||
|
minimax_tool_parser.prev_tool_call_arr = []
|
||||||
|
minimax_tool_parser.current_tool_id = -1
|
||||||
|
minimax_tool_parser.streamed_args_for_tool = []
|
||||||
|
|
||||||
|
# Test with a simple tool call
|
||||||
|
current_text = """<tool_calls>
|
||||||
|
{"name": "get_current_weather", "arguments": {"city": "Seattle"}}
|
||||||
|
</tool_calls>"""
|
||||||
|
|
||||||
|
# First call should handle the initial setup
|
||||||
|
result = minimax_tool_parser.extract_tool_calls_streaming(
|
||||||
|
previous_text="",
|
||||||
|
current_text=current_text,
|
||||||
|
delta_text="</tool_calls>",
|
||||||
|
previous_token_ids=[],
|
||||||
|
current_token_ids=[],
|
||||||
|
delta_token_ids=[],
|
||||||
|
request=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# The result might be None or contain tool call information
|
||||||
|
# This depends on the internal state management
|
||||||
|
if result is not None and hasattr(result,
|
||||||
|
'tool_calls') and result.tool_calls:
|
||||||
|
assert len(result.tool_calls) >= 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_streaming_with_content_before_tool_calls(minimax_tool_parser):
|
||||||
|
"""Test streaming when there's content before tool calls."""
|
||||||
|
# Reset streaming state
|
||||||
|
minimax_tool_parser.current_tool_name_sent = False
|
||||||
|
minimax_tool_parser.prev_tool_call_arr = []
|
||||||
|
minimax_tool_parser.current_tool_id = -1
|
||||||
|
minimax_tool_parser.streamed_args_for_tool = []
|
||||||
|
|
||||||
|
current_text = "I'll help you with that. <tool_calls>"
|
||||||
|
|
||||||
|
# When there's content before tool calls, it should be returned as content
|
||||||
|
result = minimax_tool_parser.extract_tool_calls_streaming(
|
||||||
|
previous_text="I'll help you",
|
||||||
|
current_text=current_text,
|
||||||
|
delta_text=" with that. <tool_calls>",
|
||||||
|
previous_token_ids=[],
|
||||||
|
current_token_ids=[],
|
||||||
|
delta_token_ids=[],
|
||||||
|
request=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result is not None and hasattr(result, 'content'):
|
||||||
|
# Should contain some content
|
||||||
|
assert result.content is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_streaming_no_tool_calls(minimax_tool_parser):
|
||||||
|
"""Test streaming when there are no tool calls."""
|
||||||
|
current_text = "This is just regular text without any tool calls."
|
||||||
|
|
||||||
|
result = minimax_tool_parser.extract_tool_calls_streaming(
|
||||||
|
previous_text="This is just regular text",
|
||||||
|
current_text=current_text,
|
||||||
|
delta_text=" without any tool calls.",
|
||||||
|
previous_token_ids=[],
|
||||||
|
current_token_ids=[],
|
||||||
|
delta_token_ids=[],
|
||||||
|
request=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should return the delta text as content
|
||||||
|
assert result is not None
|
||||||
|
assert hasattr(result, 'content')
|
||||||
|
assert result.content == " without any tool calls."
|
||||||
|
|
||||||
|
|
||||||
|
def test_streaming_with_thinking_tags(minimax_tool_parser):
|
||||||
|
"""Test streaming with thinking tags that contain tool calls."""
|
||||||
|
# Reset streaming state
|
||||||
|
minimax_tool_parser.current_tool_name_sent = False
|
||||||
|
minimax_tool_parser.prev_tool_call_arr = []
|
||||||
|
minimax_tool_parser.current_tool_id = -1
|
||||||
|
minimax_tool_parser.streamed_args_for_tool = []
|
||||||
|
|
||||||
|
current_text = """<think><tool_calls>{"name": "ignored", "arguments": {}}</tool_calls></think><tool_calls>{"name": "real_tool", "arguments": {"param": "value"}}</tool_calls>"""
|
||||||
|
|
||||||
|
result = minimax_tool_parser.extract_tool_calls_streaming(
|
||||||
|
previous_text="",
|
||||||
|
current_text=current_text,
|
||||||
|
delta_text=current_text,
|
||||||
|
previous_token_ids=[],
|
||||||
|
current_token_ids=[],
|
||||||
|
delta_token_ids=[],
|
||||||
|
request=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# The preprocessing should remove tool calls from thinking tags
|
||||||
|
# and only process the real tool call
|
||||||
|
if result is not None and hasattr(result,
|
||||||
|
'tool_calls') and result.tool_calls:
|
||||||
|
for tool_call in result.tool_calls:
|
||||||
|
assert tool_call.function.name != "ignored"
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_tool_calls_multiline_json_not_supported(minimax_tool_parser):
|
||||||
|
"""Test that multiline JSON in tool calls is not currently supported."""
|
||||||
|
model_output = """<tool_calls>
|
||||||
|
{
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"arguments": {
|
||||||
|
"city": "New York",
|
||||||
|
"state": "NY",
|
||||||
|
"unit": "celsius"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
</tool_calls>"""
|
||||||
|
|
||||||
|
extracted_tool_calls = minimax_tool_parser.extract_tool_calls(
|
||||||
|
model_output, request=None) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
# Multiline JSON is currently not supported, should return no tools called
|
||||||
|
assert not extracted_tool_calls.tools_called
|
||||||
|
assert extracted_tool_calls.tool_calls == []
|
||||||
|
assert extracted_tool_calls.content is None
|
||||||
@ -10,6 +10,7 @@ from .internlm2_tool_parser import Internlm2ToolParser
|
|||||||
from .jamba_tool_parser import JambaToolParser
|
from .jamba_tool_parser import JambaToolParser
|
||||||
from .llama4_pythonic_tool_parser import Llama4PythonicToolParser
|
from .llama4_pythonic_tool_parser import Llama4PythonicToolParser
|
||||||
from .llama_tool_parser import Llama3JsonToolParser
|
from .llama_tool_parser import Llama3JsonToolParser
|
||||||
|
from .minimax_tool_parser import MinimaxToolParser
|
||||||
from .mistral_tool_parser import MistralToolParser
|
from .mistral_tool_parser import MistralToolParser
|
||||||
from .phi4mini_tool_parser import Phi4MiniJsonToolParser
|
from .phi4mini_tool_parser import Phi4MiniJsonToolParser
|
||||||
from .pythonic_tool_parser import PythonicToolParser
|
from .pythonic_tool_parser import PythonicToolParser
|
||||||
@ -20,5 +21,5 @@ __all__ = [
|
|||||||
"GraniteToolParser", "Hermes2ProToolParser", "MistralToolParser",
|
"GraniteToolParser", "Hermes2ProToolParser", "MistralToolParser",
|
||||||
"Internlm2ToolParser", "Llama3JsonToolParser", "JambaToolParser",
|
"Internlm2ToolParser", "Llama3JsonToolParser", "JambaToolParser",
|
||||||
"Llama4PythonicToolParser", "PythonicToolParser", "Phi4MiniJsonToolParser",
|
"Llama4PythonicToolParser", "PythonicToolParser", "Phi4MiniJsonToolParser",
|
||||||
"DeepSeekV3ToolParser", "xLAMToolParser"
|
"DeepSeekV3ToolParser", "xLAMToolParser", "MinimaxToolParser"
|
||||||
]
|
]
|
||||||
|
|||||||
369
vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py
Normal file
369
vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py
Normal file
@ -0,0 +1,369 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import json
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import partial_json_parser
|
||||||
|
import regex as re
|
||||||
|
from partial_json_parser.core.options import Allow
|
||||||
|
|
||||||
|
from vllm.entrypoints.chat_utils import random_tool_call_id
|
||||||
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||||
|
DeltaFunctionCall, DeltaMessage,
|
||||||
|
DeltaToolCall,
|
||||||
|
ExtractedToolCallInformation,
|
||||||
|
FunctionCall, ToolCall)
|
||||||
|
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||||
|
ToolParser, ToolParserManager)
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ToolParserManager.register_module("minimax")
|
||||||
|
class MinimaxToolParser(ToolParser):
|
||||||
|
|
||||||
|
def __init__(self, tokenizer: AnyTokenizer):
|
||||||
|
super().__init__(tokenizer)
|
||||||
|
|
||||||
|
self.current_tool_name_sent: bool = False
|
||||||
|
self.prev_tool_call_arr: list[dict] = []
|
||||||
|
self.current_tool_id: int = -1
|
||||||
|
self.streamed_args_for_tool: list[str] = []
|
||||||
|
|
||||||
|
self.tool_call_start_token: str = "<tool_calls>"
|
||||||
|
self.tool_call_end_token: str = "</tool_calls>"
|
||||||
|
|
||||||
|
self.tool_call_regex = re.compile(
|
||||||
|
r"<tool_calls>(.*?)</tool_calls>|<tool_calls>(.*)", re.DOTALL)
|
||||||
|
|
||||||
|
# Add regex pattern for thinking tag
|
||||||
|
self.thinking_tag_pattern = r"<think>(.*?)</think>"
|
||||||
|
|
||||||
|
if not self.model_tokenizer:
|
||||||
|
raise ValueError(
|
||||||
|
"The model tokenizer must be passed to the ToolParser "
|
||||||
|
"constructor during construction.")
|
||||||
|
|
||||||
|
self.tool_call_start_token_id = self.vocab.get(
|
||||||
|
self.tool_call_start_token)
|
||||||
|
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
|
||||||
|
|
||||||
|
if (self.tool_call_start_token_id is None
|
||||||
|
or self.tool_call_end_token_id is None):
|
||||||
|
logger.warning(
|
||||||
|
"Minimax Tool parser could not locate tool call start/end "
|
||||||
|
"tokens in the tokenizer. Falling back to string matching.")
|
||||||
|
|
||||||
|
def preprocess_model_output(self, model_output: str) -> str:
|
||||||
|
"""
|
||||||
|
Remove tool calls from within thinking tags to avoid processing them.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def remove_tool_calls_from_think(match):
|
||||||
|
think_content = match.group(1)
|
||||||
|
# Remove tool_calls from within the think tag
|
||||||
|
cleaned_content = re.sub(r"<tool_calls>.*?</tool_calls>",
|
||||||
|
"",
|
||||||
|
think_content,
|
||||||
|
flags=re.DOTALL)
|
||||||
|
return f"<think>{cleaned_content}</think>"
|
||||||
|
|
||||||
|
# Process thinking tags and remove tool_calls from within them
|
||||||
|
processed_output = re.sub(self.thinking_tag_pattern,
|
||||||
|
remove_tool_calls_from_think,
|
||||||
|
model_output,
|
||||||
|
flags=re.DOTALL)
|
||||||
|
|
||||||
|
return processed_output
|
||||||
|
|
||||||
|
def extract_tool_calls(
|
||||||
|
self,
|
||||||
|
model_output: str,
|
||||||
|
request: ChatCompletionRequest,
|
||||||
|
) -> ExtractedToolCallInformation:
|
||||||
|
|
||||||
|
# Preprocess to remove tool calls from thinking tags
|
||||||
|
processed_output = self.preprocess_model_output(model_output)
|
||||||
|
|
||||||
|
if self.tool_call_start_token not in processed_output:
|
||||||
|
return ExtractedToolCallInformation(tools_called=False,
|
||||||
|
tool_calls=[],
|
||||||
|
content=model_output)
|
||||||
|
|
||||||
|
try:
|
||||||
|
function_call_tuples = (
|
||||||
|
self.tool_call_regex.findall(processed_output))
|
||||||
|
|
||||||
|
raw_function_calls = []
|
||||||
|
for match in function_call_tuples:
|
||||||
|
tool_call_content = match[0] if match[0] else match[1]
|
||||||
|
if tool_call_content.strip():
|
||||||
|
lines = tool_call_content.strip().split('\n')
|
||||||
|
for line in lines:
|
||||||
|
line = line.strip()
|
||||||
|
if line and line.startswith('{') and line.endswith(
|
||||||
|
'}'):
|
||||||
|
try:
|
||||||
|
parsed_call = json.loads(line)
|
||||||
|
raw_function_calls.append(parsed_call)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
tool_calls = []
|
||||||
|
for function_call in raw_function_calls:
|
||||||
|
if "name" in function_call and "arguments" in function_call:
|
||||||
|
tool_calls.append(
|
||||||
|
ToolCall(type="function",
|
||||||
|
function=FunctionCall(
|
||||||
|
name=function_call["name"],
|
||||||
|
arguments=json.dumps(
|
||||||
|
function_call["arguments"],
|
||||||
|
ensure_ascii=False))))
|
||||||
|
|
||||||
|
# Extract content before the first valid tool call
|
||||||
|
# Find the position in processed output, then map back to original
|
||||||
|
processed_pos = processed_output.find(self.tool_call_start_token)
|
||||||
|
if processed_pos != -1:
|
||||||
|
# Get the content before tool calls in processed output
|
||||||
|
processed_content = processed_output[:processed_pos].strip()
|
||||||
|
|
||||||
|
if processed_content:
|
||||||
|
# Find the end of this content in the original output
|
||||||
|
# Look for the last non-empty line of processed content
|
||||||
|
lines = processed_content.split('\n')
|
||||||
|
for line in reversed(lines):
|
||||||
|
line = line.strip()
|
||||||
|
if line:
|
||||||
|
# Find this line in original output
|
||||||
|
pos = model_output.find(line)
|
||||||
|
if pos != -1:
|
||||||
|
content = model_output[:pos + len(line)]
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
content = ""
|
||||||
|
else:
|
||||||
|
content = ""
|
||||||
|
else:
|
||||||
|
content = model_output
|
||||||
|
|
||||||
|
return ExtractedToolCallInformation(
|
||||||
|
tools_called=len(tool_calls) > 0,
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
content=content.strip() if content.strip() else None)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"An unexpected error occurred during tool call extraction.")
|
||||||
|
return ExtractedToolCallInformation(tools_called=False,
|
||||||
|
tool_calls=[],
|
||||||
|
content=model_output)
|
||||||
|
|
||||||
|
def extract_tool_calls_streaming(
|
||||||
|
self,
|
||||||
|
previous_text: str,
|
||||||
|
current_text: str,
|
||||||
|
delta_text: str,
|
||||||
|
previous_token_ids: Sequence[int],
|
||||||
|
current_token_ids: Sequence[int],
|
||||||
|
delta_token_ids: Sequence[int],
|
||||||
|
request: ChatCompletionRequest,
|
||||||
|
) -> Union[DeltaMessage, None]:
|
||||||
|
logger.debug("delta_text: %s", delta_text)
|
||||||
|
logger.debug("delta_token_ids: %s", delta_token_ids)
|
||||||
|
|
||||||
|
# Preprocess to remove tool calls from thinking tags
|
||||||
|
processed_current_text = self.preprocess_model_output(current_text)
|
||||||
|
|
||||||
|
if self.tool_call_start_token not in processed_current_text:
|
||||||
|
return DeltaMessage(content=delta_text)
|
||||||
|
|
||||||
|
if (self.tool_call_start_token_id is not None
|
||||||
|
and self.tool_call_start_token_id in delta_token_ids
|
||||||
|
and len(delta_token_ids) == 1):
|
||||||
|
return None
|
||||||
|
|
||||||
|
original_tool_call_start_pos = current_text.find(
|
||||||
|
self.tool_call_start_token)
|
||||||
|
if original_tool_call_start_pos > 0:
|
||||||
|
delta_start_pos = len(current_text) - len(delta_text)
|
||||||
|
if delta_start_pos < original_tool_call_start_pos:
|
||||||
|
content_part = delta_text
|
||||||
|
if delta_start_pos + len(
|
||||||
|
delta_text) > original_tool_call_start_pos:
|
||||||
|
content_part = delta_text[:original_tool_call_start_pos -
|
||||||
|
delta_start_pos]
|
||||||
|
if content_part:
|
||||||
|
return DeltaMessage(content=content_part)
|
||||||
|
|
||||||
|
flags = Allow.ALL if self.current_tool_name_sent \
|
||||||
|
else Allow.ALL & ~Allow.STR
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsable_content = processed_current_text.split(
|
||||||
|
self.tool_call_start_token)[-1].split(
|
||||||
|
self.tool_call_end_token)[0]
|
||||||
|
|
||||||
|
tool_call_arr = []
|
||||||
|
if parsable_content.strip():
|
||||||
|
lines = parsable_content.strip().split('\n')
|
||||||
|
for line in lines:
|
||||||
|
line = line.strip()
|
||||||
|
if line and (line.startswith('{') or '"name"' in line):
|
||||||
|
try:
|
||||||
|
if line.endswith('}'):
|
||||||
|
parsed_call = json.loads(line)
|
||||||
|
tool_call_arr.append(parsed_call)
|
||||||
|
else:
|
||||||
|
parsed_call = partial_json_parser.loads(
|
||||||
|
line, flags)
|
||||||
|
if parsed_call and isinstance(
|
||||||
|
parsed_call, dict):
|
||||||
|
tool_call_arr.append(parsed_call)
|
||||||
|
except (json.JSONDecodeError, partial_json_parser.core.
|
||||||
|
exceptions.MalformedJSON):
|
||||||
|
continue
|
||||||
|
|
||||||
|
current_tool_call: dict = tool_call_arr[self.current_tool_id] \
|
||||||
|
if len(tool_call_arr) > self.current_tool_id >= 0 else {}
|
||||||
|
|
||||||
|
if len(tool_call_arr) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Starting a new tool in the array
|
||||||
|
elif (len(tool_call_arr) > 0
|
||||||
|
and len(tool_call_arr) > self.current_tool_id + 1):
|
||||||
|
|
||||||
|
# Handle any missed arguments from previous tool
|
||||||
|
if self.current_tool_id >= 0 and self.current_tool_id < len(
|
||||||
|
self.prev_tool_call_arr):
|
||||||
|
prev_tool_call = self.prev_tool_call_arr[
|
||||||
|
self.current_tool_id]
|
||||||
|
diff_arguments = prev_tool_call.get("arguments")
|
||||||
|
|
||||||
|
if diff_arguments:
|
||||||
|
diff_arguments_json = json.dumps(diff_arguments,
|
||||||
|
ensure_ascii=False)
|
||||||
|
already_streamed = self.streamed_args_for_tool[
|
||||||
|
self.
|
||||||
|
current_tool_id] if self.current_tool_id < len(
|
||||||
|
self.streamed_args_for_tool) else ""
|
||||||
|
|
||||||
|
if diff_arguments_json != already_streamed:
|
||||||
|
diff = diff_arguments_json[len(already_streamed):]
|
||||||
|
delta = DeltaMessage(tool_calls=[
|
||||||
|
DeltaToolCall(index=self.current_tool_id,
|
||||||
|
function=DeltaFunctionCall(
|
||||||
|
arguments=diff).model_dump(
|
||||||
|
exclude_none=True))
|
||||||
|
])
|
||||||
|
if self.current_tool_id < len(
|
||||||
|
self.streamed_args_for_tool):
|
||||||
|
self.streamed_args_for_tool[
|
||||||
|
self.current_tool_id] = diff_arguments_json
|
||||||
|
else:
|
||||||
|
delta = None
|
||||||
|
else:
|
||||||
|
delta = None
|
||||||
|
else:
|
||||||
|
delta = None
|
||||||
|
|
||||||
|
self.current_tool_id = len(tool_call_arr) - 1
|
||||||
|
self.current_tool_name_sent = False
|
||||||
|
self.streamed_args_for_tool.append("")
|
||||||
|
logger.debug("starting on new tool %d", self.current_tool_id)
|
||||||
|
return delta
|
||||||
|
|
||||||
|
# Send tool name if not sent yet
|
||||||
|
if not self.current_tool_name_sent:
|
||||||
|
function_name = current_tool_call.get("name")
|
||||||
|
if function_name:
|
||||||
|
delta = DeltaMessage(tool_calls=[
|
||||||
|
DeltaToolCall(index=self.current_tool_id,
|
||||||
|
type="function",
|
||||||
|
id=random_tool_call_id(),
|
||||||
|
function=DeltaFunctionCall(
|
||||||
|
name=function_name).model_dump(
|
||||||
|
exclude_none=True))
|
||||||
|
])
|
||||||
|
self.current_tool_name_sent = True
|
||||||
|
else:
|
||||||
|
delta = None
|
||||||
|
|
||||||
|
# Stream arguments
|
||||||
|
else:
|
||||||
|
prev_arguments = None
|
||||||
|
if (self.current_tool_id < len(self.prev_tool_call_arr)
|
||||||
|
and self.prev_tool_call_arr[self.current_tool_id]):
|
||||||
|
prev_arguments = self.prev_tool_call_arr[
|
||||||
|
self.current_tool_id].get("arguments")
|
||||||
|
|
||||||
|
cur_arguments = current_tool_call.get("arguments")
|
||||||
|
|
||||||
|
if not cur_arguments and not prev_arguments:
|
||||||
|
delta = None
|
||||||
|
elif not cur_arguments and prev_arguments:
|
||||||
|
logger.error(
|
||||||
|
"Arguments reset mid-call, skipping streaming")
|
||||||
|
delta = None
|
||||||
|
elif cur_arguments and not prev_arguments:
|
||||||
|
cur_arguments_json = json.dumps(cur_arguments,
|
||||||
|
ensure_ascii=False)
|
||||||
|
logger.debug("First tokens in arguments received: %s",
|
||||||
|
cur_arguments_json)
|
||||||
|
|
||||||
|
delta = DeltaMessage(tool_calls=[
|
||||||
|
DeltaToolCall(index=self.current_tool_id,
|
||||||
|
function=DeltaFunctionCall(
|
||||||
|
arguments=cur_arguments_json).
|
||||||
|
model_dump(exclude_none=True))
|
||||||
|
])
|
||||||
|
self.streamed_args_for_tool[
|
||||||
|
self.current_tool_id] = cur_arguments_json
|
||||||
|
|
||||||
|
elif cur_arguments and prev_arguments:
|
||||||
|
cur_args_json = json.dumps(cur_arguments,
|
||||||
|
ensure_ascii=False)
|
||||||
|
prev_args_json = json.dumps(prev_arguments,
|
||||||
|
ensure_ascii=False)
|
||||||
|
|
||||||
|
logger.debug("Searching for diff between \n%s\n%s",
|
||||||
|
cur_args_json, prev_args_json)
|
||||||
|
|
||||||
|
already_streamed = self.streamed_args_for_tool[
|
||||||
|
self.current_tool_id] if self.current_tool_id < len(
|
||||||
|
self.streamed_args_for_tool) else ""
|
||||||
|
|
||||||
|
if cur_args_json.startswith(already_streamed):
|
||||||
|
argument_diff = cur_args_json[len(already_streamed):]
|
||||||
|
elif cur_args_json != already_streamed:
|
||||||
|
argument_diff = cur_args_json
|
||||||
|
self.streamed_args_for_tool[self.current_tool_id] = ""
|
||||||
|
else:
|
||||||
|
argument_diff = ""
|
||||||
|
|
||||||
|
if argument_diff:
|
||||||
|
logger.debug("got arguments diff: %s", argument_diff)
|
||||||
|
delta = DeltaMessage(tool_calls=[
|
||||||
|
DeltaToolCall(index=self.current_tool_id,
|
||||||
|
function=DeltaFunctionCall(
|
||||||
|
arguments=argument_diff).
|
||||||
|
model_dump(exclude_none=True))
|
||||||
|
])
|
||||||
|
self.streamed_args_for_tool[
|
||||||
|
self.current_tool_id] += argument_diff
|
||||||
|
else:
|
||||||
|
delta = None
|
||||||
|
else:
|
||||||
|
delta = None
|
||||||
|
|
||||||
|
self.prev_tool_call_arr = tool_call_arr
|
||||||
|
return delta
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logger.exception("An unexpected error occurred",
|
||||||
|
"during streaming tool call handling.")
|
||||||
|
return None
|
||||||
Loading…
x
Reference in New Issue
Block a user