mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-23 18:05:01 +08:00
[BUGFIX] Mistral tool call parser v11+ (#30332)
Signed-off-by: juliendenize <julien.denize@mistral.ai>
This commit is contained in:
parent
ee14644ba9
commit
5c213d2899
@ -615,6 +615,7 @@ def test_extract_tool_calls_streaming(
|
|||||||
"single_tool_weather",
|
"single_tool_weather",
|
||||||
"multiple_tool_calls",
|
"multiple_tool_calls",
|
||||||
"content_before_tool",
|
"content_before_tool",
|
||||||
|
"complex",
|
||||||
],
|
],
|
||||||
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||||
argvalues=[
|
argvalues=[
|
||||||
@ -673,6 +674,21 @@ def test_extract_tool_calls_streaming(
|
|||||||
],
|
],
|
||||||
"bla",
|
"bla",
|
||||||
),
|
),
|
||||||
|
(
|
||||||
|
# Complex
|
||||||
|
"""[TOOL_CALLS]bash{"command": "print(\\"hello world!\\")\\nre.compile(r\'{}\')"}""", # noqa: E501
|
||||||
|
[
|
||||||
|
ToolCall(
|
||||||
|
function=FunctionCall(
|
||||||
|
name="bash",
|
||||||
|
arguments=json.dumps(
|
||||||
|
{"command": "print(\"hello world!\")\nre.compile(r'{}')"}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
],
|
||||||
|
"",
|
||||||
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_extract_tool_calls_streaming_one_chunk(
|
def test_extract_tool_calls_streaming_one_chunk(
|
||||||
|
|||||||
@ -99,12 +99,7 @@ class MistralToolParser(ToolParser):
|
|||||||
self.bot_token = "[TOOL_CALLS]"
|
self.bot_token = "[TOOL_CALLS]"
|
||||||
self.bot_token_id = self.vocab.get(self.bot_token)
|
self.bot_token_id = self.vocab.get(self.bot_token)
|
||||||
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
|
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
|
||||||
if not _is_pre_v11_tokeniser(self.model_tokenizer):
|
self._is_pre_v11 = _is_pre_v11_tokeniser(self.model_tokenizer)
|
||||||
self.fn_name_regex = re.compile(
|
|
||||||
r"([a-zA-Z0-9_-]+)(\{[\s\S]*?\}+)", re.DOTALL
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.fn_name_regex = None
|
|
||||||
|
|
||||||
if self.bot_token_id is None:
|
if self.bot_token_id is None:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
@ -148,23 +143,24 @@ class MistralToolParser(ToolParser):
|
|||||||
tool_content = model_output.replace(self.bot_token, "").strip()
|
tool_content = model_output.replace(self.bot_token, "").strip()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# we first try to directly load the json as parsing very nested
|
|
||||||
# jsons is difficult
|
|
||||||
try:
|
try:
|
||||||
if self.fn_name_regex:
|
if not self._is_pre_v11:
|
||||||
function_call_arr = []
|
function_call_arr = []
|
||||||
for single_tool_content in model_output.split(self.bot_token):
|
for single_tool_content in model_output.split(self.bot_token):
|
||||||
matches = self.fn_name_regex.findall(single_tool_content)
|
if "{" not in single_tool_content:
|
||||||
|
continue
|
||||||
|
|
||||||
for match in matches:
|
end_name = single_tool_content.find("{")
|
||||||
fn_name = match[0]
|
fn_name, args = (
|
||||||
args = match[1]
|
single_tool_content[:end_name],
|
||||||
|
single_tool_content[end_name:],
|
||||||
|
)
|
||||||
|
|
||||||
# fn_name is encoded outside serialized json dump
|
# fn_name is encoded outside serialized json dump
|
||||||
# only arguments are serialized
|
# only arguments are serialized
|
||||||
function_call_arr.append(
|
function_call_arr.append(
|
||||||
{"name": fn_name, "arguments": json.loads(args)}
|
{"name": fn_name, "arguments": json.loads(args)}
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
function_call_arr = json.loads(tool_content)
|
function_call_arr = json.loads(tool_content)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user