[Feature][V1] Support tool_choice: required when using Xgrammar as the StructuredOutputBackend. (#17845)

Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
This commit is contained in:
Chauncey 2025-05-13 14:01:31 +08:00 committed by GitHub
parent 61e0a506a3
commit dc1a821768
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 160 additions and 13 deletions

View File

@ -22,7 +22,7 @@ lm-format-enforcer >= 0.10.11, < 0.11
llguidance >= 0.7.11, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64"
outlines == 0.1.11
lark == 1.2.2
xgrammar == 0.1.18; platform_machine == "x86_64" or platform_machine == "aarch64"
xgrammar == 0.1.19; platform_machine == "x86_64" or platform_machine == "aarch64"
typing_extensions >= 4.10
filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317
partial-json-parser # used for parsing partial JSON outputs

View File

@ -0,0 +1,154 @@
# SPDX-License-Identifier: Apache-2.0
import openai # use the official client for correctness check
import pytest
import pytest_asyncio
# downloading lora to test lora requests
from ...utils import RemoteOpenAIServer
# any model with a chat template should work here
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
@pytest.fixture(scope="module")
def server(): # noqa: F811
args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
"half",
"--enable-auto-tool-choice",
"--guided-decoding-backend",
"xgrammar",
"--tool-call-parser",
"hermes"
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def client(server):
async with server.get_async_client() as async_client:
yield async_client
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_required_tool_use(client: openai.AsyncOpenAI, model_name: str):
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description":
"The city to find the weather for, e.g. 'Vienna'",
"default": "Vienna",
},
"country": {
"type":
"string",
"description":
"The country that the city is in, e.g. 'Austria'",
},
"unit": {
"type": "string",
"description":
"The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["country", "unit"],
},
},
},
{
"type": "function",
"function": {
"name": "get_forecast",
"description": "Get the weather forecast for a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description":
"The city to get the forecast for, e.g. 'Vienna'",
"default": "Vienna",
},
"country": {
"type":
"string",
"description":
"The country that the city is in, e.g. 'Austria'",
},
"days": {
"type":
"integer",
"description":
"Number of days to get the forecast for (1-7)",
},
"unit": {
"type": "string",
"description":
"The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["country", "days", "unit"],
},
},
},
]
messages = [
{
"role": "user",
"content": "Hi! How are you doing today?"
},
{
"role": "assistant",
"content": "I'm doing well! How can I help you?"
},
{
"role":
"user",
"content":
"Can you tell me what the current weather is in Berlin and the "\
"forecast for the next 5 days, in fahrenheit?",
},
]
# Non-streaming test
chat_completion = await client.chat.completions.create(
messages=messages,
model=model_name,
tools=tools,
tool_choice="required",
)
assert chat_completion.choices[0].message.tool_calls is not None
assert len(chat_completion.choices[0].message.tool_calls) > 0
# Streaming test
stream = await client.chat.completions.create(
messages=messages,
model=model_name,
tools=tools,
tool_choice="required",
stream=True,
)
output = []
async for chunk in stream:
if chunk.choices and chunk.choices[0].delta.tool_calls:
output.extend(chunk.choices[0].delta.tool_calls)
assert len(output) > 0

View File

@ -74,7 +74,9 @@ def sample_json_schema():
},
"required": ["company", "duration", "position"],
"additionalProperties": False
}
},
"minItems": 0,
"maxItems": 3
}
},
"required":

View File

@ -57,14 +57,6 @@ def unsupported_array_schemas():
"type": "array",
"maxContains": 5
},
{
"type": "array",
"minItems": 1
},
{
"type": "array",
"maxItems": 10
},
]

View File

@ -215,9 +215,8 @@ def has_xgrammar_unsupported_json_features(schema: dict[str, Any]) -> bool:
# Check for array unsupported keywords
if obj.get("type") == "array" and any(
key in obj
for key in ("uniqueItems", "contains", "minContains",
"maxContains", "minItems", "maxItems")):
key in obj for key in ("uniqueItems", "contains",
"minContains", "maxContains")):
return True
# Unsupported keywords for strings