Gigachat 3 tool parser and tests (#29905)

Signed-off-by: Viacheslav Barinov <viacheslav.teh@gmail.com>
This commit is contained in:
Viacheslav 2025-12-06 15:04:14 +03:00 committed by GitHub
parent 17a9abec2b
commit 21bb323542
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 383 additions and 0 deletions

View File

@ -376,6 +376,19 @@ Supported models:
Flags: `--tool-call-parser olmo3`
### Gigachat 3 Models (`gigachat3`)
Use chat template from the Hugging Face model files.
Supported models:
* `ai-sage/GigaChat3-702B-A36B-preview`
* `ai-sage/GigaChat3-702B-A36B-preview-bf16`
* `ai-sage/GigaChat3-10B-A1.8B`
* `ai-sage/GigaChat3-10B-A1.8B-bf16`
Flags: `--tool-call-parser gigachat3`
### Models with Pythonic Tool Calls (`pythonic`)
A growing number of models output a python list to represent tool calls instead of using JSON. This has the advantage of inherently supporting parallel tool calls and removing ambiguity around the JSON schema required for tool calls. The `pythonic` tool parser can support such models.

View File

@ -0,0 +1,176 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import pytest
from tests.entrypoints.openai.tool_parsers.utils import (
run_tool_extraction,
run_tool_extraction_streaming,
)
from vllm.entrypoints.openai.protocol import FunctionCall
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.tokenizers import TokenizerLike
SIMPLE_ARGS_DICT = {
"action": "create",
"id": "preferences",
}
SIMPLE_FUNCTION_JSON = json.dumps(
{
"name": "manage_user_memory",
"arguments": SIMPLE_ARGS_DICT,
},
ensure_ascii=False,
)
SIMPLE_FUNCTION_OUTPUT = "function call" + SIMPLE_FUNCTION_JSON
SIMPLE_FUNCTION_CALL = FunctionCall(
name="manage_user_memory",
arguments=json.dumps(SIMPLE_ARGS_DICT, ensure_ascii=False),
)
PARAMETERLESS_FUNCTION_JSON = json.dumps(
{
"name": "manage_user_memory",
"arguments": {},
},
ensure_ascii=False,
)
PARAMETERLESS_FUNCTION_OUTPUT = "function call" + PARAMETERLESS_FUNCTION_JSON
PARAMETERLESS_FUNCTION_CALL = FunctionCall(
name="manage_user_memory",
arguments=json.dumps({}, ensure_ascii=False),
)
COMPLEX_ARGS_DICT = {
"action": "create",
"id": "preferences",
"content": {
"short_answers": True,
"hate_emojis": True,
"english_ui": False,
"russian_math_explanations": True,
},
}
COMPLEX_FUNCTION_JSON = json.dumps(
{
"name": "manage_user_memory",
"arguments": COMPLEX_ARGS_DICT,
},
ensure_ascii=False,
)
COMPLEX_FUNCTION_OUTPUT = "function call" + COMPLEX_FUNCTION_JSON
COMPLEX_FUNCTION_CALL = FunctionCall(
name="manage_user_memory",
arguments=json.dumps(COMPLEX_ARGS_DICT, ensure_ascii=False),
)
@pytest.mark.parametrize("streaming", [True, False])
def test_no_tool_call(streaming: bool, default_tokenizer: TokenizerLike):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("gigachat3")(
default_tokenizer
)
model_output = "How can I help you today?"
content, tool_calls = run_tool_extraction(
tool_parser, model_output, streaming=streaming
)
assert content == model_output
assert len(tool_calls) == 0
TEST_CASES = [
pytest.param(
True,
SIMPLE_FUNCTION_OUTPUT,
[SIMPLE_FUNCTION_CALL],
None,
id="simple_streaming",
),
pytest.param(
False,
SIMPLE_FUNCTION_OUTPUT,
[SIMPLE_FUNCTION_CALL],
None,
id="simple_nonstreaming",
),
pytest.param(
True,
PARAMETERLESS_FUNCTION_OUTPUT,
[PARAMETERLESS_FUNCTION_CALL],
None,
id="parameterless_streaming",
),
pytest.param(
False,
PARAMETERLESS_FUNCTION_OUTPUT,
[PARAMETERLESS_FUNCTION_CALL],
None,
id="parameterless_nonstreaming",
),
pytest.param(
True,
COMPLEX_FUNCTION_OUTPUT,
[COMPLEX_FUNCTION_CALL],
None,
id="complex_streaming",
),
pytest.param(
False,
COMPLEX_FUNCTION_OUTPUT,
[COMPLEX_FUNCTION_CALL],
None,
id="complex_nonstreaming",
),
]
@pytest.mark.parametrize(
"streaming, model_output, expected_tool_calls, expected_content", TEST_CASES
)
def test_tool_call(
streaming: bool,
model_output: str,
expected_tool_calls: list[FunctionCall],
expected_content: str | None,
default_tokenizer: TokenizerLike,
):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("gigachat3")(
default_tokenizer
)
content, tool_calls = run_tool_extraction(
tool_parser, model_output, streaming=streaming
)
assert content == expected_content
assert len(tool_calls) == len(expected_tool_calls)
for actual, expected in zip(tool_calls, expected_tool_calls):
assert actual.type == "function"
assert actual.function.name == expected.name
actual_args = json.loads(actual.function.arguments)
expected_args = json.loads(expected.arguments)
assert actual_args == expected_args
def test_streaming_tool_call_with_large_steps(default_tokenizer: TokenizerLike):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("gigachat3")(
default_tokenizer
)
model_output_deltas = [
"function call",
COMPLEX_FUNCTION_JSON[:40],
COMPLEX_FUNCTION_JSON[40:],
]
reconstructor = run_tool_extraction_streaming(
tool_parser,
model_output_deltas,
assert_one_tool_per_delta=False,
)
assert len(reconstructor.tool_calls) == 1
call = reconstructor.tool_calls[0]
assert call.type == "function"
assert call.function.name == "manage_user_memory"
args_dict = json.loads(call.function.arguments)
assert args_dict == COMPLEX_ARGS_DICT

View File

@ -134,6 +134,10 @@ _TOOL_PARSERS_TO_REGISTER = {
"xlam_tool_parser",
"xLAMToolParser",
),
"gigachat3": (
"gigachat3_tool_parser",
"GigaChat3ToolParser",
),
}

