mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-24 01:34:40 +08:00
[Refactor] move tool parsing logic from protocol.py to the tool parser (#27383)
Co-authored-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
parent
e0ef8a2920
commit
3567816932
@ -9,10 +9,10 @@ import regex as re
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionToolsParam,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.tool_parsers.utils import get_json_schema_from_tools
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
@ -67,8 +67,9 @@ EXAMPLE_TOOLS = [
|
||||
def _compile_and_check(
|
||||
tools: list[ChatCompletionToolsParam], sample_output, should_match: bool
|
||||
):
|
||||
self = MagicMock(tool_choice="required", tools=tools)
|
||||
schema = ChatCompletionRequest._get_json_schema_from_tool(self)
|
||||
# self = MagicMock(tool_choice="required", tools=tools)
|
||||
# schema = ChatCompletionRequest._get_json_schema_from_tool(self)
|
||||
schema = get_json_schema_from_tools(tools=tools, tool_choice="required")
|
||||
assert isinstance(schema, dict)
|
||||
|
||||
# use build_regex_from_schema used in JSONLogitsProcessor to create Guide
|
||||
|
||||
@ -854,8 +854,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
self.structured_outputs = StructuredOutputsParams(**kwargs)
|
||||
|
||||
response_format = self.response_format
|
||||
json_schema_from_tool = self._get_json_schema_from_tool()
|
||||
if response_format is not None or json_schema_from_tool is not None:
|
||||
if response_format is not None:
|
||||
# If structured outputs wasn't already enabled,
|
||||
# we must enable it for these features to work
|
||||
if self.structured_outputs is None:
|
||||
@ -881,10 +880,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
s_tag_obj = structural_tag.model_dump(by_alias=True)
|
||||
self.structured_outputs.structural_tag = json.dumps(s_tag_obj)
|
||||
|
||||
# Set structured output params for tool calling
|
||||
if json_schema_from_tool is not None:
|
||||
self.structured_outputs.json = json_schema_from_tool
|
||||
|
||||
extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
|
||||
if self.kv_transfer_params:
|
||||
# Pass in kv_transfer_params via extra_args
|
||||
@ -924,72 +919,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
extra_args=extra_args or None,
|
||||
)
|
||||
|
||||
def _get_json_schema_from_tool(self) -> str | dict | None:
|
||||
# user has chosen to not use any tool
|
||||
if self.tool_choice == "none" or self.tools is None:
|
||||
return None
|
||||
|
||||
# user has chosen to use a named tool
|
||||
if type(self.tool_choice) is ChatCompletionNamedToolChoiceParam:
|
||||
tool_name = self.tool_choice.function.name
|
||||
tools = {tool.function.name: tool.function for tool in self.tools}
|
||||
if tool_name not in tools:
|
||||
raise ValueError(f"Tool '{tool_name}' has not been passed in `tools`.")
|
||||
tool = tools[tool_name]
|
||||
return tool.parameters
|
||||
|
||||
if self.tool_choice == "required":
|
||||
# Pydantic schema generation cannot be used since the JSON schema
|
||||
# has to be constructed for a specific instantiation of a tool list
|
||||
# so that parameters of a function are correctly generated
|
||||
# based on the chosen function name
|
||||
def get_tool_schema(tool: ChatCompletionToolsParam) -> dict:
|
||||
return {
|
||||
"properties": {
|
||||
"name": {"type": "string", "enum": [tool.function.name]},
|
||||
# parameters are always generated as '{}' in the final
|
||||
# output if they are missing from the request
|
||||
# (i.e. are None or '{}') so the schema is
|
||||
# updated to produce an empty object in that case
|
||||
"parameters": tool.function.parameters
|
||||
if tool.function.parameters
|
||||
else {"type": "object", "properties": {}},
|
||||
},
|
||||
"required": ["name", "parameters"],
|
||||
}
|
||||
|
||||
def get_tool_schema_defs(tools: list[ChatCompletionToolsParam]) -> dict:
|
||||
all_defs = dict[str, dict[str, Any]]()
|
||||
for tool in tools:
|
||||
if tool.function.parameters is None:
|
||||
continue
|
||||
defs = tool.function.parameters.pop("$defs", {})
|
||||
for def_name, def_schema in defs.items():
|
||||
if def_name in all_defs and all_defs[def_name] != def_schema:
|
||||
raise ValueError(
|
||||
f"Tool definition '{def_name}' has "
|
||||
"multiple schemas, which is not "
|
||||
"supported."
|
||||
)
|
||||
else:
|
||||
all_defs[def_name] = def_schema
|
||||
return all_defs
|
||||
|
||||
json_schema = {
|
||||
"type": "array",
|
||||
"minItems": 1,
|
||||
"items": {
|
||||
"type": "object",
|
||||
"anyOf": [get_tool_schema(tool) for tool in self.tools],
|
||||
},
|
||||
}
|
||||
json_schema_defs = get_tool_schema_defs(self.tools)
|
||||
if json_schema_defs:
|
||||
json_schema["$defs"] = json_schema_defs
|
||||
return json_schema
|
||||
|
||||
return None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_stream_options(cls, data):
|
||||
|
||||
@ -10,7 +10,11 @@ from vllm.entrypoints.openai.protocol import (
|
||||
DeltaMessage,
|
||||
ExtractedToolCallInformation,
|
||||
)
|
||||
from vllm.entrypoints.openai.tool_parsers.utils import get_json_schema_from_tools
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sampling_params import (
|
||||
StructuredOutputsParams,
|
||||
)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils.collection_utils import is_list_of
|
||||
from vllm.utils.import_utils import import_from_path
|
||||
@ -44,6 +48,18 @@ class ToolParser:
|
||||
"""
|
||||
Static method that used to adjust the request parameters.
|
||||
"""
|
||||
if not request.tools:
|
||||
return request
|
||||
json_schema_from_tool = get_json_schema_from_tools(
|
||||
tool_choice=request.tool_choice, tools=request.tools
|
||||
)
|
||||
# Set structured output params for tool calling
|
||||
if json_schema_from_tool is not None:
|
||||
if request.structured_outputs is None:
|
||||
request.structured_outputs = StructuredOutputsParams()
|
||||
# tool_choice: "Forced Function" or "required" will override
|
||||
# structured output json settings to make tool calling work correctly
|
||||
request.structured_outputs.json = json_schema_from_tool
|
||||
return request
|
||||
|
||||
def extract_tool_calls(
|
||||
|
||||
@ -112,6 +112,7 @@ class Hermes2ProToolParser(ToolParser):
|
||||
return delta_text
|
||||
|
||||
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
|
||||
request = super().adjust_request(request)
|
||||
if request.tools and request.tool_choice != "none":
|
||||
# do not skip special tokens because the tool_call tokens are
|
||||
# marked "special" in some models. Since they are skipped
|
||||
|
||||
@ -35,6 +35,7 @@ class Internlm2ToolParser(ToolParser):
|
||||
self.position = 0
|
||||
|
||||
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
|
||||
request = super().adjust_request(request)
|
||||
if request.tools and request.tool_choice != "none":
|
||||
# do not skip special tokens because internlm use the special
|
||||
# tokens to indicate the start and end of the tool calls
|
||||
|
||||
@ -68,6 +68,7 @@ class JambaToolParser(ToolParser):
|
||||
)
|
||||
|
||||
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
|
||||
request = super().adjust_request(request)
|
||||
if request.tools and request.tool_choice != "none":
|
||||
# do not skip special tokens because jamba use the special
|
||||
# tokens to indicate the start and end of the tool calls
|
||||
|
||||
@ -94,6 +94,7 @@ class MistralToolParser(ToolParser):
|
||||
)
|
||||
|
||||
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
|
||||
request = super().adjust_request(request)
|
||||
if (
|
||||
not isinstance(self.model_tokenizer, MistralTokenizer)
|
||||
and request.tools
|
||||
|
||||
@ -51,6 +51,7 @@ class Step3ToolParser(ToolParser):
|
||||
self.tool_block_finished = False
|
||||
|
||||
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
|
||||
request = super().adjust_request(request)
|
||||
if request.tools and request.tool_choice != "none":
|
||||
request.skip_special_tokens = False
|
||||
return request
|
||||
|
||||
@ -6,8 +6,18 @@ from json import JSONDecodeError, JSONDecoder
|
||||
from typing import Any
|
||||
|
||||
import partial_json_parser
|
||||
from openai.types.responses import (
|
||||
FunctionTool,
|
||||
ToolChoiceFunction,
|
||||
)
|
||||
from openai.types.responses.tool import Tool
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionNamedToolChoiceParam,
|
||||
ChatCompletionToolsParam,
|
||||
)
|
||||
|
||||
|
||||
def find_common_prefix(s1: str, s2: str) -> str:
|
||||
"""
|
||||
@ -122,3 +132,98 @@ def consume_space(i: int, s: str) -> int:
|
||||
while i < len(s) and s[i].isspace():
|
||||
i += 1
|
||||
return i
|
||||
|
||||
|
||||
def _extract_tool_info(
|
||||
tool: Tool | ChatCompletionToolsParam,
|
||||
) -> tuple[str, dict[str, Any] | None]:
|
||||
if isinstance(tool, FunctionTool):
|
||||
return tool.name, tool.parameters
|
||||
elif isinstance(tool, ChatCompletionToolsParam):
|
||||
return tool.function.name, tool.function.parameters
|
||||
else:
|
||||
raise TypeError(f"Unsupported tool type: {type(tool)}")
|
||||
|
||||
|
||||
def _get_tool_schema_from_tool(tool: Tool | ChatCompletionToolsParam) -> dict:
|
||||
name, params = _extract_tool_info(tool)
|
||||
params = params if params else {"type": "object", "properties": {}}
|
||||
return {
|
||||
"properties": {
|
||||
"name": {"type": "string", "enum": [name]},
|
||||
"parameters": params,
|
||||
},
|
||||
"required": ["name", "parameters"],
|
||||
}
|
||||
|
||||
|
||||
def _get_tool_schema_defs(
|
||||
tools: list[Tool | ChatCompletionToolsParam],
|
||||
) -> dict:
|
||||
all_defs: dict[str, dict[str, Any]] = {}
|
||||
for tool in tools:
|
||||
_, params = _extract_tool_info(tool)
|
||||
if params is None:
|
||||
continue
|
||||
defs = params.pop("$defs", {})
|
||||
for def_name, def_schema in defs.items():
|
||||
if def_name in all_defs and all_defs[def_name] != def_schema:
|
||||
raise ValueError(
|
||||
f"Tool definition '{def_name}' has multiple schemas, "
|
||||
"which is not supported."
|
||||
)
|
||||
all_defs[def_name] = def_schema
|
||||
return all_defs
|
||||
|
||||
|
||||
def _get_json_schema_from_tools(
|
||||
tools: list[Tool | ChatCompletionToolsParam],
|
||||
) -> dict:
|
||||
json_schema = {
|
||||
"type": "array",
|
||||
"minItems": 1,
|
||||
"items": {
|
||||
"type": "object",
|
||||
"anyOf": [_get_tool_schema_from_tool(tool) for tool in tools],
|
||||
},
|
||||
}
|
||||
json_schema_defs = _get_tool_schema_defs(tools)
|
||||
if json_schema_defs:
|
||||
json_schema["$defs"] = json_schema_defs
|
||||
return json_schema
|
||||
|
||||
|
||||
def get_json_schema_from_tools(
|
||||
tool_choice: str | ToolChoiceFunction | ChatCompletionNamedToolChoiceParam,
|
||||
tools: list[FunctionTool | ChatCompletionToolsParam] | None,
|
||||
) -> str | dict | None:
|
||||
# tool_choice: "none"
|
||||
if tool_choice in ("none", None) or tools is None:
|
||||
return None
|
||||
# tool_choice: Forced Function (Responses)
|
||||
if (not isinstance(tool_choice, str)) and isinstance(
|
||||
tool_choice, ToolChoiceFunction
|
||||
):
|
||||
tool_name = tool_choice.name
|
||||
tool_map = {tool.name: tool for tool in tools if isinstance(tool, FunctionTool)}
|
||||
if tool_name not in tool_map:
|
||||
raise ValueError(f"Tool '{tool_name}' has not been passed in `tools`.")
|
||||
return tool_map[tool_name].parameters
|
||||
# tool_choice: Forced Function (ChatCompletion)
|
||||
if (not isinstance(tool_choice, str)) and isinstance(
|
||||
tool_choice, ChatCompletionNamedToolChoiceParam
|
||||
):
|
||||
tool_name = tool_choice.function.name
|
||||
tool_map = {
|
||||
tool.function.name: tool
|
||||
for tool in tools
|
||||
if isinstance(tool, ChatCompletionToolsParam)
|
||||
}
|
||||
if tool_name not in tool_map:
|
||||
raise ValueError(f"Tool '{tool_name}' has not been passed in `tools`.")
|
||||
return tool_map[tool_name].function.parameters
|
||||
# tool_choice: "required"
|
||||
if tool_choice == "required":
|
||||
return _get_json_schema_from_tools(tools)
|
||||
# tool_choice: "auto"
|
||||
return None
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user