Chauncey 3567816932
[Refactor] move tool parsing logic from protocol.py to the tool parser (#27383)
Co-authored-by: Aaron Pham <contact@aarnphm.xyz>
2025-10-24 09:53:23 +00:00

230 lines
7.3 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from json import JSONDecodeError, JSONDecoder
from typing import Any
import partial_json_parser
from openai.types.responses import (
FunctionTool,
ToolChoiceFunction,
)
from openai.types.responses.tool import Tool
from partial_json_parser.core.options import Allow
from vllm.entrypoints.openai.protocol import (
ChatCompletionNamedToolChoiceParam,
ChatCompletionToolsParam,
)
def find_common_prefix(s1: str, s2: str) -> str:
"""
Finds a common prefix that is shared between two strings, if there is one.
Order of arguments is NOT important.
This function is provided as a UTILITY for extracting information from JSON
generated by partial_json_parser, to help in ensuring that the right tokens
are returned in streaming, so that close-quotes, close-brackets and
close-braces are not returned prematurely.
e.g. find_common_prefix('{"fruit": "ap"}', '{"fruit": "apple"}') ->
'{"fruit": "ap'
"""
prefix = ""
min_length = min(len(s1), len(s2))
for i in range(0, min_length):
if s1[i] == s2[i]:
prefix += s1[i]
else:
break
return prefix
def find_common_suffix(s1: str, s2: str) -> str:
"""
Finds a common suffix shared between two strings, if there is one. Order of
arguments is NOT important.
Stops when the suffix ends OR it hits an alphanumeric character
e.g. find_common_suffix('{"fruit": "ap"}', '{"fruit": "apple"}') -> '"}'
"""
suffix = ""
min_length = min(len(s1), len(s2))
for i in range(1, min_length + 1):
if s1[-i] == s2[-i] and not s1[-i].isalnum():
suffix = s1[-i] + suffix
else:
break
return suffix
def extract_intermediate_diff(curr: str, old: str) -> str:
"""
Given two strings, extract the difference in the middle between two strings
that are known to have a common prefix and/or suffix.
This function is provided as a UTILITY for extracting information from JSON
generated by partial_json_parser, to help in ensuring that the right tokens
are returned in streaming, so that close-quotes, close-brackets and
close-braces are not returned prematurely. The order of arguments IS
important - the new version of the partially-parsed JSON must be the first
argument, and the secnod argument must be from the previous generation.
What it returns, is tokens that should be streamed to the client.
e.g. extract_intermediate_diff('{"fruit": "apple"}', '{"fruit": "ap"}')
-> 'ple'
"""
suffix = find_common_suffix(curr, old)
old = old[::-1].replace(suffix[::-1], "", 1)[::-1]
prefix = find_common_prefix(curr, old)
diff = curr
if len(suffix):
diff = diff[::-1].replace(suffix[::-1], "", 1)[::-1]
if len(prefix):
# replace the prefix only once in case it's mirrored
diff = diff.replace(prefix, "", 1)
return diff
def find_all_indices(string: str, substring: str) -> list[int]:
"""
Find all (starting) indices of a substring in a given string. Useful for
tool call extraction
"""
indices = []
index = -1
while True:
index = string.find(substring, index + 1)
if index == -1:
break
indices.append(index)
return indices
# partial_json_parser doesn't support extra data and
# JSONDecoder.raw_decode doesn't support partial JSON
def partial_json_loads(input_str: str, flags: Allow) -> tuple[Any, int]:
try:
return (partial_json_parser.loads(input_str, flags), len(input_str))
except JSONDecodeError as e:
if "Extra data" in e.msg:
dec = JSONDecoder()
return dec.raw_decode(input_str)
raise
def is_complete_json(input_str: str) -> bool:
try:
json.loads(input_str)
return True
except JSONDecodeError:
return False
def consume_space(i: int, s: str) -> int:
while i < len(s) and s[i].isspace():
i += 1
return i
def _extract_tool_info(
tool: Tool | ChatCompletionToolsParam,
) -> tuple[str, dict[str, Any] | None]:
if isinstance(tool, FunctionTool):
return tool.name, tool.parameters
elif isinstance(tool, ChatCompletionToolsParam):
return tool.function.name, tool.function.parameters
else:
raise TypeError(f"Unsupported tool type: {type(tool)}")
def _get_tool_schema_from_tool(tool: Tool | ChatCompletionToolsParam) -> dict:
name, params = _extract_tool_info(tool)
params = params if params else {"type": "object", "properties": {}}
return {
"properties": {
"name": {"type": "string", "enum": [name]},
"parameters": params,
},
"required": ["name", "parameters"],
}
def _get_tool_schema_defs(
tools: list[Tool | ChatCompletionToolsParam],
) -> dict:
all_defs: dict[str, dict[str, Any]] = {}
for tool in tools:
_, params = _extract_tool_info(tool)
if params is None:
continue
defs = params.pop("$defs", {})
for def_name, def_schema in defs.items():
if def_name in all_defs and all_defs[def_name] != def_schema:
raise ValueError(
f"Tool definition '{def_name}' has multiple schemas, "
"which is not supported."
)
all_defs[def_name] = def_schema
return all_defs
def _get_json_schema_from_tools(
tools: list[Tool | ChatCompletionToolsParam],
) -> dict:
json_schema = {
"type": "array",
"minItems": 1,
"items": {
"type": "object",
"anyOf": [_get_tool_schema_from_tool(tool) for tool in tools],
},
}
json_schema_defs = _get_tool_schema_defs(tools)
if json_schema_defs:
json_schema["$defs"] = json_schema_defs
return json_schema
def get_json_schema_from_tools(
tool_choice: str | ToolChoiceFunction | ChatCompletionNamedToolChoiceParam,
tools: list[FunctionTool | ChatCompletionToolsParam] | None,
) -> str | dict | None:
# tool_choice: "none"
if tool_choice in ("none", None) or tools is None:
return None
# tool_choice: Forced Function (Responses)
if (not isinstance(tool_choice, str)) and isinstance(
tool_choice, ToolChoiceFunction
):
tool_name = tool_choice.name
tool_map = {tool.name: tool for tool in tools if isinstance(tool, FunctionTool)}
if tool_name not in tool_map:
raise ValueError(f"Tool '{tool_name}' has not been passed in `tools`.")
return tool_map[tool_name].parameters
# tool_choice: Forced Function (ChatCompletion)
if (not isinstance(tool_choice, str)) and isinstance(
tool_choice, ChatCompletionNamedToolChoiceParam
):
tool_name = tool_choice.function.name
tool_map = {
tool.function.name: tool
for tool in tools
if isinstance(tool, ChatCompletionToolsParam)
}
if tool_name not in tool_map:
raise ValueError(f"Tool '{tool_name}' has not been passed in `tools`.")
return tool_map[tool_name].function.parameters
# tool_choice: "required"
if tool_choice == "required":
return _get_json_schema_from_tools(tools)
# tool_choice: "auto"
return None