mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-03 19:55:45 +08:00
[Bugfix]: Fix the streaming output for function calls in the minimax (#22015)
Signed-off-by: QscQ <qscqesze@gmail.com> Signed-off-by: qingjun <qingjun@minimaxi.com>
This commit is contained in:
parent
a00d8b236f
commit
5e9455ae8f
@ -3,10 +3,12 @@
|
|||||||
# ruff: noqa: E501
|
# ruff: noqa: E501
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall
|
from vllm.entrypoints.openai.protocol import (ChatCompletionToolsParam,
|
||||||
|
FunctionCall, ToolCall)
|
||||||
from vllm.entrypoints.openai.tool_parsers import MinimaxToolParser
|
from vllm.entrypoints.openai.tool_parsers import MinimaxToolParser
|
||||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
|
||||||
@ -24,6 +26,57 @@ def minimax_tool_parser(minimax_tokenizer):
|
|||||||
return MinimaxToolParser(minimax_tokenizer)
|
return MinimaxToolParser(minimax_tokenizer)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_tools():
|
||||||
|
return [
|
||||||
|
ChatCompletionToolsParam(type="function",
|
||||||
|
function={
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"city": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city name"
|
||||||
|
},
|
||||||
|
"state": {
|
||||||
|
"type": "string",
|
||||||
|
"description":
|
||||||
|
"The state code"
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
"type": "string",
|
||||||
|
"enum":
|
||||||
|
["fahrenheit", "celsius"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["city", "state"]
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
ChatCompletionToolsParam(type="function",
|
||||||
|
function={
|
||||||
|
"name": "calculate_area",
|
||||||
|
"description":
|
||||||
|
"Calculate area of a shape",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"shape": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"dimensions": {
|
||||||
|
"type": "object"
|
||||||
|
},
|
||||||
|
"precision": {
|
||||||
|
"type": "integer"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def assert_tool_calls(actual_tool_calls: list[ToolCall],
|
def assert_tool_calls(actual_tool_calls: list[ToolCall],
|
||||||
expected_tool_calls: list[ToolCall]):
|
expected_tool_calls: list[ToolCall]):
|
||||||
assert len(actual_tool_calls) == len(expected_tool_calls)
|
assert len(actual_tool_calls) == len(expected_tool_calls)
|
||||||
@ -370,3 +423,794 @@ def test_extract_tool_calls_multiline_json_not_supported(minimax_tool_parser):
|
|||||||
assert not extracted_tool_calls.tools_called
|
assert not extracted_tool_calls.tools_called
|
||||||
assert extracted_tool_calls.tool_calls == []
|
assert extracted_tool_calls.tool_calls == []
|
||||||
assert extracted_tool_calls.content is None
|
assert extracted_tool_calls.content is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_streaming_arguments_incremental_output(minimax_tool_parser):
|
||||||
|
"""Test that streaming arguments are returned incrementally, not cumulatively."""
|
||||||
|
# Reset streaming state
|
||||||
|
minimax_tool_parser.current_tool_name_sent = False
|
||||||
|
minimax_tool_parser.prev_tool_call_arr = []
|
||||||
|
minimax_tool_parser.current_tool_id = -1
|
||||||
|
minimax_tool_parser.streamed_args_for_tool = []
|
||||||
|
|
||||||
|
# Simulate progressive tool call building
|
||||||
|
stages = [
|
||||||
|
# Stage 1: Function name complete
|
||||||
|
'<tool_calls>\n{"name": "get_current_weather", "arguments": ',
|
||||||
|
# Stage 2: Arguments object starts with first key
|
||||||
|
'<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": ',
|
||||||
|
# Stage 3: First parameter value added
|
||||||
|
'<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle"',
|
||||||
|
# Stage 4: Second parameter added
|
||||||
|
'<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA"',
|
||||||
|
# Stage 5: Third parameter added, arguments complete
|
||||||
|
'<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}',
|
||||||
|
# Stage 6: Tool calls closed
|
||||||
|
'<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}\n</tool',
|
||||||
|
'<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}\n</tool_calls>'
|
||||||
|
]
|
||||||
|
|
||||||
|
function_name_sent = False
|
||||||
|
previous_args_content = ""
|
||||||
|
|
||||||
|
for i, current_text in enumerate(stages):
|
||||||
|
previous_text = stages[i - 1] if i > 0 else ""
|
||||||
|
delta_text = current_text[len(previous_text
|
||||||
|
):] if i > 0 else current_text
|
||||||
|
|
||||||
|
result = minimax_tool_parser.extract_tool_calls_streaming(
|
||||||
|
previous_text=previous_text,
|
||||||
|
current_text=current_text,
|
||||||
|
delta_text=delta_text,
|
||||||
|
previous_token_ids=[],
|
||||||
|
current_token_ids=[],
|
||||||
|
delta_token_ids=[],
|
||||||
|
request=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Stage {i}: Current text: {repr(current_text)}")
|
||||||
|
print(f"Stage {i}: Delta text: {repr(delta_text)}")
|
||||||
|
|
||||||
|
if result is not None and hasattr(result,
|
||||||
|
'tool_calls') and result.tool_calls:
|
||||||
|
tool_call = result.tool_calls[0]
|
||||||
|
|
||||||
|
# Check if function name is sent (should happen only once)
|
||||||
|
if tool_call.function and tool_call.function.name:
|
||||||
|
assert tool_call.function.name == "get_current_weather"
|
||||||
|
function_name_sent = True
|
||||||
|
print(
|
||||||
|
f"Stage {i}: Function name sent: {tool_call.function.name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if arguments are sent incrementally
|
||||||
|
if tool_call.function and tool_call.function.arguments:
|
||||||
|
args_fragment = tool_call.function.arguments
|
||||||
|
print(
|
||||||
|
f"Stage {i}: Got arguments fragment: {repr(args_fragment)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# For incremental output, each fragment should be new content only
|
||||||
|
# The fragment should not contain all previous content
|
||||||
|
if i >= 2 and previous_args_content: # After we start getting arguments
|
||||||
|
# The new fragment should not be identical to or contain all previous content
|
||||||
|
assert args_fragment != previous_args_content, f"Fragment should be incremental, not cumulative: {args_fragment}"
|
||||||
|
|
||||||
|
# If this is truly incremental, the fragment should be relatively small
|
||||||
|
# compared to the complete arguments so far
|
||||||
|
if len(args_fragment) > len(previous_args_content):
|
||||||
|
print(
|
||||||
|
"Warning: Fragment seems cumulative rather than incremental"
|
||||||
|
)
|
||||||
|
|
||||||
|
previous_args_content = args_fragment
|
||||||
|
|
||||||
|
# Verify function name was sent at least once
|
||||||
|
assert function_name_sent, "Function name should have been sent"
|
||||||
|
|
||||||
|
|
||||||
|
def test_streaming_arguments_delta_only(minimax_tool_parser):
|
||||||
|
"""Test that each streaming call returns only the delta (new part) of arguments."""
|
||||||
|
# Reset streaming state
|
||||||
|
minimax_tool_parser.current_tool_name_sent = False
|
||||||
|
minimax_tool_parser.prev_tool_call_arr = []
|
||||||
|
minimax_tool_parser.current_tool_id = -1
|
||||||
|
minimax_tool_parser.streamed_args_for_tool = []
|
||||||
|
|
||||||
|
# Simulate two consecutive calls with growing arguments
|
||||||
|
call1_text = '<tool_calls>\n{"name": "test_tool", "arguments": {"param1": "value1"}}'
|
||||||
|
call2_text = '<tool_calls>\n{"name": "test_tool", "arguments": {"param1": "value1", "param2": "value2"}}'
|
||||||
|
|
||||||
|
print(f"Call 1 text: {repr(call1_text)}")
|
||||||
|
print(f"Call 2 text: {repr(call2_text)}")
|
||||||
|
|
||||||
|
# First call - should get the function name and initial arguments
|
||||||
|
result1 = minimax_tool_parser.extract_tool_calls_streaming(
|
||||||
|
previous_text="",
|
||||||
|
current_text=call1_text,
|
||||||
|
delta_text=call1_text,
|
||||||
|
previous_token_ids=[],
|
||||||
|
current_token_ids=[],
|
||||||
|
delta_token_ids=[],
|
||||||
|
request=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Result 1: {result1}")
|
||||||
|
if result1 and hasattr(result1, 'tool_calls') and result1.tool_calls:
|
||||||
|
for i, tc in enumerate(result1.tool_calls):
|
||||||
|
print(f" Tool call {i}: {tc}")
|
||||||
|
|
||||||
|
# Second call - should only get the delta (new part) of arguments
|
||||||
|
result2 = minimax_tool_parser.extract_tool_calls_streaming(
|
||||||
|
previous_text=call1_text,
|
||||||
|
current_text=call2_text,
|
||||||
|
delta_text=', "param2": "value2"}',
|
||||||
|
previous_token_ids=[],
|
||||||
|
current_token_ids=[],
|
||||||
|
delta_token_ids=[],
|
||||||
|
request=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Result 2: {result2}")
|
||||||
|
if result2 and hasattr(result2, 'tool_calls') and result2.tool_calls:
|
||||||
|
for i, tc in enumerate(result2.tool_calls):
|
||||||
|
print(f" Tool call {i}: {tc}")
|
||||||
|
|
||||||
|
# Verify the second call only returns the delta
|
||||||
|
if result2 is not None and hasattr(result2,
|
||||||
|
'tool_calls') and result2.tool_calls:
|
||||||
|
tool_call = result2.tool_calls[0]
|
||||||
|
if tool_call.function and tool_call.function.arguments:
|
||||||
|
args_delta = tool_call.function.arguments
|
||||||
|
print(f"Arguments delta from second call: {repr(args_delta)}")
|
||||||
|
|
||||||
|
# Should only contain the new part, not the full arguments
|
||||||
|
# The delta should be something like ', "param2": "value2"}' or just '"param2": "value2"'
|
||||||
|
assert ', "param2": "value2"}' in args_delta or '"param2": "value2"' in args_delta, f"Expected delta containing param2, got: {args_delta}"
|
||||||
|
|
||||||
|
# Should NOT contain the previous parameter data
|
||||||
|
assert '"param1": "value1"' not in args_delta, f"Arguments delta should not contain previous data: {args_delta}"
|
||||||
|
|
||||||
|
# The delta should be relatively short (incremental, not cumulative)
|
||||||
|
expected_max_length = len(
|
||||||
|
', "param2": "value2"}') + 10 # Some tolerance
|
||||||
|
assert len(
|
||||||
|
args_delta
|
||||||
|
) <= expected_max_length, f"Delta seems too long (possibly cumulative): {args_delta}"
|
||||||
|
|
||||||
|
print("✓ Delta validation passed")
|
||||||
|
else:
|
||||||
|
print("No arguments in result2 tool call")
|
||||||
|
else:
|
||||||
|
print("No tool calls in result2 or result2 is None")
|
||||||
|
# This might be acceptable if no incremental update is needed
|
||||||
|
# But let's at least verify that result1 had some content
|
||||||
|
assert result1 is not None, "At least the first call should return something"
|
||||||
|
|
||||||
|
|
||||||
|
def test_streaming_openai_compatibility(minimax_tool_parser):
|
||||||
|
"""Test that streaming behavior with buffering works correctly."""
|
||||||
|
# Reset streaming state
|
||||||
|
minimax_tool_parser.current_tool_name_sent = False
|
||||||
|
minimax_tool_parser.prev_tool_call_arr = []
|
||||||
|
minimax_tool_parser.current_tool_id = -1
|
||||||
|
minimax_tool_parser.streamed_args_for_tool = []
|
||||||
|
# Reset buffering state
|
||||||
|
minimax_tool_parser.pending_buffer = ""
|
||||||
|
minimax_tool_parser.in_thinking_tag = False
|
||||||
|
minimax_tool_parser.thinking_depth = 0
|
||||||
|
|
||||||
|
# Test scenario: simple buffering without complex tool call context
|
||||||
|
test_cases: list[dict[str, Any]] = [
|
||||||
|
{
|
||||||
|
'stage': 'Token: <',
|
||||||
|
'previous': '',
|
||||||
|
'current': '<',
|
||||||
|
'delta': '<',
|
||||||
|
'expected_content': None, # Should be buffered
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'stage': 'Token: tool_calls>',
|
||||||
|
'previous': '<',
|
||||||
|
'current': '<tool_calls>',
|
||||||
|
'delta': 'tool_calls>',
|
||||||
|
'expected_content': None, # Complete tag, should not output
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'stage': 'Regular content',
|
||||||
|
'previous': 'Hello',
|
||||||
|
'current': 'Hello world',
|
||||||
|
'delta': ' world',
|
||||||
|
'expected_content': ' world', # Normal content should pass through
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'stage': 'Content with end tag start',
|
||||||
|
'previous': 'Text',
|
||||||
|
'current': 'Text content</tool_',
|
||||||
|
'delta': ' content</tool_',
|
||||||
|
'expected_content':
|
||||||
|
' content', # Content part output, </tool_ buffered
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'stage': 'Complete end tag',
|
||||||
|
'previous': 'Text content</tool_',
|
||||||
|
'current': 'Text content</tool_calls>',
|
||||||
|
'delta': 'calls>',
|
||||||
|
'expected_content': None, # Complete close tag, should not output
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
for i, test_case in enumerate(test_cases):
|
||||||
|
print(f"\n--- Stage {i}: {test_case['stage']} ---")
|
||||||
|
print(f"Previous: {repr(test_case['previous'])}")
|
||||||
|
print(f"Current: {repr(test_case['current'])}")
|
||||||
|
print(f"Delta: {repr(test_case['delta'])}")
|
||||||
|
|
||||||
|
result = minimax_tool_parser.extract_tool_calls_streaming(
|
||||||
|
previous_text=test_case['previous'],
|
||||||
|
current_text=test_case['current'],
|
||||||
|
delta_text=test_case['delta'],
|
||||||
|
previous_token_ids=[],
|
||||||
|
current_token_ids=[],
|
||||||
|
delta_token_ids=[],
|
||||||
|
request=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Result: {result}")
|
||||||
|
|
||||||
|
# Check expected content
|
||||||
|
if test_case['expected_content'] is None:
|
||||||
|
assert result is None or not getattr(result, 'content', None), \
|
||||||
|
f"Stage {i}: Expected no content, got {result}"
|
||||||
|
print("✓ No content output as expected")
|
||||||
|
else:
|
||||||
|
assert result is not None and hasattr(result, 'content'), \
|
||||||
|
f"Stage {i}: Expected content, got {result}"
|
||||||
|
assert result.content == test_case['expected_content'], \
|
||||||
|
f"Stage {i}: Expected content {test_case['expected_content']}, got {result.content}"
|
||||||
|
print(f"✓ Content matches: {repr(result.content)}")
|
||||||
|
|
||||||
|
print("✓ Streaming test with buffering completed successfully")
|
||||||
|
|
||||||
|
|
||||||
|
def test_streaming_thinking_tag_buffering(minimax_tool_parser):
|
||||||
|
"""Test that tool calls within thinking tags are properly handled during streaming."""
|
||||||
|
# Reset streaming state
|
||||||
|
minimax_tool_parser.current_tool_name_sent = False
|
||||||
|
minimax_tool_parser.prev_tool_call_arr = []
|
||||||
|
minimax_tool_parser.current_tool_id = -1
|
||||||
|
minimax_tool_parser.streamed_args_for_tool = []
|
||||||
|
# Reset buffering state
|
||||||
|
minimax_tool_parser.pending_buffer = ""
|
||||||
|
minimax_tool_parser.in_thinking_tag = False
|
||||||
|
minimax_tool_parser.thinking_depth = 0
|
||||||
|
|
||||||
|
# Test scenario: tool calls within thinking tags should be ignored
|
||||||
|
test_cases: list[dict[str, Any]] = [
|
||||||
|
{
|
||||||
|
'stage': 'Start thinking',
|
||||||
|
'previous': '',
|
||||||
|
'current': '<think>I need to use a tool. <tool_calls>',
|
||||||
|
'delta': '<think>I need to use a tool. <tool_calls>',
|
||||||
|
'expected_content':
|
||||||
|
'<think>I need to use a tool. <tool_calls>', # Should pass through as content
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'stage':
|
||||||
|
'Tool call in thinking',
|
||||||
|
'previous':
|
||||||
|
'<think>I need to use a tool. <tool_calls>',
|
||||||
|
'current':
|
||||||
|
'<think>I need to use a tool. <tool_calls>\n{"name": "ignored_tool", "arguments": {"param": "value"}}\n</tool_calls>',
|
||||||
|
'delta':
|
||||||
|
'\n{"name": "ignored_tool", "arguments": {"param": "value"}}\n</tool_calls>',
|
||||||
|
'expected_content':
|
||||||
|
'\n{"name": "ignored_tool", "arguments": {"param": "value"}}\n</tool_calls>', # </tool_calls> should be preserved in thinking tags
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'stage': 'Real tool call after thinking',
|
||||||
|
'previous':
|
||||||
|
'<think>I need to use a tool. <tool_calls>\n{"name": "ignored_tool", "arguments": {"param": "value"}}\n</tool_calls></think>',
|
||||||
|
'current':
|
||||||
|
'<think>I need to use a tool. <tool_calls>\n{"name": "ignored_tool", "arguments": {"param": "value"}}\n</tool_calls></think>\n<tool_calls>',
|
||||||
|
'delta': '\n<tool_calls>',
|
||||||
|
'expected_content':
|
||||||
|
'\n', # Should output '\n' and suppress <tool_calls>
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
for i, test_case in enumerate(test_cases):
|
||||||
|
print(f"\n--- Stage {i}: {test_case['stage']} ---")
|
||||||
|
print(f"Previous: {repr(test_case['previous'])}")
|
||||||
|
print(f"Current: {repr(test_case['current'])}")
|
||||||
|
print(f"Delta: {repr(test_case['delta'])}")
|
||||||
|
|
||||||
|
result = minimax_tool_parser.extract_tool_calls_streaming(
|
||||||
|
previous_text=test_case['previous'],
|
||||||
|
current_text=test_case['current'],
|
||||||
|
delta_text=test_case['delta'],
|
||||||
|
previous_token_ids=[],
|
||||||
|
current_token_ids=[],
|
||||||
|
delta_token_ids=[],
|
||||||
|
request=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Result: {result}")
|
||||||
|
|
||||||
|
# Check expected content
|
||||||
|
if 'expected_content' in test_case:
|
||||||
|
if test_case['expected_content'] is None:
|
||||||
|
assert result is None or not getattr(result, 'content', None), \
|
||||||
|
f"Stage {i}: Expected no content, got {result}"
|
||||||
|
else:
|
||||||
|
assert result is not None and hasattr(result, 'content'), \
|
||||||
|
f"Stage {i}: Expected content, got {result}"
|
||||||
|
assert result.content == test_case['expected_content'], \
|
||||||
|
f"Stage {i}: Expected content {test_case['expected_content']}, got {result.content}"
|
||||||
|
print(f"✓ Content matches: {repr(result.content)}")
|
||||||
|
|
||||||
|
# Check tool calls
|
||||||
|
if test_case.get('expected_tool_call'):
|
||||||
|
assert result is not None and hasattr(result, 'tool_calls') and result.tool_calls, \
|
||||||
|
f"Stage {i}: Expected tool call, got {result}"
|
||||||
|
|
||||||
|
tool_call = result.tool_calls[0]
|
||||||
|
assert tool_call.function.name == "real_tool", \
|
||||||
|
f"Expected real_tool, got {tool_call.function.name}"
|
||||||
|
print(f"✓ Real tool call detected: {tool_call.function.name}")
|
||||||
|
|
||||||
|
print("✓ Thinking tag buffering test completed successfully")
|
||||||
|
|
||||||
|
|
||||||
|
def reset_streaming_state(minimax_tool_parser):
|
||||||
|
"""Helper function to properly reset the streaming state for MinimaxToolParser."""
|
||||||
|
# Reset minimax-specific state
|
||||||
|
minimax_tool_parser._reset_streaming_state()
|
||||||
|
|
||||||
|
# Reset base class state (these should still be reset for compatibility)
|
||||||
|
minimax_tool_parser.prev_tool_call_arr = []
|
||||||
|
minimax_tool_parser.current_tool_id = -1
|
||||||
|
minimax_tool_parser.current_tool_name_sent = False
|
||||||
|
minimax_tool_parser.streamed_args_for_tool = []
|
||||||
|
|
||||||
|
|
||||||
|
def test_streaming_complex_scenario_with_multiple_tools(minimax_tool_parser):
|
||||||
|
"""Test complex streaming scenario: tools inside <think> tags and multiple tool calls in one group."""
|
||||||
|
# Reset streaming state
|
||||||
|
reset_streaming_state(minimax_tool_parser)
|
||||||
|
|
||||||
|
# Complex scenario: tools inside thinking tags and multiple tools in one group
|
||||||
|
test_stages: list[dict[str, Any]] = [
|
||||||
|
{
|
||||||
|
'stage': 'Initial content',
|
||||||
|
'previous': '',
|
||||||
|
'current': 'Let me help you with this task.',
|
||||||
|
'delta': 'Let me help you with this task.',
|
||||||
|
'expected_content': 'Let me help you with this task.',
|
||||||
|
'expected_tool_calls': 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'stage': 'Start thinking tag',
|
||||||
|
'previous': 'Let me help you with this task.',
|
||||||
|
'current':
|
||||||
|
'Let me help you with this task.<think>I need to analyze this situation first.',
|
||||||
|
'delta': '<think>I need to analyze this situation first.',
|
||||||
|
'expected_content':
|
||||||
|
'<think>I need to analyze this situation first.',
|
||||||
|
'expected_tool_calls': 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'stage': 'Tool call inside thinking tag starts',
|
||||||
|
'previous':
|
||||||
|
'Let me help you with this task.<think>I need to analyze this situation first.',
|
||||||
|
'current':
|
||||||
|
'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>',
|
||||||
|
'delta': '<tool_calls>',
|
||||||
|
'expected_content':
|
||||||
|
'<tool_calls>', # Inside thinking tags, tool tags should be preserved as content
|
||||||
|
'expected_tool_calls': 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'stage': 'Complete tool call inside thinking tag',
|
||||||
|
'previous':
|
||||||
|
'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>',
|
||||||
|
'current':
|
||||||
|
'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls>',
|
||||||
|
'delta':
|
||||||
|
'\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls>',
|
||||||
|
'expected_content':
|
||||||
|
'\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls>',
|
||||||
|
'expected_tool_calls':
|
||||||
|
0, # Tools inside thinking tags should be ignored
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'stage': 'End thinking tag',
|
||||||
|
'previous':
|
||||||
|
'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls>',
|
||||||
|
'current':
|
||||||
|
'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>',
|
||||||
|
'delta': '</think>',
|
||||||
|
'expected_content': '</think>',
|
||||||
|
'expected_tool_calls': 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'stage': 'Multiple tools group starts',
|
||||||
|
'previous':
|
||||||
|
'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>',
|
||||||
|
'current':
|
||||||
|
'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>\nNow I need to get weather information and calculate area.<tool_calls>',
|
||||||
|
'delta':
|
||||||
|
'\nNow I need to get weather information and calculate area.<tool_calls>',
|
||||||
|
'expected_content':
|
||||||
|
'\nNow I need to get weather information and calculate area.', # <tool_calls> should be filtered
|
||||||
|
'expected_tool_calls': 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'stage': 'First tool in group',
|
||||||
|
'previous':
|
||||||
|
'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>\nNow I need to get weather information and calculate area.<tool_calls>',
|
||||||
|
'current':
|
||||||
|
'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>\nNow I need to get weather information and calculate area.<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}',
|
||||||
|
'delta':
|
||||||
|
'\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}',
|
||||||
|
'expected_content':
|
||||||
|
None, # No content should be output when tool call is in progress
|
||||||
|
'expected_tool_calls': 1,
|
||||||
|
'expected_tool_name': 'get_current_weather',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'stage': 'Second tool in group',
|
||||||
|
'previous':
|
||||||
|
'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>\nNow I need to get weather information and calculate area.<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}',
|
||||||
|
'current':
|
||||||
|
'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>\nNow I need to get weather information and calculate area.<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}\n{"name": "calculate_area", "arguments": {"shape": "rectangle", "dimensions": {"width": 10, "height": 5}}}',
|
||||||
|
'delta':
|
||||||
|
'\n{"name": "calculate_area", "arguments": {"shape": "rectangle", "dimensions": {"width": 10, "height": 5}}}',
|
||||||
|
'expected_content': None,
|
||||||
|
'expected_tool_calls': 1,
|
||||||
|
'expected_tool_name': 'calculate_area',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'stage': 'Complete tool calls group',
|
||||||
|
'previous':
|
||||||
|
'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>\nNow I need to get weather information and calculate area.<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}\n{"name": "calculate_area", "arguments": {"shape": "rectangle", "dimensions": {"width": 10, "height": 5}}}',
|
||||||
|
'current':
|
||||||
|
'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>\nNow I need to get weather information and calculate area.<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}\n{"name": "calculate_area", "arguments": {"shape": "rectangle", "dimensions": {"width": 10, "height": 5}}}</tool_calls>',
|
||||||
|
'delta': '</tool_calls>',
|
||||||
|
'expected_content': None,
|
||||||
|
'expected_tool_calls': 0,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
tool_calls_count = 0
|
||||||
|
|
||||||
|
for i, test_case in enumerate(test_stages):
|
||||||
|
print(f"\n--- Stage {i}: {test_case['stage']} ---")
|
||||||
|
print(
|
||||||
|
f"Previous: {repr(test_case['previous'][:100])}{'...' if len(test_case['previous']) > 100 else ''}"
|
||||||
|
)
|
||||||
|
print(f"Current: {repr(test_case['current'][-100:])}")
|
||||||
|
print(f"Delta: {repr(test_case['delta'])}")
|
||||||
|
|
||||||
|
result = minimax_tool_parser.extract_tool_calls_streaming(
|
||||||
|
previous_text=test_case['previous'],
|
||||||
|
current_text=test_case['current'],
|
||||||
|
delta_text=test_case['delta'],
|
||||||
|
previous_token_ids=[],
|
||||||
|
current_token_ids=[],
|
||||||
|
delta_token_ids=[],
|
||||||
|
request=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Result: {result}")
|
||||||
|
|
||||||
|
# Check expected content
|
||||||
|
if test_case['expected_content'] is None:
|
||||||
|
assert result is None or not getattr(result, 'content', None), \
|
||||||
|
f"Stage {i}: Expected no content output, got {result}"
|
||||||
|
print("✓ No content output as expected")
|
||||||
|
else:
|
||||||
|
assert result is not None and hasattr(result, 'content'), \
|
||||||
|
f"Stage {i}: Expected content output, got {result}"
|
||||||
|
assert result.content == test_case['expected_content'], \
|
||||||
|
f"Stage {i}: Expected content {repr(test_case['expected_content'])}, got {repr(result.content)}"
|
||||||
|
print(f"✓ Content matches: {repr(result.content)}")
|
||||||
|
|
||||||
|
# Check tool calls
|
||||||
|
expected_tool_calls = test_case['expected_tool_calls']
|
||||||
|
actual_tool_calls = len(result.tool_calls) if result and hasattr(
|
||||||
|
result, 'tool_calls') and result.tool_calls else 0
|
||||||
|
|
||||||
|
if expected_tool_calls > 0:
|
||||||
|
assert actual_tool_calls >= expected_tool_calls, \
|
||||||
|
f"Stage {i}: Expected at least {expected_tool_calls} tool calls, got {actual_tool_calls}"
|
||||||
|
|
||||||
|
if 'expected_tool_name' in test_case:
|
||||||
|
# Find the tool call with the expected name
|
||||||
|
found_tool_call = None
|
||||||
|
for tool_call in result.tool_calls:
|
||||||
|
if tool_call.function.name == test_case[
|
||||||
|
'expected_tool_name']:
|
||||||
|
found_tool_call = tool_call
|
||||||
|
break
|
||||||
|
|
||||||
|
assert found_tool_call is not None, \
|
||||||
|
f"Stage {i}: Expected tool name {test_case['expected_tool_name']} not found in tool calls: {[tc.function.name for tc in result.tool_calls]}"
|
||||||
|
print(f"✓ Tool call correct: {found_tool_call.function.name}")
|
||||||
|
|
||||||
|
# Ensure tools inside thinking tags are not called
|
||||||
|
assert found_tool_call.function.name != "internal_analysis", \
|
||||||
|
f"Stage {i}: Tool 'internal_analysis' inside thinking tags should not be called"
|
||||||
|
|
||||||
|
tool_calls_count += actual_tool_calls
|
||||||
|
print(f"✓ Detected {actual_tool_calls} tool calls")
|
||||||
|
else:
|
||||||
|
assert actual_tool_calls == 0, \
|
||||||
|
f"Stage {i}: Expected no tool calls, got {actual_tool_calls}"
|
||||||
|
|
||||||
|
# Verify overall results
|
||||||
|
print("\n=== Test Summary ===")
|
||||||
|
print(f"Total tool calls count: {tool_calls_count}")
|
||||||
|
assert tool_calls_count >= 2, f"Expected at least 2 valid tool calls (outside thinking tags), but got {tool_calls_count}"
|
||||||
|
|
||||||
|
print("✓ Complex streaming test completed:")
|
||||||
|
print(" - ✓ Tools inside thinking tags correctly ignored")
|
||||||
|
print(" - ✓ Two tool groups outside thinking tags correctly parsed")
|
||||||
|
print(" - ✓ Content and tool call streaming correctly handled")
|
||||||
|
print(" - ✓ Buffering mechanism works correctly")
|
||||||
|
|
||||||
|
|
||||||
|
def test_streaming_character_by_character_output(minimax_tool_parser):
|
||||||
|
"""Test character-by-character streaming output to simulate real streaming scenarios."""
|
||||||
|
# Reset streaming state
|
||||||
|
reset_streaming_state(minimax_tool_parser)
|
||||||
|
|
||||||
|
# Complete text that will be streamed character by character
|
||||||
|
complete_text = """I'll help you with the weather analysis. <think>Let me think about this. <tool_calls>
|
||||||
|
{"name": "internal_analysis", "arguments": {"type": "thinking"}}
|
||||||
|
</tool_calls>This tool should be ignored.</think>
|
||||||
|
|
||||||
|
Now I'll get the weather information for you. <tool_calls>
|
||||||
|
{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}
|
||||||
|
{"name": "calculate_area", "arguments": {"shape": "rectangle", "dimensions": {"width": 10, "height": 5}}}
|
||||||
|
</tool_calls>Here are the results."""
|
||||||
|
|
||||||
|
print("\n=== Starting character-by-character streaming test ===")
|
||||||
|
print(f"Complete text length: {len(complete_text)} characters")
|
||||||
|
|
||||||
|
# Track the streaming results
|
||||||
|
content_fragments = []
|
||||||
|
tool_calls_detected = []
|
||||||
|
|
||||||
|
# Stream character by character
|
||||||
|
for i in range(1, len(complete_text) + 1):
|
||||||
|
current_text = complete_text[:i]
|
||||||
|
previous_text = complete_text[:i - 1] if i > 1 else ""
|
||||||
|
delta_text = complete_text[i - 1:i]
|
||||||
|
|
||||||
|
# Show progress every 50 characters
|
||||||
|
if i % 50 == 0 or i == len(complete_text):
|
||||||
|
print(f"Progress: {i}/{len(complete_text)} characters")
|
||||||
|
|
||||||
|
# Call the streaming parser
|
||||||
|
result = minimax_tool_parser.extract_tool_calls_streaming(
|
||||||
|
previous_text=previous_text,
|
||||||
|
current_text=current_text,
|
||||||
|
delta_text=delta_text,
|
||||||
|
previous_token_ids=[],
|
||||||
|
current_token_ids=[],
|
||||||
|
delta_token_ids=[],
|
||||||
|
request=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Collect results
|
||||||
|
if result is not None:
|
||||||
|
if hasattr(result, 'content') and result.content:
|
||||||
|
content_fragments.append(result.content)
|
||||||
|
# Log important content fragments
|
||||||
|
if any(
|
||||||
|
keyword in result.content for keyword in
|
||||||
|
['<think>', '</think>', '<tool_calls>', '</tool_calls>']):
|
||||||
|
print(
|
||||||
|
f" Char {i}: Content fragment: {repr(result.content)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if hasattr(result, 'tool_calls') and result.tool_calls:
|
||||||
|
for tool_call in result.tool_calls:
|
||||||
|
tool_info = {
|
||||||
|
'character_position':
|
||||||
|
i,
|
||||||
|
'function_name':
|
||||||
|
tool_call.function.name
|
||||||
|
if tool_call.function else None,
|
||||||
|
'arguments':
|
||||||
|
tool_call.function.arguments
|
||||||
|
if tool_call.function else None,
|
||||||
|
}
|
||||||
|
tool_calls_detected.append(tool_info)
|
||||||
|
print(
|
||||||
|
f" Char {i}: Tool call detected: {tool_call.function.name}"
|
||||||
|
)
|
||||||
|
if tool_call.function.arguments:
|
||||||
|
print(
|
||||||
|
f" Arguments: {repr(tool_call.function.arguments)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify results
|
||||||
|
print("\n=== Streaming Test Results ===")
|
||||||
|
print(f"Total content fragments: {len(content_fragments)}")
|
||||||
|
print(f"Total tool calls detected: {len(tool_calls_detected)}")
|
||||||
|
|
||||||
|
# Reconstruct content from fragments
|
||||||
|
reconstructed_content = ''.join(content_fragments)
|
||||||
|
print(f"Reconstructed content length: {len(reconstructed_content)}")
|
||||||
|
|
||||||
|
# Verify thinking tags content is preserved
|
||||||
|
assert '<think>' in reconstructed_content, "Opening thinking tag should be preserved in content"
|
||||||
|
assert '</think>' in reconstructed_content, "Closing thinking tag should be preserved in content"
|
||||||
|
|
||||||
|
# Verify that tool calls inside thinking tags are NOT extracted as actual tool calls
|
||||||
|
thinking_tool_calls = [
|
||||||
|
tc for tc in tool_calls_detected
|
||||||
|
if tc['function_name'] == 'internal_analysis'
|
||||||
|
]
|
||||||
|
assert len(
|
||||||
|
thinking_tool_calls
|
||||||
|
) == 0, f"Tool calls inside thinking tags should be ignored, but found: {thinking_tool_calls}"
|
||||||
|
|
||||||
|
# Verify that real tool calls outside thinking tags ARE extracted
|
||||||
|
weather_tool_calls = [
|
||||||
|
tc for tc in tool_calls_detected
|
||||||
|
if tc['function_name'] == 'get_current_weather'
|
||||||
|
]
|
||||||
|
area_tool_calls = [
|
||||||
|
tc for tc in tool_calls_detected
|
||||||
|
if tc['function_name'] == 'calculate_area'
|
||||||
|
]
|
||||||
|
print(tool_calls_detected)
|
||||||
|
assert len(weather_tool_calls
|
||||||
|
) > 0, "get_current_weather tool call should be detected"
|
||||||
|
assert len(
|
||||||
|
area_tool_calls) > 0, "calculate_area tool call should be detected"
|
||||||
|
|
||||||
|
# Verify tool call arguments are properly streamed
|
||||||
|
weather_args_found = any(tc['arguments'] for tc in weather_tool_calls
|
||||||
|
if tc['arguments'])
|
||||||
|
area_args_found = any(tc['arguments'] for tc in area_tool_calls
|
||||||
|
if tc['arguments'])
|
||||||
|
|
||||||
|
print(f"Weather tool call with arguments: {weather_args_found}")
|
||||||
|
print(f"Area tool call with arguments: {area_args_found}")
|
||||||
|
|
||||||
|
# Verify content before and after tool calls
|
||||||
|
assert 'I\'ll help you with the weather analysis.' in reconstructed_content, "Initial content should be preserved"
|
||||||
|
assert 'Here are the results.' in reconstructed_content, "Final content should be preserved"
|
||||||
|
|
||||||
|
# Verify that <tool_calls> and </tool_calls> tags are not included in the final content
|
||||||
|
# (they should be filtered out when not inside thinking tags)
|
||||||
|
content_outside_thinking = reconstructed_content
|
||||||
|
# Remove thinking tag content to check content outside
|
||||||
|
if '<think>' in content_outside_thinking and '</think>' in content_outside_thinking:
|
||||||
|
start_think = content_outside_thinking.find('<think>')
|
||||||
|
end_think = content_outside_thinking.find('</think>') + len('</think>')
|
||||||
|
content_outside_thinking = content_outside_thinking[:
|
||||||
|
start_think] + content_outside_thinking[
|
||||||
|
end_think:]
|
||||||
|
|
||||||
|
# Outside thinking tags, tool_calls tags should be filtered
|
||||||
|
tool_calls_in_content = content_outside_thinking.count('<tool_calls>')
|
||||||
|
assert tool_calls_in_content == 0, f"<tool_calls> tags should be filtered from content outside thinking tags, but found {tool_calls_in_content}"
|
||||||
|
|
||||||
|
print(
|
||||||
|
"\n=== Character-by-character streaming test completed successfully ==="
|
||||||
|
)
|
||||||
|
print("✓ Tool calls inside thinking tags correctly ignored")
|
||||||
|
print("✓ Tool calls outside thinking tags correctly detected")
|
||||||
|
print("✓ Content properly streamed and reconstructed")
|
||||||
|
print("✓ Tool call tags properly filtered from content")
|
||||||
|
print("✓ Character-level streaming works correctly")
|
||||||
|
|
||||||
|
|
||||||
|
def test_streaming_character_by_character_simple_tool_call(
|
||||||
|
minimax_tool_parser):
|
||||||
|
"""Test character-by-character streaming for a simple tool call scenario."""
|
||||||
|
# Reset streaming state
|
||||||
|
reset_streaming_state(minimax_tool_parser)
|
||||||
|
|
||||||
|
# Simple tool call text
|
||||||
|
simple_text = 'Let me check the weather. <tool_calls>\n{"name": "get_weather", "arguments": {"city": "NYC"}}\n</tool_calls>'
|
||||||
|
|
||||||
|
print("\n=== Simple character-by-character test ===")
|
||||||
|
print(f"Text: {repr(simple_text)}")
|
||||||
|
|
||||||
|
content_parts = []
|
||||||
|
tool_name_sent = False
|
||||||
|
tool_args_sent = False
|
||||||
|
|
||||||
|
for i in range(1, len(simple_text) + 1):
|
||||||
|
current_text = simple_text[:i]
|
||||||
|
previous_text = simple_text[:i - 1] if i > 1 else ""
|
||||||
|
delta_text = simple_text[i - 1:i]
|
||||||
|
|
||||||
|
result = minimax_tool_parser.extract_tool_calls_streaming(
|
||||||
|
previous_text=previous_text,
|
||||||
|
current_text=current_text,
|
||||||
|
delta_text=delta_text,
|
||||||
|
previous_token_ids=[],
|
||||||
|
current_token_ids=[],
|
||||||
|
delta_token_ids=[],
|
||||||
|
request=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result:
|
||||||
|
if hasattr(result, 'content') and result.content:
|
||||||
|
content_parts.append(result.content)
|
||||||
|
print(
|
||||||
|
f" Char {i} ({repr(delta_text)}): Content: {repr(result.content)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if hasattr(result, 'tool_calls') and result.tool_calls:
|
||||||
|
for tool_call in result.tool_calls:
|
||||||
|
if tool_call.function and tool_call.function.name:
|
||||||
|
tool_name_sent = True
|
||||||
|
print(
|
||||||
|
f" Char {i}: Tool name: {tool_call.function.name}"
|
||||||
|
)
|
||||||
|
if tool_call.function and tool_call.function.arguments:
|
||||||
|
tool_args_sent = True
|
||||||
|
print(
|
||||||
|
f" Char {i}: Tool args: {repr(tool_call.function.arguments)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify basic expectations
|
||||||
|
reconstructed_content = ''.join(content_parts)
|
||||||
|
print(f"Final reconstructed content: {repr(reconstructed_content)}")
|
||||||
|
|
||||||
|
assert tool_name_sent, "Tool name should be sent during streaming"
|
||||||
|
assert tool_args_sent, "Tool arguments should be sent during streaming"
|
||||||
|
assert "Let me check the weather." in reconstructed_content, "Initial content should be preserved"
|
||||||
|
|
||||||
|
print("✓ Simple character-by-character test passed")
|
||||||
|
|
||||||
|
|
||||||
|
def test_streaming_character_by_character_with_buffering(minimax_tool_parser):
|
||||||
|
"""Test character-by-character streaming with edge cases that trigger buffering."""
|
||||||
|
# Reset streaming state
|
||||||
|
reset_streaming_state(minimax_tool_parser)
|
||||||
|
|
||||||
|
# Text that includes potential buffering scenarios
|
||||||
|
buffering_text = 'Hello world<tool_calls>\n{"name": "test"}\n</tool_calls>done'
|
||||||
|
|
||||||
|
print("\n=== Buffering character-by-character test ===")
|
||||||
|
print(f"Text: {repr(buffering_text)}")
|
||||||
|
|
||||||
|
all_content = []
|
||||||
|
|
||||||
|
for i in range(1, len(buffering_text) + 1):
|
||||||
|
current_text = buffering_text[:i]
|
||||||
|
previous_text = buffering_text[:i - 1] if i > 1 else ""
|
||||||
|
delta_text = buffering_text[i - 1:i]
|
||||||
|
|
||||||
|
result = minimax_tool_parser.extract_tool_calls_streaming(
|
||||||
|
previous_text=previous_text,
|
||||||
|
current_text=current_text,
|
||||||
|
delta_text=delta_text,
|
||||||
|
previous_token_ids=[],
|
||||||
|
current_token_ids=[],
|
||||||
|
delta_token_ids=[],
|
||||||
|
request=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result and hasattr(result, 'content') and result.content:
|
||||||
|
all_content.append(result.content)
|
||||||
|
print(f" Char {i} ({repr(delta_text)}): {repr(result.content)}")
|
||||||
|
|
||||||
|
final_content = ''.join(all_content)
|
||||||
|
print(f"Final content: {repr(final_content)}")
|
||||||
|
|
||||||
|
# The parser should handle the edge case where </tool_calls> appears before <tool_calls>
|
||||||
|
assert "Hello" in final_content, "Initial 'Hello' should be preserved"
|
||||||
|
assert "world" in final_content, "Content after false closing tag should be preserved"
|
||||||
|
assert "done" in final_content, "Final content should be preserved"
|
||||||
|
|
||||||
|
print("✓ Buffering character-by-character test passed")
|
||||||
|
|||||||
@ -3,11 +3,9 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
import partial_json_parser
|
|
||||||
import regex as re
|
import regex as re
|
||||||
from partial_json_parser.core.options import Allow
|
|
||||||
|
|
||||||
from vllm.entrypoints.chat_utils import random_tool_call_id
|
from vllm.entrypoints.chat_utils import random_tool_call_id
|
||||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||||
@ -17,6 +15,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
|||||||
FunctionCall, ToolCall)
|
FunctionCall, ToolCall)
|
||||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||||
ToolParser, ToolParserManager)
|
ToolParser, ToolParserManager)
|
||||||
|
from vllm.entrypoints.openai.tool_parsers.utils import (
|
||||||
|
extract_intermediate_diff)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
|
|
||||||
@ -29,25 +29,32 @@ class MinimaxToolParser(ToolParser):
|
|||||||
def __init__(self, tokenizer: AnyTokenizer):
|
def __init__(self, tokenizer: AnyTokenizer):
|
||||||
super().__init__(tokenizer)
|
super().__init__(tokenizer)
|
||||||
|
|
||||||
self.current_tool_name_sent: bool = False
|
# Initialize streaming state for tracking tool call progress
|
||||||
self.prev_tool_call_arr: list[dict] = []
|
self.streaming_state: dict[str, Any] = {
|
||||||
self.current_tool_id: int = -1
|
"current_tool_index": -1, # Index of current tool being processed
|
||||||
self.streamed_args_for_tool: list[str] = []
|
"tool_ids": [], # List of tool call IDs
|
||||||
|
"sent_tools": [], # List of tools that have been sent
|
||||||
self.tool_call_start_token: str = "<tool_calls>"
|
}
|
||||||
self.tool_call_end_token: str = "</tool_calls>"
|
|
||||||
|
|
||||||
|
# Define tool call tokens and patterns
|
||||||
|
self.tool_call_start_token = "<tool_calls>"
|
||||||
|
self.tool_call_end_token = "</tool_calls>"
|
||||||
self.tool_call_regex = re.compile(
|
self.tool_call_regex = re.compile(
|
||||||
r"<tool_calls>(.*?)</tool_calls>|<tool_calls>(.*)", re.DOTALL)
|
r"<tool_calls>(.*?)</tool_calls>|<tool_calls>(.*)", re.DOTALL)
|
||||||
|
|
||||||
# Add regex pattern for thinking tag
|
|
||||||
self.thinking_tag_pattern = r"<think>(.*?)</think>"
|
self.thinking_tag_pattern = r"<think>(.*?)</think>"
|
||||||
|
self.tool_name_pattern = re.compile(r'"name":\s*"([^"]+)"')
|
||||||
|
self.tool_args_pattern = re.compile(r'"arguments":\s*')
|
||||||
|
|
||||||
|
# Buffer for handling partial tool calls during streaming
|
||||||
|
self.pending_buffer = ""
|
||||||
|
self.in_thinking_tag = False
|
||||||
|
|
||||||
if not self.model_tokenizer:
|
if not self.model_tokenizer:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The model tokenizer must be passed to the ToolParser "
|
"The model tokenizer must be passed to the ToolParser "
|
||||||
"constructor during construction.")
|
"constructor during construction.")
|
||||||
|
|
||||||
|
# Get token IDs for tool call start/end tokens
|
||||||
self.tool_call_start_token_id = self.vocab.get(
|
self.tool_call_start_token_id = self.vocab.get(
|
||||||
self.tool_call_start_token)
|
self.tool_call_start_token)
|
||||||
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
|
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
|
||||||
@ -60,33 +67,95 @@ class MinimaxToolParser(ToolParser):
|
|||||||
|
|
||||||
def preprocess_model_output(self, model_output: str) -> str:
|
def preprocess_model_output(self, model_output: str) -> str:
|
||||||
"""
|
"""
|
||||||
Remove tool calls from within thinking tags to avoid processing them.
|
Preprocess model output by removing tool calls from thinking tags.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_output: Raw model output string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Preprocessed model output with tool calls removed from thinking tags
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def remove_tool_calls_from_think(match):
|
def remove_tool_calls_from_think(match):
|
||||||
think_content = match.group(1)
|
think_content = match.group(1)
|
||||||
# Remove tool_calls from within the think tag
|
|
||||||
cleaned_content = re.sub(r"<tool_calls>.*?</tool_calls>",
|
cleaned_content = re.sub(r"<tool_calls>.*?</tool_calls>",
|
||||||
"",
|
"",
|
||||||
think_content,
|
think_content,
|
||||||
flags=re.DOTALL)
|
flags=re.DOTALL)
|
||||||
return f"<think>{cleaned_content}</think>"
|
return f"<think>{cleaned_content}</think>"
|
||||||
|
|
||||||
# Process thinking tags and remove tool_calls from within them
|
return re.sub(self.thinking_tag_pattern,
|
||||||
processed_output = re.sub(self.thinking_tag_pattern,
|
remove_tool_calls_from_think,
|
||||||
remove_tool_calls_from_think,
|
model_output,
|
||||||
model_output,
|
flags=re.DOTALL)
|
||||||
flags=re.DOTALL)
|
|
||||||
|
|
||||||
return processed_output
|
def _clean_duplicate_braces(self, args_text: str) -> str:
|
||||||
|
"""
|
||||||
|
Clean duplicate closing braces from arguments text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args_text: Raw arguments text
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cleaned arguments text with proper JSON formatting
|
||||||
|
"""
|
||||||
|
args_text = args_text.strip()
|
||||||
|
if not args_text:
|
||||||
|
return args_text
|
||||||
|
|
||||||
|
try:
|
||||||
|
json.loads(args_text)
|
||||||
|
return args_text
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
while args_text.endswith('}}'):
|
||||||
|
candidate = args_text[:-1]
|
||||||
|
try:
|
||||||
|
json.loads(candidate)
|
||||||
|
return candidate
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
args_text = candidate
|
||||||
|
|
||||||
|
return args_text
|
||||||
|
|
||||||
|
def _clean_delta_braces(self, delta_text: str) -> str:
|
||||||
|
"""
|
||||||
|
Clean delta text by removing excessive closing braces.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
delta_text: Delta text to clean
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cleaned delta text
|
||||||
|
"""
|
||||||
|
if not delta_text:
|
||||||
|
return delta_text
|
||||||
|
|
||||||
|
delta_stripped = delta_text.strip()
|
||||||
|
|
||||||
|
if delta_stripped and all(c in '}\n\r\t ' for c in delta_stripped):
|
||||||
|
brace_count = delta_stripped.count('}')
|
||||||
|
if brace_count > 1:
|
||||||
|
return '}\n' if delta_text.endswith('\n') else '}'
|
||||||
|
|
||||||
|
return delta_text
|
||||||
|
|
||||||
def extract_tool_calls(
|
def extract_tool_calls(
|
||||||
self,
|
self,
|
||||||
model_output: str,
|
model_output: str,
|
||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
) -> ExtractedToolCallInformation:
|
) -> ExtractedToolCallInformation:
|
||||||
|
"""
|
||||||
|
Extract tool calls from model output for non-streaming mode.
|
||||||
|
|
||||||
# Preprocess to remove tool calls from thinking tags
|
Args:
|
||||||
|
model_output: Complete model output
|
||||||
|
request: Chat completion request
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ExtractedToolCallInformation containing tool calls and content
|
||||||
|
"""
|
||||||
processed_output = self.preprocess_model_output(model_output)
|
processed_output = self.preprocess_model_output(model_output)
|
||||||
|
|
||||||
if self.tool_call_start_token not in processed_output:
|
if self.tool_call_start_token not in processed_output:
|
||||||
@ -95,8 +164,8 @@ class MinimaxToolParser(ToolParser):
|
|||||||
content=model_output)
|
content=model_output)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
function_call_tuples = (
|
function_call_tuples = self.tool_call_regex.findall(
|
||||||
self.tool_call_regex.findall(processed_output))
|
processed_output)
|
||||||
|
|
||||||
raw_function_calls = []
|
raw_function_calls = []
|
||||||
for match in function_call_tuples:
|
for match in function_call_tuples:
|
||||||
@ -124,21 +193,15 @@ class MinimaxToolParser(ToolParser):
|
|||||||
function_call["arguments"],
|
function_call["arguments"],
|
||||||
ensure_ascii=False))))
|
ensure_ascii=False))))
|
||||||
|
|
||||||
# Extract content before the first valid tool call
|
|
||||||
# Find the position in processed output, then map back to original
|
|
||||||
processed_pos = processed_output.find(self.tool_call_start_token)
|
processed_pos = processed_output.find(self.tool_call_start_token)
|
||||||
if processed_pos != -1:
|
if processed_pos != -1:
|
||||||
# Get the content before tool calls in processed output
|
|
||||||
processed_content = processed_output[:processed_pos].strip()
|
processed_content = processed_output[:processed_pos].strip()
|
||||||
|
|
||||||
if processed_content:
|
if processed_content:
|
||||||
# Find the end of this content in the original output
|
|
||||||
# Look for the last non-empty line of processed content
|
|
||||||
lines = processed_content.split('\n')
|
lines = processed_content.split('\n')
|
||||||
for line in reversed(lines):
|
for line in reversed(lines):
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
if line:
|
if line:
|
||||||
# Find this line in original output
|
|
||||||
pos = model_output.find(line)
|
pos = model_output.find(line)
|
||||||
if pos != -1:
|
if pos != -1:
|
||||||
content = model_output[:pos + len(line)]
|
content = model_output[:pos + len(line)]
|
||||||
@ -162,6 +225,445 @@ class MinimaxToolParser(ToolParser):
|
|||||||
tool_calls=[],
|
tool_calls=[],
|
||||||
content=model_output)
|
content=model_output)
|
||||||
|
|
||||||
|
def _update_thinking_state(self, text: str) -> None:
|
||||||
|
"""
|
||||||
|
Update the thinking tag state based on text content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to analyze for thinking tags
|
||||||
|
"""
|
||||||
|
open_count = text.count("<think>")
|
||||||
|
close_count = text.count("</think>")
|
||||||
|
self.in_thinking_tag = open_count > close_count or (
|
||||||
|
open_count == close_count and text.endswith("</think>"))
|
||||||
|
|
||||||
|
def _is_potential_tag_start(self, text: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if text might be the start of a tool call tag.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if text could be the start of a tool call tag
|
||||||
|
"""
|
||||||
|
for tag in [self.tool_call_start_token, self.tool_call_end_token]:
|
||||||
|
if any(
|
||||||
|
tag.startswith(text[-i:])
|
||||||
|
for i in range(1, min(len(text) + 1, len(tag)))):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _should_buffer_content(self, delta_text: str) -> bool:
|
||||||
|
"""
|
||||||
|
Determine if content should be buffered for later processing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
delta_text: Delta text to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if content should be buffered
|
||||||
|
"""
|
||||||
|
if self.in_thinking_tag:
|
||||||
|
return False
|
||||||
|
return bool(self.pending_buffer
|
||||||
|
or self.tool_call_start_token in delta_text
|
||||||
|
or self.tool_call_end_token in delta_text
|
||||||
|
or delta_text.startswith('<'))
|
||||||
|
|
||||||
|
def _split_content_for_buffering(self, delta_text: str) -> tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Split delta text into safe content and potential tag content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
delta_text: Delta text to split
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (safe_content, potential_tag_content)
|
||||||
|
"""
|
||||||
|
if self.in_thinking_tag:
|
||||||
|
return delta_text, ""
|
||||||
|
|
||||||
|
for tag in [self.tool_call_start_token, self.tool_call_end_token]:
|
||||||
|
for i in range(1, len(tag)):
|
||||||
|
tag_prefix = tag[:i]
|
||||||
|
pos = delta_text.rfind(tag_prefix)
|
||||||
|
if pos != -1 and tag.startswith(delta_text[pos:]):
|
||||||
|
return delta_text[:pos], delta_text[pos:]
|
||||||
|
return delta_text, ""
|
||||||
|
|
||||||
|
def _process_buffer(self, new_content: str) -> str:
|
||||||
|
"""
|
||||||
|
Process buffered content and return output content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
new_content: New content to add to buffer
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Processed output content
|
||||||
|
"""
|
||||||
|
self.pending_buffer += new_content
|
||||||
|
output_content = ""
|
||||||
|
|
||||||
|
if self.in_thinking_tag:
|
||||||
|
output_content = self.pending_buffer
|
||||||
|
self.pending_buffer = ""
|
||||||
|
return output_content
|
||||||
|
|
||||||
|
while self.pending_buffer:
|
||||||
|
start_pos = self.pending_buffer.find(self.tool_call_start_token)
|
||||||
|
end_pos = self.pending_buffer.find(self.tool_call_end_token)
|
||||||
|
|
||||||
|
if start_pos != -1 and (end_pos == -1 or start_pos < end_pos):
|
||||||
|
tag_pos, tag_len = start_pos, len(self.tool_call_start_token)
|
||||||
|
elif end_pos != -1:
|
||||||
|
tag_pos, tag_len = end_pos, len(self.tool_call_end_token)
|
||||||
|
else:
|
||||||
|
if self._is_potential_tag_start(self.pending_buffer):
|
||||||
|
break
|
||||||
|
output_content += self.pending_buffer
|
||||||
|
self.pending_buffer = ""
|
||||||
|
break
|
||||||
|
|
||||||
|
output_content += self.pending_buffer[:tag_pos]
|
||||||
|
self.pending_buffer = self.pending_buffer[tag_pos + tag_len:]
|
||||||
|
|
||||||
|
return output_content
|
||||||
|
|
||||||
|
def _reset_streaming_state(self) -> None:
|
||||||
|
"""Reset the streaming state to initial values."""
|
||||||
|
self.streaming_state = {
|
||||||
|
"current_tool_index": -1,
|
||||||
|
"tool_ids": [],
|
||||||
|
"sent_tools": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
def _advance_to_next_tool(self) -> None:
|
||||||
|
"""Advance to the next tool in the streaming sequence."""
|
||||||
|
self.streaming_state["current_tool_index"] = int(
|
||||||
|
self.streaming_state["current_tool_index"]) + 1
|
||||||
|
|
||||||
|
def _set_current_tool_index(self, index: int) -> None:
|
||||||
|
"""
|
||||||
|
Set the current tool index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index: Tool index to set
|
||||||
|
"""
|
||||||
|
self.streaming_state["current_tool_index"] = index
|
||||||
|
|
||||||
|
def _get_current_tool_index(self) -> int:
|
||||||
|
"""
|
||||||
|
Get the current tool index.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Current tool index
|
||||||
|
"""
|
||||||
|
return int(self.streaming_state["current_tool_index"])
|
||||||
|
|
||||||
|
def _get_next_unsent_tool_index(self, tool_count: int) -> int:
|
||||||
|
"""
|
||||||
|
Get the index of the next unsent tool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_count: Total number of tools
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Index of next unsent tool, or -1 if all tools sent
|
||||||
|
"""
|
||||||
|
sent_tools = list(self.streaming_state["sent_tools"])
|
||||||
|
for i in range(tool_count):
|
||||||
|
if i < len(sent_tools):
|
||||||
|
if not sent_tools[i]["sent_name"]:
|
||||||
|
return i
|
||||||
|
else:
|
||||||
|
return i
|
||||||
|
return -1
|
||||||
|
|
||||||
|
def _ensure_state_arrays(self, tool_count: int) -> None:
|
||||||
|
"""
|
||||||
|
Ensure state arrays have sufficient capacity for tool_count tools.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_count: Number of tools to prepare for
|
||||||
|
"""
|
||||||
|
sent_tools = list(self.streaming_state["sent_tools"])
|
||||||
|
tool_ids = list(self.streaming_state["tool_ids"])
|
||||||
|
|
||||||
|
while len(sent_tools) < tool_count:
|
||||||
|
sent_tools.append({
|
||||||
|
"sent_name": False,
|
||||||
|
"sent_arguments": "",
|
||||||
|
"id": random_tool_call_id(),
|
||||||
|
})
|
||||||
|
|
||||||
|
while len(tool_ids) < tool_count:
|
||||||
|
tool_ids.append(None)
|
||||||
|
|
||||||
|
self.streaming_state["sent_tools"] = sent_tools
|
||||||
|
self.streaming_state["tool_ids"] = tool_ids
|
||||||
|
|
||||||
|
def _detect_tools_in_text(self, text: str) -> int:
|
||||||
|
"""
|
||||||
|
Detect the number of tools in text by counting name patterns.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to analyze
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of tools detected
|
||||||
|
"""
|
||||||
|
matches = self.tool_name_pattern.findall(text)
|
||||||
|
return len(matches)
|
||||||
|
|
||||||
|
def _find_tool_boundaries(self, text: str) -> list[tuple[int, int]]:
|
||||||
|
"""
|
||||||
|
Find the boundaries of tool calls in text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to analyze
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of (start, end) positions for tool calls
|
||||||
|
"""
|
||||||
|
boundaries = []
|
||||||
|
i = 0
|
||||||
|
while i < len(text):
|
||||||
|
if text[i] == '{':
|
||||||
|
start = i
|
||||||
|
depth = 0
|
||||||
|
has_name = False
|
||||||
|
has_arguments = False
|
||||||
|
|
||||||
|
while i < len(text):
|
||||||
|
if text[i] == '{':
|
||||||
|
depth += 1
|
||||||
|
elif text[i] == '}':
|
||||||
|
depth -= 1
|
||||||
|
if depth == 0:
|
||||||
|
end = i + 1
|
||||||
|
segment = text[start:end]
|
||||||
|
if '"name"' in segment and '"arguments"' in segment:
|
||||||
|
boundaries.append((start, end))
|
||||||
|
break
|
||||||
|
|
||||||
|
if not has_name and '"name"' in text[start:i + 1]:
|
||||||
|
has_name = True
|
||||||
|
if not has_arguments and '"arguments"' in text[start:i +
|
||||||
|
1]:
|
||||||
|
has_arguments = True
|
||||||
|
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
if depth > 0 and has_name:
|
||||||
|
boundaries.append((start, i))
|
||||||
|
else:
|
||||||
|
i += 1
|
||||||
|
return boundaries
|
||||||
|
|
||||||
|
def _extract_tool_args(self, tool_content: str, args_match) -> str:
|
||||||
|
"""
|
||||||
|
Extract tool arguments from tool content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_content: Tool call content
|
||||||
|
args_match: Regex match for arguments pattern
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Extracted arguments as string
|
||||||
|
"""
|
||||||
|
args_start_pos = args_match.end()
|
||||||
|
remaining_content = tool_content[args_start_pos:]
|
||||||
|
|
||||||
|
if remaining_content.strip().startswith('{'):
|
||||||
|
depth = 0
|
||||||
|
for i, char in enumerate(remaining_content):
|
||||||
|
if char == '{':
|
||||||
|
depth += 1
|
||||||
|
elif char == '}':
|
||||||
|
depth -= 1
|
||||||
|
if depth == 0:
|
||||||
|
return remaining_content[:i + 1]
|
||||||
|
else:
|
||||||
|
args_end = remaining_content.find('}')
|
||||||
|
if args_end > 0:
|
||||||
|
return remaining_content[:args_end].strip()
|
||||||
|
|
||||||
|
return remaining_content.rstrip('}').strip()
|
||||||
|
|
||||||
|
def _get_current_tool_content(
|
||||||
|
self, text: str,
|
||||||
|
tool_index: int) -> tuple[Optional[str], Optional[str]]:
|
||||||
|
"""
|
||||||
|
Get the content of a specific tool by index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text containing tool calls
|
||||||
|
tool_index: Index of tool to extract
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (tool_name, tool_arguments) or (None, None) if not found
|
||||||
|
"""
|
||||||
|
boundaries = self._find_tool_boundaries(text)
|
||||||
|
|
||||||
|
if tool_index >= len(boundaries):
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
start, end = boundaries[tool_index]
|
||||||
|
tool_content = text[start:end]
|
||||||
|
|
||||||
|
name_match = self.tool_name_pattern.search(tool_content)
|
||||||
|
name = name_match.group(1) if name_match else None
|
||||||
|
|
||||||
|
args_match = self.tool_args_pattern.search(tool_content)
|
||||||
|
if args_match:
|
||||||
|
try:
|
||||||
|
args_text = self._extract_tool_args(tool_content, args_match)
|
||||||
|
return name, args_text
|
||||||
|
except Exception:
|
||||||
|
remaining_content = tool_content[args_match.end():]
|
||||||
|
args_text = remaining_content.rstrip('}').strip()
|
||||||
|
return name, args_text
|
||||||
|
|
||||||
|
return name, None
|
||||||
|
|
||||||
|
def _handle_tool_name_streaming(
|
||||||
|
self, tool_content: str,
|
||||||
|
tool_count: int) -> Union[DeltaMessage, None]:
|
||||||
|
"""
|
||||||
|
Handle streaming of tool names.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_content: Content containing tool calls
|
||||||
|
tool_count: Total number of tools
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DeltaMessage with tool name or None if no tool to stream
|
||||||
|
"""
|
||||||
|
next_idx = self._get_next_unsent_tool_index(tool_count)
|
||||||
|
|
||||||
|
if next_idx == -1:
|
||||||
|
return None
|
||||||
|
|
||||||
|
boundaries = self._find_tool_boundaries(tool_content)
|
||||||
|
if next_idx >= len(boundaries):
|
||||||
|
return None
|
||||||
|
|
||||||
|
tool_name, _ = self._get_current_tool_content(tool_content, next_idx)
|
||||||
|
if not tool_name:
|
||||||
|
return None
|
||||||
|
|
||||||
|
self._set_current_tool_index(next_idx)
|
||||||
|
sent_tools = list(self.streaming_state["sent_tools"])
|
||||||
|
tool_ids = list(self.streaming_state["tool_ids"])
|
||||||
|
|
||||||
|
tool_id = sent_tools[next_idx]["id"]
|
||||||
|
tool_ids[next_idx] = tool_id
|
||||||
|
sent_tools[next_idx]["sent_name"] = True
|
||||||
|
|
||||||
|
self.streaming_state["sent_tools"] = sent_tools
|
||||||
|
self.streaming_state["tool_ids"] = tool_ids
|
||||||
|
|
||||||
|
return DeltaMessage(tool_calls=[
|
||||||
|
DeltaToolCall(index=next_idx,
|
||||||
|
type="function",
|
||||||
|
id=tool_id,
|
||||||
|
function=DeltaFunctionCall(
|
||||||
|
name=tool_name).model_dump(exclude_none=True))
|
||||||
|
])
|
||||||
|
|
||||||
|
def _handle_tool_args_streaming(
|
||||||
|
self, tool_content: str,
|
||||||
|
tool_count: int) -> Union[DeltaMessage, None]:
|
||||||
|
"""
|
||||||
|
Handle streaming of tool arguments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_content: Content containing tool calls
|
||||||
|
tool_count: Total number of tools
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DeltaMessage with tool arguments or None if no arguments to stream
|
||||||
|
"""
|
||||||
|
current_idx = self._get_current_tool_index()
|
||||||
|
|
||||||
|
if current_idx < 0 or current_idx >= tool_count:
|
||||||
|
return None
|
||||||
|
|
||||||
|
tool_name, tool_args = self._get_current_tool_content(
|
||||||
|
tool_content, current_idx)
|
||||||
|
if not tool_name or tool_args is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
sent_tools = list(self.streaming_state["sent_tools"])
|
||||||
|
|
||||||
|
if not sent_tools[current_idx]["sent_name"]:
|
||||||
|
return None
|
||||||
|
|
||||||
|
clean_args = self._clean_duplicate_braces(tool_args)
|
||||||
|
sent_args = sent_tools[current_idx]["sent_arguments"]
|
||||||
|
|
||||||
|
if clean_args != sent_args:
|
||||||
|
if sent_args and clean_args.startswith(sent_args):
|
||||||
|
args_delta = extract_intermediate_diff(clean_args, sent_args)
|
||||||
|
if args_delta:
|
||||||
|
args_delta = self._clean_delta_braces(args_delta)
|
||||||
|
sent_tools[current_idx]["sent_arguments"] = clean_args
|
||||||
|
self.streaming_state["sent_tools"] = sent_tools
|
||||||
|
|
||||||
|
if clean_args.endswith('}'):
|
||||||
|
self._advance_to_next_tool()
|
||||||
|
|
||||||
|
return DeltaMessage(tool_calls=[
|
||||||
|
DeltaToolCall(index=current_idx,
|
||||||
|
function=DeltaFunctionCall(
|
||||||
|
arguments=args_delta).model_dump(
|
||||||
|
exclude_none=True))
|
||||||
|
])
|
||||||
|
elif not sent_args and clean_args:
|
||||||
|
clean_args_delta = self._clean_delta_braces(clean_args)
|
||||||
|
sent_tools[current_idx]["sent_arguments"] = clean_args
|
||||||
|
self.streaming_state["sent_tools"] = sent_tools
|
||||||
|
|
||||||
|
if clean_args.endswith('}'):
|
||||||
|
self._advance_to_next_tool()
|
||||||
|
|
||||||
|
return DeltaMessage(tool_calls=[
|
||||||
|
DeltaToolCall(index=current_idx,
|
||||||
|
function=DeltaFunctionCall(
|
||||||
|
arguments=clean_args_delta).model_dump(
|
||||||
|
exclude_none=True))
|
||||||
|
])
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _is_end_tool_calls(self, current_text: str) -> bool:
|
||||||
|
if self.tool_call_end_token not in current_text:
|
||||||
|
return False
|
||||||
|
|
||||||
|
end_token_positions = []
|
||||||
|
search_start = 0
|
||||||
|
while True:
|
||||||
|
pos = current_text.find(self.tool_call_end_token, search_start)
|
||||||
|
if pos == -1:
|
||||||
|
break
|
||||||
|
end_token_positions.append(pos)
|
||||||
|
search_start = pos + 1
|
||||||
|
|
||||||
|
think_regions = []
|
||||||
|
for match in re.finditer(self.thinking_tag_pattern,
|
||||||
|
current_text,
|
||||||
|
flags=re.DOTALL):
|
||||||
|
think_regions.append((match.start(), match.end()))
|
||||||
|
|
||||||
|
for pos in end_token_positions:
|
||||||
|
in_think = any(pos >= t_start and pos < t_end
|
||||||
|
for t_start, t_end in think_regions)
|
||||||
|
if not in_think:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
def extract_tool_calls_streaming(
|
def extract_tool_calls_streaming(
|
||||||
self,
|
self,
|
||||||
previous_text: str,
|
previous_text: str,
|
||||||
@ -172,13 +674,37 @@ class MinimaxToolParser(ToolParser):
|
|||||||
delta_token_ids: Sequence[int],
|
delta_token_ids: Sequence[int],
|
||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
) -> Union[DeltaMessage, None]:
|
) -> Union[DeltaMessage, None]:
|
||||||
logger.debug("delta_text: %s", delta_text)
|
self._update_thinking_state(current_text)
|
||||||
logger.debug("delta_token_ids: %s", delta_token_ids)
|
|
||||||
|
if self.in_thinking_tag:
|
||||||
|
return DeltaMessage(content=delta_text)
|
||||||
|
|
||||||
|
if self._should_buffer_content(delta_text):
|
||||||
|
buffered_output = self._process_buffer(delta_text)
|
||||||
|
return DeltaMessage(
|
||||||
|
content=buffered_output) if buffered_output else None
|
||||||
|
|
||||||
|
if self._is_end_tool_calls(current_text):
|
||||||
|
return DeltaMessage(content=delta_text)
|
||||||
|
|
||||||
|
safe_content, potential_tag = self._split_content_for_buffering(
|
||||||
|
delta_text)
|
||||||
|
if potential_tag:
|
||||||
|
self.pending_buffer += potential_tag
|
||||||
|
return DeltaMessage(content=safe_content) if safe_content else None
|
||||||
|
|
||||||
# Preprocess to remove tool calls from thinking tags
|
|
||||||
processed_current_text = self.preprocess_model_output(current_text)
|
processed_current_text = self.preprocess_model_output(current_text)
|
||||||
|
|
||||||
if self.tool_call_start_token not in processed_current_text:
|
if self.tool_call_start_token not in processed_current_text:
|
||||||
|
if (self.tool_call_end_token in delta_text
|
||||||
|
and self.tool_call_start_token in current_text):
|
||||||
|
return None
|
||||||
|
if delta_text.strip(
|
||||||
|
) == '' and self.tool_call_start_token in current_text:
|
||||||
|
return None
|
||||||
|
if (self._get_current_tool_index() != -1
|
||||||
|
and self.tool_call_end_token in current_text):
|
||||||
|
self._reset_streaming_state()
|
||||||
return DeltaMessage(content=delta_text)
|
return DeltaMessage(content=delta_text)
|
||||||
|
|
||||||
if (self.tool_call_start_token_id is not None
|
if (self.tool_call_start_token_id is not None
|
||||||
@ -186,184 +712,104 @@ class MinimaxToolParser(ToolParser):
|
|||||||
and len(delta_token_ids) == 1):
|
and len(delta_token_ids) == 1):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
original_tool_call_start_pos = current_text.find(
|
original_tool_start = self._find_tool_start_outside_thinking(
|
||||||
self.tool_call_start_token)
|
current_text)
|
||||||
if original_tool_call_start_pos > 0:
|
if original_tool_start is None:
|
||||||
delta_start_pos = len(current_text) - len(delta_text)
|
return None
|
||||||
if delta_start_pos < original_tool_call_start_pos:
|
|
||||||
content_part = delta_text
|
|
||||||
if delta_start_pos + len(
|
|
||||||
delta_text) > original_tool_call_start_pos:
|
|
||||||
content_part = delta_text[:original_tool_call_start_pos -
|
|
||||||
delta_start_pos]
|
|
||||||
if content_part:
|
|
||||||
return DeltaMessage(content=content_part)
|
|
||||||
|
|
||||||
flags = Allow.ALL if self.current_tool_name_sent \
|
content_before_tools = self._extract_content_before_tools(
|
||||||
else Allow.ALL & ~Allow.STR
|
current_text, delta_text, original_tool_start)
|
||||||
|
if content_before_tools:
|
||||||
|
return DeltaMessage(content=content_before_tools)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
parsable_content = processed_current_text.split(
|
tool_content = self._extract_tool_content(current_text,
|
||||||
self.tool_call_start_token)[-1].split(
|
original_tool_start)
|
||||||
self.tool_call_end_token)[0]
|
current_tools_count = self._detect_tools_in_text(tool_content)
|
||||||
|
|
||||||
tool_call_arr = []
|
if current_tools_count == 0:
|
||||||
if parsable_content.strip():
|
|
||||||
lines = parsable_content.strip().split('\n')
|
|
||||||
for line in lines:
|
|
||||||
line = line.strip()
|
|
||||||
if line and (line.startswith('{') or '"name"' in line):
|
|
||||||
try:
|
|
||||||
if line.endswith('}'):
|
|
||||||
parsed_call = json.loads(line)
|
|
||||||
tool_call_arr.append(parsed_call)
|
|
||||||
else:
|
|
||||||
parsed_call = partial_json_parser.loads(
|
|
||||||
line, flags)
|
|
||||||
if parsed_call and isinstance(
|
|
||||||
parsed_call, dict):
|
|
||||||
tool_call_arr.append(parsed_call)
|
|
||||||
except (json.JSONDecodeError, partial_json_parser.core.
|
|
||||||
exceptions.MalformedJSON):
|
|
||||||
continue
|
|
||||||
|
|
||||||
current_tool_call: dict = tool_call_arr[self.current_tool_id] \
|
|
||||||
if len(tool_call_arr) > self.current_tool_id >= 0 else {}
|
|
||||||
|
|
||||||
if len(tool_call_arr) == 0:
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Starting a new tool in the array
|
if self._get_current_tool_index() == -1:
|
||||||
elif (len(tool_call_arr) > 0
|
self._reset_streaming_state()
|
||||||
and len(tool_call_arr) > self.current_tool_id + 1):
|
|
||||||
|
|
||||||
# Handle any missed arguments from previous tool
|
self._ensure_state_arrays(current_tools_count)
|
||||||
if self.current_tool_id >= 0 and self.current_tool_id < len(
|
|
||||||
self.prev_tool_call_arr):
|
|
||||||
prev_tool_call = self.prev_tool_call_arr[
|
|
||||||
self.current_tool_id]
|
|
||||||
diff_arguments = prev_tool_call.get("arguments")
|
|
||||||
|
|
||||||
if diff_arguments:
|
return (self._handle_tool_name_streaming(tool_content,
|
||||||
diff_arguments_json = json.dumps(diff_arguments,
|
current_tools_count)
|
||||||
ensure_ascii=False)
|
or self._handle_tool_args_streaming(
|
||||||
already_streamed = self.streamed_args_for_tool[
|
tool_content, current_tools_count))
|
||||||
self.
|
|
||||||
current_tool_id] if self.current_tool_id < len(
|
|
||||||
self.streamed_args_for_tool) else ""
|
|
||||||
|
|
||||||
if diff_arguments_json != already_streamed:
|
|
||||||
diff = diff_arguments_json[len(already_streamed):]
|
|
||||||
delta = DeltaMessage(tool_calls=[
|
|
||||||
DeltaToolCall(index=self.current_tool_id,
|
|
||||||
function=DeltaFunctionCall(
|
|
||||||
arguments=diff).model_dump(
|
|
||||||
exclude_none=True))
|
|
||||||
])
|
|
||||||
if self.current_tool_id < len(
|
|
||||||
self.streamed_args_for_tool):
|
|
||||||
self.streamed_args_for_tool[
|
|
||||||
self.current_tool_id] = diff_arguments_json
|
|
||||||
else:
|
|
||||||
delta = None
|
|
||||||
else:
|
|
||||||
delta = None
|
|
||||||
else:
|
|
||||||
delta = None
|
|
||||||
|
|
||||||
self.current_tool_id = len(tool_call_arr) - 1
|
|
||||||
self.current_tool_name_sent = False
|
|
||||||
self.streamed_args_for_tool.append("")
|
|
||||||
logger.debug("starting on new tool %d", self.current_tool_id)
|
|
||||||
return delta
|
|
||||||
|
|
||||||
# Send tool name if not sent yet
|
|
||||||
if not self.current_tool_name_sent:
|
|
||||||
function_name = current_tool_call.get("name")
|
|
||||||
if function_name:
|
|
||||||
delta = DeltaMessage(tool_calls=[
|
|
||||||
DeltaToolCall(index=self.current_tool_id,
|
|
||||||
type="function",
|
|
||||||
id=random_tool_call_id(),
|
|
||||||
function=DeltaFunctionCall(
|
|
||||||
name=function_name).model_dump(
|
|
||||||
exclude_none=True))
|
|
||||||
])
|
|
||||||
self.current_tool_name_sent = True
|
|
||||||
else:
|
|
||||||
delta = None
|
|
||||||
|
|
||||||
# Stream arguments
|
|
||||||
else:
|
|
||||||
prev_arguments = None
|
|
||||||
if (self.current_tool_id < len(self.prev_tool_call_arr)
|
|
||||||
and self.prev_tool_call_arr[self.current_tool_id]):
|
|
||||||
prev_arguments = self.prev_tool_call_arr[
|
|
||||||
self.current_tool_id].get("arguments")
|
|
||||||
|
|
||||||
cur_arguments = current_tool_call.get("arguments")
|
|
||||||
|
|
||||||
if not cur_arguments and not prev_arguments:
|
|
||||||
delta = None
|
|
||||||
elif not cur_arguments and prev_arguments:
|
|
||||||
logger.error(
|
|
||||||
"Arguments reset mid-call, skipping streaming")
|
|
||||||
delta = None
|
|
||||||
elif cur_arguments and not prev_arguments:
|
|
||||||
cur_arguments_json = json.dumps(cur_arguments,
|
|
||||||
ensure_ascii=False)
|
|
||||||
logger.debug("First tokens in arguments received: %s",
|
|
||||||
cur_arguments_json)
|
|
||||||
|
|
||||||
delta = DeltaMessage(tool_calls=[
|
|
||||||
DeltaToolCall(index=self.current_tool_id,
|
|
||||||
function=DeltaFunctionCall(
|
|
||||||
arguments=cur_arguments_json).
|
|
||||||
model_dump(exclude_none=True))
|
|
||||||
])
|
|
||||||
self.streamed_args_for_tool[
|
|
||||||
self.current_tool_id] = cur_arguments_json
|
|
||||||
|
|
||||||
elif cur_arguments and prev_arguments:
|
|
||||||
cur_args_json = json.dumps(cur_arguments,
|
|
||||||
ensure_ascii=False)
|
|
||||||
prev_args_json = json.dumps(prev_arguments,
|
|
||||||
ensure_ascii=False)
|
|
||||||
|
|
||||||
logger.debug("Searching for diff between \n%s\n%s",
|
|
||||||
cur_args_json, prev_args_json)
|
|
||||||
|
|
||||||
already_streamed = self.streamed_args_for_tool[
|
|
||||||
self.current_tool_id] if self.current_tool_id < len(
|
|
||||||
self.streamed_args_for_tool) else ""
|
|
||||||
|
|
||||||
if cur_args_json.startswith(already_streamed):
|
|
||||||
argument_diff = cur_args_json[len(already_streamed):]
|
|
||||||
elif cur_args_json != already_streamed:
|
|
||||||
argument_diff = cur_args_json
|
|
||||||
self.streamed_args_for_tool[self.current_tool_id] = ""
|
|
||||||
else:
|
|
||||||
argument_diff = ""
|
|
||||||
|
|
||||||
if argument_diff:
|
|
||||||
logger.debug("got arguments diff: %s", argument_diff)
|
|
||||||
delta = DeltaMessage(tool_calls=[
|
|
||||||
DeltaToolCall(index=self.current_tool_id,
|
|
||||||
function=DeltaFunctionCall(
|
|
||||||
arguments=argument_diff).
|
|
||||||
model_dump(exclude_none=True))
|
|
||||||
])
|
|
||||||
self.streamed_args_for_tool[
|
|
||||||
self.current_tool_id] += argument_diff
|
|
||||||
else:
|
|
||||||
delta = None
|
|
||||||
else:
|
|
||||||
delta = None
|
|
||||||
|
|
||||||
self.prev_tool_call_arr = tool_call_arr
|
|
||||||
return delta
|
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("An unexpected error occurred",
|
logger.exception("An unexpected error occurred ",
|
||||||
"during streaming tool call handling.")
|
"during streaming tool call handling.")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def _find_tool_start_outside_thinking(self,
|
||||||
|
current_text: str) -> Optional[int]:
|
||||||
|
"""
|
||||||
|
Find the start position of tool calls outside of thinking tags.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
current_text: Current text to search
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Position of tool call start or None if not found
|
||||||
|
"""
|
||||||
|
search_start = 0
|
||||||
|
while True:
|
||||||
|
pos = current_text.find(self.tool_call_start_token, search_start)
|
||||||
|
if pos == -1:
|
||||||
|
return None
|
||||||
|
|
||||||
|
think_regions = [(m.start(), m.end()) for m in re.finditer(
|
||||||
|
r"<think>(.*?)</think>", current_text, flags=re.DOTALL)]
|
||||||
|
in_think = any(pos >= t_start and pos < t_end
|
||||||
|
for t_start, t_end in think_regions)
|
||||||
|
|
||||||
|
if not in_think:
|
||||||
|
return pos
|
||||||
|
|
||||||
|
search_start = pos + 1
|
||||||
|
|
||||||
|
def _extract_content_before_tools(self, current_text: str, delta_text: str,
|
||||||
|
tool_start: int) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Extract content that appears before tool calls.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
current_text: Current text
|
||||||
|
delta_text: Delta text
|
||||||
|
tool_start: Start position of tools
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Content before tools or None
|
||||||
|
"""
|
||||||
|
if tool_start > 0:
|
||||||
|
delta_start_pos = len(current_text) - len(delta_text)
|
||||||
|
if delta_start_pos < tool_start:
|
||||||
|
content_part = delta_text
|
||||||
|
if delta_start_pos + len(delta_text) > tool_start:
|
||||||
|
content_part = delta_text[:tool_start - delta_start_pos]
|
||||||
|
return content_part if content_part else None
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _extract_tool_content(self, current_text: str, tool_start: int) -> str:
|
||||||
|
"""
|
||||||
|
Extract tool content from current text starting at tool_start.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
current_text: Current text
|
||||||
|
tool_start: Start position of tool calls
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Extracted tool content
|
||||||
|
"""
|
||||||
|
tool_content_start = tool_start + len(self.tool_call_start_token)
|
||||||
|
tool_content = current_text[tool_content_start:]
|
||||||
|
|
||||||
|
end_pos = tool_content.find(self.tool_call_end_token)
|
||||||
|
if end_pos != -1:
|
||||||
|
tool_content = tool_content[:end_pos]
|
||||||
|
|
||||||
|
return tool_content
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user