From 8f2720def9c32163822d9ccf9a31019f7a892e12 Mon Sep 17 00:00:00 2001 From: Chauncey Date: Thu, 10 Jul 2025 13:56:35 +0800 Subject: [PATCH] [Frontend] Support Tool Calling with both `tool_choice='required'` and `$defs`. (#20629) Signed-off-by: chaunceyjiang --- .../test_completion_with_function_calling.py | 35 +++++++++++++++++++ vllm/entrypoints/openai/protocol.py | 21 +++++++++++ 2 files changed, 56 insertions(+) diff --git a/tests/entrypoints/openai/test_completion_with_function_calling.py b/tests/entrypoints/openai/test_completion_with_function_calling.py index 799648d3992e..eca048d855b5 100644 --- a/tests/entrypoints/openai/test_completion_with_function_calling.py +++ b/tests/entrypoints/openai/test_completion_with_function_calling.py @@ -72,8 +72,43 @@ async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str, "The unit to fetch the temperature in", "enum": ["celsius", "fahrenheit"], }, + "options": { + "$ref": "#/$defs/WeatherOptions", + "description": + "Optional parameters for weather query", + }, }, "required": ["country", "unit"], + "$defs": { + "WeatherOptions": { + "title": "WeatherOptions", + "type": "object", + "additionalProperties": False, + "properties": { + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "default": "celsius", + "description": "Temperature unit", + "title": "Temperature Unit", + }, + "include_forecast": { + "type": "boolean", + "default": False, + "description": + "Whether to include a 24-hour forecast", + "title": "Include Forecast", + }, + "language": { + "type": "string", + "default": "zh-CN", + "description": "Language of the response", + "title": "Language", + "enum": ["zh-CN", "en-US", "ja-JP"], + }, + }, + }, + }, }, }, }, diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 14b2253d1dba..b3395c598e74 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -707,6 +707,24 @@ class ChatCompletionRequest(OpenAIBaseModel): "required": ["name", "parameters"] } + def get_tool_schema_defs( + tools: list[ChatCompletionToolsParam]) -> dict: + all_defs = dict[str, dict[str, Any]]() + for tool in tools: + if tool.function.parameters is None: + continue + defs = tool.function.parameters.pop("$defs", {}) + for def_name, def_schema in defs.items(): + if def_name in all_defs and all_defs[ + def_name] != def_schema: + raise ValueError( + f"Tool definition '{def_name}' has " + "multiple schemas, which is not " + "supported.") + else: + all_defs[def_name] = def_schema + return all_defs + json_schema = { "type": "array", "minItems": 1, @@ -715,6 +733,9 @@ class ChatCompletionRequest(OpenAIBaseModel): "anyOf": [get_tool_schema(tool) for tool in self.tools] } } + json_schema_defs = get_tool_schema_defs(self.tools) + if json_schema_defs: + json_schema["$defs"] = json_schema_defs return json_schema return None