mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:25:01 +08:00
Signed-off-by: avigny <47987522+avigny@users.noreply.github.com> Signed-off-by: Chauncey <chaunceyjiang@gmail.com> Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com> Co-authored-by: Jeff Cook <jeff@jeffcook.io> Co-authored-by: sfbemerk <benjaminmerkel@mail.de> Co-authored-by: Chauncey <chaunceyjiang@gmail.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
848 lines
28 KiB
Python
848 lines
28 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import json
|
|
from collections.abc import Generator
|
|
|
|
import partial_json_parser
|
|
import pytest
|
|
from mistral_common.protocol.instruct.messages import AssistantMessage
|
|
from mistral_common.protocol.instruct.request import InstructRequest
|
|
from mistral_common.protocol.instruct.tool_calls import FunctionCall, ToolCall
|
|
from partial_json_parser.core.options import Allow
|
|
|
|
from vllm.entrypoints.openai.protocol import DeltaMessage, DeltaToolCall
|
|
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import MistralToolParser
|
|
from vllm.tokenizers import (
|
|
MistralTokenizer,
|
|
TokenizerLike,
|
|
get_tokenizer,
|
|
)
|
|
from vllm.tokenizers.detokenizer_utils import detokenize_incrementally
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def mistral_pre_v11_tokenizer():
|
|
MODEL = "mistralai/Mistral-7B-Instruct-v0.3"
|
|
return get_tokenizer(tokenizer_name=MODEL)
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def mistral_tokenizer():
|
|
MODEL = "mistralai/Mistral-Small-3.2-24B-Instruct-2506"
|
|
return get_tokenizer(tokenizer_name=MODEL, tokenizer_mode="mistral")
|
|
|
|
|
|
@pytest.fixture
|
|
def mistral_pre_v11_tool_parser(mistral_pre_v11_tokenizer):
|
|
return MistralToolParser(mistral_pre_v11_tokenizer)
|
|
|
|
|
|
@pytest.fixture
|
|
def mistral_tool_parser(mistral_tokenizer):
|
|
return MistralToolParser(mistral_tokenizer)
|
|
|
|
|
|
def assert_tool_calls(
|
|
actual_tool_calls: list[ToolCall] | list[DeltaToolCall],
|
|
expected_tool_calls: list[ToolCall],
|
|
):
|
|
assert len(actual_tool_calls) == len(expected_tool_calls)
|
|
|
|
for actual_tool_call, expected_tool_call in zip(
|
|
actual_tool_calls, expected_tool_calls
|
|
):
|
|
assert isinstance(actual_tool_call.id, str)
|
|
assert len(actual_tool_call.id) == 9
|
|
|
|
if isinstance(actual_tool_call, ToolCall):
|
|
assert actual_tool_call.type == "function"
|
|
elif isinstance(actual_tool_call, DeltaToolCall):
|
|
assert actual_tool_call.function is not None
|
|
assert actual_tool_call.function.name is not None
|
|
assert actual_tool_call.function.arguments is not None
|
|
assert actual_tool_call.function is not None
|
|
assert actual_tool_call.function.name == expected_tool_call.function.name, (
|
|
f"got wrong function name:${actual_tool_call.function.name}"
|
|
)
|
|
assert (
|
|
actual_tool_call.function.arguments == expected_tool_call.function.arguments
|
|
), f"got wrong function argument:${actual_tool_call.function.arguments}"
|
|
|
|
|
|
def fix_tool_call_tokenization(
|
|
tokens: list[int],
|
|
mistral_tool_parser: MistralToolParser,
|
|
mistral_tokenizer: TokenizerLike,
|
|
):
|
|
"""
|
|
Replaces the textual token sequence for [TOOL_CALLS]
|
|
with its single special token ID.
|
|
"""
|
|
textual_tool_call_token_ids = mistral_tokenizer.encode(
|
|
text=mistral_tool_parser.bot_token,
|
|
add_special_tokens=False,
|
|
)
|
|
# textual_tool_call_token_ids must not contain special tokens like bos, eos etc
|
|
special_tool_call_token_ids = [mistral_tool_parser.bot_token_id]
|
|
|
|
# If the input is too short to contain the sequence, no replacement is possible
|
|
if not tokens or len(tokens) < len(textual_tool_call_token_ids):
|
|
return tokens
|
|
|
|
result_tokens = []
|
|
i = 0
|
|
target_len = len(textual_tool_call_token_ids)
|
|
|
|
while i < len(tokens):
|
|
# Check if the slice from the current position matches the target sequence
|
|
if tokens[i : i + target_len] == textual_tool_call_token_ids:
|
|
# If it matches, add the replacement and jump the index forward
|
|
result_tokens.extend(special_tool_call_token_ids)
|
|
i += target_len
|
|
else:
|
|
# Otherwise, just add the current token and move to the next one
|
|
result_tokens.append(tokens[i])
|
|
i += 1
|
|
|
|
return result_tokens
|
|
|
|
|
|
def stream_delta_message_generator(
|
|
mistral_tool_parser: MistralToolParser,
|
|
mistral_tokenizer: TokenizerLike,
|
|
model_output: str | None,
|
|
tools: list[tuple[str, str]] | None,
|
|
) -> Generator[DeltaMessage, None, None]:
|
|
if (
|
|
isinstance(mistral_tokenizer, MistralTokenizer)
|
|
and mistral_tokenizer.version >= 11
|
|
):
|
|
# With the newer versions of the tokenizer,
|
|
# we cannot tokenize free text
|
|
# so we need to create a list of messages to get tokenized
|
|
assert tools is not None
|
|
assistant_msg = AssistantMessage(
|
|
tool_calls=[
|
|
ToolCall(
|
|
function=FunctionCall(
|
|
name=name,
|
|
arguments=arg,
|
|
)
|
|
)
|
|
for (name, arg) in tools
|
|
],
|
|
)
|
|
request = InstructRequest(
|
|
messages=[assistant_msg],
|
|
)
|
|
all_token_ids = mistral_tokenizer.instruct.encode_instruct(request).tokens
|
|
else:
|
|
# Older versions of the tokenizer are
|
|
# able to encode directly the model's output (free text) into tokens
|
|
assert model_output is not None
|
|
all_token_ids = mistral_tokenizer.encode(model_output, add_special_tokens=False)
|
|
|
|
all_token_ids = fix_tool_call_tokenization(
|
|
all_token_ids, mistral_tool_parser, mistral_tokenizer
|
|
)
|
|
|
|
previous_text = ""
|
|
previous_tokens = None
|
|
prefix_offset = 0
|
|
read_offset = 0
|
|
for i, delta_token in enumerate(all_token_ids):
|
|
delta_token_ids = [delta_token]
|
|
previous_token_ids = all_token_ids[:i]
|
|
current_token_ids = all_token_ids[: i + 1]
|
|
|
|
(new_tokens, delta_text, new_prefix_offset, new_read_offset) = (
|
|
detokenize_incrementally(
|
|
tokenizer=mistral_tokenizer,
|
|
all_input_ids=current_token_ids,
|
|
prev_tokens=previous_tokens,
|
|
prefix_offset=prefix_offset,
|
|
read_offset=read_offset,
|
|
skip_special_tokens=isinstance(mistral_tokenizer, MistralTokenizer),
|
|
spaces_between_special_tokens=True,
|
|
)
|
|
)
|
|
|
|
current_text = previous_text + delta_text
|
|
|
|
delta_message = mistral_tool_parser.extract_tool_calls_streaming(
|
|
previous_text,
|
|
current_text,
|
|
delta_text,
|
|
previous_token_ids,
|
|
current_token_ids,
|
|
delta_token_ids,
|
|
request=None, # type: ignore[arg-type]
|
|
)
|
|
if delta_message:
|
|
yield delta_message
|
|
|
|
previous_text = current_text
|
|
previous_tokens = (
|
|
previous_tokens + new_tokens if previous_tokens else new_tokens
|
|
)
|
|
prefix_offset = new_prefix_offset
|
|
read_offset = new_read_offset
|
|
|
|
|
|
def test_extract_tool_calls_no_tools(mistral_pre_v11_tool_parser):
|
|
model_output = "This is a test"
|
|
extracted_tool_calls = mistral_pre_v11_tool_parser.extract_tool_calls(
|
|
model_output, request=None
|
|
) # type: ignore[arg-type]
|
|
assert not extracted_tool_calls.tools_called
|
|
assert extracted_tool_calls.tool_calls == []
|
|
assert extracted_tool_calls.content == model_output
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
ids=[
|
|
"single_tool_add",
|
|
"single_tool_weather",
|
|
"argument_before_name",
|
|
"argument_before_name_and_name_in_argument",
|
|
],
|
|
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
|
argvalues=[
|
|
(
|
|
"""[TOOL_CALLS][{"name": "add", "arguments":{"a": 3.5, "b": 4}}]""", # noqa: E501
|
|
[
|
|
ToolCall(
|
|
function=FunctionCall(
|
|
name="add", arguments=json.dumps({"a": 3.5, "b": 4})
|
|
)
|
|
)
|
|
],
|
|
None,
|
|
),
|
|
(
|
|
"""[TOOL_CALLS] [{"name": "get_current_weather", "arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}}]""", # noqa: E501
|
|
[
|
|
ToolCall(
|
|
function=FunctionCall(
|
|
name="get_current_weather",
|
|
arguments=json.dumps(
|
|
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
|
|
),
|
|
)
|
|
)
|
|
],
|
|
None,
|
|
),
|
|
(
|
|
"""[TOOL_CALLS] [{"arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]""", # noqa: E501
|
|
[
|
|
ToolCall(
|
|
function=FunctionCall(
|
|
name="get_current_weather",
|
|
arguments=json.dumps(
|
|
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
|
|
),
|
|
)
|
|
)
|
|
],
|
|
None,
|
|
),
|
|
(
|
|
"""[TOOL_CALLS] [{"arguments":{"name": "John Doe"}, "name": "get_age"}]""", # noqa: E501
|
|
[
|
|
ToolCall(
|
|
function=FunctionCall(
|
|
name="get_age",
|
|
arguments=json.dumps(
|
|
{
|
|
"name": "John Doe",
|
|
}
|
|
),
|
|
)
|
|
)
|
|
],
|
|
None,
|
|
),
|
|
],
|
|
)
|
|
def test_extract_tool_calls_pre_v11_tokenizer(
|
|
mistral_pre_v11_tool_parser, model_output, expected_tool_calls, expected_content
|
|
):
|
|
extracted_tool_calls = mistral_pre_v11_tool_parser.extract_tool_calls(
|
|
model_output, request=None
|
|
) # type: ignore[arg-type]
|
|
assert extracted_tool_calls.tools_called
|
|
|
|
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
|
|
|
|
assert extracted_tool_calls.content == expected_content
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
ids=[
|
|
"single_tool_add",
|
|
"single_tool_weather",
|
|
"multiple_tool_calls",
|
|
],
|
|
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
|
argvalues=[
|
|
(
|
|
"""[TOOL_CALLS]add_this_and_that{"a": 3.5, "b": 4}""", # noqa: E501
|
|
[
|
|
ToolCall(
|
|
function=FunctionCall(
|
|
name="add_this_and_that",
|
|
arguments=json.dumps({"a": 3.5, "b": 4}),
|
|
)
|
|
)
|
|
],
|
|
None,
|
|
),
|
|
(
|
|
"""[TOOL_CALLS]get_current_weather{"city": "San Francisco", "state": "CA", "unit": "celsius"}""", # noqa: E501
|
|
[
|
|
ToolCall(
|
|
function=FunctionCall(
|
|
name="get_current_weather",
|
|
arguments=json.dumps(
|
|
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
|
|
),
|
|
)
|
|
)
|
|
],
|
|
None,
|
|
),
|
|
(
|
|
"""[TOOL_CALLS]add{"a": 3.5, "b": 4}[TOOL_CALLS]multiply{"a": 3, "b": 6}""", # noqa: E501
|
|
[
|
|
ToolCall(
|
|
function=FunctionCall(
|
|
name="add", arguments=json.dumps({"a": 3.5, "b": 4})
|
|
)
|
|
),
|
|
ToolCall(
|
|
function=FunctionCall(
|
|
name="multiply", arguments=json.dumps({"a": 3, "b": 6})
|
|
)
|
|
),
|
|
],
|
|
None,
|
|
),
|
|
],
|
|
)
|
|
def test_extract_tool_calls(
|
|
mistral_tool_parser, model_output, expected_tool_calls, expected_content
|
|
):
|
|
extracted_tool_calls = mistral_tool_parser.extract_tool_calls(
|
|
model_output, request=None
|
|
) # type: ignore[arg-type]
|
|
assert extracted_tool_calls.tools_called
|
|
|
|
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
|
|
|
|
assert extracted_tool_calls.content == expected_content
|
|
|
|
|
|
def _test_extract_tool_calls_streaming(
|
|
tool_parser, tokenizer, model_output, tools, expected_tool_calls, expected_content
|
|
):
|
|
other_content: str = ""
|
|
function_names: list[str] = []
|
|
function_args_strs: list[str] = []
|
|
tool_call_idx: int = -1
|
|
tool_call_ids: list[str | None] = []
|
|
|
|
for delta_message in stream_delta_message_generator(
|
|
tool_parser, tokenizer, model_output, tools
|
|
):
|
|
# role should never be streamed from tool parser
|
|
assert not delta_message.role
|
|
|
|
if delta_message.content:
|
|
other_content += delta_message.content
|
|
|
|
streamed_tool_calls = delta_message.tool_calls
|
|
|
|
if streamed_tool_calls and len(streamed_tool_calls) > 0:
|
|
# make sure only one diff is present - correct even for parallel
|
|
assert len(streamed_tool_calls) == 1
|
|
tool_call = streamed_tool_calls[0]
|
|
|
|
assert len(tool_parser.prev_tool_call_arr) > 0
|
|
|
|
# if a new tool is being called, set up empty arguments
|
|
if tool_call.index != tool_call_idx:
|
|
tool_call_idx = tool_call.index
|
|
function_args_strs.append("")
|
|
tool_call_ids.append(None)
|
|
|
|
# if a tool call ID is streamed, make sure one hasn't been already
|
|
if tool_call.id and not tool_call_ids[tool_call.index]:
|
|
tool_call_ids[tool_call.index] = tool_call.id
|
|
|
|
# if parts of the function start being streamed
|
|
if tool_call.function:
|
|
# if the function name is defined, set it. it should be streamed
|
|
# IN ENTIRETY, exactly one time.
|
|
if tool_call.function.name:
|
|
assert isinstance(tool_call.function.name, str)
|
|
function_names.append(tool_call.function.name)
|
|
|
|
if tool_call.function.arguments:
|
|
# make sure they're a string and then add them to the list
|
|
assert isinstance(tool_call.function.arguments, str)
|
|
|
|
function_args_strs[tool_call.index] += tool_call.function.arguments
|
|
|
|
assert other_content == expected_content
|
|
|
|
actual_tool_calls = [
|
|
ToolCall(
|
|
id=tool_call_id,
|
|
function=FunctionCall(
|
|
name=function_name,
|
|
arguments=partial_json_parser.ensure_json(
|
|
function_args_str, Allow.OBJ | Allow.STR
|
|
),
|
|
),
|
|
)
|
|
for tool_call_id, function_name, function_args_str in zip(
|
|
tool_call_ids, function_names, function_args_strs
|
|
)
|
|
]
|
|
assert_tool_calls(actual_tool_calls, expected_tool_calls)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
ids=[
|
|
"no_tools",
|
|
"single_tool_add",
|
|
"single_tool_add_strings",
|
|
"single_tool_weather",
|
|
"argument_before_name",
|
|
"argument_before_name_and_name_in_argument",
|
|
"multiple_tools",
|
|
],
|
|
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
|
argvalues=[
|
|
("""This is a test""", [], """This is a test"""),
|
|
(
|
|
"""[TOOL_CALLS] [ {"name":"add" , "arguments" : {"a": 3, "b": 4} } ]""", # noqa: E501
|
|
[
|
|
ToolCall(
|
|
function=FunctionCall(
|
|
name="add", arguments=json.dumps({"a": 3, "b": 4})
|
|
)
|
|
)
|
|
],
|
|
"",
|
|
),
|
|
(
|
|
"""[TOOL_CALLS] [{"name": "add", "arguments":{"a": "3", "b": "4"}}]""", # noqa: E501
|
|
[
|
|
ToolCall(
|
|
function=FunctionCall(
|
|
name="add", arguments=json.dumps({"a": "3", "b": "4"})
|
|
)
|
|
)
|
|
],
|
|
"",
|
|
),
|
|
(
|
|
"""[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}}]""", # noqa: E501
|
|
[
|
|
ToolCall(
|
|
function=FunctionCall(
|
|
name="get_current_weather",
|
|
arguments=json.dumps(
|
|
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
|
|
),
|
|
)
|
|
)
|
|
],
|
|
"",
|
|
),
|
|
(
|
|
"""[TOOL_CALLS] [{"arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]""", # noqa: E501
|
|
[
|
|
ToolCall(
|
|
function=FunctionCall(
|
|
name="get_current_weather",
|
|
arguments=json.dumps(
|
|
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
|
|
),
|
|
)
|
|
)
|
|
],
|
|
"",
|
|
),
|
|
(
|
|
"""[TOOL_CALLS] [{"arguments": {"name": "John Doe"}, "name": "get_age"}]""", # noqa: E501
|
|
[
|
|
ToolCall(
|
|
function=FunctionCall(
|
|
name="get_age",
|
|
arguments=json.dumps(
|
|
{
|
|
"name": "John Doe",
|
|
}
|
|
),
|
|
)
|
|
)
|
|
],
|
|
"",
|
|
),
|
|
(
|
|
"""[TOOL_CALLS] [{"name": "add", "arguments": {"a": 3.5, "b": 4}}, {"name": "get_current_weather", "arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}}]""", # noqa: E501
|
|
[
|
|
ToolCall(
|
|
function=FunctionCall(
|
|
name="add", arguments=json.dumps({"a": 3.5, "b": 4})
|
|
)
|
|
),
|
|
ToolCall(
|
|
function=FunctionCall(
|
|
name="get_current_weather",
|
|
arguments=json.dumps(
|
|
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
|
|
),
|
|
)
|
|
),
|
|
],
|
|
"",
|
|
),
|
|
],
|
|
)
|
|
def test_extract_tool_calls_streaming_pre_v11_tokenizer(
|
|
mistral_pre_v11_tool_parser,
|
|
mistral_pre_v11_tokenizer,
|
|
model_output,
|
|
expected_tool_calls,
|
|
expected_content,
|
|
):
|
|
_test_extract_tool_calls_streaming(
|
|
mistral_pre_v11_tool_parser,
|
|
mistral_pre_v11_tokenizer,
|
|
model_output,
|
|
None,
|
|
expected_tool_calls,
|
|
expected_content,
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
ids=[
|
|
"single_tool_add",
|
|
"single_tool_add_strings",
|
|
"multiple_tools",
|
|
],
|
|
argnames=["tools", "expected_tool_calls", "expected_content"],
|
|
argvalues=[
|
|
(
|
|
[("add", '{"a": 3, "b": 4}')],
|
|
# [TOOL_CALLS]add{"a": 3, "b": 4}
|
|
[
|
|
ToolCall(
|
|
function=FunctionCall(
|
|
name="add", arguments=json.dumps({"a": 3, "b": 4})
|
|
)
|
|
)
|
|
],
|
|
"",
|
|
),
|
|
(
|
|
[("add_two_strings", '{"a": "3", "b": "4"}')],
|
|
# [TOOL_CALLS]add_two_strings{"a": "3", "b": "4"}
|
|
[
|
|
ToolCall(
|
|
function=FunctionCall(
|
|
name="add_two_strings",
|
|
arguments=json.dumps({"a": "3", "b": "4"}),
|
|
)
|
|
)
|
|
],
|
|
"",
|
|
),
|
|
(
|
|
[
|
|
("add", '{"a": 3.5, "b": 4}'),
|
|
(
|
|
"get_current_weather",
|
|
'{"city": "San Francisco", "state": "CA", "unit": "celsius"}', # noqa: E501
|
|
),
|
|
],
|
|
# [TOOL_CALLS]add{"a": 3.5, "b": 4}[TOOL_CALLS]get_current_weather{"city": "San Francisco", "state": "CA", "unit": "celsius"} # noqa: E501
|
|
[
|
|
ToolCall(
|
|
function=FunctionCall(
|
|
name="add", arguments=json.dumps({"a": 3.5, "b": 4})
|
|
)
|
|
),
|
|
ToolCall(
|
|
function=FunctionCall(
|
|
name="get_current_weather",
|
|
arguments=json.dumps(
|
|
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
|
|
),
|
|
)
|
|
),
|
|
],
|
|
"",
|
|
),
|
|
],
|
|
)
|
|
def test_extract_tool_calls_streaming(
|
|
mistral_tool_parser,
|
|
mistral_tokenizer,
|
|
tools,
|
|
expected_tool_calls,
|
|
expected_content,
|
|
):
|
|
_test_extract_tool_calls_streaming(
|
|
mistral_tool_parser,
|
|
mistral_tokenizer,
|
|
None,
|
|
tools,
|
|
expected_tool_calls,
|
|
expected_content,
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
ids=[
|
|
"single_tool_add",
|
|
"single_tool_weather",
|
|
"multiple_tool_calls",
|
|
"content_before_tool",
|
|
],
|
|
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
|
argvalues=[
|
|
(
|
|
"""[TOOL_CALLS]add_this_and_that{"a": 3.5, "b": 4}""", # noqa: E501
|
|
[
|
|
ToolCall(
|
|
function=FunctionCall(
|
|
name="add_this_and_that",
|
|
arguments=json.dumps({"a": 3.5, "b": 4}),
|
|
)
|
|
)
|
|
],
|
|
"",
|
|
),
|
|
(
|
|
"""[TOOL_CALLS]get_current_weather{"city": "San Francisco", "state": "CA", "unit": "celsius"}""", # noqa: E501
|
|
[
|
|
ToolCall(
|
|
function=FunctionCall(
|
|
name="get_current_weather",
|
|
arguments=json.dumps(
|
|
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
|
|
),
|
|
)
|
|
)
|
|
],
|
|
"",
|
|
),
|
|
(
|
|
"""[TOOL_CALLS]add{"a": 3.5, "b": 4}[TOOL_CALLS]multiply{"a": 3, "b": 6}""", # noqa: E501
|
|
[
|
|
ToolCall(
|
|
function=FunctionCall(
|
|
name="add", arguments=json.dumps({"a": 3.5, "b": 4})
|
|
)
|
|
),
|
|
ToolCall(
|
|
function=FunctionCall(
|
|
name="multiply", arguments=json.dumps({"a": 3, "b": 6})
|
|
)
|
|
),
|
|
],
|
|
"",
|
|
),
|
|
(
|
|
# Additional content should not be after the tool calls
|
|
"""bla[TOOL_CALLS]add_this_and_that{"a": 3.5, "b": 4}""", # noqa: E501
|
|
[
|
|
ToolCall(
|
|
function=FunctionCall(
|
|
name="add_this_and_that",
|
|
arguments=json.dumps({"a": 3.5, "b": 4}),
|
|
)
|
|
)
|
|
],
|
|
"bla",
|
|
),
|
|
],
|
|
)
|
|
def test_extract_tool_calls_streaming_one_chunk(
|
|
mistral_tool_parser,
|
|
mistral_tokenizer,
|
|
model_output,
|
|
expected_tool_calls,
|
|
expected_content,
|
|
):
|
|
if isinstance(mistral_tokenizer, MistralTokenizer):
|
|
all_token_ids = mistral_tokenizer.encode(model_output)
|
|
else:
|
|
all_token_ids = mistral_tokenizer.encode(model_output, add_special_tokens=False)
|
|
all_token_ids = fix_tool_call_tokenization(
|
|
all_token_ids, mistral_tool_parser, mistral_tokenizer
|
|
)
|
|
|
|
delta_message = mistral_tool_parser.extract_tool_calls_streaming(
|
|
previous_text="",
|
|
current_text=model_output,
|
|
delta_text=model_output,
|
|
previous_token_ids=[],
|
|
current_token_ids=all_token_ids,
|
|
delta_token_ids=all_token_ids,
|
|
request=None,
|
|
) # type: ignore[arg-type]
|
|
assert isinstance(delta_message, DeltaMessage)
|
|
assert len(delta_message.tool_calls) == len(expected_tool_calls)
|
|
|
|
assert_tool_calls(delta_message.tool_calls, expected_tool_calls)
|
|
|
|
if delta_message.content is None:
|
|
assert expected_content == ""
|
|
else:
|
|
assert delta_message.content == expected_content
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
ids=[
|
|
"no_tools",
|
|
"single_tool_add",
|
|
"single_tool_add_strings",
|
|
"single_tool_weather",
|
|
"argument_before_name",
|
|
"argument_before_name_and_name_in_argument",
|
|
"multiple_tools",
|
|
],
|
|
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
|
argvalues=[
|
|
("""This is a test""", [], """This is a test"""),
|
|
(
|
|
"""[TOOL_CALLS] [ {"name":"add" , "arguments" : {"a": 3, "b": 4} } ]""", # noqa: E501
|
|
[
|
|
ToolCall(
|
|
function=FunctionCall(
|
|
name="add", arguments=json.dumps({"a": 3, "b": 4})
|
|
)
|
|
)
|
|
],
|
|
"",
|
|
),
|
|
(
|
|
"""[TOOL_CALLS] [{"name": "add", "arguments":{"a": "3", "b": "4"}}]""", # noqa: E501
|
|
[
|
|
ToolCall(
|
|
function=FunctionCall(
|
|
name="add", arguments=json.dumps({"a": "3", "b": "4"})
|
|
)
|
|
)
|
|
],
|
|
"",
|
|
),
|
|
(
|
|
"""[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}}]""", # noqa: E501
|
|
[
|
|
ToolCall(
|
|
function=FunctionCall(
|
|
name="get_current_weather",
|
|
arguments=json.dumps(
|
|
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
|
|
),
|
|
)
|
|
)
|
|
],
|
|
"",
|
|
),
|
|
(
|
|
"""[TOOL_CALLS] [{"arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]""", # noqa: E501
|
|
[
|
|
ToolCall(
|
|
function=FunctionCall(
|
|
name="get_current_weather",
|
|
arguments=json.dumps(
|
|
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
|
|
),
|
|
)
|
|
)
|
|
],
|
|
"",
|
|
),
|
|
(
|
|
"""[TOOL_CALLS] [{"arguments": {"name": "John Doe"}, "name": "get_age"}]""", # noqa: E501
|
|
[
|
|
ToolCall(
|
|
function=FunctionCall(
|
|
name="get_age",
|
|
arguments=json.dumps(
|
|
{
|
|
"name": "John Doe",
|
|
}
|
|
),
|
|
)
|
|
)
|
|
],
|
|
"",
|
|
),
|
|
(
|
|
"""[TOOL_CALLS] [{"arguments": {"a": 3.5, "b": 4}, "name": "add"}, {"arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]""", # noqa: E501
|
|
[
|
|
ToolCall(
|
|
function=FunctionCall(
|
|
name="add", arguments=json.dumps({"a": 3.5, "b": 4})
|
|
)
|
|
),
|
|
ToolCall(
|
|
function=FunctionCall(
|
|
name="get_current_weather",
|
|
arguments=json.dumps(
|
|
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
|
|
),
|
|
)
|
|
),
|
|
],
|
|
"",
|
|
),
|
|
],
|
|
)
|
|
def test_extract_tool_calls_streaming_pre_v11_tokenizer_one_chunk(
|
|
mistral_pre_v11_tool_parser,
|
|
mistral_pre_v11_tokenizer,
|
|
model_output,
|
|
expected_tool_calls,
|
|
expected_content,
|
|
):
|
|
if isinstance(mistral_pre_v11_tokenizer, MistralTokenizer):
|
|
all_token_ids = mistral_pre_v11_tokenizer.encode(model_output)
|
|
else:
|
|
all_token_ids = mistral_pre_v11_tokenizer.encode(
|
|
model_output, add_special_tokens=False
|
|
)
|
|
all_token_ids = fix_tool_call_tokenization(
|
|
all_token_ids, mistral_pre_v11_tool_parser, mistral_pre_v11_tokenizer
|
|
)
|
|
|
|
delta_message = mistral_pre_v11_tool_parser.extract_tool_calls_streaming(
|
|
previous_text="",
|
|
current_text=model_output,
|
|
delta_text=model_output,
|
|
previous_token_ids=[],
|
|
current_token_ids=all_token_ids,
|
|
delta_token_ids=all_token_ids,
|
|
request=None,
|
|
) # type: ignore[arg-type]
|
|
assert isinstance(delta_message, DeltaMessage)
|
|
assert len(delta_message.tool_calls) == len(expected_tool_calls)
|
|
|
|
assert_tool_calls(delta_message.tool_calls, expected_tool_calls)
|
|
|
|
if delta_message.content is None:
|
|
assert expected_content == ""
|
|
else:
|
|
assert delta_message.content == expected_content
|