diff --git a/requirements/common.txt b/requirements/common.txt index f537b3aab541..80f90e60007e 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -22,7 +22,7 @@ lm-format-enforcer >= 0.10.11, < 0.11 llguidance >= 0.7.11, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64" outlines == 0.1.11 lark == 1.2.2 -xgrammar == 0.1.18; platform_machine == "x86_64" or platform_machine == "aarch64" +xgrammar == 0.1.19; platform_machine == "x86_64" or platform_machine == "aarch64" typing_extensions >= 4.10 filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317 partial-json-parser # used for parsing partial JSON outputs diff --git a/tests/entrypoints/openai/test_completion_with_function_calling.py b/tests/entrypoints/openai/test_completion_with_function_calling.py new file mode 100644 index 000000000000..dad76b54c5e9 --- /dev/null +++ b/tests/entrypoints/openai/test_completion_with_function_calling.py @@ -0,0 +1,154 @@ +# SPDX-License-Identifier: Apache-2.0 + +import openai # use the official client for correctness check +import pytest +import pytest_asyncio + +# downloading lora to test lora requests +from ...utils import RemoteOpenAIServer + +# any model with a chat template should work here +MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct" + + +@pytest.fixture(scope="module") +def server(): # noqa: F811 + args = [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "half", + "--enable-auto-tool-choice", + "--guided-decoding-backend", + "xgrammar", + "--tool-call-parser", + "hermes" + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_required_tool_use(client: openai.AsyncOpenAI, model_name: str): + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": + "The city to find the weather for, e.g. 'Vienna'", + "default": "Vienna", + }, + "country": { + "type": + "string", + "description": + "The country that the city is in, e.g. 'Austria'", + }, + "unit": { + "type": "string", + "description": + "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["country", "unit"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_forecast", + "description": "Get the weather forecast for a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": + "The city to get the forecast for, e.g. 'Vienna'", + "default": "Vienna", + }, + "country": { + "type": + "string", + "description": + "The country that the city is in, e.g. 'Austria'", + }, + "days": { + "type": + "integer", + "description": + "Number of days to get the forecast for (1-7)", + }, + "unit": { + "type": "string", + "description": + "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["country", "days", "unit"], + }, + }, + }, + ] + + messages = [ + { + "role": "user", + "content": "Hi! How are you doing today?" + }, + { + "role": "assistant", + "content": "I'm doing well! How can I help you?" + }, + { + "role": + "user", + "content": + "Can you tell me what the current weather is in Berlin and the "\ + "forecast for the next 5 days, in fahrenheit?", + }, + ] + + # Non-streaming test + chat_completion = await client.chat.completions.create( + messages=messages, + model=model_name, + tools=tools, + tool_choice="required", + ) + + assert chat_completion.choices[0].message.tool_calls is not None + assert len(chat_completion.choices[0].message.tool_calls) > 0 + + # Streaming test + stream = await client.chat.completions.create( + messages=messages, + model=model_name, + tools=tools, + tool_choice="required", + stream=True, + ) + + output = [] + async for chunk in stream: + if chunk.choices and chunk.choices[0].delta.tool_calls: + output.extend(chunk.choices[0].delta.tool_calls) + + assert len(output) > 0 diff --git a/tests/v1/entrypoints/conftest.py b/tests/v1/entrypoints/conftest.py index bdee0bb8da7d..8c03f04330dd 100644 --- a/tests/v1/entrypoints/conftest.py +++ b/tests/v1/entrypoints/conftest.py @@ -74,7 +74,9 @@ def sample_json_schema(): }, "required": ["company", "duration", "position"], "additionalProperties": False - } + }, + "minItems": 0, + "maxItems": 3 } }, "required": diff --git a/tests/v1/structured_output/test_utils.py b/tests/v1/structured_output/test_utils.py index 1cefe8726df7..ffc0bceeee49 100644 --- a/tests/v1/structured_output/test_utils.py +++ b/tests/v1/structured_output/test_utils.py @@ -57,14 +57,6 @@ def unsupported_array_schemas(): "type": "array", "maxContains": 5 }, - { - "type": "array", - "minItems": 1 - }, - { - "type": "array", - "maxItems": 10 - }, ] diff --git a/vllm/v1/structured_output/backend_xgrammar.py b/vllm/v1/structured_output/backend_xgrammar.py index baa478bc63bd..2ce2be337ecb 100644 --- a/vllm/v1/structured_output/backend_xgrammar.py +++ b/vllm/v1/structured_output/backend_xgrammar.py @@ -215,9 +215,8 @@ def has_xgrammar_unsupported_json_features(schema: dict[str, Any]) -> bool: # Check for array unsupported keywords if obj.get("type") == "array" and any( - key in obj - for key in ("uniqueItems", "contains", "minContains", - "maxContains", "minItems", "maxItems")): + key in obj for key in ("uniqueItems", "contains", + "minContains", "maxContains")): return True # Unsupported keywords for strings