mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 12:05:48 +08:00
[Bugfix] Improve JSON extraction in LlamaToolParser (#19024)
Signed-off-by: keru <keyang.ru@oracle.com> Co-authored-by: keru <keyang.ru@oracle.com>
This commit is contained in:
parent
656c24f1b5
commit
9ace2eaf35
@ -0,0 +1,132 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from vllm.entrypoints.openai.protocol import ExtractedToolCallInformation
|
||||||
|
from vllm.entrypoints.openai.tool_parsers.llama_tool_parser import (
|
||||||
|
Llama3JsonToolParser)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def parser():
|
||||||
|
# Use a small tokenizer for testing
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||||
|
return Llama3JsonToolParser(tokenizer)
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_tool_calls_simple(parser):
|
||||||
|
# Test with a simple tool call
|
||||||
|
model_output = ('Here is the result: {"name": "getOpenIncidentsTool", '
|
||||||
|
'"parameters": {}} Would you like to know more?')
|
||||||
|
result = parser.extract_tool_calls(model_output, None)
|
||||||
|
|
||||||
|
assert isinstance(result, ExtractedToolCallInformation)
|
||||||
|
assert result.tools_called is True
|
||||||
|
assert len(result.tool_calls) == 1
|
||||||
|
assert result.tool_calls[0].type == "function"
|
||||||
|
assert result.tool_calls[0].function.name == "getOpenIncidentsTool"
|
||||||
|
assert result.tool_calls[0].function.arguments == "{}"
|
||||||
|
assert result.content is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_tool_calls_with_arguments(parser):
|
||||||
|
# Test with a tool call that has arguments
|
||||||
|
model_output = (
|
||||||
|
'{"name": "searchTool", "parameters": {"query": "test query", '
|
||||||
|
'"limit": 10}}')
|
||||||
|
result = parser.extract_tool_calls(model_output, None)
|
||||||
|
|
||||||
|
assert result.tools_called is True
|
||||||
|
assert len(result.tool_calls) == 1
|
||||||
|
assert result.tool_calls[0].function.name == "searchTool"
|
||||||
|
assert '"query": "test query"' in result.tool_calls[0].function.arguments
|
||||||
|
assert '"limit": 10' in result.tool_calls[0].function.arguments
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_tool_calls_no_json(parser):
|
||||||
|
# Test with text that doesn't contain a JSON object
|
||||||
|
model_output = "This is just some text without any tool calls"
|
||||||
|
result = parser.extract_tool_calls(model_output, None)
|
||||||
|
|
||||||
|
assert result.tools_called is False
|
||||||
|
assert len(result.tool_calls) == 0
|
||||||
|
assert result.content == model_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_tool_calls_invalid_json(parser):
|
||||||
|
# Test with invalid JSON
|
||||||
|
model_output = '{"name": "invalidTool", "parameters": {invalid json}'
|
||||||
|
result = parser.extract_tool_calls(model_output, None)
|
||||||
|
|
||||||
|
assert result.tools_called is False
|
||||||
|
assert len(result.tool_calls) == 0
|
||||||
|
assert result.content == model_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_tool_calls_with_arguments_key(parser):
|
||||||
|
# Test with a tool call that uses "arguments" instead of "parameters"
|
||||||
|
model_output = '{"name": "searchTool", "arguments": {"query": "test"}}'
|
||||||
|
result = parser.extract_tool_calls(model_output, None)
|
||||||
|
|
||||||
|
assert result.tools_called is True
|
||||||
|
assert len(result.tool_calls) == 1
|
||||||
|
assert result.tool_calls[0].function.name == "searchTool"
|
||||||
|
assert '"query": "test"' in result.tool_calls[0].function.arguments
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_tool_calls_multiple_json(parser):
|
||||||
|
# Test with multiple JSONs separated by semicolons
|
||||||
|
model_output = (
|
||||||
|
'{"name": "searchTool", "parameters": {"query": "test1"}}; '
|
||||||
|
'{"name": "getOpenIncidentsTool", "parameters": {}}; '
|
||||||
|
'{"name": "searchTool", "parameters": {"query": "test2"}}')
|
||||||
|
result = parser.extract_tool_calls(model_output, None)
|
||||||
|
|
||||||
|
assert result.tools_called is True
|
||||||
|
assert len(result.tool_calls) == 3
|
||||||
|
|
||||||
|
# Check first tool call
|
||||||
|
assert result.tool_calls[0].function.name == "searchTool"
|
||||||
|
assert '"query": "test1"' in result.tool_calls[0].function.arguments
|
||||||
|
|
||||||
|
# Check second tool call
|
||||||
|
assert result.tool_calls[1].function.name == "getOpenIncidentsTool"
|
||||||
|
assert result.tool_calls[1].function.arguments == "{}"
|
||||||
|
|
||||||
|
# Check third tool call
|
||||||
|
assert result.tool_calls[2].function.name == "searchTool"
|
||||||
|
assert '"query": "test2"' in result.tool_calls[2].function.arguments
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_tool_calls_multiple_json_with_whitespace(parser):
|
||||||
|
# Test with multiple JSONs separated by semicolons and extra whitespace
|
||||||
|
model_output = (
|
||||||
|
'{"name": "searchTool", "parameters": {"query": "test1"}} ; '
|
||||||
|
'{"name": "getOpenIncidentsTool", "parameters": {}} ; '
|
||||||
|
'{"name": "searchTool", "parameters": {"query": "test2"}}')
|
||||||
|
result = parser.extract_tool_calls(model_output, None)
|
||||||
|
|
||||||
|
assert result.tools_called is True
|
||||||
|
assert len(result.tool_calls) == 3
|
||||||
|
assert result.tool_calls[0].function.name == "searchTool"
|
||||||
|
assert result.tool_calls[1].function.name == "getOpenIncidentsTool"
|
||||||
|
assert result.tool_calls[2].function.name == "searchTool"
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_tool_calls_multiple_json_with_surrounding_text(parser):
|
||||||
|
# Test with multiple JSONs and surrounding text
|
||||||
|
model_output = (
|
||||||
|
'Here are the results: '
|
||||||
|
'{"name": "searchTool", "parameters": {"query": "test1"}}; '
|
||||||
|
'{"name": "getOpenIncidentsTool", "parameters": {}}; '
|
||||||
|
'{"name": "searchTool", "parameters": {"query": "test2"}} '
|
||||||
|
'Would you like to know more?')
|
||||||
|
result = parser.extract_tool_calls(model_output, None)
|
||||||
|
|
||||||
|
assert result.tools_called is True
|
||||||
|
assert len(result.tool_calls) == 3
|
||||||
|
assert result.tool_calls[0].function.name == "searchTool"
|
||||||
|
assert result.tool_calls[1].function.name == "getOpenIncidentsTool"
|
||||||
|
assert result.tool_calls[2].function.name == "searchTool"
|
||||||
@ -3,7 +3,6 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from json import JSONDecoder
|
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import partial_json_parser
|
import partial_json_parser
|
||||||
@ -31,11 +30,11 @@ logger = init_logger(__name__)
|
|||||||
@ToolParserManager.register_module("llama4_json")
|
@ToolParserManager.register_module("llama4_json")
|
||||||
class Llama3JsonToolParser(ToolParser):
|
class Llama3JsonToolParser(ToolParser):
|
||||||
"""
|
"""
|
||||||
Tool call parser for Llama 3.1 models intended for use with the
|
Tool call parser for Llama 3.x and 4 models intended for use with the
|
||||||
examples/tool_chat_template_llama.jinja template.
|
examples/tool_chat_template_llama.jinja template.
|
||||||
|
|
||||||
Used when --enable-auto-tool-choice --tool-call-parser llama3_json
|
Used when --enable-auto-tool-choice --tool-call-parser llama3_json or
|
||||||
are all set
|
llama4_json are set.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, tokenizer: PreTrainedTokenizerBase):
|
def __init__(self, tokenizer: PreTrainedTokenizerBase):
|
||||||
@ -51,54 +50,57 @@ class Llama3JsonToolParser(ToolParser):
|
|||||||
self.bot_token = "<|python_tag|>"
|
self.bot_token = "<|python_tag|>"
|
||||||
self.bot_token_id = tokenizer.encode(self.bot_token,
|
self.bot_token_id = tokenizer.encode(self.bot_token,
|
||||||
add_special_tokens=False)[0]
|
add_special_tokens=False)[0]
|
||||||
self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL)
|
# Updated regex to match multiple JSONs separated by semicolons
|
||||||
|
# This pattern is more robust and can handle nested JSON objects
|
||||||
|
self.tool_call_regex = re.compile(
|
||||||
|
r'{[^{}]*(?:{[^{}]*}[^{}]*)*}(?:\s*;\s*{[^{}]*(?:{[^{}]*}[^{}]*)*})*',
|
||||||
|
re.DOTALL)
|
||||||
|
|
||||||
def extract_tool_calls(
|
def extract_tool_calls(
|
||||||
self, model_output: str,
|
self, model_output: str,
|
||||||
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
|
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
|
||||||
"""
|
"""
|
||||||
Extract the tool calls from a complete model response.
|
Extract the tool calls from a complete model response.
|
||||||
|
Only extracts JSON content and ignores any surrounding plain text.
|
||||||
|
Supports both single JSON and multiple JSONs separated by semicolons.
|
||||||
"""
|
"""
|
||||||
# case -- if a tool call token is not present, return a text response
|
# Quick check before running regex
|
||||||
if not (model_output.startswith(self.bot_token)
|
if not (self.bot_token in model_output or '{' in model_output):
|
||||||
or model_output.startswith('{')):
|
return ExtractedToolCallInformation(tools_called=False,
|
||||||
|
tool_calls=[],
|
||||||
|
content=model_output)
|
||||||
|
|
||||||
|
# Find JSON object(s) in the text using regex
|
||||||
|
match = self.tool_call_regex.search(model_output)
|
||||||
|
if not match:
|
||||||
return ExtractedToolCallInformation(tools_called=False,
|
return ExtractedToolCallInformation(tools_called=False,
|
||||||
tool_calls=[],
|
tool_calls=[],
|
||||||
content=model_output)
|
content=model_output)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# load the JSON, and then use it to build the Function and
|
json_str = match.group(0)
|
||||||
# Tool Call
|
# Split by semicolon and strip whitespace
|
||||||
dec = JSONDecoder()
|
json_objects = [obj.strip() for obj in json_str.split(';')]
|
||||||
function_call_arr = []
|
|
||||||
|
|
||||||
# depending on the prompt format the Llama model may or may not
|
tool_calls: list[ToolCall] = []
|
||||||
# prefix the output with the <|python_tag|> token
|
for json_obj in json_objects:
|
||||||
start_idx = len(self.bot_token) if model_output.startswith(
|
if not json_obj: # Skip empty strings
|
||||||
self.bot_token) else 0
|
continue
|
||||||
while start_idx < len(model_output):
|
obj = json.loads(json_obj)
|
||||||
(obj, end_idx) = dec.raw_decode(model_output[start_idx:])
|
tool_calls.append(
|
||||||
start_idx += end_idx + len('; ')
|
|
||||||
function_call_arr.append(obj)
|
|
||||||
|
|
||||||
tool_calls: list[ToolCall] = [
|
|
||||||
ToolCall(
|
ToolCall(
|
||||||
type="function",
|
type="function",
|
||||||
function=FunctionCall(
|
function=FunctionCall(
|
||||||
name=raw_function_call["name"],
|
name=obj["name"],
|
||||||
# function call args are JSON but as a string
|
# function call args are JSON but as a string
|
||||||
arguments=json.dumps(raw_function_call["arguments"] \
|
arguments=json.dumps(
|
||||||
if "arguments" in raw_function_call \
|
obj["arguments"]
|
||||||
else raw_function_call["parameters"],
|
if "arguments" in obj else obj["parameters"],
|
||||||
ensure_ascii=False)))
|
ensure_ascii=False))))
|
||||||
for raw_function_call in function_call_arr
|
|
||||||
]
|
|
||||||
|
|
||||||
# get any content before the tool call
|
return ExtractedToolCallInformation(tools_called=True,
|
||||||
ret = ExtractedToolCallInformation(tools_called=True,
|
|
||||||
tool_calls=tool_calls,
|
tool_calls=tool_calls,
|
||||||
content=None)
|
content=None)
|
||||||
return ret
|
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error in extracting tool call from response.")
|
logger.exception("Error in extracting tool call from response.")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user