mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-17 19:55:01 +08:00
230 lines
7.3 KiB
Python
230 lines
7.3 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import json
|
|
from json import JSONDecodeError, JSONDecoder
|
|
from typing import Any
|
|
|
|
import partial_json_parser
|
|
from openai.types.responses import (
|
|
FunctionTool,
|
|
ToolChoiceFunction,
|
|
)
|
|
from openai.types.responses.tool import Tool
|
|
from partial_json_parser.core.options import Allow
|
|
|
|
from vllm.entrypoints.openai.protocol import (
|
|
ChatCompletionNamedToolChoiceParam,
|
|
ChatCompletionToolsParam,
|
|
)
|
|
|
|
|
|
def find_common_prefix(s1: str, s2: str) -> str:
|
|
"""
|
|
Finds a common prefix that is shared between two strings, if there is one.
|
|
Order of arguments is NOT important.
|
|
|
|
This function is provided as a UTILITY for extracting information from JSON
|
|
generated by partial_json_parser, to help in ensuring that the right tokens
|
|
are returned in streaming, so that close-quotes, close-brackets and
|
|
close-braces are not returned prematurely.
|
|
|
|
e.g. find_common_prefix('{"fruit": "ap"}', '{"fruit": "apple"}') ->
|
|
'{"fruit": "ap'
|
|
"""
|
|
prefix = ""
|
|
min_length = min(len(s1), len(s2))
|
|
for i in range(0, min_length):
|
|
if s1[i] == s2[i]:
|
|
prefix += s1[i]
|
|
else:
|
|
break
|
|
return prefix
|
|
|
|
|
|
def find_common_suffix(s1: str, s2: str) -> str:
|
|
"""
|
|
Finds a common suffix shared between two strings, if there is one. Order of
|
|
arguments is NOT important.
|
|
Stops when the suffix ends OR it hits an alphanumeric character
|
|
|
|
e.g. find_common_suffix('{"fruit": "ap"}', '{"fruit": "apple"}') -> '"}'
|
|
"""
|
|
suffix = ""
|
|
min_length = min(len(s1), len(s2))
|
|
for i in range(1, min_length + 1):
|
|
if s1[-i] == s2[-i] and not s1[-i].isalnum():
|
|
suffix = s1[-i] + suffix
|
|
else:
|
|
break
|
|
return suffix
|
|
|
|
|
|
def extract_intermediate_diff(curr: str, old: str) -> str:
|
|
"""
|
|
Given two strings, extract the difference in the middle between two strings
|
|
that are known to have a common prefix and/or suffix.
|
|
|
|
This function is provided as a UTILITY for extracting information from JSON
|
|
generated by partial_json_parser, to help in ensuring that the right tokens
|
|
are returned in streaming, so that close-quotes, close-brackets and
|
|
close-braces are not returned prematurely. The order of arguments IS
|
|
important - the new version of the partially-parsed JSON must be the first
|
|
argument, and the secnod argument must be from the previous generation.
|
|
|
|
What it returns, is tokens that should be streamed to the client.
|
|
|
|
e.g. extract_intermediate_diff('{"fruit": "apple"}', '{"fruit": "ap"}')
|
|
-> 'ple'
|
|
|
|
"""
|
|
suffix = find_common_suffix(curr, old)
|
|
|
|
old = old[::-1].replace(suffix[::-1], "", 1)[::-1]
|
|
prefix = find_common_prefix(curr, old)
|
|
diff = curr
|
|
if len(suffix):
|
|
diff = diff[::-1].replace(suffix[::-1], "", 1)[::-1]
|
|
|
|
if len(prefix):
|
|
# replace the prefix only once in case it's mirrored
|
|
diff = diff.replace(prefix, "", 1)
|
|
|
|
return diff
|
|
|
|
|
|
def find_all_indices(string: str, substring: str) -> list[int]:
|
|
"""
|
|
Find all (starting) indices of a substring in a given string. Useful for
|
|
tool call extraction
|
|
"""
|
|
indices = []
|
|
index = -1
|
|
while True:
|
|
index = string.find(substring, index + 1)
|
|
if index == -1:
|
|
break
|
|
indices.append(index)
|
|
return indices
|
|
|
|
|
|
# partial_json_parser doesn't support extra data and
|
|
# JSONDecoder.raw_decode doesn't support partial JSON
|
|
def partial_json_loads(input_str: str, flags: Allow) -> tuple[Any, int]:
|
|
try:
|
|
return (partial_json_parser.loads(input_str, flags), len(input_str))
|
|
except JSONDecodeError as e:
|
|
if "Extra data" in e.msg:
|
|
dec = JSONDecoder()
|
|
return dec.raw_decode(input_str)
|
|
raise
|
|
|
|
|
|
def is_complete_json(input_str: str) -> bool:
|
|
try:
|
|
json.loads(input_str)
|
|
return True
|
|
except JSONDecodeError:
|
|
return False
|
|
|
|
|
|
def consume_space(i: int, s: str) -> int:
|
|
while i < len(s) and s[i].isspace():
|
|
i += 1
|
|
return i
|
|
|
|
|
|
def _extract_tool_info(
|
|
tool: Tool | ChatCompletionToolsParam,
|
|
) -> tuple[str, dict[str, Any] | None]:
|
|
if isinstance(tool, FunctionTool):
|
|
return tool.name, tool.parameters
|
|
elif isinstance(tool, ChatCompletionToolsParam):
|
|
return tool.function.name, tool.function.parameters
|
|
else:
|
|
raise TypeError(f"Unsupported tool type: {type(tool)}")
|
|
|
|
|
|
def _get_tool_schema_from_tool(tool: Tool | ChatCompletionToolsParam) -> dict:
|
|
name, params = _extract_tool_info(tool)
|
|
params = params if params else {"type": "object", "properties": {}}
|
|
return {
|
|
"properties": {
|
|
"name": {"type": "string", "enum": [name]},
|
|
"parameters": params,
|
|
},
|
|
"required": ["name", "parameters"],
|
|
}
|
|
|
|
|
|
def _get_tool_schema_defs(
|
|
tools: list[Tool | ChatCompletionToolsParam],
|
|
) -> dict:
|
|
all_defs: dict[str, dict[str, Any]] = {}
|
|
for tool in tools:
|
|
_, params = _extract_tool_info(tool)
|
|
if params is None:
|
|
continue
|
|
defs = params.pop("$defs", {})
|
|
for def_name, def_schema in defs.items():
|
|
if def_name in all_defs and all_defs[def_name] != def_schema:
|
|
raise ValueError(
|
|
f"Tool definition '{def_name}' has multiple schemas, "
|
|
"which is not supported."
|
|
)
|
|
all_defs[def_name] = def_schema
|
|
return all_defs
|
|
|
|
|
|
def _get_json_schema_from_tools(
|
|
tools: list[Tool | ChatCompletionToolsParam],
|
|
) -> dict:
|
|
json_schema = {
|
|
"type": "array",
|
|
"minItems": 1,
|
|
"items": {
|
|
"type": "object",
|
|
"anyOf": [_get_tool_schema_from_tool(tool) for tool in tools],
|
|
},
|
|
}
|
|
json_schema_defs = _get_tool_schema_defs(tools)
|
|
if json_schema_defs:
|
|
json_schema["$defs"] = json_schema_defs
|
|
return json_schema
|
|
|
|
|
|
def get_json_schema_from_tools(
|
|
tool_choice: str | ToolChoiceFunction | ChatCompletionNamedToolChoiceParam,
|
|
tools: list[FunctionTool | ChatCompletionToolsParam] | None,
|
|
) -> str | dict | None:
|
|
# tool_choice: "none"
|
|
if tool_choice in ("none", None) or tools is None:
|
|
return None
|
|
# tool_choice: Forced Function (Responses)
|
|
if (not isinstance(tool_choice, str)) and isinstance(
|
|
tool_choice, ToolChoiceFunction
|
|
):
|
|
tool_name = tool_choice.name
|
|
tool_map = {tool.name: tool for tool in tools if isinstance(tool, FunctionTool)}
|
|
if tool_name not in tool_map:
|
|
raise ValueError(f"Tool '{tool_name}' has not been passed in `tools`.")
|
|
return tool_map[tool_name].parameters
|
|
# tool_choice: Forced Function (ChatCompletion)
|
|
if (not isinstance(tool_choice, str)) and isinstance(
|
|
tool_choice, ChatCompletionNamedToolChoiceParam
|
|
):
|
|
tool_name = tool_choice.function.name
|
|
tool_map = {
|
|
tool.function.name: tool
|
|
for tool in tools
|
|
if isinstance(tool, ChatCompletionToolsParam)
|
|
}
|
|
if tool_name not in tool_map:
|
|
raise ValueError(f"Tool '{tool_name}' has not been passed in `tools`.")
|
|
return tool_map[tool_name].function.parameters
|
|
# tool_choice: "required"
|
|
if tool_choice == "required":
|
|
return _get_json_schema_from_tools(tools)
|
|
# tool_choice: "auto"
|
|
return None
|