diff --git a/tests/entrypoints/openai/tool_parsers/test_llama3_json_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_llama3_json_tool_parser.py new file mode 100644 index 0000000000000..09726c7e3e5b5 --- /dev/null +++ b/tests/entrypoints/openai/tool_parsers/test_llama3_json_tool_parser.py @@ -0,0 +1,132 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +from transformers import AutoTokenizer + +from vllm.entrypoints.openai.protocol import ExtractedToolCallInformation +from vllm.entrypoints.openai.tool_parsers.llama_tool_parser import ( + Llama3JsonToolParser) + + +@pytest.fixture +def parser(): + # Use a small tokenizer for testing + tokenizer = AutoTokenizer.from_pretrained("gpt2") + return Llama3JsonToolParser(tokenizer) + + +def test_extract_tool_calls_simple(parser): + # Test with a simple tool call + model_output = ('Here is the result: {"name": "getOpenIncidentsTool", ' + '"parameters": {}} Would you like to know more?') + result = parser.extract_tool_calls(model_output, None) + + assert isinstance(result, ExtractedToolCallInformation) + assert result.tools_called is True + assert len(result.tool_calls) == 1 + assert result.tool_calls[0].type == "function" + assert result.tool_calls[0].function.name == "getOpenIncidentsTool" + assert result.tool_calls[0].function.arguments == "{}" + assert result.content is None + + +def test_extract_tool_calls_with_arguments(parser): + # Test with a tool call that has arguments + model_output = ( + '{"name": "searchTool", "parameters": {"query": "test query", ' + '"limit": 10}}') + result = parser.extract_tool_calls(model_output, None) + + assert result.tools_called is True + assert len(result.tool_calls) == 1 + assert result.tool_calls[0].function.name == "searchTool" + assert '"query": "test query"' in result.tool_calls[0].function.arguments + assert '"limit": 10' in result.tool_calls[0].function.arguments + + +def test_extract_tool_calls_no_json(parser): + # Test with text that doesn't contain a JSON object + model_output = "This is just some text without any tool calls" + result = parser.extract_tool_calls(model_output, None) + + assert result.tools_called is False + assert len(result.tool_calls) == 0 + assert result.content == model_output + + +def test_extract_tool_calls_invalid_json(parser): + # Test with invalid JSON + model_output = '{"name": "invalidTool", "parameters": {invalid json}' + result = parser.extract_tool_calls(model_output, None) + + assert result.tools_called is False + assert len(result.tool_calls) == 0 + assert result.content == model_output + + +def test_extract_tool_calls_with_arguments_key(parser): + # Test with a tool call that uses "arguments" instead of "parameters" + model_output = '{"name": "searchTool", "arguments": {"query": "test"}}' + result = parser.extract_tool_calls(model_output, None) + + assert result.tools_called is True + assert len(result.tool_calls) == 1 + assert result.tool_calls[0].function.name == "searchTool" + assert '"query": "test"' in result.tool_calls[0].function.arguments + + +def test_extract_tool_calls_multiple_json(parser): + # Test with multiple JSONs separated by semicolons + model_output = ( + '{"name": "searchTool", "parameters": {"query": "test1"}}; ' + '{"name": "getOpenIncidentsTool", "parameters": {}}; ' + '{"name": "searchTool", "parameters": {"query": "test2"}}') + result = parser.extract_tool_calls(model_output, None) + + assert result.tools_called is True + assert len(result.tool_calls) == 3 + + # Check first tool call + assert result.tool_calls[0].function.name == "searchTool" + assert '"query": "test1"' in result.tool_calls[0].function.arguments + + # Check second tool call + assert result.tool_calls[1].function.name == "getOpenIncidentsTool" + assert result.tool_calls[1].function.arguments == "{}" + + # Check third tool call + assert result.tool_calls[2].function.name == "searchTool" + assert '"query": "test2"' in result.tool_calls[2].function.arguments + + +def test_extract_tool_calls_multiple_json_with_whitespace(parser): + # Test with multiple JSONs separated by semicolons and extra whitespace + model_output = ( + '{"name": "searchTool", "parameters": {"query": "test1"}} ; ' + '{"name": "getOpenIncidentsTool", "parameters": {}} ; ' + '{"name": "searchTool", "parameters": {"query": "test2"}}') + result = parser.extract_tool_calls(model_output, None) + + assert result.tools_called is True + assert len(result.tool_calls) == 3 + assert result.tool_calls[0].function.name == "searchTool" + assert result.tool_calls[1].function.name == "getOpenIncidentsTool" + assert result.tool_calls[2].function.name == "searchTool" + + +def test_extract_tool_calls_multiple_json_with_surrounding_text(parser): + # Test with multiple JSONs and surrounding text + model_output = ( + 'Here are the results: ' + '{"name": "searchTool", "parameters": {"query": "test1"}}; ' + '{"name": "getOpenIncidentsTool", "parameters": {}}; ' + '{"name": "searchTool", "parameters": {"query": "test2"}} ' + 'Would you like to know more?') + result = parser.extract_tool_calls(model_output, None) + + assert result.tools_called is True + assert len(result.tool_calls) == 3 + assert result.tool_calls[0].function.name == "searchTool" + assert result.tool_calls[1].function.name == "getOpenIncidentsTool" + assert result.tool_calls[2].function.name == "searchTool" diff --git a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py index 5698bc70af23b..194a144ad576e 100644 --- a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py @@ -3,7 +3,6 @@ import json from collections.abc import Sequence -from json import JSONDecoder from typing import Union import partial_json_parser @@ -31,11 +30,11 @@ logger = init_logger(__name__) @ToolParserManager.register_module("llama4_json") class Llama3JsonToolParser(ToolParser): """ - Tool call parser for Llama 3.1 models intended for use with the + Tool call parser for Llama 3.x and 4 models intended for use with the examples/tool_chat_template_llama.jinja template. - Used when --enable-auto-tool-choice --tool-call-parser llama3_json - are all set + Used when --enable-auto-tool-choice --tool-call-parser llama3_json or + llama4_json are set. """ def __init__(self, tokenizer: PreTrainedTokenizerBase): @@ -51,54 +50,57 @@ class Llama3JsonToolParser(ToolParser): self.bot_token = "<|python_tag|>" self.bot_token_id = tokenizer.encode(self.bot_token, add_special_tokens=False)[0] - self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL) + # Updated regex to match multiple JSONs separated by semicolons + # This pattern is more robust and can handle nested JSON objects + self.tool_call_regex = re.compile( + r'{[^{}]*(?:{[^{}]*}[^{}]*)*}(?:\s*;\s*{[^{}]*(?:{[^{}]*}[^{}]*)*})*', + re.DOTALL) def extract_tool_calls( self, model_output: str, request: ChatCompletionRequest) -> ExtractedToolCallInformation: """ Extract the tool calls from a complete model response. + Only extracts JSON content and ignores any surrounding plain text. + Supports both single JSON and multiple JSONs separated by semicolons. """ - # case -- if a tool call token is not present, return a text response - if not (model_output.startswith(self.bot_token) - or model_output.startswith('{')): + # Quick check before running regex + if not (self.bot_token in model_output or '{' in model_output): + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + # Find JSON object(s) in the text using regex + match = self.tool_call_regex.search(model_output) + if not match: return ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=model_output) try: - # load the JSON, and then use it to build the Function and - # Tool Call - dec = JSONDecoder() - function_call_arr = [] + json_str = match.group(0) + # Split by semicolon and strip whitespace + json_objects = [obj.strip() for obj in json_str.split(';')] - # depending on the prompt format the Llama model may or may not - # prefix the output with the <|python_tag|> token - start_idx = len(self.bot_token) if model_output.startswith( - self.bot_token) else 0 - while start_idx < len(model_output): - (obj, end_idx) = dec.raw_decode(model_output[start_idx:]) - start_idx += end_idx + len('; ') - function_call_arr.append(obj) + tool_calls: list[ToolCall] = [] + for json_obj in json_objects: + if not json_obj: # Skip empty strings + continue + obj = json.loads(json_obj) + tool_calls.append( + ToolCall( + type="function", + function=FunctionCall( + name=obj["name"], + # function call args are JSON but as a string + arguments=json.dumps( + obj["arguments"] + if "arguments" in obj else obj["parameters"], + ensure_ascii=False)))) - tool_calls: list[ToolCall] = [ - ToolCall( - type="function", - function=FunctionCall( - name=raw_function_call["name"], - # function call args are JSON but as a string - arguments=json.dumps(raw_function_call["arguments"] \ - if "arguments" in raw_function_call \ - else raw_function_call["parameters"], - ensure_ascii=False))) - for raw_function_call in function_call_arr - ] - - # get any content before the tool call - ret = ExtractedToolCallInformation(tools_called=True, - tool_calls=tool_calls, - content=None) - return ret + return ExtractedToolCallInformation(tools_called=True, + tool_calls=tool_calls, + content=None) except Exception: logger.exception("Error in extracting tool call from response.")