[Refactor] move tool parsing logic from protocol.py to the tool parser (#27383)

Co-authored-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
Chauncey 2025-10-24 17:53:23 +08:00 committed by GitHub
parent e0ef8a2920
commit 3567816932
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 131 additions and 75 deletions

View File

@ -9,10 +9,10 @@ import regex as re
from pydantic import TypeAdapter
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
ChatCompletionToolsParam,
)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.tool_parsers.utils import get_json_schema_from_tools
pytestmark = pytest.mark.cpu_test
@ -67,8 +67,9 @@ EXAMPLE_TOOLS = [
def _compile_and_check(
tools: list[ChatCompletionToolsParam], sample_output, should_match: bool
):
self = MagicMock(tool_choice="required", tools=tools)
schema = ChatCompletionRequest._get_json_schema_from_tool(self)
# self = MagicMock(tool_choice="required", tools=tools)
# schema = ChatCompletionRequest._get_json_schema_from_tool(self)
schema = get_json_schema_from_tools(tools=tools, tool_choice="required")
assert isinstance(schema, dict)
# use build_regex_from_schema used in JSONLogitsProcessor to create Guide

View File

@ -854,8 +854,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
self.structured_outputs = StructuredOutputsParams(**kwargs)
response_format = self.response_format
json_schema_from_tool = self._get_json_schema_from_tool()
if response_format is not None or json_schema_from_tool is not None:
if response_format is not None:
# If structured outputs wasn't already enabled,
# we must enable it for these features to work
if self.structured_outputs is None:
@ -881,10 +880,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
s_tag_obj = structural_tag.model_dump(by_alias=True)
self.structured_outputs.structural_tag = json.dumps(s_tag_obj)
# Set structured output params for tool calling
if json_schema_from_tool is not None:
self.structured_outputs.json = json_schema_from_tool
extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
if self.kv_transfer_params:
# Pass in kv_transfer_params via extra_args
@ -924,72 +919,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
extra_args=extra_args or None,
)
def _get_json_schema_from_tool(self) -> str | dict | None:
# user has chosen to not use any tool
if self.tool_choice == "none" or self.tools is None:
return None
# user has chosen to use a named tool
if type(self.tool_choice) is ChatCompletionNamedToolChoiceParam:
tool_name = self.tool_choice.function.name
tools = {tool.function.name: tool.function for tool in self.tools}
if tool_name not in tools:
raise ValueError(f"Tool '{tool_name}' has not been passed in `tools`.")
tool = tools[tool_name]
return tool.parameters
if self.tool_choice == "required":
# Pydantic schema generation cannot be used since the JSON schema
# has to be constructed for a specific instantiation of a tool list
# so that parameters of a function are correctly generated
# based on the chosen function name
def get_tool_schema(tool: ChatCompletionToolsParam) -> dict:
return {
"properties": {
"name": {"type": "string", "enum": [tool.function.name]},
# parameters are always generated as '{}' in the final
# output if they are missing from the request
# (i.e. are None or '{}') so the schema is
# updated to produce an empty object in that case
"parameters": tool.function.parameters
if tool.function.parameters
else {"type": "object", "properties": {}},
},
"required": ["name", "parameters"],
}
def get_tool_schema_defs(tools: list[ChatCompletionToolsParam]) -> dict:
all_defs = dict[str, dict[str, Any]]()
for tool in tools:
if tool.function.parameters is None:
continue
defs = tool.function.parameters.pop("$defs", {})
for def_name, def_schema in defs.items():
if def_name in all_defs and all_defs[def_name] != def_schema:
raise ValueError(
f"Tool definition '{def_name}' has "
"multiple schemas, which is not "
"supported."
)
else:
all_defs[def_name] = def_schema
return all_defs
json_schema = {
"type": "array",
"minItems": 1,
"items": {
"type": "object",
"anyOf": [get_tool_schema(tool) for tool in self.tools],
},
}
json_schema_defs = get_tool_schema_defs(self.tools)
if json_schema_defs:
json_schema["$defs"] = json_schema_defs
return json_schema
return None
@model_validator(mode="before")
@classmethod
def validate_stream_options(cls, data):

View File

@ -10,7 +10,11 @@ from vllm.entrypoints.openai.protocol import (
DeltaMessage,
ExtractedToolCallInformation,
)
from vllm.entrypoints.openai.tool_parsers.utils import get_json_schema_from_tools
from vllm.logger import init_logger
from vllm.sampling_params import (
StructuredOutputsParams,
)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils.collection_utils import is_list_of
from vllm.utils.import_utils import import_from_path
@ -44,6 +48,18 @@ class ToolParser:
"""
Static method that used to adjust the request parameters.
"""
if not request.tools:
return request
json_schema_from_tool = get_json_schema_from_tools(
tool_choice=request.tool_choice, tools=request.tools
)
# Set structured output params for tool calling
if json_schema_from_tool is not None:
if request.structured_outputs is None:
request.structured_outputs = StructuredOutputsParams()
# tool_choice: "Forced Function" or "required" will override
# structured output json settings to make tool calling work correctly
request.structured_outputs.json = json_schema_from_tool
return request
def extract_tool_calls(

View File

@ -112,6 +112,7 @@ class Hermes2ProToolParser(ToolParser):
return delta_text
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
request = super().adjust_request(request)
if request.tools and request.tool_choice != "none":
# do not skip special tokens because the tool_call tokens are
# marked "special" in some models. Since they are skipped

View File

@ -35,6 +35,7 @@ class Internlm2ToolParser(ToolParser):
self.position = 0
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
request = super().adjust_request(request)
if request.tools and request.tool_choice != "none":
# do not skip special tokens because internlm use the special
# tokens to indicate the start and end of the tool calls

View File

@ -68,6 +68,7 @@ class JambaToolParser(ToolParser):
)
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
request = super().adjust_request(request)
if request.tools and request.tool_choice != "none":
# do not skip special tokens because jamba use the special
# tokens to indicate the start and end of the tool calls

