diff --git a/tests/tool_use/test_mistral_tool_parser.py b/tests/tool_use/test_mistral_tool_parser.py index e5deb7f40eb35..2dd0399cb8eeb 100644 --- a/tests/tool_use/test_mistral_tool_parser.py +++ b/tests/tool_use/test_mistral_tool_parser.py @@ -615,6 +615,7 @@ def test_extract_tool_calls_streaming( "single_tool_weather", "multiple_tool_calls", "content_before_tool", + "complex", ], argnames=["model_output", "expected_tool_calls", "expected_content"], argvalues=[ @@ -673,6 +674,21 @@ def test_extract_tool_calls_streaming( ], "bla", ), + ( + # Complex + """[TOOL_CALLS]bash{"command": "print(\\"hello world!\\")\\nre.compile(r\'{}\')"}""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="bash", + arguments=json.dumps( + {"command": "print(\"hello world!\")\nre.compile(r'{}')"} + ), + ) + ) + ], + "", + ), ], ) def test_extract_tool_calls_streaming_one_chunk( diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index aa5089ffe84d7..bc827f045606c 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -99,12 +99,7 @@ class MistralToolParser(ToolParser): self.bot_token = "[TOOL_CALLS]" self.bot_token_id = self.vocab.get(self.bot_token) self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL) - if not _is_pre_v11_tokeniser(self.model_tokenizer): - self.fn_name_regex = re.compile( - r"([a-zA-Z0-9_-]+)(\{[\s\S]*?\}+)", re.DOTALL - ) - else: - self.fn_name_regex = None + self._is_pre_v11 = _is_pre_v11_tokeniser(self.model_tokenizer) if self.bot_token_id is None: raise RuntimeError( @@ -148,23 +143,24 @@ class MistralToolParser(ToolParser): tool_content = model_output.replace(self.bot_token, "").strip() try: - # we first try to directly load the json as parsing very nested - # jsons is difficult try: - if self.fn_name_regex: + if not self._is_pre_v11: function_call_arr = [] for single_tool_content in model_output.split(self.bot_token): - matches = self.fn_name_regex.findall(single_tool_content) + if "{" not in single_tool_content: + continue - for match in matches: - fn_name = match[0] - args = match[1] + end_name = single_tool_content.find("{") + fn_name, args = ( + single_tool_content[:end_name], + single_tool_content[end_name:], + ) - # fn_name is encoded outside serialized json dump - # only arguments are serialized - function_call_arr.append( - {"name": fn_name, "arguments": json.loads(args)} - ) + # fn_name is encoded outside serialized json dump + # only arguments are serialized + function_call_arr.append( + {"name": fn_name, "arguments": json.loads(args)} + ) else: function_call_arr = json.loads(tool_content) except json.JSONDecodeError: