From 14e53ed11f5134381bde03484148cb3cd1b69cd6 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Wed, 2 Apr 2025 05:00:08 -0400 Subject: [PATCH] [V1] Fix json_object support with xgrammar (#15488) Signed-off-by: Russell Bryant --- requirements/common.txt | 2 +- .../entrypoints/llm/test_struct_output_generate.py | 12 ++---------- vllm/model_executor/guided_decoding/__init__.py | 6 ------ .../guided_decoding/xgrammar_decoding.py | 5 ++++- vllm/v1/structured_output/backend_xgrammar.py | 4 +++- 5 files changed, 10 insertions(+), 19 deletions(-) diff --git a/requirements/common.txt b/requirements/common.txt index 48e58c85c89b1..08fee27fe681b 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -21,7 +21,7 @@ lm-format-enforcer >= 0.10.11, < 0.11 llguidance >= 0.7.9, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64" outlines == 0.1.11 lark == 1.2.2 -xgrammar == 0.1.16; platform_machine == "x86_64" or platform_machine == "aarch64" +xgrammar == 0.1.17; platform_machine == "x86_64" or platform_machine == "aarch64" typing_extensions >= 4.10 filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317 partial-json-parser # used for parsing partial JSON outputs diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py index 0ffee08c23462..d848490b89e8a 100644 --- a/tests/v1/entrypoints/llm/test_struct_output_generate.py +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -125,17 +125,9 @@ def test_structured_output( print(generated_text) assert generated_text is not None - # Parse to verify it is valid JSON + # Parse to verify it is a valid JSON object parsed_json = json.loads(generated_text) - allowed_types: tuple[type, ...] = (dict, ) - if guided_decoding_backend.startswith("xgrammar"): - # TODO - we are currently too permissive with xgrammar and - # allow # any valid json (typically comes back as a list or - # object). We can fix this by specifying a jsonschema of - # {"type": "object"}, # but we need this fix in a release - # first: https://github.com/mlc-ai/xgrammar/pull/264 - allowed_types = (dict, list) - assert isinstance(parsed_json, allowed_types) + assert isinstance(parsed_json, dict) # # Test 3: test a jsonschema incompatible with xgrammar diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index cecb3a8a1d4a8..d4fd11fd2e305 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -79,12 +79,6 @@ def maybe_backend_fallback( "xgrammar does not support Lark grammars and the " "grammar failed to convert to GBNF.", "outlines") - elif guided_params.json_object: - # https://github.com/mlc-ai/xgrammar/issues/256 - fallback_or_error(guided_params, - "xgrammar does not support json_object.", - "guidance") - # If the xgrammar module cannot be imported successfully, # we should still allow users to use guided decoding with a fallback. elif not xgr_installed: diff --git a/vllm/model_executor/guided_decoding/xgrammar_decoding.py b/vllm/model_executor/guided_decoding/xgrammar_decoding.py index 47b1e7e3f9811..b44301f1a4c9b 100644 --- a/vllm/model_executor/guided_decoding/xgrammar_decoding.py +++ b/vllm/model_executor/guided_decoding/xgrammar_decoding.py @@ -320,7 +320,10 @@ class XGrammarLogitsProcessor: elif self.config.grammar_str is not None: self.ctx = compiler.compile_grammar(self.config.grammar_str) elif self.config.json_object: - self.ctx = compiler.compile_builtin_json_grammar() + any_whitespace = self.config.any_whitespace + self.ctx = compiler\ + .compile_json_schema('{"type": "object"}', + any_whitespace=any_whitespace) else: raise ValueError( "Invalid configuration for xgrammar logits processor") diff --git a/vllm/v1/structured_output/backend_xgrammar.py b/vllm/v1/structured_output/backend_xgrammar.py index 7fe62f26af597..783a33481243c 100644 --- a/vllm/v1/structured_output/backend_xgrammar.py +++ b/vllm/v1/structured_output/backend_xgrammar.py @@ -84,7 +84,9 @@ class XgrammarBackend(StructuredOutputBackend): ctx = self.compiler.compile_json_schema( grammar_spec, any_whitespace=not self.disable_any_whitespace) elif request_type == StructuredOutputOptions.JSON_OBJECT: - ctx = self.compiler.compile_builtin_json_grammar() + ctx = self.compiler.compile_json_schema( + '{"type": "object"}', + any_whitespace=not self.disable_any_whitespace) elif request_type == StructuredOutputOptions.GRAMMAR: ctx = self.compiler.compile_grammar(grammar_spec) elif request_type == StructuredOutputOptions.REGEX: