vllm/tests/tool_use/test_kimi_k2_tool_parser.py
Chauncey c02fccdbd2
[Refactor] Lazy import tool_parser (#27974)
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
2025-11-04 10:10:10 +08:00

212 lines
8.0 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
import json
import pytest
from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall
from vllm.entrypoints.openai.tool_parsers.kimi_k2_tool_parser import KimiK2ToolParser
from vllm.transformers_utils.tokenizer import get_tokenizer
pytestmark = pytest.mark.cpu_test
# Use a common model that is likely to be available
MODEL = "moonshotai/Kimi-K2-Instruct"
@pytest.fixture(scope="module")
def kimi_k2_tokenizer():
return get_tokenizer(tokenizer_name=MODEL, trust_remote_code=True)
@pytest.fixture
def kimi_k2_tool_parser(kimi_k2_tokenizer):
return KimiK2ToolParser(kimi_k2_tokenizer)
def assert_tool_calls(
actual_tool_calls: list[ToolCall], expected_tool_calls: list[ToolCall]
):
assert len(actual_tool_calls) == len(expected_tool_calls)
for actual_tool_call, expected_tool_call in zip(
actual_tool_calls, expected_tool_calls
):
assert actual_tool_call.type == "function"
assert actual_tool_call.function == expected_tool_call.function
# assert tool call id format: should contain function name and numeric index
# Format can be either "functions.func_name:0" or "func_name:0"
assert actual_tool_call.id.split(":")[-1].isdigit()
assert (
actual_tool_call.id.split(":")[0].split(".")[-1]
== expected_tool_call.function.name
)
def test_extract_tool_calls_no_tools(kimi_k2_tool_parser):
model_output = "This is a test"
extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls(
model_output, request=None
) # type: ignore[arg-type]
assert not extracted_tool_calls.tools_called
assert extracted_tool_calls.tool_calls == []
assert extracted_tool_calls.content == model_output
@pytest.mark.parametrize(
ids=[
"tool_call_with_content_before",
"multi_tool_call_with_content_before",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
(
"""I'll help you check the weather. <|tool_calls_section_begin|> <|tool_call_begin|>
functions.get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing"} <|tool_call_end|> <|tool_calls_section_end|>""",
[
ToolCall(
id="functions.get_weather:0",
function=FunctionCall(
name="get_weather",
arguments=json.dumps(
{
"city": "Beijing",
},
),
),
type="function",
)
],
"I'll help you check the weather. ",
),
(
"""I'll help you check the weather. <|tool_calls_section_begin|> <|tool_call_begin|>
functions.get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing"} <|tool_call_end|> <|tool_call_begin|>
functions.get_weather:1 <|tool_call_argument_begin|> {"city": "Shanghai"} <|tool_call_end|> <|tool_calls_section_end|>""",
[
ToolCall(
id="functions.get_weather:0",
function=FunctionCall(
name="get_weather",
arguments=json.dumps(
{
"city": "Beijing",
},
),
),
type="function",
),
ToolCall(
id="functions.get_weather:1",
function=FunctionCall(
name="get_weather",
arguments=json.dumps(
{
"city": "Shanghai",
},
),
),
type="function",
),
],
"I'll help you check the weather. ",
),
],
)
def test_extract_tool_calls(
kimi_k2_tool_parser, model_output, expected_tool_calls, expected_content
):
extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls(
model_output, request=None
) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
assert extracted_tool_calls.content == expected_content
def test_extract_tool_calls_invalid_json(kimi_k2_tool_parser):
"""we'll return every funcall result"""
model_output = """I'll help you check the weather. <|tool_calls_section_begin|> <|tool_call_begin|>
functions.invalid_get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing" <|tool_call_end|> <|tool_call_begin|>
functions.valid_get_weather:1 <|tool_call_argument_begin|> {"city": "Shanghai"} <|tool_call_end|> <|tool_calls_section_end|>"""
extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls(
model_output, request=None
) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
# Should extract only the valid JSON tool calls
assert len(extracted_tool_calls.tool_calls) == 2
assert extracted_tool_calls.tool_calls[0].function.name == "invalid_get_weather"
assert extracted_tool_calls.tool_calls[1].function.name == "valid_get_weather"
def test_extract_tool_calls_invalid_funcall(kimi_k2_tool_parser):
"""we'll return every funcall result"""
model_output = """I'll help you check the weather. <|tool_calls_section_begin|> <|tool_call_begin|>
functions.invalid_get_weather.0 <|tool_call_argument_begin|> {"city": "Beijing"} <|tool_call_end|> <|tool_call_begin|>
functions.valid_get_weather:1 <|tool_call_argument_begin|> {"city": "Shanghai"} <|tool_call_end|> <|tool_calls_section_end|>"""
extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls(
model_output, request=None
) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
# Should extract only the valid JSON tool calls
assert len(extracted_tool_calls.tool_calls) == 1
assert extracted_tool_calls.tool_calls[0].function.name == "valid_get_weather"
def test_streaming_basic_functionality(kimi_k2_tool_parser):
"""Test basic streaming functionality."""
# Reset streaming state
kimi_k2_tool_parser.current_tool_name_sent = False
kimi_k2_tool_parser.prev_tool_call_arr = []
kimi_k2_tool_parser.current_tool_id = -1
kimi_k2_tool_parser.streamed_args_for_tool = []
# Test with a simple tool call
current_text = """ check the weather. <|tool_calls_section_begin|> <|tool_call_begin|>
functions.get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing"} <|tool_call_end|> <|tool_calls_section_end|>"""
# First call should handle the initial setup
result = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="I'll help you",
current_text=current_text,
delta_text="<|tool_calls_section_end|>",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=None,
)
# The result might be None or contain tool call information
# This depends on the internal state management
if result is not None and hasattr(result, "tool_calls") and result.tool_calls:
assert len(result.tool_calls) >= 0
def test_streaming_no_tool_calls(kimi_k2_tool_parser):
"""Test streaming when there are no tool calls."""
current_text = "This is just regular text without any tool calls."
result = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="This is just regular text",
current_text=current_text,
delta_text=" without any tool calls.",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=None,
)
# Should return the delta text as content
assert result is not None
assert hasattr(result, "content")
assert result.content == " without any tool calls."