Fix edge case Mistral tool parser (#30724)

Signed-off-by: Joachim Studnia <joachim@mistral.ai>
Signed-off-by: Joachim Studnia <studniajoachim@gmail.com>
Signed-off-by: juliendenize <julien.denize@mistral.ai>
Signed-off-by: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: juliendenize <julien.denize@mistral.ai>
Co-authored-by: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Joachim Studnia 2025-12-23 15:19:58 +01:00 committed by GitHub
parent bb62dda2c3
commit 38c361f99d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 115 additions and 56 deletions

View File

@ -281,6 +281,8 @@ def test_extract_tool_calls_pre_v11_tokenizer(
"single_tool_add", "single_tool_add",
"single_tool_weather", "single_tool_weather",
"multiple_tool_calls", "multiple_tool_calls",
"complex",
"wrong_json",
], ],
argnames=["model_output", "expected_tool_calls", "expected_content"], argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[ argvalues=[
@ -326,6 +328,36 @@ def test_extract_tool_calls_pre_v11_tokenizer(
], ],
None, None,
), ),
(
# Complex
"""hi{hi[TOOL_CALLS]bash{"command": "print(\\"hello world!\\")\\nre.compile(r\'{}\')""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="bash",
arguments=json.dumps(
{"command": "print(\"hello world!\")\nre.compile(r'{}')"}
)[:-2],
)
)
],
"hi{hi",
),
(
# Wrong json
"""hi{hi[TOOL_CALLS]bash{"command": "print(\\"hello world!\\")\\nre.compile(r\'{}\')"}""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="bash",
arguments=json.dumps(
{"command": "print(\"hello world!\")\nre.compile(r'{}')"}
),
)
)
],
"hi{hi",
),
], ],
) )
def test_extract_tool_calls( def test_extract_tool_calls(
@ -673,7 +705,7 @@ def test_extract_tool_calls_streaming(
), ),
( (
# Complex # Complex
"""[TOOL_CALLS]bash{"command": "print(\\"hello world!\\")\\nre.compile(r\'{}\')"}""", # noqa: E501 """hi{hi[TOOL_CALLS]bash{"command": "print(\\"hello world!\\")\\nre.compile(r\'{}\')"}""", # noqa: E501
[ [
ToolCall( ToolCall(
function=FunctionCall( function=FunctionCall(
@ -684,7 +716,7 @@ def test_extract_tool_calls_streaming(
) )
) )
], ],
"", "hi{hi",
), ),
], ],
) )

View File

@ -131,78 +131,105 @@ class MistralToolParser(ToolParser):
request: ChatCompletionRequest, request: ChatCompletionRequest,
) -> ExtractedToolCallInformation: ) -> ExtractedToolCallInformation:
""" """
Extract the tool calls from a complete model response. Requires Extract the tool calls from a complete model response.
find-and-replacing single quotes with double quotes for JSON parsing,
make sure your tool call arguments don't ever include quotes! Content and tool calls formatting depends on the Mistral's tokenizer version
used to train the model:
- < v11: `content[BOT] [{tool_call1},{tool_call2}]`
- >= v11: `content[BOT]tool_name1{args_call1}[BOT]tool_name2{args_call2}`
with [BOT] the tool call token.
Note:
For tokenizer versions >= v11, tool calls with arguments wrongly formatted
are still returned as tool calls. This is to allow the model to know it
tried to make a tool call. It reduces chance of another failure and
prevents that the context is filled with tool calls wrongly placed in
assistant message contents.
""" """
# case -- if a tool call token is not present, return a text response # If the tool call token is not present, return a text response
if self.bot_token not in model_output: if self.bot_token not in model_output:
return ExtractedToolCallInformation( return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output tools_called=False, tool_calls=[], content=model_output
) )
# first remove the BOT token content_and_raw_tool_calls = model_output.split(self.bot_token)
tool_content = model_output.replace(self.bot_token, "").strip() content = content_and_raw_tool_calls[0]
raw_tool_calls = content_and_raw_tool_calls[1:]
try: # >= v11: content[BOT]tool_name1{args_call1}[BOT]tool_name2{args_call2}
if not self._is_pre_v11:
tool_calls = []
for raw_tool_call in raw_tool_calls:
if "{" not in raw_tool_call:
continue
end_name = raw_tool_call.find("{")
tool_name, args = (
raw_tool_call[:end_name],
raw_tool_call[end_name:],
)
tool_calls.append({"name": tool_name, "arguments": args})
# < v11: content[BOT] [{tool_call1},{tool_call2}]
else:
if len(raw_tool_calls) != 1:
raise ValueError(
"Only one BOT token should have been outputted, "
f"but got {model_output}."
)
stringified_tool_calls = raw_tool_calls[0].strip()
try: try:
if not self._is_pre_v11: tool_calls = json.loads(stringified_tool_calls)
function_call_arr = []
for single_tool_content in model_output.split(self.bot_token):
if "{" not in single_tool_content:
continue
end_name = single_tool_content.find("{")
fn_name, args = (
single_tool_content[:end_name],
single_tool_content[end_name:],
)
# fn_name is encoded outside serialized json dump
# only arguments are serialized
function_call_arr.append(
{"name": fn_name, "arguments": json.loads(args)}
)
else:
function_call_arr = json.loads(tool_content)
except json.JSONDecodeError: except json.JSONDecodeError:
# use a regex to find the part corresponding to the tool call. # use a regex to find the part corresponding to the tool call.
# NOTE: This use case should not happen if the model is trained # NOTE: This use case should not happen if the model is trained
# correctly. It's an easy possible fix so it's included, but # correctly. It's an easy possible fix so it's included, but
# can be brittle for very complex / highly nested tool calls # can be brittle for very complex / highly nested tool calls
raw_tool_call = self.tool_call_regex.findall(tool_content)[0] try:
function_call_arr = json.loads(raw_tool_call) raw_tool_call = self.tool_call_regex.findall(
stringified_tool_calls
# Tool Call )[0]
tool_calls: list[MistralToolCall] = [ tool_calls = json.loads(raw_tool_call)
MistralToolCall( except (IndexError, json.JSONDecodeError):
type="function", logger.exception("Error in extracting tool call from response: {e}")
function=FunctionCall( # If raw decoding and decoding post regex rule fails, then just
name=raw_function_call["name"], # return content.
# function call args are JSON but as a string return ExtractedToolCallInformation(
arguments=json.dumps( tools_called=False,
raw_function_call["arguments"], ensure_ascii=False tool_calls=[],
content=stringified_tool_calls,
)
else:
tool_calls = [
{
"name": tool_call["name"],
"arguments": json.dumps(
tool_call["arguments"], ensure_ascii=False
), ),
), }
) for tool_call in tool_calls
for raw_function_call in function_call_arr ]
]
# get any content before the tool call mistral_tool_calls: list[MistralToolCall] = [
content = model_output.split(self.bot_token)[0] MistralToolCall(
return ExtractedToolCallInformation( type="function",
tools_called=True, function=FunctionCall(
tool_calls=tool_calls, name=tool_call["name"],
content=content if len(content) > 0 else None, arguments=tool_call["arguments"],
),
) )
for tool_call in tool_calls
]
except Exception: return ExtractedToolCallInformation(
logger.exception("Error in extracting tool call from response.") tools_called=True,
# return information to just treat the tool call as regular JSON tool_calls=mistral_tool_calls,
return ExtractedToolCallInformation( content=content if len(content) > 0 else None,
tools_called=False, tool_calls=[], content=tool_content )
)
def extract_tool_calls_streaming( def extract_tool_calls_streaming(
self, self,