# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # ruff: noqa: E501 import json from collections.abc import Generator import pytest from vllm.entrypoints.openai.protocol import ( ChatCompletionRequest, DeltaMessage, FunctionCall, ToolCall, ) from vllm.entrypoints.openai.tool_parsers.ernie45_tool_parser import Ernie45ToolParser from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer # Use a common model that is likely to be available MODEL = "baidu/ERNIE-4.5-21B-A3B-Thinking" @pytest.fixture(scope="module") def ernie45_tokenizer(): return get_tokenizer(tokenizer_name=MODEL, trust_remote_code=True) @pytest.fixture def ernie45_tool_parser(ernie45_tokenizer): return Ernie45ToolParser(ernie45_tokenizer) def assert_tool_calls( actual_tool_calls: list[ToolCall], expected_tool_calls: list[ToolCall] ): assert len(actual_tool_calls) == len(expected_tool_calls) for actual_tool_call, expected_tool_call in zip( actual_tool_calls, expected_tool_calls ): assert isinstance(actual_tool_call.id, str) assert len(actual_tool_call.id) > 0 assert actual_tool_call.type == "function" assert actual_tool_call.function.name == expected_tool_call.function.name # Compare arguments as JSON objects to handle formatting differences actual_args = json.loads(actual_tool_call.function.arguments) expected_args = json.loads(expected_tool_call.function.arguments) assert actual_args == expected_args def test_extract_tool_calls_no_tools(ernie45_tool_parser): model_output = "This is a test" extracted_tool_calls = ernie45_tool_parser.extract_tool_calls( model_output, request=None ) # type: ignore[arg-type] assert not extracted_tool_calls.tools_called assert extracted_tool_calls.tool_calls == [] assert extracted_tool_calls.content == model_output @pytest.mark.parametrize( ids=[ "single_tool_call", "multiple_tool_calls", "tool_call_with_content_before", ], argnames=["model_output", "expected_tool_calls", "expected_content"], argvalues=[ ( """ {"name": "get_current_temperature", "arguments": {"location": "Beijing"}} """, [ ToolCall( function=FunctionCall( name="get_current_temperature", arguments=json.dumps( { "location": "Beijing", } ), ) ) ], None, ), ( """ {"name": "get_current_temperature", "arguments": {"location": "Beijing"}} {"name": "get_temperature_unit", "arguments": {"location": "Guangzhou", "unit": "c"}} """, [ ToolCall( function=FunctionCall( name="get_current_temperature", arguments=json.dumps( { "location": "Beijing", } ), ) ), ToolCall( function=FunctionCall( name="get_temperature_unit", arguments=json.dumps( { "location": "Guangzhou", "unit": "c", } ), ) ), ], None, ), ( """I need to call two tools to handle these two issues separately. {"name": "get_current_temperature", "arguments": {"location": "Beijing"}} {"name": "get_temperature_unit", "arguments": {"location": "Guangzhou", "unit": "c"}} """, [ ToolCall( function=FunctionCall( name="get_current_temperature", arguments=json.dumps( { "location": "Beijing", } ), ) ), ToolCall( function=FunctionCall( name="get_temperature_unit", arguments=json.dumps( { "location": "Guangzhou", "unit": "c", } ), ) ), ], "I need to call two tools to handle these two issues separately.\n", ), ], ) def test_extract_tool_calls( ernie45_tool_parser, model_output, expected_tool_calls, expected_content ): extracted_tool_calls = ernie45_tool_parser.extract_tool_calls( model_output, request=None ) # type: ignore[arg-type] assert extracted_tool_calls.tools_called assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls) assert extracted_tool_calls.content == expected_content def stream_delta_message_generator( ernie45_tool_parser: Ernie45ToolParser, ernie45_tokenizer: AnyTokenizer, model_output: str, request: ChatCompletionRequest | None = None, ) -> Generator[DeltaMessage, None, None]: all_token_ids = ernie45_tokenizer.encode(model_output, add_special_tokens=False) previous_text = "" previous_tokens = None prefix_offset = 0 read_offset = 0 for i, delta_token in enumerate(all_token_ids): delta_token_ids = [delta_token] previous_token_ids = all_token_ids[:i] current_token_ids = all_token_ids[: i + 1] (new_tokens, delta_text, new_prefix_offset, new_read_offset) = ( detokenize_incrementally( tokenizer=ernie45_tokenizer, all_input_ids=current_token_ids, prev_tokens=previous_tokens, prefix_offset=prefix_offset, read_offset=read_offset, skip_special_tokens=False, spaces_between_special_tokens=True, ) ) current_text = previous_text + delta_text delta_message = ernie45_tool_parser.extract_tool_calls_streaming( previous_text, current_text, delta_text, previous_token_ids, current_token_ids, delta_token_ids, request=request, ) if delta_message: yield delta_message previous_text = current_text previous_tokens = ( previous_tokens + new_tokens if previous_tokens else new_tokens ) prefix_offset = new_prefix_offset read_offset = new_read_offset @pytest.mark.parametrize( ids=[ "single_tool_call", "multiple_tool_calls", "tool_call_with_content_before", ], argnames=["model_output", "expected_tool_calls", "expected_content"], argvalues=[ ( """ {"name": "get_current_temperature", "arguments": {"location": "Beijing"}} """, [ ToolCall( function=FunctionCall( name="get_current_temperature", arguments=json.dumps( { "location": "Beijing", } ), ) ) ], None, ), ( """ {"name": "get_current_temperature", "arguments": {"location": "Beijing"}} {"name": "get_temperature_unit", "arguments": {"location": "Guangzhou", "unit": "c"}} """, [ ToolCall( function=FunctionCall( name="get_current_temperature", arguments=json.dumps( { "location": "Beijing", } ), ) ), ToolCall( function=FunctionCall( name="get_temperature_unit", arguments=json.dumps( { "location": "Guangzhou", "unit": "c", } ), ) ), ], None, ), ( """I need to call two tools to handle these two issues separately. {"name": "get_current_temperature", "arguments": {"location": "Beijing"}} {"name": "get_temperature_unit", "arguments": {"location": "Guangzhou", "unit": "c"}} """, [ ToolCall( function=FunctionCall( name="get_current_temperature", arguments=json.dumps( { "location": "Beijing", } ), ) ), ToolCall( function=FunctionCall( name="get_temperature_unit", arguments=json.dumps( { "location": "Guangzhou", "unit": "c", } ), ) ), ], "I need to call two tools to handle these two issues separately.\n", ), ], ) def test_extract_tool_calls_streaming_incremental( ernie45_tool_parser, ernie45_tokenizer, model_output, expected_tool_calls, expected_content, ): """Verify the Ernie45 Parser streaming behavior by verifying each chunk is as expected.""" # noqa: E501 request = ChatCompletionRequest(model=MODEL, messages=[], tools=[]) tool_calls_dict = {} for delta_message in stream_delta_message_generator( ernie45_tool_parser, ernie45_tokenizer, model_output, request ): if ( delta_message.role is None and delta_message.content is None and delta_message.reasoning is None and len(delta_message.tool_calls) == 0 ): continue tool_calls = delta_message.tool_calls for tool_call_chunk in tool_calls: index = tool_call_chunk.index if index not in tool_calls_dict: if tool_call_chunk.function.arguments is None: tool_call_chunk.function.arguments = "" tool_calls_dict[index] = tool_call_chunk else: tool_calls_dict[ index ].function.arguments += tool_call_chunk.function.arguments actual_tool_calls = list(tool_calls_dict.values()) assert len(actual_tool_calls) > 0 # check tool call format assert_tool_calls(actual_tool_calls, expected_tool_calls)