From 5e9455ae8f33599865f8855b28db2d074ea04eb5 Mon Sep 17 00:00:00 2001 From: qscqesze Date: Thu, 7 Aug 2025 11:30:27 +0800 Subject: [PATCH] [Bugfix]: Fix the streaming output for function calls in the minimax (#22015) Signed-off-by: QscQ Signed-off-by: qingjun --- tests/tool_use/test_minimax_tool_parser.py | 846 ++++++++++++++++- .../tool_parsers/minimax_tool_parser.py | 850 +++++++++++++----- 2 files changed, 1493 insertions(+), 203 deletions(-) diff --git a/tests/tool_use/test_minimax_tool_parser.py b/tests/tool_use/test_minimax_tool_parser.py index 49b8e4b96f1bb..ddf26007121e5 100644 --- a/tests/tool_use/test_minimax_tool_parser.py +++ b/tests/tool_use/test_minimax_tool_parser.py @@ -3,10 +3,12 @@ # ruff: noqa: E501 import json +from typing import Any import pytest -from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall +from vllm.entrypoints.openai.protocol import (ChatCompletionToolsParam, + FunctionCall, ToolCall) from vllm.entrypoints.openai.tool_parsers import MinimaxToolParser from vllm.transformers_utils.tokenizer import get_tokenizer @@ -24,6 +26,57 @@ def minimax_tool_parser(minimax_tokenizer): return MinimaxToolParser(minimax_tokenizer) +@pytest.fixture +def sample_tools(): + return [ + ChatCompletionToolsParam(type="function", + function={ + "name": "get_current_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city name" + }, + "state": { + "type": "string", + "description": + "The state code" + }, + "unit": { + "type": "string", + "enum": + ["fahrenheit", "celsius"] + } + }, + "required": ["city", "state"] + } + }), + ChatCompletionToolsParam(type="function", + function={ + "name": "calculate_area", + "description": + "Calculate area of a shape", + "parameters": { + "type": "object", + "properties": { + "shape": { + "type": "string" + }, + "dimensions": { + "type": "object" + }, + "precision": { + "type": "integer" + } + } + } + }) + ] + + def assert_tool_calls(actual_tool_calls: list[ToolCall], expected_tool_calls: list[ToolCall]): assert len(actual_tool_calls) == len(expected_tool_calls) @@ -370,3 +423,794 @@ def test_extract_tool_calls_multiline_json_not_supported(minimax_tool_parser): assert not extracted_tool_calls.tools_called assert extracted_tool_calls.tool_calls == [] assert extracted_tool_calls.content is None + + +def test_streaming_arguments_incremental_output(minimax_tool_parser): + """Test that streaming arguments are returned incrementally, not cumulatively.""" + # Reset streaming state + minimax_tool_parser.current_tool_name_sent = False + minimax_tool_parser.prev_tool_call_arr = [] + minimax_tool_parser.current_tool_id = -1 + minimax_tool_parser.streamed_args_for_tool = [] + + # Simulate progressive tool call building + stages = [ + # Stage 1: Function name complete + '\n{"name": "get_current_weather", "arguments": ', + # Stage 2: Arguments object starts with first key + '\n{"name": "get_current_weather", "arguments": {"city": ', + # Stage 3: First parameter value added + '\n{"name": "get_current_weather", "arguments": {"city": "Seattle"', + # Stage 4: Second parameter added + '\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA"', + # Stage 5: Third parameter added, arguments complete + '\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}', + # Stage 6: Tool calls closed + '\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}\n\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}\n' + ] + + function_name_sent = False + previous_args_content = "" + + for i, current_text in enumerate(stages): + previous_text = stages[i - 1] if i > 0 else "" + delta_text = current_text[len(previous_text + ):] if i > 0 else current_text + + result = minimax_tool_parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[], + request=None, + ) + + print(f"Stage {i}: Current text: {repr(current_text)}") + print(f"Stage {i}: Delta text: {repr(delta_text)}") + + if result is not None and hasattr(result, + 'tool_calls') and result.tool_calls: + tool_call = result.tool_calls[0] + + # Check if function name is sent (should happen only once) + if tool_call.function and tool_call.function.name: + assert tool_call.function.name == "get_current_weather" + function_name_sent = True + print( + f"Stage {i}: Function name sent: {tool_call.function.name}" + ) + + # Check if arguments are sent incrementally + if tool_call.function and tool_call.function.arguments: + args_fragment = tool_call.function.arguments + print( + f"Stage {i}: Got arguments fragment: {repr(args_fragment)}" + ) + + # For incremental output, each fragment should be new content only + # The fragment should not contain all previous content + if i >= 2 and previous_args_content: # After we start getting arguments + # The new fragment should not be identical to or contain all previous content + assert args_fragment != previous_args_content, f"Fragment should be incremental, not cumulative: {args_fragment}" + + # If this is truly incremental, the fragment should be relatively small + # compared to the complete arguments so far + if len(args_fragment) > len(previous_args_content): + print( + "Warning: Fragment seems cumulative rather than incremental" + ) + + previous_args_content = args_fragment + + # Verify function name was sent at least once + assert function_name_sent, "Function name should have been sent" + + +def test_streaming_arguments_delta_only(minimax_tool_parser): + """Test that each streaming call returns only the delta (new part) of arguments.""" + # Reset streaming state + minimax_tool_parser.current_tool_name_sent = False + minimax_tool_parser.prev_tool_call_arr = [] + minimax_tool_parser.current_tool_id = -1 + minimax_tool_parser.streamed_args_for_tool = [] + + # Simulate two consecutive calls with growing arguments + call1_text = '\n{"name": "test_tool", "arguments": {"param1": "value1"}}' + call2_text = '\n{"name": "test_tool", "arguments": {"param1": "value1", "param2": "value2"}}' + + print(f"Call 1 text: {repr(call1_text)}") + print(f"Call 2 text: {repr(call2_text)}") + + # First call - should get the function name and initial arguments + result1 = minimax_tool_parser.extract_tool_calls_streaming( + previous_text="", + current_text=call1_text, + delta_text=call1_text, + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[], + request=None, + ) + + print(f"Result 1: {result1}") + if result1 and hasattr(result1, 'tool_calls') and result1.tool_calls: + for i, tc in enumerate(result1.tool_calls): + print(f" Tool call {i}: {tc}") + + # Second call - should only get the delta (new part) of arguments + result2 = minimax_tool_parser.extract_tool_calls_streaming( + previous_text=call1_text, + current_text=call2_text, + delta_text=', "param2": "value2"}', + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[], + request=None, + ) + + print(f"Result 2: {result2}") + if result2 and hasattr(result2, 'tool_calls') and result2.tool_calls: + for i, tc in enumerate(result2.tool_calls): + print(f" Tool call {i}: {tc}") + + # Verify the second call only returns the delta + if result2 is not None and hasattr(result2, + 'tool_calls') and result2.tool_calls: + tool_call = result2.tool_calls[0] + if tool_call.function and tool_call.function.arguments: + args_delta = tool_call.function.arguments + print(f"Arguments delta from second call: {repr(args_delta)}") + + # Should only contain the new part, not the full arguments + # The delta should be something like ', "param2": "value2"}' or just '"param2": "value2"' + assert ', "param2": "value2"}' in args_delta or '"param2": "value2"' in args_delta, f"Expected delta containing param2, got: {args_delta}" + + # Should NOT contain the previous parameter data + assert '"param1": "value1"' not in args_delta, f"Arguments delta should not contain previous data: {args_delta}" + + # The delta should be relatively short (incremental, not cumulative) + expected_max_length = len( + ', "param2": "value2"}') + 10 # Some tolerance + assert len( + args_delta + ) <= expected_max_length, f"Delta seems too long (possibly cumulative): {args_delta}" + + print("✓ Delta validation passed") + else: + print("No arguments in result2 tool call") + else: + print("No tool calls in result2 or result2 is None") + # This might be acceptable if no incremental update is needed + # But let's at least verify that result1 had some content + assert result1 is not None, "At least the first call should return something" + + +def test_streaming_openai_compatibility(minimax_tool_parser): + """Test that streaming behavior with buffering works correctly.""" + # Reset streaming state + minimax_tool_parser.current_tool_name_sent = False + minimax_tool_parser.prev_tool_call_arr = [] + minimax_tool_parser.current_tool_id = -1 + minimax_tool_parser.streamed_args_for_tool = [] + # Reset buffering state + minimax_tool_parser.pending_buffer = "" + minimax_tool_parser.in_thinking_tag = False + minimax_tool_parser.thinking_depth = 0 + + # Test scenario: simple buffering without complex tool call context + test_cases: list[dict[str, Any]] = [ + { + 'stage': 'Token: <', + 'previous': '', + 'current': '<', + 'delta': '<', + 'expected_content': None, # Should be buffered + }, + { + 'stage': 'Token: tool_calls>', + 'previous': '<', + 'current': '', + 'delta': 'tool_calls>', + 'expected_content': None, # Complete tag, should not output + }, + { + 'stage': 'Regular content', + 'previous': 'Hello', + 'current': 'Hello world', + 'delta': ' world', + 'expected_content': ' world', # Normal content should pass through + }, + { + 'stage': 'Content with end tag start', + 'previous': 'Text', + 'current': 'Text content', + 'delta': 'calls>', + 'expected_content': None, # Complete close tag, should not output + }, + ] + + for i, test_case in enumerate(test_cases): + print(f"\n--- Stage {i}: {test_case['stage']} ---") + print(f"Previous: {repr(test_case['previous'])}") + print(f"Current: {repr(test_case['current'])}") + print(f"Delta: {repr(test_case['delta'])}") + + result = minimax_tool_parser.extract_tool_calls_streaming( + previous_text=test_case['previous'], + current_text=test_case['current'], + delta_text=test_case['delta'], + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[], + request=None, + ) + + print(f"Result: {result}") + + # Check expected content + if test_case['expected_content'] is None: + assert result is None or not getattr(result, 'content', None), \ + f"Stage {i}: Expected no content, got {result}" + print("✓ No content output as expected") + else: + assert result is not None and hasattr(result, 'content'), \ + f"Stage {i}: Expected content, got {result}" + assert result.content == test_case['expected_content'], \ + f"Stage {i}: Expected content {test_case['expected_content']}, got {result.content}" + print(f"✓ Content matches: {repr(result.content)}") + + print("✓ Streaming test with buffering completed successfully") + + +def test_streaming_thinking_tag_buffering(minimax_tool_parser): + """Test that tool calls within thinking tags are properly handled during streaming.""" + # Reset streaming state + minimax_tool_parser.current_tool_name_sent = False + minimax_tool_parser.prev_tool_call_arr = [] + minimax_tool_parser.current_tool_id = -1 + minimax_tool_parser.streamed_args_for_tool = [] + # Reset buffering state + minimax_tool_parser.pending_buffer = "" + minimax_tool_parser.in_thinking_tag = False + minimax_tool_parser.thinking_depth = 0 + + # Test scenario: tool calls within thinking tags should be ignored + test_cases: list[dict[str, Any]] = [ + { + 'stage': 'Start thinking', + 'previous': '', + 'current': 'I need to use a tool. ', + 'delta': 'I need to use a tool. ', + 'expected_content': + 'I need to use a tool. ', # Should pass through as content + }, + { + 'stage': + 'Tool call in thinking', + 'previous': + 'I need to use a tool. ', + 'current': + 'I need to use a tool. \n{"name": "ignored_tool", "arguments": {"param": "value"}}\n', + 'delta': + '\n{"name": "ignored_tool", "arguments": {"param": "value"}}\n', + 'expected_content': + '\n{"name": "ignored_tool", "arguments": {"param": "value"}}\n', # should be preserved in thinking tags + }, + { + 'stage': 'Real tool call after thinking', + 'previous': + 'I need to use a tool. \n{"name": "ignored_tool", "arguments": {"param": "value"}}\n', + 'current': + 'I need to use a tool. \n{"name": "ignored_tool", "arguments": {"param": "value"}}\n\n', + 'delta': '\n', + 'expected_content': + '\n', # Should output '\n' and suppress + } + ] + + for i, test_case in enumerate(test_cases): + print(f"\n--- Stage {i}: {test_case['stage']} ---") + print(f"Previous: {repr(test_case['previous'])}") + print(f"Current: {repr(test_case['current'])}") + print(f"Delta: {repr(test_case['delta'])}") + + result = minimax_tool_parser.extract_tool_calls_streaming( + previous_text=test_case['previous'], + current_text=test_case['current'], + delta_text=test_case['delta'], + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[], + request=None, + ) + + print(f"Result: {result}") + + # Check expected content + if 'expected_content' in test_case: + if test_case['expected_content'] is None: + assert result is None or not getattr(result, 'content', None), \ + f"Stage {i}: Expected no content, got {result}" + else: + assert result is not None and hasattr(result, 'content'), \ + f"Stage {i}: Expected content, got {result}" + assert result.content == test_case['expected_content'], \ + f"Stage {i}: Expected content {test_case['expected_content']}, got {result.content}" + print(f"✓ Content matches: {repr(result.content)}") + + # Check tool calls + if test_case.get('expected_tool_call'): + assert result is not None and hasattr(result, 'tool_calls') and result.tool_calls, \ + f"Stage {i}: Expected tool call, got {result}" + + tool_call = result.tool_calls[0] + assert tool_call.function.name == "real_tool", \ + f"Expected real_tool, got {tool_call.function.name}" + print(f"✓ Real tool call detected: {tool_call.function.name}") + + print("✓ Thinking tag buffering test completed successfully") + + +def reset_streaming_state(minimax_tool_parser): + """Helper function to properly reset the streaming state for MinimaxToolParser.""" + # Reset minimax-specific state + minimax_tool_parser._reset_streaming_state() + + # Reset base class state (these should still be reset for compatibility) + minimax_tool_parser.prev_tool_call_arr = [] + minimax_tool_parser.current_tool_id = -1 + minimax_tool_parser.current_tool_name_sent = False + minimax_tool_parser.streamed_args_for_tool = [] + + +def test_streaming_complex_scenario_with_multiple_tools(minimax_tool_parser): + """Test complex streaming scenario: tools inside tags and multiple tool calls in one group.""" + # Reset streaming state + reset_streaming_state(minimax_tool_parser) + + # Complex scenario: tools inside thinking tags and multiple tools in one group + test_stages: list[dict[str, Any]] = [ + { + 'stage': 'Initial content', + 'previous': '', + 'current': 'Let me help you with this task.', + 'delta': 'Let me help you with this task.', + 'expected_content': 'Let me help you with this task.', + 'expected_tool_calls': 0, + }, + { + 'stage': 'Start thinking tag', + 'previous': 'Let me help you with this task.', + 'current': + 'Let me help you with this task.I need to analyze this situation first.', + 'delta': 'I need to analyze this situation first.', + 'expected_content': + 'I need to analyze this situation first.', + 'expected_tool_calls': 0, + }, + { + 'stage': 'Tool call inside thinking tag starts', + 'previous': + 'Let me help you with this task.I need to analyze this situation first.', + 'current': + 'Let me help you with this task.I need to analyze this situation first.', + 'delta': '', + 'expected_content': + '', # Inside thinking tags, tool tags should be preserved as content + 'expected_tool_calls': 0, + }, + { + 'stage': 'Complete tool call inside thinking tag', + 'previous': + 'Let me help you with this task.I need to analyze this situation first.', + 'current': + 'Let me help you with this task.I need to analyze this situation first.\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n', + 'delta': + '\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n', + 'expected_content': + '\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n', + 'expected_tool_calls': + 0, # Tools inside thinking tags should be ignored + }, + { + 'stage': 'End thinking tag', + 'previous': + 'Let me help you with this task.I need to analyze this situation first.\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n', + 'current': + 'Let me help you with this task.I need to analyze this situation first.\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n', + 'delta': '', + 'expected_content': '', + 'expected_tool_calls': 0, + }, + { + 'stage': 'Multiple tools group starts', + 'previous': + 'Let me help you with this task.I need to analyze this situation first.\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n', + 'current': + 'Let me help you with this task.I need to analyze this situation first.\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n\nNow I need to get weather information and calculate area.', + 'delta': + '\nNow I need to get weather information and calculate area.', + 'expected_content': + '\nNow I need to get weather information and calculate area.', # should be filtered + 'expected_tool_calls': 0, + }, + { + 'stage': 'First tool in group', + 'previous': + 'Let me help you with this task.I need to analyze this situation first.\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n\nNow I need to get weather information and calculate area.', + 'current': + 'Let me help you with this task.I need to analyze this situation first.\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n\nNow I need to get weather information and calculate area.\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}', + 'delta': + '\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}', + 'expected_content': + None, # No content should be output when tool call is in progress + 'expected_tool_calls': 1, + 'expected_tool_name': 'get_current_weather', + }, + { + 'stage': 'Second tool in group', + 'previous': + 'Let me help you with this task.I need to analyze this situation first.\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n\nNow I need to get weather information and calculate area.\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}', + 'current': + 'Let me help you with this task.I need to analyze this situation first.\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n\nNow I need to get weather information and calculate area.\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}\n{"name": "calculate_area", "arguments": {"shape": "rectangle", "dimensions": {"width": 10, "height": 5}}}', + 'delta': + '\n{"name": "calculate_area", "arguments": {"shape": "rectangle", "dimensions": {"width": 10, "height": 5}}}', + 'expected_content': None, + 'expected_tool_calls': 1, + 'expected_tool_name': 'calculate_area', + }, + { + 'stage': 'Complete tool calls group', + 'previous': + 'Let me help you with this task.I need to analyze this situation first.\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n\nNow I need to get weather information and calculate area.\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}\n{"name": "calculate_area", "arguments": {"shape": "rectangle", "dimensions": {"width": 10, "height": 5}}}', + 'current': + 'Let me help you with this task.I need to analyze this situation first.\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n\nNow I need to get weather information and calculate area.\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}\n{"name": "calculate_area", "arguments": {"shape": "rectangle", "dimensions": {"width": 10, "height": 5}}}', + 'delta': '', + 'expected_content': None, + 'expected_tool_calls': 0, + } + ] + + tool_calls_count = 0 + + for i, test_case in enumerate(test_stages): + print(f"\n--- Stage {i}: {test_case['stage']} ---") + print( + f"Previous: {repr(test_case['previous'][:100])}{'...' if len(test_case['previous']) > 100 else ''}" + ) + print(f"Current: {repr(test_case['current'][-100:])}") + print(f"Delta: {repr(test_case['delta'])}") + + result = minimax_tool_parser.extract_tool_calls_streaming( + previous_text=test_case['previous'], + current_text=test_case['current'], + delta_text=test_case['delta'], + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[], + request=None, + ) + + print(f"Result: {result}") + + # Check expected content + if test_case['expected_content'] is None: + assert result is None or not getattr(result, 'content', None), \ + f"Stage {i}: Expected no content output, got {result}" + print("✓ No content output as expected") + else: + assert result is not None and hasattr(result, 'content'), \ + f"Stage {i}: Expected content output, got {result}" + assert result.content == test_case['expected_content'], \ + f"Stage {i}: Expected content {repr(test_case['expected_content'])}, got {repr(result.content)}" + print(f"✓ Content matches: {repr(result.content)}") + + # Check tool calls + expected_tool_calls = test_case['expected_tool_calls'] + actual_tool_calls = len(result.tool_calls) if result and hasattr( + result, 'tool_calls') and result.tool_calls else 0 + + if expected_tool_calls > 0: + assert actual_tool_calls >= expected_tool_calls, \ + f"Stage {i}: Expected at least {expected_tool_calls} tool calls, got {actual_tool_calls}" + + if 'expected_tool_name' in test_case: + # Find the tool call with the expected name + found_tool_call = None + for tool_call in result.tool_calls: + if tool_call.function.name == test_case[ + 'expected_tool_name']: + found_tool_call = tool_call + break + + assert found_tool_call is not None, \ + f"Stage {i}: Expected tool name {test_case['expected_tool_name']} not found in tool calls: {[tc.function.name for tc in result.tool_calls]}" + print(f"✓ Tool call correct: {found_tool_call.function.name}") + + # Ensure tools inside thinking tags are not called + assert found_tool_call.function.name != "internal_analysis", \ + f"Stage {i}: Tool 'internal_analysis' inside thinking tags should not be called" + + tool_calls_count += actual_tool_calls + print(f"✓ Detected {actual_tool_calls} tool calls") + else: + assert actual_tool_calls == 0, \ + f"Stage {i}: Expected no tool calls, got {actual_tool_calls}" + + # Verify overall results + print("\n=== Test Summary ===") + print(f"Total tool calls count: {tool_calls_count}") + assert tool_calls_count >= 2, f"Expected at least 2 valid tool calls (outside thinking tags), but got {tool_calls_count}" + + print("✓ Complex streaming test completed:") + print(" - ✓ Tools inside thinking tags correctly ignored") + print(" - ✓ Two tool groups outside thinking tags correctly parsed") + print(" - ✓ Content and tool call streaming correctly handled") + print(" - ✓ Buffering mechanism works correctly") + + +def test_streaming_character_by_character_output(minimax_tool_parser): + """Test character-by-character streaming output to simulate real streaming scenarios.""" + # Reset streaming state + reset_streaming_state(minimax_tool_parser) + + # Complete text that will be streamed character by character + complete_text = """I'll help you with the weather analysis. Let me think about this. +{"name": "internal_analysis", "arguments": {"type": "thinking"}} +This tool should be ignored. + +Now I'll get the weather information for you. +{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}} +{"name": "calculate_area", "arguments": {"shape": "rectangle", "dimensions": {"width": 10, "height": 5}}} +Here are the results.""" + + print("\n=== Starting character-by-character streaming test ===") + print(f"Complete text length: {len(complete_text)} characters") + + # Track the streaming results + content_fragments = [] + tool_calls_detected = [] + + # Stream character by character + for i in range(1, len(complete_text) + 1): + current_text = complete_text[:i] + previous_text = complete_text[:i - 1] if i > 1 else "" + delta_text = complete_text[i - 1:i] + + # Show progress every 50 characters + if i % 50 == 0 or i == len(complete_text): + print(f"Progress: {i}/{len(complete_text)} characters") + + # Call the streaming parser + result = minimax_tool_parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[], + request=None, + ) + + # Collect results + if result is not None: + if hasattr(result, 'content') and result.content: + content_fragments.append(result.content) + # Log important content fragments + if any( + keyword in result.content for keyword in + ['', '', '', '']): + print( + f" Char {i}: Content fragment: {repr(result.content)}" + ) + + if hasattr(result, 'tool_calls') and result.tool_calls: + for tool_call in result.tool_calls: + tool_info = { + 'character_position': + i, + 'function_name': + tool_call.function.name + if tool_call.function else None, + 'arguments': + tool_call.function.arguments + if tool_call.function else None, + } + tool_calls_detected.append(tool_info) + print( + f" Char {i}: Tool call detected: {tool_call.function.name}" + ) + if tool_call.function.arguments: + print( + f" Arguments: {repr(tool_call.function.arguments)}" + ) + + # Verify results + print("\n=== Streaming Test Results ===") + print(f"Total content fragments: {len(content_fragments)}") + print(f"Total tool calls detected: {len(tool_calls_detected)}") + + # Reconstruct content from fragments + reconstructed_content = ''.join(content_fragments) + print(f"Reconstructed content length: {len(reconstructed_content)}") + + # Verify thinking tags content is preserved + assert '' in reconstructed_content, "Opening thinking tag should be preserved in content" + assert '' in reconstructed_content, "Closing thinking tag should be preserved in content" + + # Verify that tool calls inside thinking tags are NOT extracted as actual tool calls + thinking_tool_calls = [ + tc for tc in tool_calls_detected + if tc['function_name'] == 'internal_analysis' + ] + assert len( + thinking_tool_calls + ) == 0, f"Tool calls inside thinking tags should be ignored, but found: {thinking_tool_calls}" + + # Verify that real tool calls outside thinking tags ARE extracted + weather_tool_calls = [ + tc for tc in tool_calls_detected + if tc['function_name'] == 'get_current_weather' + ] + area_tool_calls = [ + tc for tc in tool_calls_detected + if tc['function_name'] == 'calculate_area' + ] + print(tool_calls_detected) + assert len(weather_tool_calls + ) > 0, "get_current_weather tool call should be detected" + assert len( + area_tool_calls) > 0, "calculate_area tool call should be detected" + + # Verify tool call arguments are properly streamed + weather_args_found = any(tc['arguments'] for tc in weather_tool_calls + if tc['arguments']) + area_args_found = any(tc['arguments'] for tc in area_tool_calls + if tc['arguments']) + + print(f"Weather tool call with arguments: {weather_args_found}") + print(f"Area tool call with arguments: {area_args_found}") + + # Verify content before and after tool calls + assert 'I\'ll help you with the weather analysis.' in reconstructed_content, "Initial content should be preserved" + assert 'Here are the results.' in reconstructed_content, "Final content should be preserved" + + # Verify that and tags are not included in the final content + # (they should be filtered out when not inside thinking tags) + content_outside_thinking = reconstructed_content + # Remove thinking tag content to check content outside + if '' in content_outside_thinking and '' in content_outside_thinking: + start_think = content_outside_thinking.find('') + end_think = content_outside_thinking.find('') + len('') + content_outside_thinking = content_outside_thinking[: + start_think] + content_outside_thinking[ + end_think:] + + # Outside thinking tags, tool_calls tags should be filtered + tool_calls_in_content = content_outside_thinking.count('') + assert tool_calls_in_content == 0, f" tags should be filtered from content outside thinking tags, but found {tool_calls_in_content}" + + print( + "\n=== Character-by-character streaming test completed successfully ===" + ) + print("✓ Tool calls inside thinking tags correctly ignored") + print("✓ Tool calls outside thinking tags correctly detected") + print("✓ Content properly streamed and reconstructed") + print("✓ Tool call tags properly filtered from content") + print("✓ Character-level streaming works correctly") + + +def test_streaming_character_by_character_simple_tool_call( + minimax_tool_parser): + """Test character-by-character streaming for a simple tool call scenario.""" + # Reset streaming state + reset_streaming_state(minimax_tool_parser) + + # Simple tool call text + simple_text = 'Let me check the weather. \n{"name": "get_weather", "arguments": {"city": "NYC"}}\n' + + print("\n=== Simple character-by-character test ===") + print(f"Text: {repr(simple_text)}") + + content_parts = [] + tool_name_sent = False + tool_args_sent = False + + for i in range(1, len(simple_text) + 1): + current_text = simple_text[:i] + previous_text = simple_text[:i - 1] if i > 1 else "" + delta_text = simple_text[i - 1:i] + + result = minimax_tool_parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[], + request=None, + ) + + if result: + if hasattr(result, 'content') and result.content: + content_parts.append(result.content) + print( + f" Char {i} ({repr(delta_text)}): Content: {repr(result.content)}" + ) + + if hasattr(result, 'tool_calls') and result.tool_calls: + for tool_call in result.tool_calls: + if tool_call.function and tool_call.function.name: + tool_name_sent = True + print( + f" Char {i}: Tool name: {tool_call.function.name}" + ) + if tool_call.function and tool_call.function.arguments: + tool_args_sent = True + print( + f" Char {i}: Tool args: {repr(tool_call.function.arguments)}" + ) + + # Verify basic expectations + reconstructed_content = ''.join(content_parts) + print(f"Final reconstructed content: {repr(reconstructed_content)}") + + assert tool_name_sent, "Tool name should be sent during streaming" + assert tool_args_sent, "Tool arguments should be sent during streaming" + assert "Let me check the weather." in reconstructed_content, "Initial content should be preserved" + + print("✓ Simple character-by-character test passed") + + +def test_streaming_character_by_character_with_buffering(minimax_tool_parser): + """Test character-by-character streaming with edge cases that trigger buffering.""" + # Reset streaming state + reset_streaming_state(minimax_tool_parser) + + # Text that includes potential buffering scenarios + buffering_text = 'Hello world\n{"name": "test"}\ndone' + + print("\n=== Buffering character-by-character test ===") + print(f"Text: {repr(buffering_text)}") + + all_content = [] + + for i in range(1, len(buffering_text) + 1): + current_text = buffering_text[:i] + previous_text = buffering_text[:i - 1] if i > 1 else "" + delta_text = buffering_text[i - 1:i] + + result = minimax_tool_parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[], + request=None, + ) + + if result and hasattr(result, 'content') and result.content: + all_content.append(result.content) + print(f" Char {i} ({repr(delta_text)}): {repr(result.content)}") + + final_content = ''.join(all_content) + print(f"Final content: {repr(final_content)}") + + # The parser should handle the edge case where appears before + assert "Hello" in final_content, "Initial 'Hello' should be preserved" + assert "world" in final_content, "Content after false closing tag should be preserved" + assert "done" in final_content, "Final content should be preserved" + + print("✓ Buffering character-by-character test passed") diff --git a/vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py index 6ba32e38fcde2..226309ef293a9 100644 --- a/vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py @@ -3,11 +3,9 @@ import json from collections.abc import Sequence -from typing import Union +from typing import Any, Optional, Union -import partial_json_parser import regex as re -from partial_json_parser.core.options import Allow from vllm.entrypoints.chat_utils import random_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, @@ -17,6 +15,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, FunctionCall, ToolCall) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ToolParser, ToolParserManager) +from vllm.entrypoints.openai.tool_parsers.utils import ( + extract_intermediate_diff) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -29,25 +29,32 @@ class MinimaxToolParser(ToolParser): def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) - self.current_tool_name_sent: bool = False - self.prev_tool_call_arr: list[dict] = [] - self.current_tool_id: int = -1 - self.streamed_args_for_tool: list[str] = [] - - self.tool_call_start_token: str = "" - self.tool_call_end_token: str = "" + # Initialize streaming state for tracking tool call progress + self.streaming_state: dict[str, Any] = { + "current_tool_index": -1, # Index of current tool being processed + "tool_ids": [], # List of tool call IDs + "sent_tools": [], # List of tools that have been sent + } + # Define tool call tokens and patterns + self.tool_call_start_token = "" + self.tool_call_end_token = "" self.tool_call_regex = re.compile( r"(.*?)|(.*)", re.DOTALL) - - # Add regex pattern for thinking tag self.thinking_tag_pattern = r"(.*?)" + self.tool_name_pattern = re.compile(r'"name":\s*"([^"]+)"') + self.tool_args_pattern = re.compile(r'"arguments":\s*') + + # Buffer for handling partial tool calls during streaming + self.pending_buffer = "" + self.in_thinking_tag = False if not self.model_tokenizer: raise ValueError( "The model tokenizer must be passed to the ToolParser " "constructor during construction.") + # Get token IDs for tool call start/end tokens self.tool_call_start_token_id = self.vocab.get( self.tool_call_start_token) self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) @@ -60,33 +67,95 @@ class MinimaxToolParser(ToolParser): def preprocess_model_output(self, model_output: str) -> str: """ - Remove tool calls from within thinking tags to avoid processing them. + Preprocess model output by removing tool calls from thinking tags. + + Args: + model_output: Raw model output string + + Returns: + Preprocessed model output with tool calls removed from thinking tags """ def remove_tool_calls_from_think(match): think_content = match.group(1) - # Remove tool_calls from within the think tag cleaned_content = re.sub(r".*?", "", think_content, flags=re.DOTALL) return f"{cleaned_content}" - # Process thinking tags and remove tool_calls from within them - processed_output = re.sub(self.thinking_tag_pattern, - remove_tool_calls_from_think, - model_output, - flags=re.DOTALL) + return re.sub(self.thinking_tag_pattern, + remove_tool_calls_from_think, + model_output, + flags=re.DOTALL) - return processed_output + def _clean_duplicate_braces(self, args_text: str) -> str: + """ + Clean duplicate closing braces from arguments text. + + Args: + args_text: Raw arguments text + + Returns: + Cleaned arguments text with proper JSON formatting + """ + args_text = args_text.strip() + if not args_text: + return args_text + + try: + json.loads(args_text) + return args_text + except json.JSONDecodeError: + pass + + while args_text.endswith('}}'): + candidate = args_text[:-1] + try: + json.loads(candidate) + return candidate + except json.JSONDecodeError: + args_text = candidate + + return args_text + + def _clean_delta_braces(self, delta_text: str) -> str: + """ + Clean delta text by removing excessive closing braces. + + Args: + delta_text: Delta text to clean + + Returns: + Cleaned delta text + """ + if not delta_text: + return delta_text + + delta_stripped = delta_text.strip() + + if delta_stripped and all(c in '}\n\r\t ' for c in delta_stripped): + brace_count = delta_stripped.count('}') + if brace_count > 1: + return '}\n' if delta_text.endswith('\n') else '}' + + return delta_text def extract_tool_calls( self, model_output: str, request: ChatCompletionRequest, ) -> ExtractedToolCallInformation: - - # Preprocess to remove tool calls from thinking tags + """ + Extract tool calls from model output for non-streaming mode. + + Args: + model_output: Complete model output + request: Chat completion request + + Returns: + ExtractedToolCallInformation containing tool calls and content + """ processed_output = self.preprocess_model_output(model_output) if self.tool_call_start_token not in processed_output: @@ -95,8 +164,8 @@ class MinimaxToolParser(ToolParser): content=model_output) try: - function_call_tuples = ( - self.tool_call_regex.findall(processed_output)) + function_call_tuples = self.tool_call_regex.findall( + processed_output) raw_function_calls = [] for match in function_call_tuples: @@ -124,21 +193,15 @@ class MinimaxToolParser(ToolParser): function_call["arguments"], ensure_ascii=False)))) - # Extract content before the first valid tool call - # Find the position in processed output, then map back to original processed_pos = processed_output.find(self.tool_call_start_token) if processed_pos != -1: - # Get the content before tool calls in processed output processed_content = processed_output[:processed_pos].strip() if processed_content: - # Find the end of this content in the original output - # Look for the last non-empty line of processed content lines = processed_content.split('\n') for line in reversed(lines): line = line.strip() if line: - # Find this line in original output pos = model_output.find(line) if pos != -1: content = model_output[:pos + len(line)] @@ -162,6 +225,445 @@ class MinimaxToolParser(ToolParser): tool_calls=[], content=model_output) + def _update_thinking_state(self, text: str) -> None: + """ + Update the thinking tag state based on text content. + + Args: + text: Text to analyze for thinking tags + """ + open_count = text.count("") + close_count = text.count("") + self.in_thinking_tag = open_count > close_count or ( + open_count == close_count and text.endswith("")) + + def _is_potential_tag_start(self, text: str) -> bool: + """ + Check if text might be the start of a tool call tag. + + Args: + text: Text to check + + Returns: + True if text could be the start of a tool call tag + """ + for tag in [self.tool_call_start_token, self.tool_call_end_token]: + if any( + tag.startswith(text[-i:]) + for i in range(1, min(len(text) + 1, len(tag)))): + return True + return False + + def _should_buffer_content(self, delta_text: str) -> bool: + """ + Determine if content should be buffered for later processing. + + Args: + delta_text: Delta text to check + + Returns: + True if content should be buffered + """ + if self.in_thinking_tag: + return False + return bool(self.pending_buffer + or self.tool_call_start_token in delta_text + or self.tool_call_end_token in delta_text + or delta_text.startswith('<')) + + def _split_content_for_buffering(self, delta_text: str) -> tuple[str, str]: + """ + Split delta text into safe content and potential tag content. + + Args: + delta_text: Delta text to split + + Returns: + Tuple of (safe_content, potential_tag_content) + """ + if self.in_thinking_tag: + return delta_text, "" + + for tag in [self.tool_call_start_token, self.tool_call_end_token]: + for i in range(1, len(tag)): + tag_prefix = tag[:i] + pos = delta_text.rfind(tag_prefix) + if pos != -1 and tag.startswith(delta_text[pos:]): + return delta_text[:pos], delta_text[pos:] + return delta_text, "" + + def _process_buffer(self, new_content: str) -> str: + """ + Process buffered content and return output content. + + Args: + new_content: New content to add to buffer + + Returns: + Processed output content + """ + self.pending_buffer += new_content + output_content = "" + + if self.in_thinking_tag: + output_content = self.pending_buffer + self.pending_buffer = "" + return output_content + + while self.pending_buffer: + start_pos = self.pending_buffer.find(self.tool_call_start_token) + end_pos = self.pending_buffer.find(self.tool_call_end_token) + + if start_pos != -1 and (end_pos == -1 or start_pos < end_pos): + tag_pos, tag_len = start_pos, len(self.tool_call_start_token) + elif end_pos != -1: + tag_pos, tag_len = end_pos, len(self.tool_call_end_token) + else: + if self._is_potential_tag_start(self.pending_buffer): + break + output_content += self.pending_buffer + self.pending_buffer = "" + break + + output_content += self.pending_buffer[:tag_pos] + self.pending_buffer = self.pending_buffer[tag_pos + tag_len:] + + return output_content + + def _reset_streaming_state(self) -> None: + """Reset the streaming state to initial values.""" + self.streaming_state = { + "current_tool_index": -1, + "tool_ids": [], + "sent_tools": [], + } + + def _advance_to_next_tool(self) -> None: + """Advance to the next tool in the streaming sequence.""" + self.streaming_state["current_tool_index"] = int( + self.streaming_state["current_tool_index"]) + 1 + + def _set_current_tool_index(self, index: int) -> None: + """ + Set the current tool index. + + Args: + index: Tool index to set + """ + self.streaming_state["current_tool_index"] = index + + def _get_current_tool_index(self) -> int: + """ + Get the current tool index. + + Returns: + Current tool index + """ + return int(self.streaming_state["current_tool_index"]) + + def _get_next_unsent_tool_index(self, tool_count: int) -> int: + """ + Get the index of the next unsent tool. + + Args: + tool_count: Total number of tools + + Returns: + Index of next unsent tool, or -1 if all tools sent + """ + sent_tools = list(self.streaming_state["sent_tools"]) + for i in range(tool_count): + if i < len(sent_tools): + if not sent_tools[i]["sent_name"]: + return i + else: + return i + return -1 + + def _ensure_state_arrays(self, tool_count: int) -> None: + """ + Ensure state arrays have sufficient capacity for tool_count tools. + + Args: + tool_count: Number of tools to prepare for + """ + sent_tools = list(self.streaming_state["sent_tools"]) + tool_ids = list(self.streaming_state["tool_ids"]) + + while len(sent_tools) < tool_count: + sent_tools.append({ + "sent_name": False, + "sent_arguments": "", + "id": random_tool_call_id(), + }) + + while len(tool_ids) < tool_count: + tool_ids.append(None) + + self.streaming_state["sent_tools"] = sent_tools + self.streaming_state["tool_ids"] = tool_ids + + def _detect_tools_in_text(self, text: str) -> int: + """ + Detect the number of tools in text by counting name patterns. + + Args: + text: Text to analyze + + Returns: + Number of tools detected + """ + matches = self.tool_name_pattern.findall(text) + return len(matches) + + def _find_tool_boundaries(self, text: str) -> list[tuple[int, int]]: + """ + Find the boundaries of tool calls in text. + + Args: + text: Text to analyze + + Returns: + List of (start, end) positions for tool calls + """ + boundaries = [] + i = 0 + while i < len(text): + if text[i] == '{': + start = i + depth = 0 + has_name = False + has_arguments = False + + while i < len(text): + if text[i] == '{': + depth += 1 + elif text[i] == '}': + depth -= 1 + if depth == 0: + end = i + 1 + segment = text[start:end] + if '"name"' in segment and '"arguments"' in segment: + boundaries.append((start, end)) + break + + if not has_name and '"name"' in text[start:i + 1]: + has_name = True + if not has_arguments and '"arguments"' in text[start:i + + 1]: + has_arguments = True + + i += 1 + + if depth > 0 and has_name: + boundaries.append((start, i)) + else: + i += 1 + return boundaries + + def _extract_tool_args(self, tool_content: str, args_match) -> str: + """ + Extract tool arguments from tool content. + + Args: + tool_content: Tool call content + args_match: Regex match for arguments pattern + + Returns: + Extracted arguments as string + """ + args_start_pos = args_match.end() + remaining_content = tool_content[args_start_pos:] + + if remaining_content.strip().startswith('{'): + depth = 0 + for i, char in enumerate(remaining_content): + if char == '{': + depth += 1 + elif char == '}': + depth -= 1 + if depth == 0: + return remaining_content[:i + 1] + else: + args_end = remaining_content.find('}') + if args_end > 0: + return remaining_content[:args_end].strip() + + return remaining_content.rstrip('}').strip() + + def _get_current_tool_content( + self, text: str, + tool_index: int) -> tuple[Optional[str], Optional[str]]: + """ + Get the content of a specific tool by index. + + Args: + text: Text containing tool calls + tool_index: Index of tool to extract + + Returns: + Tuple of (tool_name, tool_arguments) or (None, None) if not found + """ + boundaries = self._find_tool_boundaries(text) + + if tool_index >= len(boundaries): + return None, None + + start, end = boundaries[tool_index] + tool_content = text[start:end] + + name_match = self.tool_name_pattern.search(tool_content) + name = name_match.group(1) if name_match else None + + args_match = self.tool_args_pattern.search(tool_content) + if args_match: + try: + args_text = self._extract_tool_args(tool_content, args_match) + return name, args_text + except Exception: + remaining_content = tool_content[args_match.end():] + args_text = remaining_content.rstrip('}').strip() + return name, args_text + + return name, None + + def _handle_tool_name_streaming( + self, tool_content: str, + tool_count: int) -> Union[DeltaMessage, None]: + """ + Handle streaming of tool names. + + Args: + tool_content: Content containing tool calls + tool_count: Total number of tools + + Returns: + DeltaMessage with tool name or None if no tool to stream + """ + next_idx = self._get_next_unsent_tool_index(tool_count) + + if next_idx == -1: + return None + + boundaries = self._find_tool_boundaries(tool_content) + if next_idx >= len(boundaries): + return None + + tool_name, _ = self._get_current_tool_content(tool_content, next_idx) + if not tool_name: + return None + + self._set_current_tool_index(next_idx) + sent_tools = list(self.streaming_state["sent_tools"]) + tool_ids = list(self.streaming_state["tool_ids"]) + + tool_id = sent_tools[next_idx]["id"] + tool_ids[next_idx] = tool_id + sent_tools[next_idx]["sent_name"] = True + + self.streaming_state["sent_tools"] = sent_tools + self.streaming_state["tool_ids"] = tool_ids + + return DeltaMessage(tool_calls=[ + DeltaToolCall(index=next_idx, + type="function", + id=tool_id, + function=DeltaFunctionCall( + name=tool_name).model_dump(exclude_none=True)) + ]) + + def _handle_tool_args_streaming( + self, tool_content: str, + tool_count: int) -> Union[DeltaMessage, None]: + """ + Handle streaming of tool arguments. + + Args: + tool_content: Content containing tool calls + tool_count: Total number of tools + + Returns: + DeltaMessage with tool arguments or None if no arguments to stream + """ + current_idx = self._get_current_tool_index() + + if current_idx < 0 or current_idx >= tool_count: + return None + + tool_name, tool_args = self._get_current_tool_content( + tool_content, current_idx) + if not tool_name or tool_args is None: + return None + + sent_tools = list(self.streaming_state["sent_tools"]) + + if not sent_tools[current_idx]["sent_name"]: + return None + + clean_args = self._clean_duplicate_braces(tool_args) + sent_args = sent_tools[current_idx]["sent_arguments"] + + if clean_args != sent_args: + if sent_args and clean_args.startswith(sent_args): + args_delta = extract_intermediate_diff(clean_args, sent_args) + if args_delta: + args_delta = self._clean_delta_braces(args_delta) + sent_tools[current_idx]["sent_arguments"] = clean_args + self.streaming_state["sent_tools"] = sent_tools + + if clean_args.endswith('}'): + self._advance_to_next_tool() + + return DeltaMessage(tool_calls=[ + DeltaToolCall(index=current_idx, + function=DeltaFunctionCall( + arguments=args_delta).model_dump( + exclude_none=True)) + ]) + elif not sent_args and clean_args: + clean_args_delta = self._clean_delta_braces(clean_args) + sent_tools[current_idx]["sent_arguments"] = clean_args + self.streaming_state["sent_tools"] = sent_tools + + if clean_args.endswith('}'): + self._advance_to_next_tool() + + return DeltaMessage(tool_calls=[ + DeltaToolCall(index=current_idx, + function=DeltaFunctionCall( + arguments=clean_args_delta).model_dump( + exclude_none=True)) + ]) + + return None + + def _is_end_tool_calls(self, current_text: str) -> bool: + if self.tool_call_end_token not in current_text: + return False + + end_token_positions = [] + search_start = 0 + while True: + pos = current_text.find(self.tool_call_end_token, search_start) + if pos == -1: + break + end_token_positions.append(pos) + search_start = pos + 1 + + think_regions = [] + for match in re.finditer(self.thinking_tag_pattern, + current_text, + flags=re.DOTALL): + think_regions.append((match.start(), match.end())) + + for pos in end_token_positions: + in_think = any(pos >= t_start and pos < t_end + for t_start, t_end in think_regions) + if not in_think: + return True + + return False + def extract_tool_calls_streaming( self, previous_text: str, @@ -172,13 +674,37 @@ class MinimaxToolParser(ToolParser): delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: - logger.debug("delta_text: %s", delta_text) - logger.debug("delta_token_ids: %s", delta_token_ids) + self._update_thinking_state(current_text) + + if self.in_thinking_tag: + return DeltaMessage(content=delta_text) + + if self._should_buffer_content(delta_text): + buffered_output = self._process_buffer(delta_text) + return DeltaMessage( + content=buffered_output) if buffered_output else None + + if self._is_end_tool_calls(current_text): + return DeltaMessage(content=delta_text) + + safe_content, potential_tag = self._split_content_for_buffering( + delta_text) + if potential_tag: + self.pending_buffer += potential_tag + return DeltaMessage(content=safe_content) if safe_content else None - # Preprocess to remove tool calls from thinking tags processed_current_text = self.preprocess_model_output(current_text) if self.tool_call_start_token not in processed_current_text: + if (self.tool_call_end_token in delta_text + and self.tool_call_start_token in current_text): + return None + if delta_text.strip( + ) == '' and self.tool_call_start_token in current_text: + return None + if (self._get_current_tool_index() != -1 + and self.tool_call_end_token in current_text): + self._reset_streaming_state() return DeltaMessage(content=delta_text) if (self.tool_call_start_token_id is not None @@ -186,184 +712,104 @@ class MinimaxToolParser(ToolParser): and len(delta_token_ids) == 1): return None - original_tool_call_start_pos = current_text.find( - self.tool_call_start_token) - if original_tool_call_start_pos > 0: - delta_start_pos = len(current_text) - len(delta_text) - if delta_start_pos < original_tool_call_start_pos: - content_part = delta_text - if delta_start_pos + len( - delta_text) > original_tool_call_start_pos: - content_part = delta_text[:original_tool_call_start_pos - - delta_start_pos] - if content_part: - return DeltaMessage(content=content_part) + original_tool_start = self._find_tool_start_outside_thinking( + current_text) + if original_tool_start is None: + return None - flags = Allow.ALL if self.current_tool_name_sent \ - else Allow.ALL & ~Allow.STR + content_before_tools = self._extract_content_before_tools( + current_text, delta_text, original_tool_start) + if content_before_tools: + return DeltaMessage(content=content_before_tools) try: - parsable_content = processed_current_text.split( - self.tool_call_start_token)[-1].split( - self.tool_call_end_token)[0] + tool_content = self._extract_tool_content(current_text, + original_tool_start) + current_tools_count = self._detect_tools_in_text(tool_content) - tool_call_arr = [] - if parsable_content.strip(): - lines = parsable_content.strip().split('\n') - for line in lines: - line = line.strip() - if line and (line.startswith('{') or '"name"' in line): - try: - if line.endswith('}'): - parsed_call = json.loads(line) - tool_call_arr.append(parsed_call) - else: - parsed_call = partial_json_parser.loads( - line, flags) - if parsed_call and isinstance( - parsed_call, dict): - tool_call_arr.append(parsed_call) - except (json.JSONDecodeError, partial_json_parser.core. - exceptions.MalformedJSON): - continue - - current_tool_call: dict = tool_call_arr[self.current_tool_id] \ - if len(tool_call_arr) > self.current_tool_id >= 0 else {} - - if len(tool_call_arr) == 0: + if current_tools_count == 0: return None - # Starting a new tool in the array - elif (len(tool_call_arr) > 0 - and len(tool_call_arr) > self.current_tool_id + 1): + if self._get_current_tool_index() == -1: + self._reset_streaming_state() - # Handle any missed arguments from previous tool - if self.current_tool_id >= 0 and self.current_tool_id < len( - self.prev_tool_call_arr): - prev_tool_call = self.prev_tool_call_arr[ - self.current_tool_id] - diff_arguments = prev_tool_call.get("arguments") + self._ensure_state_arrays(current_tools_count) - if diff_arguments: - diff_arguments_json = json.dumps(diff_arguments, - ensure_ascii=False) - already_streamed = self.streamed_args_for_tool[ - self. - current_tool_id] if self.current_tool_id < len( - self.streamed_args_for_tool) else "" - - if diff_arguments_json != already_streamed: - diff = diff_arguments_json[len(already_streamed):] - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=diff).model_dump( - exclude_none=True)) - ]) - if self.current_tool_id < len( - self.streamed_args_for_tool): - self.streamed_args_for_tool[ - self.current_tool_id] = diff_arguments_json - else: - delta = None - else: - delta = None - else: - delta = None - - self.current_tool_id = len(tool_call_arr) - 1 - self.current_tool_name_sent = False - self.streamed_args_for_tool.append("") - logger.debug("starting on new tool %d", self.current_tool_id) - return delta - - # Send tool name if not sent yet - if not self.current_tool_name_sent: - function_name = current_tool_call.get("name") - if function_name: - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - type="function", - id=random_tool_call_id(), - function=DeltaFunctionCall( - name=function_name).model_dump( - exclude_none=True)) - ]) - self.current_tool_name_sent = True - else: - delta = None - - # Stream arguments - else: - prev_arguments = None - if (self.current_tool_id < len(self.prev_tool_call_arr) - and self.prev_tool_call_arr[self.current_tool_id]): - prev_arguments = self.prev_tool_call_arr[ - self.current_tool_id].get("arguments") - - cur_arguments = current_tool_call.get("arguments") - - if not cur_arguments and not prev_arguments: - delta = None - elif not cur_arguments and prev_arguments: - logger.error( - "Arguments reset mid-call, skipping streaming") - delta = None - elif cur_arguments and not prev_arguments: - cur_arguments_json = json.dumps(cur_arguments, - ensure_ascii=False) - logger.debug("First tokens in arguments received: %s", - cur_arguments_json) - - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=cur_arguments_json). - model_dump(exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] = cur_arguments_json - - elif cur_arguments and prev_arguments: - cur_args_json = json.dumps(cur_arguments, - ensure_ascii=False) - prev_args_json = json.dumps(prev_arguments, - ensure_ascii=False) - - logger.debug("Searching for diff between \n%s\n%s", - cur_args_json, prev_args_json) - - already_streamed = self.streamed_args_for_tool[ - self.current_tool_id] if self.current_tool_id < len( - self.streamed_args_for_tool) else "" - - if cur_args_json.startswith(already_streamed): - argument_diff = cur_args_json[len(already_streamed):] - elif cur_args_json != already_streamed: - argument_diff = cur_args_json - self.streamed_args_for_tool[self.current_tool_id] = "" - else: - argument_diff = "" - - if argument_diff: - logger.debug("got arguments diff: %s", argument_diff) - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=argument_diff). - model_dump(exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] += argument_diff - else: - delta = None - else: - delta = None - - self.prev_tool_call_arr = tool_call_arr - return delta + return (self._handle_tool_name_streaming(tool_content, + current_tools_count) + or self._handle_tool_args_streaming( + tool_content, current_tools_count)) except Exception: - logger.exception("An unexpected error occurred", + logger.exception("An unexpected error occurred ", "during streaming tool call handling.") return None + + def _find_tool_start_outside_thinking(self, + current_text: str) -> Optional[int]: + """ + Find the start position of tool calls outside of thinking tags. + + Args: + current_text: Current text to search + + Returns: + Position of tool call start or None if not found + """ + search_start = 0 + while True: + pos = current_text.find(self.tool_call_start_token, search_start) + if pos == -1: + return None + + think_regions = [(m.start(), m.end()) for m in re.finditer( + r"(.*?)", current_text, flags=re.DOTALL)] + in_think = any(pos >= t_start and pos < t_end + for t_start, t_end in think_regions) + + if not in_think: + return pos + + search_start = pos + 1 + + def _extract_content_before_tools(self, current_text: str, delta_text: str, + tool_start: int) -> Optional[str]: + """ + Extract content that appears before tool calls. + + Args: + current_text: Current text + delta_text: Delta text + tool_start: Start position of tools + + Returns: + Content before tools or None + """ + if tool_start > 0: + delta_start_pos = len(current_text) - len(delta_text) + if delta_start_pos < tool_start: + content_part = delta_text + if delta_start_pos + len(delta_text) > tool_start: + content_part = delta_text[:tool_start - delta_start_pos] + return content_part if content_part else None + return None + + def _extract_tool_content(self, current_text: str, tool_start: int) -> str: + """ + Extract tool content from current text starting at tool_start. + + Args: + current_text: Current text + tool_start: Start position of tool calls + + Returns: + Extracted tool content + """ + tool_content_start = tool_start + len(self.tool_call_start_token) + tool_content = current_text[tool_content_start:] + + end_pos = tool_content.find(self.tool_call_end_token) + if end_pos != -1: + tool_content = tool_content[:end_pos] + + return tool_content