mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-31 11:03:04 +08:00
Fix edge case Mistral tool parser (#30724)
Signed-off-by: Joachim Studnia <joachim@mistral.ai> Signed-off-by: Joachim Studnia <studniajoachim@gmail.com> Signed-off-by: juliendenize <julien.denize@mistral.ai> Signed-off-by: Julien Denize <40604584+juliendenize@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: juliendenize <julien.denize@mistral.ai> Co-authored-by: Julien Denize <40604584+juliendenize@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
bb62dda2c3
commit
38c361f99d
@ -281,6 +281,8 @@ def test_extract_tool_calls_pre_v11_tokenizer(
|
||||
"single_tool_add",
|
||||
"single_tool_weather",
|
||||
"multiple_tool_calls",
|
||||
"complex",
|
||||
"wrong_json",
|
||||
],
|
||||
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||
argvalues=[
|
||||
@ -326,6 +328,36 @@ def test_extract_tool_calls_pre_v11_tokenizer(
|
||||
],
|
||||
None,
|
||||
),
|
||||
(
|
||||
# Complex
|
||||
"""hi{hi[TOOL_CALLS]bash{"command": "print(\\"hello world!\\")\\nre.compile(r\'{}\')""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="bash",
|
||||
arguments=json.dumps(
|
||||
{"command": "print(\"hello world!\")\nre.compile(r'{}')"}
|
||||
)[:-2],
|
||||
)
|
||||
)
|
||||
],
|
||||
"hi{hi",
|
||||
),
|
||||
(
|
||||
# Wrong json
|
||||
"""hi{hi[TOOL_CALLS]bash{"command": "print(\\"hello world!\\")\\nre.compile(r\'{}\')"}""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="bash",
|
||||
arguments=json.dumps(
|
||||
{"command": "print(\"hello world!\")\nre.compile(r'{}')"}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
"hi{hi",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls(
|
||||
@ -673,7 +705,7 @@ def test_extract_tool_calls_streaming(
|
||||
),
|
||||
(
|
||||
# Complex
|
||||
"""[TOOL_CALLS]bash{"command": "print(\\"hello world!\\")\\nre.compile(r\'{}\')"}""", # noqa: E501
|
||||
"""hi{hi[TOOL_CALLS]bash{"command": "print(\\"hello world!\\")\\nre.compile(r\'{}\')"}""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
@ -684,7 +716,7 @@ def test_extract_tool_calls_streaming(
|
||||
)
|
||||
)
|
||||
],
|
||||
"",
|
||||
"hi{hi",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
@ -131,78 +131,105 @@ class MistralToolParser(ToolParser):
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
"""
|
||||
Extract the tool calls from a complete model response. Requires
|
||||
find-and-replacing single quotes with double quotes for JSON parsing,
|
||||
make sure your tool call arguments don't ever include quotes!
|
||||
Extract the tool calls from a complete model response.
|
||||
|
||||
Content and tool calls formatting depends on the Mistral's tokenizer version
|
||||
used to train the model:
|
||||
|
||||
- < v11: `content[BOT] [{tool_call1},{tool_call2}]`
|
||||
- >= v11: `content[BOT]tool_name1{args_call1}[BOT]tool_name2{args_call2}`
|
||||
|
||||
with [BOT] the tool call token.
|
||||
|
||||
Note:
|
||||
For tokenizer versions >= v11, tool calls with arguments wrongly formatted
|
||||
are still returned as tool calls. This is to allow the model to know it
|
||||
tried to make a tool call. It reduces chance of another failure and
|
||||
prevents that the context is filled with tool calls wrongly placed in
|
||||
assistant message contents.
|
||||
"""
|
||||
|
||||
# case -- if a tool call token is not present, return a text response
|
||||
# If the tool call token is not present, return a text response
|
||||
if self.bot_token not in model_output:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
# first remove the BOT token
|
||||
tool_content = model_output.replace(self.bot_token, "").strip()
|
||||
content_and_raw_tool_calls = model_output.split(self.bot_token)
|
||||
content = content_and_raw_tool_calls[0]
|
||||
raw_tool_calls = content_and_raw_tool_calls[1:]
|
||||
|
||||
try:
|
||||
# >= v11: content[BOT]tool_name1{args_call1}[BOT]tool_name2{args_call2}
|
||||
if not self._is_pre_v11:
|
||||
tool_calls = []
|
||||
for raw_tool_call in raw_tool_calls:
|
||||
if "{" not in raw_tool_call:
|
||||
continue
|
||||
|
||||
end_name = raw_tool_call.find("{")
|
||||
tool_name, args = (
|
||||
raw_tool_call[:end_name],
|
||||
raw_tool_call[end_name:],
|
||||
)
|
||||
|
||||
tool_calls.append({"name": tool_name, "arguments": args})
|
||||
|
||||
# < v11: content[BOT] [{tool_call1},{tool_call2}]
|
||||
else:
|
||||
if len(raw_tool_calls) != 1:
|
||||
raise ValueError(
|
||||
"Only one BOT token should have been outputted, "
|
||||
f"but got {model_output}."
|
||||
)
|
||||
stringified_tool_calls = raw_tool_calls[0].strip()
|
||||
try:
|
||||
if not self._is_pre_v11:
|
||||
function_call_arr = []
|
||||
for single_tool_content in model_output.split(self.bot_token):
|
||||
if "{" not in single_tool_content:
|
||||
continue
|
||||
|
||||
end_name = single_tool_content.find("{")
|
||||
fn_name, args = (
|
||||
single_tool_content[:end_name],
|
||||
single_tool_content[end_name:],
|
||||
)
|
||||
|
||||
# fn_name is encoded outside serialized json dump
|
||||
# only arguments are serialized
|
||||
function_call_arr.append(
|
||||
{"name": fn_name, "arguments": json.loads(args)}
|
||||
)
|
||||
else:
|
||||
function_call_arr = json.loads(tool_content)
|
||||
tool_calls = json.loads(stringified_tool_calls)
|
||||
except json.JSONDecodeError:
|
||||
# use a regex to find the part corresponding to the tool call.
|
||||
# NOTE: This use case should not happen if the model is trained
|
||||
# correctly. It's an easy possible fix so it's included, but
|
||||
# can be brittle for very complex / highly nested tool calls
|
||||
raw_tool_call = self.tool_call_regex.findall(tool_content)[0]
|
||||
function_call_arr = json.loads(raw_tool_call)
|
||||
|
||||
# Tool Call
|
||||
tool_calls: list[MistralToolCall] = [
|
||||
MistralToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=raw_function_call["name"],
|
||||
# function call args are JSON but as a string
|
||||
arguments=json.dumps(
|
||||
raw_function_call["arguments"], ensure_ascii=False
|
||||
try:
|
||||
raw_tool_call = self.tool_call_regex.findall(
|
||||
stringified_tool_calls
|
||||
)[0]
|
||||
tool_calls = json.loads(raw_tool_call)
|
||||
except (IndexError, json.JSONDecodeError):
|
||||
logger.exception("Error in extracting tool call from response: {e}")
|
||||
# If raw decoding and decoding post regex rule fails, then just
|
||||
# return content.
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False,
|
||||
tool_calls=[],
|
||||
content=stringified_tool_calls,
|
||||
)
|
||||
else:
|
||||
tool_calls = [
|
||||
{
|
||||
"name": tool_call["name"],
|
||||
"arguments": json.dumps(
|
||||
tool_call["arguments"], ensure_ascii=False
|
||||
),
|
||||
),
|
||||
)
|
||||
for raw_function_call in function_call_arr
|
||||
]
|
||||
}
|
||||
for tool_call in tool_calls
|
||||
]
|
||||
|
||||
# get any content before the tool call
|
||||
content = model_output.split(self.bot_token)[0]
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=content if len(content) > 0 else None,
|
||||
mistral_tool_calls: list[MistralToolCall] = [
|
||||
MistralToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=tool_call["name"],
|
||||
arguments=tool_call["arguments"],
|
||||
),
|
||||
)
|
||||
for tool_call in tool_calls
|
||||
]
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error in extracting tool call from response.")
|
||||
# return information to just treat the tool call as regular JSON
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=tool_content
|
||||
)
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=mistral_tool_calls,
|
||||
content=content if len(content) > 0 else None,
|
||||
)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user