View File

@ -94,6 +94,7 @@ class MistralToolParser(ToolParser):
)
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
request = super().adjust_request(request)
if (
not isinstance(self.model_tokenizer, MistralTokenizer)
and request.tools

View File

@ -51,6 +51,7 @@ class Step3ToolParser(ToolParser):
self.tool_block_finished = False
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
request = super().adjust_request(request)
if request.tools and request.tool_choice != "none":
request.skip_special_tokens = False
return request

View File

@ -6,8 +6,18 @@ from json import JSONDecodeError, JSONDecoder
from typing import Any
import partial_json_parser
from openai.types.responses import (
FunctionTool,
ToolChoiceFunction,
)
from openai.types.responses.tool import Tool
from partial_json_parser.core.options import Allow
from vllm.entrypoints.openai.protocol import (
ChatCompletionNamedToolChoiceParam,
ChatCompletionToolsParam,
)
def find_common_prefix(s1: str, s2: str) -> str:
"""
@ -122,3 +132,98 @@ def consume_space(i: int, s: str) -> int:
while i < len(s) and s[i].isspace():
i += 1
return i
def _extract_tool_info(
tool: Tool | ChatCompletionToolsParam,
) -> tuple[str, dict[str, Any] | None]:
if isinstance(tool, FunctionTool):
return tool.name, tool.parameters
elif isinstance(tool, ChatCompletionToolsParam):
return tool.function.name, tool.function.parameters
else:
raise TypeError(f"Unsupported tool type: {type(tool)}")
def _get_tool_schema_from_tool(tool: Tool | ChatCompletionToolsParam) -> dict:
name, params = _extract_tool_info(tool)
params = params if params else {"type": "object", "properties": {}}
return {
"properties": {
"name": {"type": "string", "enum": [name]},
"parameters": params,
},
"required": ["name", "parameters"],
}
def _get_tool_schema_defs(
tools: list[Tool | ChatCompletionToolsParam],
) -> dict:
all_defs: dict[str, dict[str, Any]] = {}
for tool in tools:
_, params = _extract_tool_info(tool)
if params is None:
continue
defs = params.pop("$defs", {})
for def_name, def_schema in defs.items():
if def_name in all_defs and all_defs[def_name] != def_schema:
raise ValueError(
f"Tool definition '{def_name}' has multiple schemas, "
"which is not supported."
)
all_defs[def_name] = def_schema
return all_defs
def _get_json_schema_from_tools(
tools: list[Tool | ChatCompletionToolsParam],
) -> dict:
json_schema = {
"type": "array",
"minItems": 1,
"items": {
"type": "object",
"anyOf": [_get_tool_schema_from_tool(tool) for tool in tools],
},
}
json_schema_defs = _get_tool_schema_defs(tools)
if json_schema_defs:
json_schema["$defs"] = json_schema_defs
return json_schema
def get_json_schema_from_tools(
tool_choice: str | ToolChoiceFunction | ChatCompletionNamedToolChoiceParam,
tools: list[FunctionTool | ChatCompletionToolsParam] | None,
) -> str | dict | None:
# tool_choice: "none"
if tool_choice in ("none", None) or tools is None:
return None
# tool_choice: Forced Function (Responses)
if (not isinstance(tool_choice, str)) and isinstance(
tool_choice, ToolChoiceFunction
):
tool_name = tool_choice.name
tool_map = {tool.name: tool for tool in tools if isinstance(tool, FunctionTool)}
if tool_name not in tool_map:
raise ValueError(f"Tool '{tool_name}' has not been passed in `tools`.")
return tool_map[tool_name].parameters
# tool_choice: Forced Function (ChatCompletion)
if (not isinstance(tool_choice, str)) and isinstance(
tool_choice, ChatCompletionNamedToolChoiceParam
):
tool_name = tool_choice.function.name
tool_map = {
tool.function.name: tool
for tool in tools
if isinstance(tool, ChatCompletionToolsParam)
}
if tool_name not in tool_map:
raise ValueError(f"Tool '{tool_name}' has not been passed in `tools`.")
return tool_map[tool_name].function.parameters
# tool_choice: "required"
if tool_choice == "required":
return _get_json_schema_from_tools(tools)
# tool_choice: "auto"
return None