diff --git a/tests/tool_parsers/test_mistral_tool_parser.py b/tests/tool_parsers/test_mistral_tool_parser.py index 9400a67267f4c..d2502079d0de9 100644 --- a/tests/tool_parsers/test_mistral_tool_parser.py +++ b/tests/tool_parsers/test_mistral_tool_parser.py @@ -281,6 +281,8 @@ def test_extract_tool_calls_pre_v11_tokenizer( "single_tool_add", "single_tool_weather", "multiple_tool_calls", + "complex", + "wrong_json", ], argnames=["model_output", "expected_tool_calls", "expected_content"], argvalues=[ @@ -326,6 +328,36 @@ def test_extract_tool_calls_pre_v11_tokenizer( ], None, ), + ( + # Complex + """hi{hi[TOOL_CALLS]bash{"command": "print(\\"hello world!\\")\\nre.compile(r\'{}\')""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="bash", + arguments=json.dumps( + {"command": "print(\"hello world!\")\nre.compile(r'{}')"} + )[:-2], + ) + ) + ], + "hi{hi", + ), + ( + # Wrong json + """hi{hi[TOOL_CALLS]bash{"command": "print(\\"hello world!\\")\\nre.compile(r\'{}\')"}""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="bash", + arguments=json.dumps( + {"command": "print(\"hello world!\")\nre.compile(r'{}')"} + ), + ) + ) + ], + "hi{hi", + ), ], ) def test_extract_tool_calls( @@ -673,7 +705,7 @@ def test_extract_tool_calls_streaming( ), ( # Complex - """[TOOL_CALLS]bash{"command": "print(\\"hello world!\\")\\nre.compile(r\'{}\')"}""", # noqa: E501 + """hi{hi[TOOL_CALLS]bash{"command": "print(\\"hello world!\\")\\nre.compile(r\'{}\')"}""", # noqa: E501 [ ToolCall( function=FunctionCall( @@ -684,7 +716,7 @@ def test_extract_tool_calls_streaming( ) ) ], - "", + "hi{hi", ), ], ) diff --git a/vllm/tool_parsers/mistral_tool_parser.py b/vllm/tool_parsers/mistral_tool_parser.py index 49a175f69f434..35b853b0ad7e1 100644 --- a/vllm/tool_parsers/mistral_tool_parser.py +++ b/vllm/tool_parsers/mistral_tool_parser.py @@ -131,78 +131,105 @@ class MistralToolParser(ToolParser): request: ChatCompletionRequest, ) -> ExtractedToolCallInformation: """ - Extract the tool calls from a complete model response. Requires - find-and-replacing single quotes with double quotes for JSON parsing, - make sure your tool call arguments don't ever include quotes! + Extract the tool calls from a complete model response. + + Content and tool calls formatting depends on the Mistral's tokenizer version + used to train the model: + + - < v11: `content[BOT] [{tool_call1},{tool_call2}]` + - >= v11: `content[BOT]tool_name1{args_call1}[BOT]tool_name2{args_call2}` + + with [BOT] the tool call token. + + Note: + For tokenizer versions >= v11, tool calls with arguments wrongly formatted + are still returned as tool calls. This is to allow the model to know it + tried to make a tool call. It reduces chance of another failure and + prevents that the context is filled with tool calls wrongly placed in + assistant message contents. """ - # case -- if a tool call token is not present, return a text response + # If the tool call token is not present, return a text response if self.bot_token not in model_output: return ExtractedToolCallInformation( tools_called=False, tool_calls=[], content=model_output ) - # first remove the BOT token - tool_content = model_output.replace(self.bot_token, "").strip() + content_and_raw_tool_calls = model_output.split(self.bot_token) + content = content_and_raw_tool_calls[0] + raw_tool_calls = content_and_raw_tool_calls[1:] - try: + # >= v11: content[BOT]tool_name1{args_call1}[BOT]tool_name2{args_call2} + if not self._is_pre_v11: + tool_calls = [] + for raw_tool_call in raw_tool_calls: + if "{" not in raw_tool_call: + continue + + end_name = raw_tool_call.find("{") + tool_name, args = ( + raw_tool_call[:end_name], + raw_tool_call[end_name:], + ) + + tool_calls.append({"name": tool_name, "arguments": args}) + + # < v11: content[BOT] [{tool_call1},{tool_call2}] + else: + if len(raw_tool_calls) != 1: + raise ValueError( + "Only one BOT token should have been outputted, " + f"but got {model_output}." + ) + stringified_tool_calls = raw_tool_calls[0].strip() try: - if not self._is_pre_v11: - function_call_arr = [] - for single_tool_content in model_output.split(self.bot_token): - if "{" not in single_tool_content: - continue - - end_name = single_tool_content.find("{") - fn_name, args = ( - single_tool_content[:end_name], - single_tool_content[end_name:], - ) - - # fn_name is encoded outside serialized json dump - # only arguments are serialized - function_call_arr.append( - {"name": fn_name, "arguments": json.loads(args)} - ) - else: - function_call_arr = json.loads(tool_content) + tool_calls = json.loads(stringified_tool_calls) except json.JSONDecodeError: # use a regex to find the part corresponding to the tool call. # NOTE: This use case should not happen if the model is trained # correctly. It's an easy possible fix so it's included, but # can be brittle for very complex / highly nested tool calls - raw_tool_call = self.tool_call_regex.findall(tool_content)[0] - function_call_arr = json.loads(raw_tool_call) - - # Tool Call - tool_calls: list[MistralToolCall] = [ - MistralToolCall( - type="function", - function=FunctionCall( - name=raw_function_call["name"], - # function call args are JSON but as a string - arguments=json.dumps( - raw_function_call["arguments"], ensure_ascii=False + try: + raw_tool_call = self.tool_call_regex.findall( + stringified_tool_calls + )[0] + tool_calls = json.loads(raw_tool_call) + except (IndexError, json.JSONDecodeError): + logger.exception("Error in extracting tool call from response: {e}") + # If raw decoding and decoding post regex rule fails, then just + # return content. + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=stringified_tool_calls, + ) + else: + tool_calls = [ + { + "name": tool_call["name"], + "arguments": json.dumps( + tool_call["arguments"], ensure_ascii=False ), - ), - ) - for raw_function_call in function_call_arr - ] + } + for tool_call in tool_calls + ] - # get any content before the tool call - content = model_output.split(self.bot_token)[0] - return ExtractedToolCallInformation( - tools_called=True, - tool_calls=tool_calls, - content=content if len(content) > 0 else None, + mistral_tool_calls: list[MistralToolCall] = [ + MistralToolCall( + type="function", + function=FunctionCall( + name=tool_call["name"], + arguments=tool_call["arguments"], + ), ) + for tool_call in tool_calls + ] - except Exception: - logger.exception("Error in extracting tool call from response.") - # return information to just treat the tool call as regular JSON - return ExtractedToolCallInformation( - tools_called=False, tool_calls=[], content=tool_content - ) + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=mistral_tool_calls, + content=content if len(content) > 0 else None, + ) def extract_tool_calls_streaming( self,