mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-25 15:04:42 +08:00
Add kimi-k2 tool parser (#20789)
Signed-off-by: wangzhengtao <wangzhengtao@moonshot.cn> Co-authored-by: wangzhengtao <wangzhengtao@moonshot.cn> Co-authored-by: wangzhengtao <wangzhengtao@msh.team>
This commit is contained in:
parent
cf75cd2098
commit
0cf893cae1
195
tests/tool_use/test_kimi_k2_tool_parser.py
Normal file
195
tests/tool_use/test_kimi_k2_tool_parser.py
Normal file
@ -0,0 +1,195 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa: E501
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall
|
||||
from vllm.entrypoints.openai.tool_parsers import KimiK2ToolParser
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
pytest.skip("skip kimi_k2 parser test", allow_module_level=True)
|
||||
|
||||
# Use a common model that is likely to be available
|
||||
MODEL = "moonshotai/Kimi-K2-Instruct"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def kimi_k2_tokenizer():
|
||||
return get_tokenizer(tokenizer_name=MODEL, trust_remote_code=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def kimi_k2_tool_parser(kimi_k2_tokenizer):
|
||||
return KimiK2ToolParser(kimi_k2_tokenizer)
|
||||
|
||||
|
||||
def assert_tool_calls(actual_tool_calls: list[ToolCall],
|
||||
expected_tool_calls: list[ToolCall]):
|
||||
assert len(actual_tool_calls) == len(expected_tool_calls)
|
||||
|
||||
for actual_tool_call, expected_tool_call in zip(actual_tool_calls,
|
||||
expected_tool_calls):
|
||||
|
||||
assert actual_tool_call.type == "function"
|
||||
assert actual_tool_call.function == expected_tool_call.function
|
||||
|
||||
# assert tool call id format
|
||||
assert actual_tool_call.id.startswith("functions.")
|
||||
assert actual_tool_call.id.split(':')[-1].isdigit()
|
||||
assert actual_tool_call.id.split('.')[1].split(
|
||||
':')[0] == expected_tool_call.function.name
|
||||
|
||||
|
||||
def test_extract_tool_calls_no_tools(kimi_k2_tool_parser):
|
||||
model_output = "This is a test"
|
||||
extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls(
|
||||
model_output, request=None) # type: ignore[arg-type]
|
||||
assert not extracted_tool_calls.tools_called
|
||||
assert extracted_tool_calls.tool_calls == []
|
||||
assert extracted_tool_calls.content == model_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"tool_call_with_content_before",
|
||||
"multi_tool_call_with_content_before",
|
||||
],
|
||||
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||
argvalues=[
|
||||
(
|
||||
"""I'll help you check the weather. <|tool_calls_section_begin|> <|tool_call_begin|>
|
||||
functions.get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing"} <|tool_call_end|> <|tool_calls_section_end|>""",
|
||||
[
|
||||
ToolCall(id='functions.get_weather:0',
|
||||
function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "Beijing",
|
||||
}, ),
|
||||
),
|
||||
type='function')
|
||||
],
|
||||
"I'll help you check the weather. ",
|
||||
),
|
||||
(
|
||||
"""I'll help you check the weather. <|tool_calls_section_begin|> <|tool_call_begin|>
|
||||
functions.get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing"} <|tool_call_end|> <|tool_call_begin|>
|
||||
functions.get_weather:1 <|tool_call_argument_begin|> {"city": "Shanghai"} <|tool_call_end|> <|tool_calls_section_end|>""",
|
||||
[
|
||||
ToolCall(id='functions.get_weather:0',
|
||||
function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "Beijing",
|
||||
}, ),
|
||||
),
|
||||
type='function'),
|
||||
ToolCall(id='functions.get_weather:1',
|
||||
function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "Shanghai",
|
||||
}, ),
|
||||
),
|
||||
type='function')
|
||||
],
|
||||
"I'll help you check the weather. ",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls(kimi_k2_tool_parser, model_output,
|
||||
expected_tool_calls, expected_content):
|
||||
extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls(
|
||||
model_output, request=None) # type: ignore[arg-type]
|
||||
assert extracted_tool_calls.tools_called
|
||||
|
||||
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
|
||||
|
||||
assert extracted_tool_calls.content == expected_content
|
||||
|
||||
|
||||
def test_extract_tool_calls_invalid_json(kimi_k2_tool_parser):
|
||||
"""we'll return every funcall result"""
|
||||
model_output = """I'll help you check the weather. <|tool_calls_section_begin|> <|tool_call_begin|>
|
||||
functions.invalid_get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing" <|tool_call_end|> <|tool_call_begin|>
|
||||
functions.valid_get_weather:1 <|tool_call_argument_begin|> {"city": "Shanghai"} <|tool_call_end|> <|tool_calls_section_end|>"""
|
||||
|
||||
extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls(
|
||||
model_output, request=None) # type: ignore[arg-type]
|
||||
|
||||
assert extracted_tool_calls.tools_called
|
||||
# Should extract only the valid JSON tool calls
|
||||
assert len(extracted_tool_calls.tool_calls) == 2
|
||||
assert extracted_tool_calls.tool_calls[
|
||||
0].function.name == "invalid_get_weather"
|
||||
assert extracted_tool_calls.tool_calls[
|
||||
1].function.name == "valid_get_weather"
|
||||
|
||||
|
||||
def test_extract_tool_calls_invalid_funcall(kimi_k2_tool_parser):
|
||||
"""we'll return every funcall result"""
|
||||
model_output = """I'll help you check the weather. <|tool_calls_section_begin|> <|tool_call_begin|>
|
||||
functions.invalid_get_weather.0 <|tool_call_argument_begin|> {"city": "Beijing"} <|tool_call_end|> <|tool_call_begin|>
|
||||
functions.valid_get_weather:1 <|tool_call_argument_begin|> {"city": "Shanghai"} <|tool_call_end|> <|tool_calls_section_end|>"""
|
||||
|
||||
extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls(
|
||||
model_output, request=None) # type: ignore[arg-type]
|
||||
|
||||
assert extracted_tool_calls.tools_called
|
||||
# Should extract only the valid JSON tool calls
|
||||
assert len(extracted_tool_calls.tool_calls) == 1
|
||||
assert extracted_tool_calls.tool_calls[
|
||||
0].function.name == "valid_get_weather"
|
||||
|
||||
|
||||
def test_streaming_basic_functionality(kimi_k2_tool_parser):
|
||||
"""Test basic streaming functionality."""
|
||||
# Reset streaming state
|
||||
kimi_k2_tool_parser.current_tool_name_sent = False
|
||||
kimi_k2_tool_parser.prev_tool_call_arr = []
|
||||
kimi_k2_tool_parser.current_tool_id = -1
|
||||
kimi_k2_tool_parser.streamed_args_for_tool = []
|
||||
|
||||
# Test with a simple tool call
|
||||
current_text = """ check the weather. <|tool_calls_section_begin|> <|tool_call_begin|>
|
||||
functions.get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing"} <|tool_call_end|> <|tool_calls_section_end|>"""
|
||||
|
||||
# First call should handle the initial setup
|
||||
result = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="I'll help you",
|
||||
current_text=current_text,
|
||||
delta_text="<|tool_calls_section_end|>",
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
request=None,
|
||||
)
|
||||
|
||||
# The result might be None or contain tool call information
|
||||
# This depends on the internal state management
|
||||
if result is not None and hasattr(result,
|
||||
'tool_calls') and result.tool_calls:
|
||||
assert len(result.tool_calls) >= 0
|
||||
|
||||
|
||||
def test_streaming_no_tool_calls(kimi_k2_tool_parser):
|
||||
"""Test streaming when there are no tool calls."""
|
||||
current_text = "This is just regular text without any tool calls."
|
||||
|
||||
result = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="This is just regular text",
|
||||
current_text=current_text,
|
||||
delta_text=" without any tool calls.",
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
request=None,
|
||||
)
|
||||
|
||||
# Should return the delta text as content
|
||||
assert result is not None
|
||||
assert hasattr(result, 'content')
|
||||
assert result.content == " without any tool calls."
|
||||
@ -1143,7 +1143,7 @@ class ModelConfig:
|
||||
if not hasattr(self.hf_text_config, "model_type"):
|
||||
return False
|
||||
elif self.hf_text_config.model_type in \
|
||||
('deepseek_v2', 'deepseek_v3', 'deepseek_mtp'):
|
||||
('deepseek_v2', 'deepseek_v3', 'deepseek_mtp', 'kimi_k2'):
|
||||
return self.hf_text_config.kv_lora_rank is not None
|
||||
elif self.hf_text_config.model_type == 'eagle':
|
||||
# if the model is an EAGLE module, check for the
|
||||
|
||||
@ -8,6 +8,7 @@ from .granite_tool_parser import GraniteToolParser
|
||||
from .hermes_tool_parser import Hermes2ProToolParser
|
||||
from .internlm2_tool_parser import Internlm2ToolParser
|
||||
from .jamba_tool_parser import JambaToolParser
|
||||
from .kimi_k2_tool_parser import KimiK2ToolParser
|
||||
from .llama4_pythonic_tool_parser import Llama4PythonicToolParser
|
||||
from .llama_tool_parser import Llama3JsonToolParser
|
||||
from .minimax_tool_parser import MinimaxToolParser
|
||||
@ -21,5 +22,6 @@ __all__ = [
|
||||
"GraniteToolParser", "Hermes2ProToolParser", "MistralToolParser",
|
||||
"Internlm2ToolParser", "Llama3JsonToolParser", "JambaToolParser",
|
||||
"Llama4PythonicToolParser", "PythonicToolParser", "Phi4MiniJsonToolParser",
|
||||
"DeepSeekV3ToolParser", "xLAMToolParser", "MinimaxToolParser"
|
||||
"DeepSeekV3ToolParser", "xLAMToolParser", "MinimaxToolParser",
|
||||
"KimiK2ToolParser"
|
||||
]
|
||||
|
||||
377
vllm/entrypoints/openai/tool_parsers/kimi_k2_tool_parser.py
Normal file
377
vllm/entrypoints/openai/tool_parsers/kimi_k2_tool_parser.py
Normal file
@ -0,0 +1,377 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# code modified from deepseekv3_tool_parser.py
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Union
|
||||
|
||||
import regex as re
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ToolParserManager.register_module(["kimi_k2"])
|
||||
class KimiK2ToolParser(ToolParser):
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
super().__init__(tokenizer)
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
self.current_tool_id: int = -1
|
||||
self.streamed_args_for_tool: list[str] = (
|
||||
[]) # map what has been streamed for each tool so far to a list
|
||||
|
||||
self.tool_calls_start_token: str = "<|tool_calls_section_begin|>"
|
||||
self.tool_calls_end_token: str = "<|tool_calls_section_end|>"
|
||||
|
||||
self.tool_call_start_token: str = "<|tool_call_begin|>"
|
||||
self.tool_call_end_token: str = "<|tool_call_end|>"
|
||||
|
||||
self.tool_call_regex = re.compile(
|
||||
r"<\|tool_call_begin\|>\s*(?P<tool_call_id>[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>.*?)\s*<\|tool_call_end\|>"
|
||||
)
|
||||
|
||||
self.stream_tool_call_portion_regex = re.compile(
|
||||
r"(?P<tool_call_id>[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>.*)"
|
||||
)
|
||||
|
||||
self.stream_tool_call_name_regex = re.compile(
|
||||
r"(?P<tool_call_id>[\w\.]+:\d+)\s*")
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ToolParser "
|
||||
"constructor during construction.")
|
||||
self.tool_calls_start_token_id = self.vocab.get(
|
||||
self.tool_calls_start_token)
|
||||
self.tool_calls_end_token_id = self.vocab.get(
|
||||
self.tool_calls_end_token)
|
||||
|
||||
self.tool_call_start_token_id = self.vocab.get(
|
||||
self.tool_call_start_token)
|
||||
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
|
||||
|
||||
if (self.tool_calls_start_token_id is None
|
||||
or self.tool_calls_end_token_id is None):
|
||||
raise RuntimeError(
|
||||
"Kimi-K2 Tool parser could not locate tool call start/end "
|
||||
"tokens in the tokenizer!")
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
|
||||
# sanity check; avoid unnecessary processing
|
||||
if self.tool_calls_start_token not in model_output:
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
else:
|
||||
try:
|
||||
# there are two possible captures - between tags, or between a
|
||||
# tag and end-of-string so the result of
|
||||
# findall is an array of tuples where one is a function call and
|
||||
# the other is None
|
||||
function_call_tuples = self.tool_call_regex.findall(
|
||||
model_output)
|
||||
|
||||
logger.debug("function_call_tuples: %s", function_call_tuples)
|
||||
|
||||
tool_calls = []
|
||||
for match in function_call_tuples:
|
||||
function_id, function_args = match
|
||||
# function_id: functions.get_weather:0
|
||||
function_name = function_id.split('.')[1].split(':')[0]
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
id=function_id,
|
||||
type='function',
|
||||
function=FunctionCall(name=function_name,
|
||||
arguments=function_args),
|
||||
))
|
||||
|
||||
content = model_output[:model_output.
|
||||
find(self.tool_calls_start_token)]
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=content if content else None,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Error in extracting tool call from response.")
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
|
||||
logger.debug("delta_text: %s", delta_text)
|
||||
logger.debug("delta_token_ids: %s", delta_token_ids)
|
||||
# check to see if we should be streaming a tool call - is there a
|
||||
if self.tool_calls_start_token_id not in current_token_ids:
|
||||
logger.debug("No tool call tokens found!")
|
||||
return DeltaMessage(content=delta_text)
|
||||
delta_text = delta_text.replace(self.tool_calls_start_token,
|
||||
"").replace(self.tool_calls_end_token,
|
||||
"")
|
||||
try:
|
||||
|
||||
# figure out where we are in the parsing by counting tool call
|
||||
# start & end tags
|
||||
prev_tool_start_count = previous_token_ids.count(
|
||||
self.tool_call_start_token_id)
|
||||
prev_tool_end_count = previous_token_ids.count(
|
||||
self.tool_call_end_token_id)
|
||||
cur_tool_start_count = current_token_ids.count(
|
||||
self.tool_call_start_token_id)
|
||||
cur_tool_end_count = current_token_ids.count(
|
||||
self.tool_call_end_token_id)
|
||||
tool_call_portion = None
|
||||
text_portion = None
|
||||
|
||||
# case: if we're generating text, OR rounding out a tool call
|
||||
if (cur_tool_start_count == cur_tool_end_count
|
||||
and prev_tool_end_count == cur_tool_end_count
|
||||
and self.tool_call_end_token not in delta_text):
|
||||
logger.debug("Generating text content! skipping tool parsing.")
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
if self.tool_call_end_token in delta_text:
|
||||
logger.debug("tool_call_end_token in delta_text")
|
||||
full_text = current_text + delta_text
|
||||
tool_call_portion = full_text.split(
|
||||
self.tool_call_start_token)[-1].split(
|
||||
self.tool_call_end_token)[0].rstrip()
|
||||
delta_text = delta_text.split(
|
||||
self.tool_call_end_token)[0].rstrip()
|
||||
text_portion = delta_text.split(
|
||||
self.tool_call_end_token)[-1].lstrip()
|
||||
|
||||
# case -- we're starting a new tool call
|
||||
if (cur_tool_start_count > cur_tool_end_count
|
||||
and cur_tool_start_count > prev_tool_start_count):
|
||||
if len(delta_token_ids) > 1:
|
||||
tool_call_portion = current_text.split(
|
||||
self.tool_call_start_token)[-1]
|
||||
else:
|
||||
tool_call_portion = None
|
||||
delta = None
|
||||
|
||||
text_portion = None
|
||||
|
||||
# set cursors and state appropriately
|
||||
self.current_tool_id += 1
|
||||
self.current_tool_name_sent = False
|
||||
self.streamed_args_for_tool.append("")
|
||||
logger.debug("Starting on a new tool %s", self.current_tool_id)
|
||||
|
||||
# case -- we're updating an existing tool call
|
||||
elif (cur_tool_start_count > cur_tool_end_count
|
||||
and cur_tool_start_count == prev_tool_start_count):
|
||||
|
||||
# get the portion of the text that's the tool call
|
||||
tool_call_portion = current_text.split(
|
||||
self.tool_call_start_token)[-1]
|
||||
text_portion = None
|
||||
|
||||
# case -- the current tool call is being closed.
|
||||
elif (cur_tool_start_count == cur_tool_end_count
|
||||
and cur_tool_end_count >= prev_tool_end_count):
|
||||
if self.prev_tool_call_arr is None or len(
|
||||
self.prev_tool_call_arr) == 0:
|
||||
logger.debug(
|
||||
"attempting to close tool call, but no tool call")
|
||||
return None
|
||||
diff = self.prev_tool_call_arr[self.current_tool_id].get(
|
||||
"arguments")
|
||||
if diff:
|
||||
diff = (diff.encode("utf-8").decode("unicode_escape")
|
||||
if diff is str else diff)
|
||||
if '"}' not in delta_text:
|
||||
return None
|
||||
end_loc = delta_text.rindex('"}')
|
||||
diff = delta_text[:end_loc] + '"}'
|
||||
logger.debug(
|
||||
"Finishing tool and found diff that had not "
|
||||
"been streamed yet: %s",
|
||||
diff,
|
||||
)
|
||||
self.streamed_args_for_tool[self.current_tool_id] += diff
|
||||
return DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=diff).model_dump(exclude_none=True),
|
||||
)
|
||||
])
|
||||
|
||||
# case -- otherwise we're just generating text
|
||||
else:
|
||||
text = delta_text.replace(self.tool_call_start_token, "")
|
||||
text = text.replace(self.tool_call_end_token, "")
|
||||
delta = DeltaMessage(tool_calls=[], content=text)
|
||||
return delta
|
||||
|
||||
current_tool_call = dict()
|
||||
if tool_call_portion:
|
||||
current_tool_call_matches = (
|
||||
self.stream_tool_call_portion_regex.match(
|
||||
tool_call_portion))
|
||||
if current_tool_call_matches:
|
||||
tool_id, tool_args = (current_tool_call_matches.groups())
|
||||
tool_name = tool_id.split('.')[1].split(':')[0]
|
||||
current_tool_call['id'] = tool_id
|
||||
current_tool_call["name"] = tool_name
|
||||
current_tool_call["arguments"] = tool_args
|
||||
else:
|
||||
current_tool_call_name_matches = (
|
||||
self.stream_tool_call_name_regex.match(
|
||||
tool_call_portion))
|
||||
if current_tool_call_name_matches:
|
||||
tool_id_str, = current_tool_call_name_matches.groups()
|
||||
tool_name = tool_id_str.split('.')[1].split(':')[0]
|
||||
current_tool_call['id'] = tool_id_str
|
||||
current_tool_call["name"] = tool_name
|
||||
current_tool_call["arguments"] = ""
|
||||
else:
|
||||
logger.debug("Not enough token")
|
||||
return None
|
||||
|
||||
# case - we haven't sent the tool name yet. If it's available, send
|
||||
# it. otherwise, wait until it's available.
|
||||
if not self.current_tool_name_sent:
|
||||
if current_tool_call is None:
|
||||
return None
|
||||
function_name: Union[str, None] = current_tool_call.get("name")
|
||||
tool_id = current_tool_call.get("id")
|
||||
if function_name:
|
||||
self.current_tool_name_sent = True
|
||||
return DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
type="function",
|
||||
id=tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name).model_dump(
|
||||
exclude_none=True),
|
||||
)
|
||||
])
|
||||
else:
|
||||
return None
|
||||
|
||||
# case -- otherwise, send the tool call delta
|
||||
|
||||
# if the tool call portion is None, send the delta as text
|
||||
if tool_call_portion is None:
|
||||
# if there's text but not tool calls, send that -
|
||||
# otherwise None to skip chunk
|
||||
delta = (DeltaMessage(
|
||||
content=delta_text) if text_portion is not None else None)
|
||||
return delta
|
||||
|
||||
# now, the nitty-gritty of tool calls
|
||||
# now we have the portion to parse as tool call.
|
||||
|
||||
logger.debug("Trying to parse current tool call with ID %s",
|
||||
self.current_tool_id)
|
||||
|
||||
# if we're starting a new tool call, push an empty object in as
|
||||
# a placeholder for the arguments
|
||||
if len(self.prev_tool_call_arr) <= self.current_tool_id:
|
||||
self.prev_tool_call_arr.append({})
|
||||
|
||||
# main logic for tool parsing here - compare prev. partially-parsed
|
||||
# JSON to the current partially-parsed JSON
|
||||
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
|
||||
"arguments")
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
|
||||
logger.debug("diffing old arguments: %s", prev_arguments)
|
||||
logger.debug("against new ones: %s", cur_arguments)
|
||||
|
||||
# case -- no arguments have been created yet. skip sending a delta.
|
||||
if not cur_arguments and not prev_arguments:
|
||||
logger.debug("Skipping text %s - no arguments", delta_text)
|
||||
delta = None
|
||||
|
||||
# case -- prev arguments are defined, but non are now.
|
||||
# probably impossible, but not a fatal error - just keep going
|
||||
elif not cur_arguments and prev_arguments:
|
||||
logger.error("should be impossible to have arguments reset "
|
||||
"mid-call. skipping streaming anything.")
|
||||
delta = None
|
||||
|
||||
# case -- we now have the first info about arguments available from
|
||||
# autocompleting the JSON
|
||||
elif cur_arguments and not prev_arguments:
|
||||
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=cur_arguments).model_dump(
|
||||
exclude_none=True),
|
||||
)
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] = cur_arguments
|
||||
|
||||
# last case -- we have an update to existing arguments.
|
||||
elif cur_arguments and prev_arguments:
|
||||
if (isinstance(delta_text, str)
|
||||
and cur_arguments != prev_arguments
|
||||
and len(cur_arguments) > len(prev_arguments)
|
||||
and cur_arguments.startswith(prev_arguments)):
|
||||
delta_arguments = cur_arguments[len(prev_arguments):]
|
||||
logger.debug("got diff %s", delta_text)
|
||||
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=delta_arguments).model_dump(
|
||||
exclude_none=True),
|
||||
)
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] = cur_arguments
|
||||
else:
|
||||
delta = None
|
||||
|
||||
# handle saving the state for the current tool into
|
||||
# the "prev" list for use in diffing for the next iteration
|
||||
if self.current_tool_id == len(self.prev_tool_call_arr) - 1:
|
||||
self.prev_tool_call_arr[
|
||||
self.current_tool_id] = current_tool_call
|
||||
else:
|
||||
self.prev_tool_call_arr.append(current_tool_call)
|
||||
|
||||
return delta
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error trying to handle streaming tool call.")
|
||||
return None # do not stream a delta. skip this token ID.
|
||||
Loading…
x
Reference in New Issue
Block a user