diff --git a/docs/features/tool_calling.md b/docs/features/tool_calling.md index b6dfbf10b4568..c77fe44659790 100644 --- a/docs/features/tool_calling.md +++ b/docs/features/tool_calling.md @@ -376,6 +376,19 @@ Supported models: Flags: `--tool-call-parser olmo3` +### Gigachat 3 Models (`gigachat3`) + +Use chat template from the Hugging Face model files. + +Supported models: + +* `ai-sage/GigaChat3-702B-A36B-preview` +* `ai-sage/GigaChat3-702B-A36B-preview-bf16` +* `ai-sage/GigaChat3-10B-A1.8B` +* `ai-sage/GigaChat3-10B-A1.8B-bf16` + +Flags: `--tool-call-parser gigachat3` + ### Models with Pythonic Tool Calls (`pythonic`) A growing number of models output a python list to represent tool calls instead of using JSON. This has the advantage of inherently supporting parallel tool calls and removing ambiguity around the JSON schema required for tool calls. The `pythonic` tool parser can support such models. diff --git a/tests/entrypoints/openai/tool_parsers/test_gigachat3_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_gigachat3_tool_parser.py new file mode 100644 index 0000000000000..02c5189d0f6c1 --- /dev/null +++ b/tests/entrypoints/openai/tool_parsers/test_gigachat3_tool_parser.py @@ -0,0 +1,176 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json + +import pytest + +from tests.entrypoints.openai.tool_parsers.utils import ( + run_tool_extraction, + run_tool_extraction_streaming, +) +from vllm.entrypoints.openai.protocol import FunctionCall +from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager +from vllm.tokenizers import TokenizerLike + +SIMPLE_ARGS_DICT = { + "action": "create", + "id": "preferences", +} +SIMPLE_FUNCTION_JSON = json.dumps( + { + "name": "manage_user_memory", + "arguments": SIMPLE_ARGS_DICT, + }, + ensure_ascii=False, +) +SIMPLE_FUNCTION_OUTPUT = "function call" + SIMPLE_FUNCTION_JSON +SIMPLE_FUNCTION_CALL = FunctionCall( + name="manage_user_memory", + arguments=json.dumps(SIMPLE_ARGS_DICT, ensure_ascii=False), +) + + +PARAMETERLESS_FUNCTION_JSON = json.dumps( + { + "name": "manage_user_memory", + "arguments": {}, + }, + ensure_ascii=False, +) +PARAMETERLESS_FUNCTION_OUTPUT = "function call" + PARAMETERLESS_FUNCTION_JSON +PARAMETERLESS_FUNCTION_CALL = FunctionCall( + name="manage_user_memory", + arguments=json.dumps({}, ensure_ascii=False), +) + + +COMPLEX_ARGS_DICT = { + "action": "create", + "id": "preferences", + "content": { + "short_answers": True, + "hate_emojis": True, + "english_ui": False, + "russian_math_explanations": True, + }, +} +COMPLEX_FUNCTION_JSON = json.dumps( + { + "name": "manage_user_memory", + "arguments": COMPLEX_ARGS_DICT, + }, + ensure_ascii=False, +) +COMPLEX_FUNCTION_OUTPUT = "function call" + COMPLEX_FUNCTION_JSON +COMPLEX_FUNCTION_CALL = FunctionCall( + name="manage_user_memory", + arguments=json.dumps(COMPLEX_ARGS_DICT, ensure_ascii=False), +) + + +@pytest.mark.parametrize("streaming", [True, False]) +def test_no_tool_call(streaming: bool, default_tokenizer: TokenizerLike): + tool_parser: ToolParser = ToolParserManager.get_tool_parser("gigachat3")( + default_tokenizer + ) + model_output = "How can I help you today?" + content, tool_calls = run_tool_extraction( + tool_parser, model_output, streaming=streaming + ) + assert content == model_output + assert len(tool_calls) == 0 + + +TEST_CASES = [ + pytest.param( + True, + SIMPLE_FUNCTION_OUTPUT, + [SIMPLE_FUNCTION_CALL], + None, + id="simple_streaming", + ), + pytest.param( + False, + SIMPLE_FUNCTION_OUTPUT, + [SIMPLE_FUNCTION_CALL], + None, + id="simple_nonstreaming", + ), + pytest.param( + True, + PARAMETERLESS_FUNCTION_OUTPUT, + [PARAMETERLESS_FUNCTION_CALL], + None, + id="parameterless_streaming", + ), + pytest.param( + False, + PARAMETERLESS_FUNCTION_OUTPUT, + [PARAMETERLESS_FUNCTION_CALL], + None, + id="parameterless_nonstreaming", + ), + pytest.param( + True, + COMPLEX_FUNCTION_OUTPUT, + [COMPLEX_FUNCTION_CALL], + None, + id="complex_streaming", + ), + pytest.param( + False, + COMPLEX_FUNCTION_OUTPUT, + [COMPLEX_FUNCTION_CALL], + None, + id="complex_nonstreaming", + ), +] + + +@pytest.mark.parametrize( + "streaming, model_output, expected_tool_calls, expected_content", TEST_CASES +) +def test_tool_call( + streaming: bool, + model_output: str, + expected_tool_calls: list[FunctionCall], + expected_content: str | None, + default_tokenizer: TokenizerLike, +): + tool_parser: ToolParser = ToolParserManager.get_tool_parser("gigachat3")( + default_tokenizer + ) + content, tool_calls = run_tool_extraction( + tool_parser, model_output, streaming=streaming + ) + assert content == expected_content + assert len(tool_calls) == len(expected_tool_calls) + for actual, expected in zip(tool_calls, expected_tool_calls): + assert actual.type == "function" + assert actual.function.name == expected.name + actual_args = json.loads(actual.function.arguments) + expected_args = json.loads(expected.arguments) + assert actual_args == expected_args + + +def test_streaming_tool_call_with_large_steps(default_tokenizer: TokenizerLike): + tool_parser: ToolParser = ToolParserManager.get_tool_parser("gigachat3")( + default_tokenizer + ) + model_output_deltas = [ + "function call", + COMPLEX_FUNCTION_JSON[:40], + COMPLEX_FUNCTION_JSON[40:], + ] + reconstructor = run_tool_extraction_streaming( + tool_parser, + model_output_deltas, + assert_one_tool_per_delta=False, + ) + assert len(reconstructor.tool_calls) == 1 + call = reconstructor.tool_calls[0] + assert call.type == "function" + assert call.function.name == "manage_user_memory" + args_dict = json.loads(call.function.arguments) + assert args_dict == COMPLEX_ARGS_DICT diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py index ed43ea7eec82f..7be1263e802dc 100644 --- a/vllm/entrypoints/openai/tool_parsers/__init__.py +++ b/vllm/entrypoints/openai/tool_parsers/__init__.py @@ -134,6 +134,10 @@ _TOOL_PARSERS_TO_REGISTER = { "xlam_tool_parser", "xLAMToolParser", ), + "gigachat3": ( + "gigachat3_tool_parser", + "GigaChat3ToolParser", + ), } diff --git a/vllm/entrypoints/openai/tool_parsers/gigachat3_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/gigachat3_tool_parser.py new file mode 100644 index 0000000000000..dd27ffa83cfc4 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/gigachat3_tool_parser.py @@ -0,0 +1,190 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json +from collections.abc import Sequence + +import regex as re + +from vllm.entrypoints.chat_utils import make_tool_call_id +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ToolParser +from vllm.logger import init_logger +from vllm.tokenizers import TokenizerLike + +logger = init_logger(__name__) + +REGEX_FUNCTION_CALL = re.compile( + r"function call(?:<\|role_sep\|>\n)?(\{.*)", + re.DOTALL, +) + +NAME_REGEX = re.compile( + r'"name"\s*:\s*"([^"]*)"', + re.DOTALL, +) + +ARGS_REGEX = re.compile( + r'"arguments"\s*:\s*(.*)', + re.DOTALL, +) + + +class GigaChat3ToolParser(ToolParser): + def __init__(self, tokenizer: TokenizerLike): + super().__init__(tokenizer) + self.tool_started: bool = False + self.tool_name_sent: bool = False + self.tool_id: str | None = None + self.prev_tool_call_arr: list[dict] = [] + self.content_buffer: str = "" + self.trigger_start = "function call{" + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + match = REGEX_FUNCTION_CALL.search(model_output) + if not match: + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=model_output, + ) + json_candidate = match.group(1).strip() + try: + data = json.loads(json_candidate) + except json.JSONDecodeError: + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=model_output, + ) + if not (isinstance(data, dict) and "name" in data and "arguments" in data): + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=model_output, + ) + name = data["name"] + args = data["arguments"] + if not isinstance(args, str): + args = json.dumps(args, ensure_ascii=False) + + tool_calls = [ + ToolCall( + type="function", + function=FunctionCall( + name=name, + arguments=args, + ), + ) + ] + prefix = model_output[: match.start()] + content = prefix.rstrip() if prefix and prefix.strip() else None + + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=content, + ) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> DeltaMessage | None: + func_name = None + cur_args = None + if not self.tool_started: + match = REGEX_FUNCTION_CALL.search(current_text) + if match: + self.tool_started = True + self.content_buffer = "" + else: + self.content_buffer += delta_text + clean_buffer = self.content_buffer.lstrip() + is_prefix = self.trigger_start.startswith(clean_buffer) + starts_with_trigger = clean_buffer.startswith(self.trigger_start) + if is_prefix or starts_with_trigger: + return None + else: + flush_text = self.content_buffer + self.content_buffer = "" + return DeltaMessage(content=flush_text) + + match = REGEX_FUNCTION_CALL.search(current_text) + if not match: + return None + json_tail = match.group(1).strip() + name_match = NAME_REGEX.search(json_tail) + if name_match: + func_name = name_match.group(1) + args_match = ARGS_REGEX.search(json_tail) + if args_match: + cur_args = args_match.group(1).strip() + if cur_args.endswith("}"): # last '}' end of json + try: + candidate = cur_args[:-1].strip() + json.loads(candidate) + cur_args = candidate + except json.JSONDecodeError: + pass + if not self.prev_tool_call_arr: + self.prev_tool_call_arr.append({}) + if not self.tool_name_sent: + if not func_name: + return None + self.tool_name_sent = True + self.tool_id = make_tool_call_id() + self.prev_tool_call_arr[0]["name"] = func_name + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=0, + id=self.tool_id, + type="function", + function=DeltaFunctionCall( + name=func_name, + ).model_dump(exclude_none=True), + ) + ], + content=None, + ) + if cur_args is None: + return None + prev_args = self.prev_tool_call_arr[0].get("arguments", "") + if not prev_args: + delta_args = cur_args + elif cur_args.startswith(prev_args): + delta_args = cur_args[len(prev_args) :] + else: + return None + if not delta_args: + return None + self.prev_tool_call_arr[0]["arguments"] = cur_args + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=0, + function=DeltaFunctionCall( + arguments=delta_args, + ).model_dump(exclude_none=True), + ) + ], + content=None, + )