From 3567816932e674abce3f44ceb0aff03f73b5aaff Mon Sep 17 00:00:00 2001 From: Chauncey Date: Fri, 24 Oct 2025 17:53:23 +0800 Subject: [PATCH] [Refactor] move tool parsing logic from protocol.py to the tool parser (#27383) Co-authored-by: Aaron Pham --- tests/tool_use/test_tool_choice_required.py | 7 +- vllm/entrypoints/openai/protocol.py | 73 +----------- .../tool_parsers/abstract_tool_parser.py | 16 +++ .../openai/tool_parsers/hermes_tool_parser.py | 1 + .../tool_parsers/internlm2_tool_parser.py | 1 + .../openai/tool_parsers/jamba_tool_parser.py | 1 + .../tool_parsers/mistral_tool_parser.py | 1 + .../openai/tool_parsers/step3_tool_parser.py | 1 + vllm/entrypoints/openai/tool_parsers/utils.py | 105 ++++++++++++++++++ 9 files changed, 131 insertions(+), 75 deletions(-) diff --git a/tests/tool_use/test_tool_choice_required.py b/tests/tool_use/test_tool_choice_required.py index d52c141f6210d..d5572cfbebe3c 100644 --- a/tests/tool_use/test_tool_choice_required.py +++ b/tests/tool_use/test_tool_choice_required.py @@ -9,10 +9,10 @@ import regex as re from pydantic import TypeAdapter from vllm.entrypoints.openai.protocol import ( - ChatCompletionRequest, ChatCompletionToolsParam, ) from vllm.entrypoints.openai.serving_chat import OpenAIServingChat +from vllm.entrypoints.openai.tool_parsers.utils import get_json_schema_from_tools pytestmark = pytest.mark.cpu_test @@ -67,8 +67,9 @@ EXAMPLE_TOOLS = [ def _compile_and_check( tools: list[ChatCompletionToolsParam], sample_output, should_match: bool ): - self = MagicMock(tool_choice="required", tools=tools) - schema = ChatCompletionRequest._get_json_schema_from_tool(self) + # self = MagicMock(tool_choice="required", tools=tools) + # schema = ChatCompletionRequest._get_json_schema_from_tool(self) + schema = get_json_schema_from_tools(tools=tools, tool_choice="required") assert isinstance(schema, dict) # use build_regex_from_schema used in JSONLogitsProcessor to create Guide diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 7d32d5b23f1e0..9782641296d62 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -854,8 +854,7 @@ class ChatCompletionRequest(OpenAIBaseModel): self.structured_outputs = StructuredOutputsParams(**kwargs) response_format = self.response_format - json_schema_from_tool = self._get_json_schema_from_tool() - if response_format is not None or json_schema_from_tool is not None: + if response_format is not None: # If structured outputs wasn't already enabled, # we must enable it for these features to work if self.structured_outputs is None: @@ -881,10 +880,6 @@ class ChatCompletionRequest(OpenAIBaseModel): s_tag_obj = structural_tag.model_dump(by_alias=True) self.structured_outputs.structural_tag = json.dumps(s_tag_obj) - # Set structured output params for tool calling - if json_schema_from_tool is not None: - self.structured_outputs.json = json_schema_from_tool - extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {} if self.kv_transfer_params: # Pass in kv_transfer_params via extra_args @@ -924,72 +919,6 @@ class ChatCompletionRequest(OpenAIBaseModel): extra_args=extra_args or None, ) - def _get_json_schema_from_tool(self) -> str | dict | None: - # user has chosen to not use any tool - if self.tool_choice == "none" or self.tools is None: - return None - - # user has chosen to use a named tool - if type(self.tool_choice) is ChatCompletionNamedToolChoiceParam: - tool_name = self.tool_choice.function.name - tools = {tool.function.name: tool.function for tool in self.tools} - if tool_name not in tools: - raise ValueError(f"Tool '{tool_name}' has not been passed in `tools`.") - tool = tools[tool_name] - return tool.parameters - - if self.tool_choice == "required": - # Pydantic schema generation cannot be used since the JSON schema - # has to be constructed for a specific instantiation of a tool list - # so that parameters of a function are correctly generated - # based on the chosen function name - def get_tool_schema(tool: ChatCompletionToolsParam) -> dict: - return { - "properties": { - "name": {"type": "string", "enum": [tool.function.name]}, - # parameters are always generated as '{}' in the final - # output if they are missing from the request - # (i.e. are None or '{}') so the schema is - # updated to produce an empty object in that case - "parameters": tool.function.parameters - if tool.function.parameters - else {"type": "object", "properties": {}}, - }, - "required": ["name", "parameters"], - } - - def get_tool_schema_defs(tools: list[ChatCompletionToolsParam]) -> dict: - all_defs = dict[str, dict[str, Any]]() - for tool in tools: - if tool.function.parameters is None: - continue - defs = tool.function.parameters.pop("$defs", {}) - for def_name, def_schema in defs.items(): - if def_name in all_defs and all_defs[def_name] != def_schema: - raise ValueError( - f"Tool definition '{def_name}' has " - "multiple schemas, which is not " - "supported." - ) - else: - all_defs[def_name] = def_schema - return all_defs - - json_schema = { - "type": "array", - "minItems": 1, - "items": { - "type": "object", - "anyOf": [get_tool_schema(tool) for tool in self.tools], - }, - } - json_schema_defs = get_tool_schema_defs(self.tools) - if json_schema_defs: - json_schema["$defs"] = json_schema_defs - return json_schema - - return None - @model_validator(mode="before") @classmethod def validate_stream_options(cls, data): diff --git a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py index 4733288644680..212326fdafb1e 100644 --- a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py @@ -10,7 +10,11 @@ from vllm.entrypoints.openai.protocol import ( DeltaMessage, ExtractedToolCallInformation, ) +from vllm.entrypoints.openai.tool_parsers.utils import get_json_schema_from_tools from vllm.logger import init_logger +from vllm.sampling_params import ( + StructuredOutputsParams, +) from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils.collection_utils import is_list_of from vllm.utils.import_utils import import_from_path @@ -44,6 +48,18 @@ class ToolParser: """ Static method that used to adjust the request parameters. """ + if not request.tools: + return request + json_schema_from_tool = get_json_schema_from_tools( + tool_choice=request.tool_choice, tools=request.tools + ) + # Set structured output params for tool calling + if json_schema_from_tool is not None: + if request.structured_outputs is None: + request.structured_outputs = StructuredOutputsParams() + # tool_choice: "Forced Function" or "required" will override + # structured output json settings to make tool calling work correctly + request.structured_outputs.json = json_schema_from_tool return request def extract_tool_calls( diff --git a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py index ca3239e94377f..6332de42f424e 100644 --- a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py @@ -112,6 +112,7 @@ class Hermes2ProToolParser(ToolParser): return delta_text def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: + request = super().adjust_request(request) if request.tools and request.tool_choice != "none": # do not skip special tokens because the tool_call tokens are # marked "special" in some models. Since they are skipped diff --git a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py index 958aa3b98fafb..c87bab4353b5b 100644 --- a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py @@ -35,6 +35,7 @@ class Internlm2ToolParser(ToolParser): self.position = 0 def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: + request = super().adjust_request(request) if request.tools and request.tool_choice != "none": # do not skip special tokens because internlm use the special # tokens to indicate the start and end of the tool calls diff --git a/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py index ca0faabada207..21ee2b762cd0a 100644 --- a/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py @@ -68,6 +68,7 @@ class JambaToolParser(ToolParser): ) def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: + request = super().adjust_request(request) if request.tools and request.tool_choice != "none": # do not skip special tokens because jamba use the special # tokens to indicate the start and end of the tool calls diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index 12b3d7bea8a42..dbdf0085367bc 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -94,6 +94,7 @@ class MistralToolParser(ToolParser): ) def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: + request = super().adjust_request(request) if ( not isinstance(self.model_tokenizer, MistralTokenizer) and request.tools diff --git a/vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py index 0a80c5ccc354d..d0255ec085391 100644 --- a/vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py @@ -51,6 +51,7 @@ class Step3ToolParser(ToolParser): self.tool_block_finished = False def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: + request = super().adjust_request(request) if request.tools and request.tool_choice != "none": request.skip_special_tokens = False return request diff --git a/vllm/entrypoints/openai/tool_parsers/utils.py b/vllm/entrypoints/openai/tool_parsers/utils.py index e076ab38e3364..570eb447a4678 100644 --- a/vllm/entrypoints/openai/tool_parsers/utils.py +++ b/vllm/entrypoints/openai/tool_parsers/utils.py @@ -6,8 +6,18 @@ from json import JSONDecodeError, JSONDecoder from typing import Any import partial_json_parser +from openai.types.responses import ( + FunctionTool, + ToolChoiceFunction, +) +from openai.types.responses.tool import Tool from partial_json_parser.core.options import Allow +from vllm.entrypoints.openai.protocol import ( + ChatCompletionNamedToolChoiceParam, + ChatCompletionToolsParam, +) + def find_common_prefix(s1: str, s2: str) -> str: """ @@ -122,3 +132,98 @@ def consume_space(i: int, s: str) -> int: while i < len(s) and s[i].isspace(): i += 1 return i + + +def _extract_tool_info( + tool: Tool | ChatCompletionToolsParam, +) -> tuple[str, dict[str, Any] | None]: + if isinstance(tool, FunctionTool): + return tool.name, tool.parameters + elif isinstance(tool, ChatCompletionToolsParam): + return tool.function.name, tool.function.parameters + else: + raise TypeError(f"Unsupported tool type: {type(tool)}") + + +def _get_tool_schema_from_tool(tool: Tool | ChatCompletionToolsParam) -> dict: + name, params = _extract_tool_info(tool) + params = params if params else {"type": "object", "properties": {}} + return { + "properties": { + "name": {"type": "string", "enum": [name]}, + "parameters": params, + }, + "required": ["name", "parameters"], + } + + +def _get_tool_schema_defs( + tools: list[Tool | ChatCompletionToolsParam], +) -> dict: + all_defs: dict[str, dict[str, Any]] = {} + for tool in tools: + _, params = _extract_tool_info(tool) + if params is None: + continue + defs = params.pop("$defs", {}) + for def_name, def_schema in defs.items(): + if def_name in all_defs and all_defs[def_name] != def_schema: + raise ValueError( + f"Tool definition '{def_name}' has multiple schemas, " + "which is not supported." + ) + all_defs[def_name] = def_schema + return all_defs + + +def _get_json_schema_from_tools( + tools: list[Tool | ChatCompletionToolsParam], +) -> dict: + json_schema = { + "type": "array", + "minItems": 1, + "items": { + "type": "object", + "anyOf": [_get_tool_schema_from_tool(tool) for tool in tools], + }, + } + json_schema_defs = _get_tool_schema_defs(tools) + if json_schema_defs: + json_schema["$defs"] = json_schema_defs + return json_schema + + +def get_json_schema_from_tools( + tool_choice: str | ToolChoiceFunction | ChatCompletionNamedToolChoiceParam, + tools: list[FunctionTool | ChatCompletionToolsParam] | None, +) -> str | dict | None: + # tool_choice: "none" + if tool_choice in ("none", None) or tools is None: + return None + # tool_choice: Forced Function (Responses) + if (not isinstance(tool_choice, str)) and isinstance( + tool_choice, ToolChoiceFunction + ): + tool_name = tool_choice.name + tool_map = {tool.name: tool for tool in tools if isinstance(tool, FunctionTool)} + if tool_name not in tool_map: + raise ValueError(f"Tool '{tool_name}' has not been passed in `tools`.") + return tool_map[tool_name].parameters + # tool_choice: Forced Function (ChatCompletion) + if (not isinstance(tool_choice, str)) and isinstance( + tool_choice, ChatCompletionNamedToolChoiceParam + ): + tool_name = tool_choice.function.name + tool_map = { + tool.function.name: tool + for tool in tools + if isinstance(tool, ChatCompletionToolsParam) + } + if tool_name not in tool_map: + raise ValueError(f"Tool '{tool_name}' has not been passed in `tools`.") + return tool_map[tool_name].function.parameters + # tool_choice: "required" + if tool_choice == "required": + return _get_json_schema_from_tools(tools) + # tool_choice: "auto" + return None