View File

@ -0,0 +1,190 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from collections.abc import Sequence
import regex as re
from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ToolParser
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__)
REGEX_FUNCTION_CALL = re.compile(
r"function call(?:<\|role_sep\|>\n)?(\{.*)",
re.DOTALL,
)
NAME_REGEX = re.compile(
r'"name"\s*:\s*"([^"]*)"',
re.DOTALL,
)
ARGS_REGEX = re.compile(
r'"arguments"\s*:\s*(.*)',
re.DOTALL,
)
class GigaChat3ToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
self.tool_started: bool = False
self.tool_name_sent: bool = False
self.tool_id: str | None = None
self.prev_tool_call_arr: list[dict] = []
self.content_buffer: str = ""
self.trigger_start = "function call{"
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
match = REGEX_FUNCTION_CALL.search(model_output)
if not match:
return ExtractedToolCallInformation(
tools_called=False,
tool_calls=[],
content=model_output,
)
json_candidate = match.group(1).strip()
try:
data = json.loads(json_candidate)
except json.JSONDecodeError:
return ExtractedToolCallInformation(
tools_called=False,
tool_calls=[],
content=model_output,
)
if not (isinstance(data, dict) and "name" in data and "arguments" in data):
return ExtractedToolCallInformation(
tools_called=False,
tool_calls=[],
content=model_output,
)
name = data["name"]
args = data["arguments"]
if not isinstance(args, str):
args = json.dumps(args, ensure_ascii=False)
tool_calls = [
ToolCall(
type="function",
function=FunctionCall(
name=name,
arguments=args,
),
)
]
prefix = model_output[: match.start()]
content = prefix.rstrip() if prefix and prefix.strip() else None
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=content,
)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
func_name = None
cur_args = None
if not self.tool_started:
match = REGEX_FUNCTION_CALL.search(current_text)
if match:
self.tool_started = True
self.content_buffer = ""
else:
self.content_buffer += delta_text
clean_buffer = self.content_buffer.lstrip()
is_prefix = self.trigger_start.startswith(clean_buffer)
starts_with_trigger = clean_buffer.startswith(self.trigger_start)
if is_prefix or starts_with_trigger:
return None
else:
flush_text = self.content_buffer
self.content_buffer = ""
return DeltaMessage(content=flush_text)
match = REGEX_FUNCTION_CALL.search(current_text)
if not match:
return None
json_tail = match.group(1).strip()
name_match = NAME_REGEX.search(json_tail)
if name_match:
func_name = name_match.group(1)
args_match = ARGS_REGEX.search(json_tail)
if args_match:
cur_args = args_match.group(1).strip()
if cur_args.endswith("}"): # last '}' end of json
try:
candidate = cur_args[:-1].strip()
json.loads(candidate)
cur_args = candidate
except json.JSONDecodeError:
pass
if not self.prev_tool_call_arr:
self.prev_tool_call_arr.append({})
if not self.tool_name_sent:
if not func_name:
return None
self.tool_name_sent = True
self.tool_id = make_tool_call_id()
self.prev_tool_call_arr[0]["name"] = func_name
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=0,
id=self.tool_id,
type="function",
function=DeltaFunctionCall(
name=func_name,
).model_dump(exclude_none=True),
)
],
content=None,
)
if cur_args is None:
return None
prev_args = self.prev_tool_call_arr[0].get("arguments", "")
if not prev_args:
delta_args = cur_args
elif cur_args.startswith(prev_args):
delta_args = cur_args[len(prev_args) :]
else:
return None
if not delta_args:
return None
self.prev_tool_call_arr[0]["arguments"] = cur_args
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=0,
function=DeltaFunctionCall(
arguments=delta_args,
).model_dump(exclude_none=True),
)
],
content=None,
)