mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-30 11:07:06 +08:00
Fix edge case Mistral tool parser (#30724)
Signed-off-by: Joachim Studnia <joachim@mistral.ai> Signed-off-by: Joachim Studnia <studniajoachim@gmail.com> Signed-off-by: juliendenize <julien.denize@mistral.ai> Signed-off-by: Julien Denize <40604584+juliendenize@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: juliendenize <julien.denize@mistral.ai> Co-authored-by: Julien Denize <40604584+juliendenize@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
bb62dda2c3
commit
38c361f99d
@ -281,6 +281,8 @@ def test_extract_tool_calls_pre_v11_tokenizer(
|
|||||||
"single_tool_add",
|
"single_tool_add",
|
||||||
"single_tool_weather",
|
"single_tool_weather",
|
||||||
"multiple_tool_calls",
|
"multiple_tool_calls",
|
||||||
|
"complex",
|
||||||
|
"wrong_json",
|
||||||
],
|
],
|
||||||
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||||
argvalues=[
|
argvalues=[
|
||||||
@ -326,6 +328,36 @@ def test_extract_tool_calls_pre_v11_tokenizer(
|
|||||||
],
|
],
|
||||||
None,
|
None,
|
||||||
),
|
),
|
||||||
|
(
|
||||||
|
# Complex
|
||||||
|
"""hi{hi[TOOL_CALLS]bash{"command": "print(\\"hello world!\\")\\nre.compile(r\'{}\')""", # noqa: E501
|
||||||
|
[
|
||||||
|
ToolCall(
|
||||||
|
function=FunctionCall(
|
||||||
|
name="bash",
|
||||||
|
arguments=json.dumps(
|
||||||
|
{"command": "print(\"hello world!\")\nre.compile(r'{}')"}
|
||||||
|
)[:-2],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
],
|
||||||
|
"hi{hi",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
# Wrong json
|
||||||
|
"""hi{hi[TOOL_CALLS]bash{"command": "print(\\"hello world!\\")\\nre.compile(r\'{}\')"}""", # noqa: E501
|
||||||
|
[
|
||||||
|
ToolCall(
|
||||||
|
function=FunctionCall(
|
||||||
|
name="bash",
|
||||||
|
arguments=json.dumps(
|
||||||
|
{"command": "print(\"hello world!\")\nre.compile(r'{}')"}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
],
|
||||||
|
"hi{hi",
|
||||||
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_extract_tool_calls(
|
def test_extract_tool_calls(
|
||||||
@ -673,7 +705,7 @@ def test_extract_tool_calls_streaming(
|
|||||||
),
|
),
|
||||||
(
|
(
|
||||||
# Complex
|
# Complex
|
||||||
"""[TOOL_CALLS]bash{"command": "print(\\"hello world!\\")\\nre.compile(r\'{}\')"}""", # noqa: E501
|
"""hi{hi[TOOL_CALLS]bash{"command": "print(\\"hello world!\\")\\nre.compile(r\'{}\')"}""", # noqa: E501
|
||||||
[
|
[
|
||||||
ToolCall(
|
ToolCall(
|
||||||
function=FunctionCall(
|
function=FunctionCall(
|
||||||
@ -684,7 +716,7 @@ def test_extract_tool_calls_streaming(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
"",
|
"hi{hi",
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|||||||
@ -131,78 +131,105 @@ class MistralToolParser(ToolParser):
|
|||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
) -> ExtractedToolCallInformation:
|
) -> ExtractedToolCallInformation:
|
||||||
"""
|
"""
|
||||||
Extract the tool calls from a complete model response. Requires
|
Extract the tool calls from a complete model response.
|
||||||
find-and-replacing single quotes with double quotes for JSON parsing,
|
|
||||||
make sure your tool call arguments don't ever include quotes!
|
Content and tool calls formatting depends on the Mistral's tokenizer version
|
||||||
|
used to train the model:
|
||||||
|
|
||||||
|
- < v11: `content[BOT] [{tool_call1},{tool_call2}]`
|
||||||
|
- >= v11: `content[BOT]tool_name1{args_call1}[BOT]tool_name2{args_call2}`
|
||||||
|
|
||||||
|
with [BOT] the tool call token.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
For tokenizer versions >= v11, tool calls with arguments wrongly formatted
|
||||||
|
are still returned as tool calls. This is to allow the model to know it
|
||||||
|
tried to make a tool call. It reduces chance of another failure and
|
||||||
|
prevents that the context is filled with tool calls wrongly placed in
|
||||||
|
assistant message contents.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# case -- if a tool call token is not present, return a text response
|
# If the tool call token is not present, return a text response
|
||||||
if self.bot_token not in model_output:
|
if self.bot_token not in model_output:
|
||||||
return ExtractedToolCallInformation(
|
return ExtractedToolCallInformation(
|
||||||
tools_called=False, tool_calls=[], content=model_output
|
tools_called=False, tool_calls=[], content=model_output
|
||||||
)
|
)
|
||||||
|
|
||||||
# first remove the BOT token
|
content_and_raw_tool_calls = model_output.split(self.bot_token)
|
||||||
tool_content = model_output.replace(self.bot_token, "").strip()
|
content = content_and_raw_tool_calls[0]
|
||||||
|
raw_tool_calls = content_and_raw_tool_calls[1:]
|
||||||
|
|
||||||
try:
|
# >= v11: content[BOT]tool_name1{args_call1}[BOT]tool_name2{args_call2}
|
||||||
|
if not self._is_pre_v11:
|
||||||
|
tool_calls = []
|
||||||
|
for raw_tool_call in raw_tool_calls:
|
||||||
|
if "{" not in raw_tool_call:
|
||||||
|
continue
|
||||||
|
|
||||||
|
end_name = raw_tool_call.find("{")
|
||||||
|
tool_name, args = (
|
||||||
|
raw_tool_call[:end_name],
|
||||||
|
raw_tool_call[end_name:],
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_calls.append({"name": tool_name, "arguments": args})
|
||||||
|
|
||||||
|
# < v11: content[BOT] [{tool_call1},{tool_call2}]
|
||||||
|
else:
|
||||||
|
if len(raw_tool_calls) != 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Only one BOT token should have been outputted, "
|
||||||
|
f"but got {model_output}."
|
||||||
|
)
|
||||||
|
stringified_tool_calls = raw_tool_calls[0].strip()
|
||||||
try:
|
try:
|
||||||
if not self._is_pre_v11:
|
tool_calls = json.loads(stringified_tool_calls)
|
||||||
function_call_arr = []
|
|
||||||
for single_tool_content in model_output.split(self.bot_token):
|
|
||||||
if "{" not in single_tool_content:
|
|
||||||
continue
|
|
||||||
|
|
||||||
end_name = single_tool_content.find("{")
|
|
||||||
fn_name, args = (
|
|
||||||
single_tool_content[:end_name],
|
|
||||||
single_tool_content[end_name:],
|
|
||||||
)
|
|
||||||
|
|
||||||
# fn_name is encoded outside serialized json dump
|
|
||||||
# only arguments are serialized
|
|
||||||
function_call_arr.append(
|
|
||||||
{"name": fn_name, "arguments": json.loads(args)}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
function_call_arr = json.loads(tool_content)
|
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
# use a regex to find the part corresponding to the tool call.
|
# use a regex to find the part corresponding to the tool call.
|
||||||
# NOTE: This use case should not happen if the model is trained
|
# NOTE: This use case should not happen if the model is trained
|
||||||
# correctly. It's an easy possible fix so it's included, but
|
# correctly. It's an easy possible fix so it's included, but
|
||||||
# can be brittle for very complex / highly nested tool calls
|
# can be brittle for very complex / highly nested tool calls
|
||||||
raw_tool_call = self.tool_call_regex.findall(tool_content)[0]
|
try:
|
||||||
function_call_arr = json.loads(raw_tool_call)
|
raw_tool_call = self.tool_call_regex.findall(
|
||||||
|
stringified_tool_calls
|
||||||
# Tool Call
|
)[0]
|
||||||
tool_calls: list[MistralToolCall] = [
|
tool_calls = json.loads(raw_tool_call)
|
||||||
MistralToolCall(
|
except (IndexError, json.JSONDecodeError):
|
||||||
type="function",
|
logger.exception("Error in extracting tool call from response: {e}")
|
||||||
function=FunctionCall(
|
# If raw decoding and decoding post regex rule fails, then just
|
||||||
name=raw_function_call["name"],
|
# return content.
|
||||||
# function call args are JSON but as a string
|
return ExtractedToolCallInformation(
|
||||||
arguments=json.dumps(
|
tools_called=False,
|
||||||
raw_function_call["arguments"], ensure_ascii=False
|
tool_calls=[],
|
||||||
|
content=stringified_tool_calls,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
tool_calls = [
|
||||||
|
{
|
||||||
|
"name": tool_call["name"],
|
||||||
|
"arguments": json.dumps(
|
||||||
|
tool_call["arguments"], ensure_ascii=False
|
||||||
),
|
),
|
||||||
),
|
}
|
||||||
)
|
for tool_call in tool_calls
|
||||||
for raw_function_call in function_call_arr
|
]
|
||||||
]
|
|
||||||
|
|
||||||
# get any content before the tool call
|
mistral_tool_calls: list[MistralToolCall] = [
|
||||||
content = model_output.split(self.bot_token)[0]
|
MistralToolCall(
|
||||||
return ExtractedToolCallInformation(
|
type="function",
|
||||||
tools_called=True,
|
function=FunctionCall(
|
||||||
tool_calls=tool_calls,
|
name=tool_call["name"],
|
||||||
content=content if len(content) > 0 else None,
|
arguments=tool_call["arguments"],
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
for tool_call in tool_calls
|
||||||
|
]
|
||||||
|
|
||||||
except Exception:
|
return ExtractedToolCallInformation(
|
||||||
logger.exception("Error in extracting tool call from response.")
|
tools_called=True,
|
||||||
# return information to just treat the tool call as regular JSON
|
tool_calls=mistral_tool_calls,
|
||||||
return ExtractedToolCallInformation(
|
content=content if len(content) > 0 else None,
|
||||||
tools_called=False, tool_calls=[], content=tool_content
|
)
|
||||||
)
|
|
||||||
|
|
||||||
def extract_tool_calls_streaming(
|
def extract_tool_calls_streaming(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user