From 363528de27db30c2cde6c9abdf3da1544e8c1fe4 Mon Sep 17 00:00:00 2001 From: qscqesze Date: Thu, 3 Jul 2025 14:48:27 +0800 Subject: [PATCH] [Feature] Support MiniMax-M1 function calls features (#20297) Signed-off-by: QscQ Signed-off-by: qingjun --- docs/features/tool_calling.md | 9 + examples/tool_chat_template_minimax_m1.jinja | 91 +++++ tests/tool_use/test_minimax_tool_parser.py | 371 ++++++++++++++++++ .../openai/tool_parsers/__init__.py | 3 +- .../tool_parsers/minimax_tool_parser.py | 369 +++++++++++++++++ 5 files changed, 842 insertions(+), 1 deletion(-) create mode 100644 examples/tool_chat_template_minimax_m1.jinja create mode 100644 tests/tool_use/test_minimax_tool_parser.py create mode 100644 vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py diff --git a/docs/features/tool_calling.md b/docs/features/tool_calling.md index 29f4b2300bbe..8858b9a4015a 100644 --- a/docs/features/tool_calling.md +++ b/docs/features/tool_calling.md @@ -264,6 +264,15 @@ For Qwen2.5, the chat template in tokenizer_config.json has already included sup Flags: `--tool-call-parser hermes` +### MiniMax Models (`minimax_m1`) + +Supported models: + +* `MiniMaxAi/MiniMax-M1-40k` (use with ) +* `MiniMaxAi/MiniMax-M1-80k` (use with ) + +Flags: `--tool-call-parser minimax --chat-template examples/tool_chat_template_minimax.jinja` + ### DeepSeek-V3 Models (`deepseek_v3`) Supported models: diff --git a/examples/tool_chat_template_minimax_m1.jinja b/examples/tool_chat_template_minimax_m1.jinja new file mode 100644 index 000000000000..2d5bbf4de56f --- /dev/null +++ b/examples/tool_chat_template_minimax_m1.jinja @@ -0,0 +1,91 @@ +{{ '' -}} +{%- if custom_tools is defined %} + {%- set tools = custom_tools %} +{%- endif %} +{%- if not tools is defined %} + {%- set tools = none %} +{%- endif %} + +{#- Extract system message #} +{% set ns = namespace(system_prompt='') -%} +{%- if messages[0]['role'] == 'system' %} + {%- if messages[0]['content'] is string %} + {%- set ns.system_prompt = messages[0]['content']|trim %} + {%- else %} + {%- set ns.system_prompt = messages[0]['content'][0]['text']|trim %} + {%- endif %} + {%- set messages = messages[1:] %} +{%- else %} + {%- if tools is not none %} + {%- set ns.system_prompt = "You are a helpful assistant created by Minimax based on MiniMax-M1 model." %} + {%- else %} + {%- set ns.system_prompt = "You are a helpful assistant created by Minimax based on MiniMax-M1 model." %} + {%- endif %} +{%- endif %} + +{#- System message #} +{%- if ns.system_prompt != '' %} +{{ 'system ai_setting=assistant\n' + ns.system_prompt + '\n' -}} +{%- endif %} + +{#- Tools configuration #} +{%- if tools is not none %} +{{ 'system tool_setting=tools\nYou are provided with these tools:\n\n' -}} +{%- for tool in tools %} +{{ tool | tojson ~ '\n' -}} +{%- endfor %} +{{ '\n\nIf you need to call tools, please respond with XML tags, and provide tool-name and json-object of arguments, following the format below:\n\n{"name": , "arguments": }\n...\n\n' -}} +{%- endif %} + +{#- Process messages #} +{%- for message in messages %} + {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %} + {%- if message['role'] == 'user' %} +{{ 'user name=user\n' -}} +{%- if message['content'] is string %} +{{ message['content']|trim -}} +{%- else %} +{%- for content in message['content'] %} +{%- if content['type'] == 'text' %} +{{ content['text']|trim -}} +{%- endif %} +{%- endfor %} +{%- endif %} +{{ '\n' -}} + {%- elif message['role'] == 'assistant' %} +{{ 'ai name=assistant\n' -}} +{%- if message['content'] is string %} +{{ message['content']|trim -}} +{%- else %} +{%- for content in message['content'] | selectattr('type', 'equalto', 'text') %} +{{ content['text']|trim -}} +{%- endfor %} +{%- endif %} +{{ '\n' -}} + {%- endif %} + {%- elif 'tool_calls' in message %} +{{ 'ai name=assistant\n\n' -}} +{%- for tool_call in message.tool_calls %} +{{ '{"name": "' + tool_call.function.name + '", "arguments": ' + tool_call.function.arguments | tojson + '}\n' -}} +{%- endfor %} +{{ '\n' -}} + {%- elif message.role == "tool" or message.role == "ipython" %} +{{ 'tool name=tools\n' -}} +{%- if message.content is string %} +{{ 'tool result: ' + message.content + '\n\n' -}} +{%- else %} +{%- for content in message['content'] %} +{%- if content['type'] == 'text' %} +{{ 'tool result: ' + content['text'] + '\n\n' -}} +{%- elif content.get('name') %} +{{ 'tool name: ' + content['name'] + '\ntool result: ' + content['text'] + '\n\n' -}} +{%- endif %} +{%- endfor %} +{%- endif %} +{{ '\n' -}} + {%- endif %} +{%- endfor %} + +{%- if add_generation_prompt %} +{{ 'ai name=assistant\n' -}} +{%- endif %} \ No newline at end of file diff --git a/tests/tool_use/test_minimax_tool_parser.py b/tests/tool_use/test_minimax_tool_parser.py new file mode 100644 index 000000000000..0c9a574e03dc --- /dev/null +++ b/tests/tool_use/test_minimax_tool_parser.py @@ -0,0 +1,371 @@ +# SPDX-License-Identifier: Apache-2.0 +# ruff: noqa: E501 + +import json + +import pytest + +from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall +from vllm.entrypoints.openai.tool_parsers import MinimaxToolParser +from vllm.transformers_utils.tokenizer import get_tokenizer + +# Use a common model that is likely to be available +MODEL = "MiniMaxAi/MiniMax-M1-40k" + + +@pytest.fixture(scope="module") +def minimax_tokenizer(): + return get_tokenizer(tokenizer_name=MODEL) + + +@pytest.fixture +def minimax_tool_parser(minimax_tokenizer): + return MinimaxToolParser(minimax_tokenizer) + + +def assert_tool_calls(actual_tool_calls: list[ToolCall], + expected_tool_calls: list[ToolCall]): + assert len(actual_tool_calls) == len(expected_tool_calls) + + for actual_tool_call, expected_tool_call in zip(actual_tool_calls, + expected_tool_calls): + assert isinstance(actual_tool_call.id, str) + assert len(actual_tool_call.id) > 16 + + assert actual_tool_call.type == "function" + assert actual_tool_call.function == expected_tool_call.function + + +def test_extract_tool_calls_no_tools(minimax_tool_parser): + model_output = "This is a test" + extracted_tool_calls = minimax_tool_parser.extract_tool_calls( + model_output, request=None) # type: ignore[arg-type] + assert not extracted_tool_calls.tools_called + assert extracted_tool_calls.tool_calls == [] + assert extracted_tool_calls.content == model_output + + +@pytest.mark.parametrize( + ids=[ + "single_tool_call", + "multiple_tool_calls", + "tool_call_with_content_before", + "tool_call_with_single_line_json", + "tool_call_incomplete_tag", + ], + argnames=["model_output", "expected_tool_calls", "expected_content"], + argvalues=[ + ( + """ +{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}} +""", + [ + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({ + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit", + }), + )) + ], + None, + ), + ( + """ +{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}} +{"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}} +""", + [ + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({ + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit", + }), + )), + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({ + "city": "Orlando", + "state": "FL", + "unit": "fahrenheit", + }), + )), + ], + None, + ), + ( + """I'll help you check the weather. +{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}} +""", + [ + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({ + "city": "Seattle", + "state": "WA", + "unit": "celsius", + }), + )) + ], + "I'll help you check the weather.", + ), + ( + """ +{"name": "get_current_weather", "arguments": {"city": "New York", "state": "NY", "unit": "celsius"}} +""", + [ + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({ + "city": "New York", + "state": "NY", + "unit": "celsius", + }), + )) + ], + None, + ), + ( + """ +{"name": "get_current_weather", "arguments": {"city": "Boston", "state": "MA"}}""", + [ + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({ + "city": "Boston", + "state": "MA", + }), + )) + ], + None, + ), + ], +) +def test_extract_tool_calls(minimax_tool_parser, model_output, + expected_tool_calls, expected_content): + extracted_tool_calls = minimax_tool_parser.extract_tool_calls( + model_output, request=None) # type: ignore[arg-type] + assert extracted_tool_calls.tools_called + + assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls) + + assert extracted_tool_calls.content == expected_content + + +def test_preprocess_model_output_with_thinking_tags(minimax_tool_parser): + """Test that tool calls within thinking tags are removed during preprocessing.""" + model_output = """Let me think about this. +{"name": "fake_tool", "arguments": {"param": "value"}} + This should be removed. + +I'll help you with that. +{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA"}} +""" + + processed_output = minimax_tool_parser.preprocess_model_output( + model_output) + + # The tool call within thinking tags should be removed + assert "fake_tool" not in processed_output + # But the thinking tag itself should remain + assert "" in processed_output + assert "" in processed_output + # The actual tool call outside thinking tags should remain + assert "get_current_weather" in processed_output + + +def test_extract_tool_calls_with_thinking_tags(minimax_tool_parser): + """Test tool extraction when thinking tags contain tool calls that should be ignored.""" + model_output = """I should use a tool. +{"name": "ignored_tool", "arguments": {"should": "ignore"}} + + +Let me help you with the weather. +{"name": "get_current_weather", "arguments": {"city": "Miami", "state": "FL", "unit": "fahrenheit"}} +""" + + extracted_tool_calls = minimax_tool_parser.extract_tool_calls( + model_output, request=None) # type: ignore[arg-type] + + assert extracted_tool_calls.tools_called + assert len(extracted_tool_calls.tool_calls) == 1 + assert extracted_tool_calls.tool_calls[ + 0].function.name == "get_current_weather" + + # Content extraction is based on the position of the first in the original model_output + # Since preprocessing removes tool calls within thinking tags, the actual first is the external one + expected_content = """I should use a tool. +{"name": "ignored_tool", "arguments": {"should": "ignore"}} + + +Let me help you with the weather.""" + assert extracted_tool_calls.content == expected_content + + +def test_extract_tool_calls_invalid_json(minimax_tool_parser): + """Test that invalid JSON in tool calls is handled gracefully.""" + model_output = """ +{"name": "valid_tool", "arguments": {"city": "Seattle"}} +{invalid json here} +{"name": "another_valid_tool", "arguments": {"param": "value"}} +""" + + extracted_tool_calls = minimax_tool_parser.extract_tool_calls( + model_output, request=None) # type: ignore[arg-type] + + assert extracted_tool_calls.tools_called + # Should extract only the valid JSON tool calls + assert len(extracted_tool_calls.tool_calls) == 2 + assert extracted_tool_calls.tool_calls[0].function.name == "valid_tool" + assert extracted_tool_calls.tool_calls[ + 1].function.name == "another_valid_tool" + + +def test_extract_tool_calls_missing_name_or_arguments(minimax_tool_parser): + """Test that tool calls missing name or arguments are filtered out.""" + model_output = """ +{"name": "valid_tool", "arguments": {"city": "Seattle"}} +{"name": "missing_args"} +{"arguments": {"city": "Portland"}} +{"name": "another_valid_tool", "arguments": {"param": "value"}} +""" + + extracted_tool_calls = minimax_tool_parser.extract_tool_calls( + model_output, request=None) # type: ignore[arg-type] + + assert extracted_tool_calls.tools_called + # Should extract only the valid tool calls with both name and arguments + assert len(extracted_tool_calls.tool_calls) == 2 + assert extracted_tool_calls.tool_calls[0].function.name == "valid_tool" + assert extracted_tool_calls.tool_calls[ + 1].function.name == "another_valid_tool" + + +def test_streaming_basic_functionality(minimax_tool_parser): + """Test basic streaming functionality.""" + # Reset streaming state + minimax_tool_parser.current_tool_name_sent = False + minimax_tool_parser.prev_tool_call_arr = [] + minimax_tool_parser.current_tool_id = -1 + minimax_tool_parser.streamed_args_for_tool = [] + + # Test with a simple tool call + current_text = """ +{"name": "get_current_weather", "arguments": {"city": "Seattle"}} +""" + + # First call should handle the initial setup + result = minimax_tool_parser.extract_tool_calls_streaming( + previous_text="", + current_text=current_text, + delta_text="", + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[], + request=None, + ) + + # The result might be None or contain tool call information + # This depends on the internal state management + if result is not None and hasattr(result, + 'tool_calls') and result.tool_calls: + assert len(result.tool_calls) >= 0 + + +def test_streaming_with_content_before_tool_calls(minimax_tool_parser): + """Test streaming when there's content before tool calls.""" + # Reset streaming state + minimax_tool_parser.current_tool_name_sent = False + minimax_tool_parser.prev_tool_call_arr = [] + minimax_tool_parser.current_tool_id = -1 + minimax_tool_parser.streamed_args_for_tool = [] + + current_text = "I'll help you with that. " + + # When there's content before tool calls, it should be returned as content + result = minimax_tool_parser.extract_tool_calls_streaming( + previous_text="I'll help you", + current_text=current_text, + delta_text=" with that. ", + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[], + request=None, + ) + + if result is not None and hasattr(result, 'content'): + # Should contain some content + assert result.content is not None + + +def test_streaming_no_tool_calls(minimax_tool_parser): + """Test streaming when there are no tool calls.""" + current_text = "This is just regular text without any tool calls." + + result = minimax_tool_parser.extract_tool_calls_streaming( + previous_text="This is just regular text", + current_text=current_text, + delta_text=" without any tool calls.", + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[], + request=None, + ) + + # Should return the delta text as content + assert result is not None + assert hasattr(result, 'content') + assert result.content == " without any tool calls." + + +def test_streaming_with_thinking_tags(minimax_tool_parser): + """Test streaming with thinking tags that contain tool calls.""" + # Reset streaming state + minimax_tool_parser.current_tool_name_sent = False + minimax_tool_parser.prev_tool_call_arr = [] + minimax_tool_parser.current_tool_id = -1 + minimax_tool_parser.streamed_args_for_tool = [] + + current_text = """{"name": "ignored", "arguments": {}}{"name": "real_tool", "arguments": {"param": "value"}}""" + + result = minimax_tool_parser.extract_tool_calls_streaming( + previous_text="", + current_text=current_text, + delta_text=current_text, + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[], + request=None, + ) + + # The preprocessing should remove tool calls from thinking tags + # and only process the real tool call + if result is not None and hasattr(result, + 'tool_calls') and result.tool_calls: + for tool_call in result.tool_calls: + assert tool_call.function.name != "ignored" + + +def test_extract_tool_calls_multiline_json_not_supported(minimax_tool_parser): + """Test that multiline JSON in tool calls is not currently supported.""" + model_output = """ +{ + "name": "get_current_weather", + "arguments": { + "city": "New York", + "state": "NY", + "unit": "celsius" + } +} +""" + + extracted_tool_calls = minimax_tool_parser.extract_tool_calls( + model_output, request=None) # type: ignore[arg-type] + + # Multiline JSON is currently not supported, should return no tools called + assert not extracted_tool_calls.tools_called + assert extracted_tool_calls.tool_calls == [] + assert extracted_tool_calls.content is None diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py index 46bd665e767d..57e675515e12 100644 --- a/vllm/entrypoints/openai/tool_parsers/__init__.py +++ b/vllm/entrypoints/openai/tool_parsers/__init__.py @@ -10,6 +10,7 @@ from .internlm2_tool_parser import Internlm2ToolParser from .jamba_tool_parser import JambaToolParser from .llama4_pythonic_tool_parser import Llama4PythonicToolParser from .llama_tool_parser import Llama3JsonToolParser +from .minimax_tool_parser import MinimaxToolParser from .mistral_tool_parser import MistralToolParser from .phi4mini_tool_parser import Phi4MiniJsonToolParser from .pythonic_tool_parser import PythonicToolParser @@ -20,5 +21,5 @@ __all__ = [ "GraniteToolParser", "Hermes2ProToolParser", "MistralToolParser", "Internlm2ToolParser", "Llama3JsonToolParser", "JambaToolParser", "Llama4PythonicToolParser", "PythonicToolParser", "Phi4MiniJsonToolParser", - "DeepSeekV3ToolParser", "xLAMToolParser" + "DeepSeekV3ToolParser", "xLAMToolParser", "MinimaxToolParser" ] diff --git a/vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py new file mode 100644 index 000000000000..6ba32e38fcde --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py @@ -0,0 +1,369 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json +from collections.abc import Sequence +from typing import Union + +import partial_json_parser +import regex as re +from partial_json_parser.core.options import Allow + +from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer + +logger = init_logger(__name__) + + +@ToolParserManager.register_module("minimax") +class MinimaxToolParser(ToolParser): + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + self.current_tool_name_sent: bool = False + self.prev_tool_call_arr: list[dict] = [] + self.current_tool_id: int = -1 + self.streamed_args_for_tool: list[str] = [] + + self.tool_call_start_token: str = "" + self.tool_call_end_token: str = "" + + self.tool_call_regex = re.compile( + r"(.*?)|(.*)", re.DOTALL) + + # Add regex pattern for thinking tag + self.thinking_tag_pattern = r"(.*?)" + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ToolParser " + "constructor during construction.") + + self.tool_call_start_token_id = self.vocab.get( + self.tool_call_start_token) + self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) + + if (self.tool_call_start_token_id is None + or self.tool_call_end_token_id is None): + logger.warning( + "Minimax Tool parser could not locate tool call start/end " + "tokens in the tokenizer. Falling back to string matching.") + + def preprocess_model_output(self, model_output: str) -> str: + """ + Remove tool calls from within thinking tags to avoid processing them. + """ + + def remove_tool_calls_from_think(match): + think_content = match.group(1) + # Remove tool_calls from within the think tag + cleaned_content = re.sub(r".*?", + "", + think_content, + flags=re.DOTALL) + return f"{cleaned_content}" + + # Process thinking tags and remove tool_calls from within them + processed_output = re.sub(self.thinking_tag_pattern, + remove_tool_calls_from_think, + model_output, + flags=re.DOTALL) + + return processed_output + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + + # Preprocess to remove tool calls from thinking tags + processed_output = self.preprocess_model_output(model_output) + + if self.tool_call_start_token not in processed_output: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + try: + function_call_tuples = ( + self.tool_call_regex.findall(processed_output)) + + raw_function_calls = [] + for match in function_call_tuples: + tool_call_content = match[0] if match[0] else match[1] + if tool_call_content.strip(): + lines = tool_call_content.strip().split('\n') + for line in lines: + line = line.strip() + if line and line.startswith('{') and line.endswith( + '}'): + try: + parsed_call = json.loads(line) + raw_function_calls.append(parsed_call) + except json.JSONDecodeError: + continue + + tool_calls = [] + for function_call in raw_function_calls: + if "name" in function_call and "arguments" in function_call: + tool_calls.append( + ToolCall(type="function", + function=FunctionCall( + name=function_call["name"], + arguments=json.dumps( + function_call["arguments"], + ensure_ascii=False)))) + + # Extract content before the first valid tool call + # Find the position in processed output, then map back to original + processed_pos = processed_output.find(self.tool_call_start_token) + if processed_pos != -1: + # Get the content before tool calls in processed output + processed_content = processed_output[:processed_pos].strip() + + if processed_content: + # Find the end of this content in the original output + # Look for the last non-empty line of processed content + lines = processed_content.split('\n') + for line in reversed(lines): + line = line.strip() + if line: + # Find this line in original output + pos = model_output.find(line) + if pos != -1: + content = model_output[:pos + len(line)] + break + else: + content = "" + else: + content = "" + else: + content = model_output + + return ExtractedToolCallInformation( + tools_called=len(tool_calls) > 0, + tool_calls=tool_calls, + content=content.strip() if content.strip() else None) + + except Exception: + logger.exception( + "An unexpected error occurred during tool call extraction.") + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + logger.debug("delta_text: %s", delta_text) + logger.debug("delta_token_ids: %s", delta_token_ids) + + # Preprocess to remove tool calls from thinking tags + processed_current_text = self.preprocess_model_output(current_text) + + if self.tool_call_start_token not in processed_current_text: + return DeltaMessage(content=delta_text) + + if (self.tool_call_start_token_id is not None + and self.tool_call_start_token_id in delta_token_ids + and len(delta_token_ids) == 1): + return None + + original_tool_call_start_pos = current_text.find( + self.tool_call_start_token) + if original_tool_call_start_pos > 0: + delta_start_pos = len(current_text) - len(delta_text) + if delta_start_pos < original_tool_call_start_pos: + content_part = delta_text + if delta_start_pos + len( + delta_text) > original_tool_call_start_pos: + content_part = delta_text[:original_tool_call_start_pos - + delta_start_pos] + if content_part: + return DeltaMessage(content=content_part) + + flags = Allow.ALL if self.current_tool_name_sent \ + else Allow.ALL & ~Allow.STR + + try: + parsable_content = processed_current_text.split( + self.tool_call_start_token)[-1].split( + self.tool_call_end_token)[0] + + tool_call_arr = [] + if parsable_content.strip(): + lines = parsable_content.strip().split('\n') + for line in lines: + line = line.strip() + if line and (line.startswith('{') or '"name"' in line): + try: + if line.endswith('}'): + parsed_call = json.loads(line) + tool_call_arr.append(parsed_call) + else: + parsed_call = partial_json_parser.loads( + line, flags) + if parsed_call and isinstance( + parsed_call, dict): + tool_call_arr.append(parsed_call) + except (json.JSONDecodeError, partial_json_parser.core. + exceptions.MalformedJSON): + continue + + current_tool_call: dict = tool_call_arr[self.current_tool_id] \ + if len(tool_call_arr) > self.current_tool_id >= 0 else {} + + if len(tool_call_arr) == 0: + return None + + # Starting a new tool in the array + elif (len(tool_call_arr) > 0 + and len(tool_call_arr) > self.current_tool_id + 1): + + # Handle any missed arguments from previous tool + if self.current_tool_id >= 0 and self.current_tool_id < len( + self.prev_tool_call_arr): + prev_tool_call = self.prev_tool_call_arr[ + self.current_tool_id] + diff_arguments = prev_tool_call.get("arguments") + + if diff_arguments: + diff_arguments_json = json.dumps(diff_arguments, + ensure_ascii=False) + already_streamed = self.streamed_args_for_tool[ + self. + current_tool_id] if self.current_tool_id < len( + self.streamed_args_for_tool) else "" + + if diff_arguments_json != already_streamed: + diff = diff_arguments_json[len(already_streamed):] + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=diff).model_dump( + exclude_none=True)) + ]) + if self.current_tool_id < len( + self.streamed_args_for_tool): + self.streamed_args_for_tool[ + self.current_tool_id] = diff_arguments_json + else: + delta = None + else: + delta = None + else: + delta = None + + self.current_tool_id = len(tool_call_arr) - 1 + self.current_tool_name_sent = False + self.streamed_args_for_tool.append("") + logger.debug("starting on new tool %d", self.current_tool_id) + return delta + + # Send tool name if not sent yet + if not self.current_tool_name_sent: + function_name = current_tool_call.get("name") + if function_name: + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + type="function", + id=random_tool_call_id(), + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True)) + ]) + self.current_tool_name_sent = True + else: + delta = None + + # Stream arguments + else: + prev_arguments = None + if (self.current_tool_id < len(self.prev_tool_call_arr) + and self.prev_tool_call_arr[self.current_tool_id]): + prev_arguments = self.prev_tool_call_arr[ + self.current_tool_id].get("arguments") + + cur_arguments = current_tool_call.get("arguments") + + if not cur_arguments and not prev_arguments: + delta = None + elif not cur_arguments and prev_arguments: + logger.error( + "Arguments reset mid-call, skipping streaming") + delta = None + elif cur_arguments and not prev_arguments: + cur_arguments_json = json.dumps(cur_arguments, + ensure_ascii=False) + logger.debug("First tokens in arguments received: %s", + cur_arguments_json) + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=cur_arguments_json). + model_dump(exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] = cur_arguments_json + + elif cur_arguments and prev_arguments: + cur_args_json = json.dumps(cur_arguments, + ensure_ascii=False) + prev_args_json = json.dumps(prev_arguments, + ensure_ascii=False) + + logger.debug("Searching for diff between \n%s\n%s", + cur_args_json, prev_args_json) + + already_streamed = self.streamed_args_for_tool[ + self.current_tool_id] if self.current_tool_id < len( + self.streamed_args_for_tool) else "" + + if cur_args_json.startswith(already_streamed): + argument_diff = cur_args_json[len(already_streamed):] + elif cur_args_json != already_streamed: + argument_diff = cur_args_json + self.streamed_args_for_tool[self.current_tool_id] = "" + else: + argument_diff = "" + + if argument_diff: + logger.debug("got arguments diff: %s", argument_diff) + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff). + model_dump(exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += argument_diff + else: + delta = None + else: + delta = None + + self.prev_tool_call_arr = tool_call_arr + return delta + + except Exception: + logger.exception("An unexpected error occurred", + "during streaming tool call handling.") + return None