mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 15:54:31 +08:00
Gigachat 3 tool parser and tests (#29905)
Signed-off-by: Viacheslav Barinov <viacheslav.teh@gmail.com>
This commit is contained in:
parent
17a9abec2b
commit
21bb323542
@ -376,6 +376,19 @@ Supported models:
|
||||
|
||||
Flags: `--tool-call-parser olmo3`
|
||||
|
||||
### Gigachat 3 Models (`gigachat3`)
|
||||
|
||||
Use chat template from the Hugging Face model files.
|
||||
|
||||
Supported models:
|
||||
|
||||
* `ai-sage/GigaChat3-702B-A36B-preview`
|
||||
* `ai-sage/GigaChat3-702B-A36B-preview-bf16`
|
||||
* `ai-sage/GigaChat3-10B-A1.8B`
|
||||
* `ai-sage/GigaChat3-10B-A1.8B-bf16`
|
||||
|
||||
Flags: `--tool-call-parser gigachat3`
|
||||
|
||||
### Models with Pythonic Tool Calls (`pythonic`)
|
||||
|
||||
A growing number of models output a python list to represent tool calls instead of using JSON. This has the advantage of inherently supporting parallel tool calls and removing ambiguity around the JSON schema required for tool calls. The `pythonic` tool parser can support such models.
|
||||
|
||||
@ -0,0 +1,176 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.entrypoints.openai.tool_parsers.utils import (
|
||||
run_tool_extraction,
|
||||
run_tool_extraction_streaming,
|
||||
)
|
||||
from vllm.entrypoints.openai.protocol import FunctionCall
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
SIMPLE_ARGS_DICT = {
|
||||
"action": "create",
|
||||
"id": "preferences",
|
||||
}
|
||||
SIMPLE_FUNCTION_JSON = json.dumps(
|
||||
{
|
||||
"name": "manage_user_memory",
|
||||
"arguments": SIMPLE_ARGS_DICT,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
SIMPLE_FUNCTION_OUTPUT = "function call" + SIMPLE_FUNCTION_JSON
|
||||
SIMPLE_FUNCTION_CALL = FunctionCall(
|
||||
name="manage_user_memory",
|
||||
arguments=json.dumps(SIMPLE_ARGS_DICT, ensure_ascii=False),
|
||||
)
|
||||
|
||||
|
||||
PARAMETERLESS_FUNCTION_JSON = json.dumps(
|
||||
{
|
||||
"name": "manage_user_memory",
|
||||
"arguments": {},
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
PARAMETERLESS_FUNCTION_OUTPUT = "function call" + PARAMETERLESS_FUNCTION_JSON
|
||||
PARAMETERLESS_FUNCTION_CALL = FunctionCall(
|
||||
name="manage_user_memory",
|
||||
arguments=json.dumps({}, ensure_ascii=False),
|
||||
)
|
||||
|
||||
|
||||
COMPLEX_ARGS_DICT = {
|
||||
"action": "create",
|
||||
"id": "preferences",
|
||||
"content": {
|
||||
"short_answers": True,
|
||||
"hate_emojis": True,
|
||||
"english_ui": False,
|
||||
"russian_math_explanations": True,
|
||||
},
|
||||
}
|
||||
COMPLEX_FUNCTION_JSON = json.dumps(
|
||||
{
|
||||
"name": "manage_user_memory",
|
||||
"arguments": COMPLEX_ARGS_DICT,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
COMPLEX_FUNCTION_OUTPUT = "function call" + COMPLEX_FUNCTION_JSON
|
||||
COMPLEX_FUNCTION_CALL = FunctionCall(
|
||||
name="manage_user_memory",
|
||||
arguments=json.dumps(COMPLEX_ARGS_DICT, ensure_ascii=False),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming", [True, False])
|
||||
def test_no_tool_call(streaming: bool, default_tokenizer: TokenizerLike):
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("gigachat3")(
|
||||
default_tokenizer
|
||||
)
|
||||
model_output = "How can I help you today?"
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser, model_output, streaming=streaming
|
||||
)
|
||||
assert content == model_output
|
||||
assert len(tool_calls) == 0
|
||||
|
||||
|
||||
TEST_CASES = [
|
||||
pytest.param(
|
||||
True,
|
||||
SIMPLE_FUNCTION_OUTPUT,
|
||||
[SIMPLE_FUNCTION_CALL],
|
||||
None,
|
||||
id="simple_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
SIMPLE_FUNCTION_OUTPUT,
|
||||
[SIMPLE_FUNCTION_CALL],
|
||||
None,
|
||||
id="simple_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
PARAMETERLESS_FUNCTION_OUTPUT,
|
||||
[PARAMETERLESS_FUNCTION_CALL],
|
||||
None,
|
||||
id="parameterless_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
PARAMETERLESS_FUNCTION_OUTPUT,
|
||||
[PARAMETERLESS_FUNCTION_CALL],
|
||||
None,
|
||||
id="parameterless_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
COMPLEX_FUNCTION_OUTPUT,
|
||||
[COMPLEX_FUNCTION_CALL],
|
||||
None,
|
||||
id="complex_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
COMPLEX_FUNCTION_OUTPUT,
|
||||
[COMPLEX_FUNCTION_CALL],
|
||||
None,
|
||||
id="complex_nonstreaming",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"streaming, model_output, expected_tool_calls, expected_content", TEST_CASES
|
||||
)
|
||||
def test_tool_call(
|
||||
streaming: bool,
|
||||
model_output: str,
|
||||
expected_tool_calls: list[FunctionCall],
|
||||
expected_content: str | None,
|
||||
default_tokenizer: TokenizerLike,
|
||||
):
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("gigachat3")(
|
||||
default_tokenizer
|
||||
)
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser, model_output, streaming=streaming
|
||||
)
|
||||
assert content == expected_content
|
||||
assert len(tool_calls) == len(expected_tool_calls)
|
||||
for actual, expected in zip(tool_calls, expected_tool_calls):
|
||||
assert actual.type == "function"
|
||||
assert actual.function.name == expected.name
|
||||
actual_args = json.loads(actual.function.arguments)
|
||||
expected_args = json.loads(expected.arguments)
|
||||
assert actual_args == expected_args
|
||||
|
||||
|
||||
def test_streaming_tool_call_with_large_steps(default_tokenizer: TokenizerLike):
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("gigachat3")(
|
||||
default_tokenizer
|
||||
)
|
||||
model_output_deltas = [
|
||||
"function call",
|
||||
COMPLEX_FUNCTION_JSON[:40],
|
||||
COMPLEX_FUNCTION_JSON[40:],
|
||||
]
|
||||
reconstructor = run_tool_extraction_streaming(
|
||||
tool_parser,
|
||||
model_output_deltas,
|
||||
assert_one_tool_per_delta=False,
|
||||
)
|
||||
assert len(reconstructor.tool_calls) == 1
|
||||
call = reconstructor.tool_calls[0]
|
||||
assert call.type == "function"
|
||||
assert call.function.name == "manage_user_memory"
|
||||
args_dict = json.loads(call.function.arguments)
|
||||
assert args_dict == COMPLEX_ARGS_DICT
|
||||
@ -134,6 +134,10 @@ _TOOL_PARSERS_TO_REGISTER = {
|
||||
"xlam_tool_parser",
|
||||
"xLAMToolParser",
|
||||
),
|
||||
"gigachat3": (
|
||||
"gigachat3_tool_parser",
|
||||
"GigaChat3ToolParser",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
|
||||
190
vllm/entrypoints/openai/tool_parsers/gigachat3_tool_parser.py
Normal file
190
vllm/entrypoints/openai/tool_parsers/gigachat3_tool_parser.py
Normal file
@ -0,0 +1,190 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
|
||||
import regex as re
|
||||
|
||||
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ToolParser
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
REGEX_FUNCTION_CALL = re.compile(
|
||||
r"function call(?:<\|role_sep\|>\n)?(\{.*)",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
NAME_REGEX = re.compile(
|
||||
r'"name"\s*:\s*"([^"]*)"',
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
ARGS_REGEX = re.compile(
|
||||
r'"arguments"\s*:\s*(.*)',
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
|
||||
class GigaChat3ToolParser(ToolParser):
|
||||
def __init__(self, tokenizer: TokenizerLike):
|
||||
super().__init__(tokenizer)
|
||||
self.tool_started: bool = False
|
||||
self.tool_name_sent: bool = False
|
||||
self.tool_id: str | None = None
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
self.content_buffer: str = ""
|
||||
self.trigger_start = "function call{"
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
match = REGEX_FUNCTION_CALL.search(model_output)
|
||||
if not match:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output,
|
||||
)
|
||||
json_candidate = match.group(1).strip()
|
||||
try:
|
||||
data = json.loads(json_candidate)
|
||||
except json.JSONDecodeError:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output,
|
||||
)
|
||||
if not (isinstance(data, dict) and "name" in data and "arguments" in data):
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output,
|
||||
)
|
||||
name = data["name"]
|
||||
args = data["arguments"]
|
||||
if not isinstance(args, str):
|
||||
args = json.dumps(args, ensure_ascii=False)
|
||||
|
||||
tool_calls = [
|
||||
ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=name,
|
||||
arguments=args,
|
||||
),
|
||||
)
|
||||
]
|
||||
prefix = model_output[: match.start()]
|
||||
content = prefix.rstrip() if prefix and prefix.strip() else None
|
||||
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=content,
|
||||
)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> DeltaMessage | None:
|
||||
func_name = None
|
||||
cur_args = None
|
||||
if not self.tool_started:
|
||||
match = REGEX_FUNCTION_CALL.search(current_text)
|
||||
if match:
|
||||
self.tool_started = True
|
||||
self.content_buffer = ""
|
||||
else:
|
||||
self.content_buffer += delta_text
|
||||
clean_buffer = self.content_buffer.lstrip()
|
||||
is_prefix = self.trigger_start.startswith(clean_buffer)
|
||||
starts_with_trigger = clean_buffer.startswith(self.trigger_start)
|
||||
if is_prefix or starts_with_trigger:
|
||||
return None
|
||||
else:
|
||||
flush_text = self.content_buffer
|
||||
self.content_buffer = ""
|
||||
return DeltaMessage(content=flush_text)
|
||||
|
||||
match = REGEX_FUNCTION_CALL.search(current_text)
|
||||
if not match:
|
||||
return None
|
||||
json_tail = match.group(1).strip()
|
||||
name_match = NAME_REGEX.search(json_tail)
|
||||
if name_match:
|
||||
func_name = name_match.group(1)
|
||||
args_match = ARGS_REGEX.search(json_tail)
|
||||
if args_match:
|
||||
cur_args = args_match.group(1).strip()
|
||||
if cur_args.endswith("}"): # last '}' end of json
|
||||
try:
|
||||
candidate = cur_args[:-1].strip()
|
||||
json.loads(candidate)
|
||||
cur_args = candidate
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
if not self.prev_tool_call_arr:
|
||||
self.prev_tool_call_arr.append({})
|
||||
if not self.tool_name_sent:
|
||||
if not func_name:
|
||||
return None
|
||||
self.tool_name_sent = True
|
||||
self.tool_id = make_tool_call_id()
|
||||
self.prev_tool_call_arr[0]["name"] = func_name
|
||||
return DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=0,
|
||||
id=self.tool_id,
|
||||
type="function",
|
||||
function=DeltaFunctionCall(
|
||||
name=func_name,
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
],
|
||||
content=None,
|
||||
)
|
||||
if cur_args is None:
|
||||
return None
|
||||
prev_args = self.prev_tool_call_arr[0].get("arguments", "")
|
||||
if not prev_args:
|
||||
delta_args = cur_args
|
||||
elif cur_args.startswith(prev_args):
|
||||
delta_args = cur_args[len(prev_args) :]
|
||||
else:
|
||||
return None
|
||||
if not delta_args:
|
||||
return None
|
||||
self.prev_tool_call_arr[0]["arguments"] = cur_args
|
||||
return DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=0,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=delta_args,
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
],
|
||||
content=None,
|
||||
)
|
||||
Loading…
x
Reference in New Issue
Block a user