diff --git a/tests/tool_use/test_kimi_k2_tool_parser.py b/tests/tool_use/test_kimi_k2_tool_parser.py new file mode 100644 index 0000000000000..8768203a7119a --- /dev/null +++ b/tests/tool_use/test_kimi_k2_tool_parser.py @@ -0,0 +1,195 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# ruff: noqa: E501 + +import json + +import pytest + +from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall +from vllm.entrypoints.openai.tool_parsers import KimiK2ToolParser +from vllm.transformers_utils.tokenizer import get_tokenizer + +pytest.skip("skip kimi_k2 parser test", allow_module_level=True) + +# Use a common model that is likely to be available +MODEL = "moonshotai/Kimi-K2-Instruct" + + +@pytest.fixture(scope="module") +def kimi_k2_tokenizer(): + return get_tokenizer(tokenizer_name=MODEL, trust_remote_code=True) + + +@pytest.fixture +def kimi_k2_tool_parser(kimi_k2_tokenizer): + return KimiK2ToolParser(kimi_k2_tokenizer) + + +def assert_tool_calls(actual_tool_calls: list[ToolCall], + expected_tool_calls: list[ToolCall]): + assert len(actual_tool_calls) == len(expected_tool_calls) + + for actual_tool_call, expected_tool_call in zip(actual_tool_calls, + expected_tool_calls): + + assert actual_tool_call.type == "function" + assert actual_tool_call.function == expected_tool_call.function + + # assert tool call id format + assert actual_tool_call.id.startswith("functions.") + assert actual_tool_call.id.split(':')[-1].isdigit() + assert actual_tool_call.id.split('.')[1].split( + ':')[0] == expected_tool_call.function.name + + +def test_extract_tool_calls_no_tools(kimi_k2_tool_parser): + model_output = "This is a test" + extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls( + model_output, request=None) # type: ignore[arg-type] + assert not extracted_tool_calls.tools_called + assert extracted_tool_calls.tool_calls == [] + assert extracted_tool_calls.content == model_output + + +@pytest.mark.parametrize( + ids=[ + "tool_call_with_content_before", + "multi_tool_call_with_content_before", + ], + argnames=["model_output", "expected_tool_calls", "expected_content"], + argvalues=[ + ( + """I'll help you check the weather. <|tool_calls_section_begin|> <|tool_call_begin|> +functions.get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing"} <|tool_call_end|> <|tool_calls_section_end|>""", + [ + ToolCall(id='functions.get_weather:0', + function=FunctionCall( + name="get_weather", + arguments=json.dumps({ + "city": "Beijing", + }, ), + ), + type='function') + ], + "I'll help you check the weather. ", + ), + ( + """I'll help you check the weather. <|tool_calls_section_begin|> <|tool_call_begin|> +functions.get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing"} <|tool_call_end|> <|tool_call_begin|> +functions.get_weather:1 <|tool_call_argument_begin|> {"city": "Shanghai"} <|tool_call_end|> <|tool_calls_section_end|>""", + [ + ToolCall(id='functions.get_weather:0', + function=FunctionCall( + name="get_weather", + arguments=json.dumps({ + "city": "Beijing", + }, ), + ), + type='function'), + ToolCall(id='functions.get_weather:1', + function=FunctionCall( + name="get_weather", + arguments=json.dumps({ + "city": "Shanghai", + }, ), + ), + type='function') + ], + "I'll help you check the weather. ", + ), + ], +) +def test_extract_tool_calls(kimi_k2_tool_parser, model_output, + expected_tool_calls, expected_content): + extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls( + model_output, request=None) # type: ignore[arg-type] + assert extracted_tool_calls.tools_called + + assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls) + + assert extracted_tool_calls.content == expected_content + + +def test_extract_tool_calls_invalid_json(kimi_k2_tool_parser): + """we'll return every funcall result""" + model_output = """I'll help you check the weather. <|tool_calls_section_begin|> <|tool_call_begin|> +functions.invalid_get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing" <|tool_call_end|> <|tool_call_begin|> +functions.valid_get_weather:1 <|tool_call_argument_begin|> {"city": "Shanghai"} <|tool_call_end|> <|tool_calls_section_end|>""" + + extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls( + model_output, request=None) # type: ignore[arg-type] + + assert extracted_tool_calls.tools_called + # Should extract only the valid JSON tool calls + assert len(extracted_tool_calls.tool_calls) == 2 + assert extracted_tool_calls.tool_calls[ + 0].function.name == "invalid_get_weather" + assert extracted_tool_calls.tool_calls[ + 1].function.name == "valid_get_weather" + + +def test_extract_tool_calls_invalid_funcall(kimi_k2_tool_parser): + """we'll return every funcall result""" + model_output = """I'll help you check the weather. <|tool_calls_section_begin|> <|tool_call_begin|> +functions.invalid_get_weather.0 <|tool_call_argument_begin|> {"city": "Beijing"} <|tool_call_end|> <|tool_call_begin|> +functions.valid_get_weather:1 <|tool_call_argument_begin|> {"city": "Shanghai"} <|tool_call_end|> <|tool_calls_section_end|>""" + + extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls( + model_output, request=None) # type: ignore[arg-type] + + assert extracted_tool_calls.tools_called + # Should extract only the valid JSON tool calls + assert len(extracted_tool_calls.tool_calls) == 1 + assert extracted_tool_calls.tool_calls[ + 0].function.name == "valid_get_weather" + + +def test_streaming_basic_functionality(kimi_k2_tool_parser): + """Test basic streaming functionality.""" + # Reset streaming state + kimi_k2_tool_parser.current_tool_name_sent = False + kimi_k2_tool_parser.prev_tool_call_arr = [] + kimi_k2_tool_parser.current_tool_id = -1 + kimi_k2_tool_parser.streamed_args_for_tool = [] + + # Test with a simple tool call + current_text = """ check the weather. <|tool_calls_section_begin|> <|tool_call_begin|> +functions.get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing"} <|tool_call_end|> <|tool_calls_section_end|>""" + + # First call should handle the initial setup + result = kimi_k2_tool_parser.extract_tool_calls_streaming( + previous_text="I'll help you", + current_text=current_text, + delta_text="<|tool_calls_section_end|>", + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[], + request=None, + ) + + # The result might be None or contain tool call information + # This depends on the internal state management + if result is not None and hasattr(result, + 'tool_calls') and result.tool_calls: + assert len(result.tool_calls) >= 0 + + +def test_streaming_no_tool_calls(kimi_k2_tool_parser): + """Test streaming when there are no tool calls.""" + current_text = "This is just regular text without any tool calls." + + result = kimi_k2_tool_parser.extract_tool_calls_streaming( + previous_text="This is just regular text", + current_text=current_text, + delta_text=" without any tool calls.", + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[], + request=None, + ) + + # Should return the delta text as content + assert result is not None + assert hasattr(result, 'content') + assert result.content == " without any tool calls." diff --git a/vllm/config.py b/vllm/config.py index 1a3ff9d42ff61..90a0ad37e1cd1 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1143,7 +1143,7 @@ class ModelConfig: if not hasattr(self.hf_text_config, "model_type"): return False elif self.hf_text_config.model_type in \ - ('deepseek_v2', 'deepseek_v3', 'deepseek_mtp'): + ('deepseek_v2', 'deepseek_v3', 'deepseek_mtp', 'kimi_k2'): return self.hf_text_config.kv_lora_rank is not None elif self.hf_text_config.model_type == 'eagle': # if the model is an EAGLE module, check for the diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py index 57e675515e12e..218a120a5bb01 100644 --- a/vllm/entrypoints/openai/tool_parsers/__init__.py +++ b/vllm/entrypoints/openai/tool_parsers/__init__.py @@ -8,6 +8,7 @@ from .granite_tool_parser import GraniteToolParser from .hermes_tool_parser import Hermes2ProToolParser from .internlm2_tool_parser import Internlm2ToolParser from .jamba_tool_parser import JambaToolParser +from .kimi_k2_tool_parser import KimiK2ToolParser from .llama4_pythonic_tool_parser import Llama4PythonicToolParser from .llama_tool_parser import Llama3JsonToolParser from .minimax_tool_parser import MinimaxToolParser @@ -21,5 +22,6 @@ __all__ = [ "GraniteToolParser", "Hermes2ProToolParser", "MistralToolParser", "Internlm2ToolParser", "Llama3JsonToolParser", "JambaToolParser", "Llama4PythonicToolParser", "PythonicToolParser", "Phi4MiniJsonToolParser", - "DeepSeekV3ToolParser", "xLAMToolParser", "MinimaxToolParser" + "DeepSeekV3ToolParser", "xLAMToolParser", "MinimaxToolParser", + "KimiK2ToolParser" ] diff --git a/vllm/entrypoints/openai/tool_parsers/kimi_k2_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/kimi_k2_tool_parser.py new file mode 100644 index 0000000000000..b0df442dd8644 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/kimi_k2_tool_parser.py @@ -0,0 +1,377 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# code modified from deepseekv3_tool_parser.py + +from collections.abc import Sequence +from typing import Union + +import regex as re + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer + +logger = init_logger(__name__) + + +@ToolParserManager.register_module(["kimi_k2"]) +class KimiK2ToolParser(ToolParser): + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + self.current_tool_name_sent: bool = False + self.prev_tool_call_arr: list[dict] = [] + self.current_tool_id: int = -1 + self.streamed_args_for_tool: list[str] = ( + []) # map what has been streamed for each tool so far to a list + + self.tool_calls_start_token: str = "<|tool_calls_section_begin|>" + self.tool_calls_end_token: str = "<|tool_calls_section_end|>" + + self.tool_call_start_token: str = "<|tool_call_begin|>" + self.tool_call_end_token: str = "<|tool_call_end|>" + + self.tool_call_regex = re.compile( + r"<\|tool_call_begin\|>\s*(?P[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P.*?)\s*<\|tool_call_end\|>" + ) + + self.stream_tool_call_portion_regex = re.compile( + r"(?P[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P.*)" + ) + + self.stream_tool_call_name_regex = re.compile( + r"(?P[\w\.]+:\d+)\s*") + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ToolParser " + "constructor during construction.") + self.tool_calls_start_token_id = self.vocab.get( + self.tool_calls_start_token) + self.tool_calls_end_token_id = self.vocab.get( + self.tool_calls_end_token) + + self.tool_call_start_token_id = self.vocab.get( + self.tool_call_start_token) + self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) + + if (self.tool_calls_start_token_id is None + or self.tool_calls_end_token_id is None): + raise RuntimeError( + "Kimi-K2 Tool parser could not locate tool call start/end " + "tokens in the tokenizer!") + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + + # sanity check; avoid unnecessary processing + if self.tool_calls_start_token not in model_output: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + else: + try: + # there are two possible captures - between tags, or between a + # tag and end-of-string so the result of + # findall is an array of tuples where one is a function call and + # the other is None + function_call_tuples = self.tool_call_regex.findall( + model_output) + + logger.debug("function_call_tuples: %s", function_call_tuples) + + tool_calls = [] + for match in function_call_tuples: + function_id, function_args = match + # function_id: functions.get_weather:0 + function_name = function_id.split('.')[1].split(':')[0] + tool_calls.append( + ToolCall( + id=function_id, + type='function', + function=FunctionCall(name=function_name, + arguments=function_args), + )) + + content = model_output[:model_output. + find(self.tool_calls_start_token)] + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=content if content else None, + ) + + except Exception: + logger.exception( + "Error in extracting tool call from response.") + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + + logger.debug("delta_text: %s", delta_text) + logger.debug("delta_token_ids: %s", delta_token_ids) + # check to see if we should be streaming a tool call - is there a + if self.tool_calls_start_token_id not in current_token_ids: + logger.debug("No tool call tokens found!") + return DeltaMessage(content=delta_text) + delta_text = delta_text.replace(self.tool_calls_start_token, + "").replace(self.tool_calls_end_token, + "") + try: + + # figure out where we are in the parsing by counting tool call + # start & end tags + prev_tool_start_count = previous_token_ids.count( + self.tool_call_start_token_id) + prev_tool_end_count = previous_token_ids.count( + self.tool_call_end_token_id) + cur_tool_start_count = current_token_ids.count( + self.tool_call_start_token_id) + cur_tool_end_count = current_token_ids.count( + self.tool_call_end_token_id) + tool_call_portion = None + text_portion = None + + # case: if we're generating text, OR rounding out a tool call + if (cur_tool_start_count == cur_tool_end_count + and prev_tool_end_count == cur_tool_end_count + and self.tool_call_end_token not in delta_text): + logger.debug("Generating text content! skipping tool parsing.") + return DeltaMessage(content=delta_text) + + if self.tool_call_end_token in delta_text: + logger.debug("tool_call_end_token in delta_text") + full_text = current_text + delta_text + tool_call_portion = full_text.split( + self.tool_call_start_token)[-1].split( + self.tool_call_end_token)[0].rstrip() + delta_text = delta_text.split( + self.tool_call_end_token)[0].rstrip() + text_portion = delta_text.split( + self.tool_call_end_token)[-1].lstrip() + + # case -- we're starting a new tool call + if (cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count > prev_tool_start_count): + if len(delta_token_ids) > 1: + tool_call_portion = current_text.split( + self.tool_call_start_token)[-1] + else: + tool_call_portion = None + delta = None + + text_portion = None + + # set cursors and state appropriately + self.current_tool_id += 1 + self.current_tool_name_sent = False + self.streamed_args_for_tool.append("") + logger.debug("Starting on a new tool %s", self.current_tool_id) + + # case -- we're updating an existing tool call + elif (cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count == prev_tool_start_count): + + # get the portion of the text that's the tool call + tool_call_portion = current_text.split( + self.tool_call_start_token)[-1] + text_portion = None + + # case -- the current tool call is being closed. + elif (cur_tool_start_count == cur_tool_end_count + and cur_tool_end_count >= prev_tool_end_count): + if self.prev_tool_call_arr is None or len( + self.prev_tool_call_arr) == 0: + logger.debug( + "attempting to close tool call, but no tool call") + return None + diff = self.prev_tool_call_arr[self.current_tool_id].get( + "arguments") + if diff: + diff = (diff.encode("utf-8").decode("unicode_escape") + if diff is str else diff) + if '"}' not in delta_text: + return None + end_loc = delta_text.rindex('"}') + diff = delta_text[:end_loc] + '"}' + logger.debug( + "Finishing tool and found diff that had not " + "been streamed yet: %s", + diff, + ) + self.streamed_args_for_tool[self.current_tool_id] += diff + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=diff).model_dump(exclude_none=True), + ) + ]) + + # case -- otherwise we're just generating text + else: + text = delta_text.replace(self.tool_call_start_token, "") + text = text.replace(self.tool_call_end_token, "") + delta = DeltaMessage(tool_calls=[], content=text) + return delta + + current_tool_call = dict() + if tool_call_portion: + current_tool_call_matches = ( + self.stream_tool_call_portion_regex.match( + tool_call_portion)) + if current_tool_call_matches: + tool_id, tool_args = (current_tool_call_matches.groups()) + tool_name = tool_id.split('.')[1].split(':')[0] + current_tool_call['id'] = tool_id + current_tool_call["name"] = tool_name + current_tool_call["arguments"] = tool_args + else: + current_tool_call_name_matches = ( + self.stream_tool_call_name_regex.match( + tool_call_portion)) + if current_tool_call_name_matches: + tool_id_str, = current_tool_call_name_matches.groups() + tool_name = tool_id_str.split('.')[1].split(':')[0] + current_tool_call['id'] = tool_id_str + current_tool_call["name"] = tool_name + current_tool_call["arguments"] = "" + else: + logger.debug("Not enough token") + return None + + # case - we haven't sent the tool name yet. If it's available, send + # it. otherwise, wait until it's available. + if not self.current_tool_name_sent: + if current_tool_call is None: + return None + function_name: Union[str, None] = current_tool_call.get("name") + tool_id = current_tool_call.get("id") + if function_name: + self.current_tool_name_sent = True + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + type="function", + id=tool_id, + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True), + ) + ]) + else: + return None + + # case -- otherwise, send the tool call delta + + # if the tool call portion is None, send the delta as text + if tool_call_portion is None: + # if there's text but not tool calls, send that - + # otherwise None to skip chunk + delta = (DeltaMessage( + content=delta_text) if text_portion is not None else None) + return delta + + # now, the nitty-gritty of tool calls + # now we have the portion to parse as tool call. + + logger.debug("Trying to parse current tool call with ID %s", + self.current_tool_id) + + # if we're starting a new tool call, push an empty object in as + # a placeholder for the arguments + if len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + + # main logic for tool parsing here - compare prev. partially-parsed + # JSON to the current partially-parsed JSON + prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get( + "arguments") + cur_arguments = current_tool_call.get("arguments") + + logger.debug("diffing old arguments: %s", prev_arguments) + logger.debug("against new ones: %s", cur_arguments) + + # case -- no arguments have been created yet. skip sending a delta. + if not cur_arguments and not prev_arguments: + logger.debug("Skipping text %s - no arguments", delta_text) + delta = None + + # case -- prev arguments are defined, but non are now. + # probably impossible, but not a fatal error - just keep going + elif not cur_arguments and prev_arguments: + logger.error("should be impossible to have arguments reset " + "mid-call. skipping streaming anything.") + delta = None + + # case -- we now have the first info about arguments available from + # autocompleting the JSON + elif cur_arguments and not prev_arguments: + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=cur_arguments).model_dump( + exclude_none=True), + ) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] = cur_arguments + + # last case -- we have an update to existing arguments. + elif cur_arguments and prev_arguments: + if (isinstance(delta_text, str) + and cur_arguments != prev_arguments + and len(cur_arguments) > len(prev_arguments) + and cur_arguments.startswith(prev_arguments)): + delta_arguments = cur_arguments[len(prev_arguments):] + logger.debug("got diff %s", delta_text) + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=delta_arguments).model_dump( + exclude_none=True), + ) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] = cur_arguments + else: + delta = None + + # handle saving the state for the current tool into + # the "prev" list for use in diffing for the next iteration + if self.current_tool_id == len(self.prev_tool_call_arr) - 1: + self.prev_tool_call_arr[ + self.current_tool_id] = current_tool_call + else: + self.prev_tool_call_arr.append(current_tool_call) + + return delta + + except Exception: + logger.exception("Error trying to handle streaming tool call.") + return None # do not stream a delta. skip this token ID. \ No newline at end of file