From 3cde34a4a4bceb511f2f9fe2dade1da116eb7e8a Mon Sep 17 00:00:00 2001 From: Travis Johnson Date: Wed, 23 Apr 2025 12:34:41 -0600 Subject: [PATCH] [Frontend] Support guidance:no-additional-properties for compatibility with xgrammar (#15949) Signed-off-by: Travis Johnson --- tests/entrypoints/llm/test_guided_generate.py | 60 ++++++++++++++++++- .../llm/test_struct_output_generate.py | 56 +++++++++++++++++ vllm/config.py | 2 +- vllm/engine/arg_utils.py | 18 +++--- .../guided_decoding/guidance_decoding.py | 15 ++++- vllm/v1/engine/processor.py | 8 --- vllm/v1/structured_output/backend_guidance.py | 60 ++++++++++++++++--- vllm/v1/structured_output/backend_xgrammar.py | 16 +++-- 8 files changed, 201 insertions(+), 34 deletions(-) diff --git a/tests/entrypoints/llm/test_guided_generate.py b/tests/entrypoints/llm/test_guided_generate.py index e43e9826e8f9b..6bc32ebadc7a3 100644 --- a/tests/entrypoints/llm/test_guided_generate.py +++ b/tests/entrypoints/llm/test_guided_generate.py @@ -383,4 +383,62 @@ def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str): assert generated_text is not None print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") output_json = json.loads(generated_text) - jsonschema.validate(instance=output_json, schema=json_schema) \ No newline at end of file + jsonschema.validate(instance=output_json, schema=json_schema) + + +@pytest.mark.skip_global_cleanup +def test_guidance_no_additional_properties(llm): + schema = { + 'type': 'object', + 'properties': { + 'a1': { + 'type': 'string' + }, + 'a2': { + 'type': 'string' + }, + 'a3': { + 'type': 'string' + } + }, + 'required': ['a1', 'a2', 'a3'], + } + + prompt = ( + "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a " + "helpful assistant.<|im_end|>\n<|im_start|>user\nPlease generate a " + "large JSON object with key-value pairs a1=b1, a2=b2, ..., a20=b20" + "<|im_end|>\n<|im_start|>assistant\n") + + def generate_with_backend(backend): + guided_params = GuidedDecodingParams(json=schema, backend=backend) + sampling_params = SamplingParams(temperature=0, + max_tokens=256, + guided_decoding=guided_params) + + outputs = llm.generate(prompts=prompt, sampling_params=sampling_params) + assert outputs is not None + generated_text = outputs[0].outputs[0].text + assert generated_text is not None + parsed_json = json.loads(generated_text) + assert isinstance(parsed_json, dict) + jsonschema.validate(instance=parsed_json, schema=schema) + return parsed_json + + base_generated = generate_with_backend('guidance:disable-any-whitespace') + assert "a1" in base_generated + assert "a2" in base_generated + assert "a3" in base_generated + # by default additional keys are generated + assert "a4" in base_generated + assert "a5" in base_generated + assert "a6" in base_generated + + generated = generate_with_backend( + 'guidance:no-additional-properties,disable-any-whitespace') + assert "a1" in generated + assert "a2" in generated + assert "a3" in generated + assert "a4" not in generated + assert "a5" not in generated + assert "a6" not in generated diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py index c243d81e7f183..fc8e271f7f915 100644 --- a/tests/v1/entrypoints/llm/test_struct_output_generate.py +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -412,3 +412,59 @@ def test_structured_output_auto_mode( # Parse to verify it is valid JSON parsed_json = json.loads(generated_text) assert isinstance(parsed_json, dict) + + +@pytest.mark.skip_global_cleanup +def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("VLLM_USE_V1", "1") + + backend = 'guidance:no-additional-properties,disable-any-whitespace' + llm = LLM(model="Qwen/Qwen2.5-1.5B-Instruct", + max_model_len=1024, + guided_decoding_backend=backend) + + schema = { + 'type': 'object', + 'properties': { + 'a1': { + 'type': 'string' + }, + 'a2': { + 'type': 'string' + }, + 'a3': { + 'type': 'string' + } + }, + 'required': ['a1', 'a2', 'a3'], + } + + prompt = ( + "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a " + "helpful assistant.<|im_end|>\n<|im_start|>user\nPlease generate a " + "large JSON object with key-value pairs a1=b1, a2=b2, ..., a20=b20" + "<|im_end|>\n<|im_start|>assistant\n") + + def generate_with_backend(backend): + guided_params = GuidedDecodingParams(json=schema, backend=backend) + sampling_params = SamplingParams(temperature=0, + max_tokens=256, + guided_decoding=guided_params) + + outputs = llm.generate(prompts=prompt, sampling_params=sampling_params) + assert outputs is not None + generated_text = outputs[0].outputs[0].text + assert generated_text is not None + parsed_json = json.loads(generated_text) + assert isinstance(parsed_json, dict) + jsonschema.validate(instance=parsed_json, schema=schema) + return parsed_json + + generated = generate_with_backend( + 'guidance:no-additional-properties,disable-any-whitespace') + assert "a1" in generated + assert "a2" in generated + assert "a3" in generated + assert "a4" not in generated + assert "a5" not in generated + assert "a6" not in generated diff --git a/vllm/config.py b/vllm/config.py index 709983d05a509..741ce04d5dffd 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3107,7 +3107,7 @@ def get_served_model_name(model: str, GuidedDecodingBackendV0 = Literal["auto", "outlines", "lm-format-enforcer", - "xgrammar"] + "xgrammar", "guidance"] GuidedDecodingBackendV1 = Literal["auto", "xgrammar", "guidance"] diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index b6d0bfeac4a44..c67759c1d09a4 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -18,7 +18,8 @@ import vllm.envs as envs from vllm import version from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig, ConfigFormat, ConfigType, DecodingConfig, Device, - DeviceConfig, DistributedExecutorBackend, HfOverrides, + DeviceConfig, DistributedExecutorBackend, + GuidedDecodingBackendV1, HfOverrides, KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig, ModelConfig, ModelImpl, MultiModalConfig, ObservabilityConfig, ParallelConfig, PoolerConfig, @@ -1370,14 +1371,13 @@ class EngineArgs: recommend_to_remove=True) return False - # Xgrammar and Guidance are supported. - SUPPORTED_GUIDED_DECODING = [ - "xgrammar", "xgrammar:disable-any-whitespace", "guidance", - "guidance:disable-any-whitespace", "auto" - ] - if self.guided_decoding_backend not in SUPPORTED_GUIDED_DECODING: - _raise_or_fallback(feature_name="--guided-decoding-backend", - recommend_to_remove=False) + # remove backend options when doing this check + if self.guided_decoding_backend.split(':')[0] \ + not in get_args(GuidedDecodingBackendV1): + _raise_or_fallback( + feature_name= + f"--guided-decoding-backend={self.guided_decoding_backend}", + recommend_to_remove=False) return False # Need at least Ampere for now (FA support required). diff --git a/vllm/model_executor/guided_decoding/guidance_decoding.py b/vllm/model_executor/guided_decoding/guidance_decoding.py index f19ebcbe420e3..95b7c71107aab 100644 --- a/vllm/model_executor/guided_decoding/guidance_decoding.py +++ b/vllm/model_executor/guided_decoding/guidance_decoding.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +import json from re import escape as regex_escape import llguidance @@ -7,6 +8,8 @@ from transformers import PreTrainedTokenizerBase from vllm.model_executor.guided_decoding.guidance_logits_processors import ( GuidanceLogitsProcessor) from vllm.sampling_params import GuidedDecodingParams +from vllm.v1.structured_output.backend_guidance import ( + process_for_additional_properties) def get_local_guidance_guided_decoding_logits_processor( @@ -20,9 +23,17 @@ def get_local_guidance_guided_decoding_logits_processor( grm = "" any_whitespace = 'disable-any-whitespace' not in \ guided_params.backend_options() - if guided_params.json: + if (guide_json := guided_params.json) is not None: + # Optionally set additionalProperties to False at the top-level + # By default, other backends do not allow additional top-level + # properties, so this makes guidance more similar to other backends + if 'no-additional-properties' in guided_params.backend_options(): + if not isinstance(guide_json, str): + guide_json = json.dumps(guide_json) + guide_json = process_for_additional_properties(guide_json) + grm = llguidance.LLMatcher.grammar_from_json_schema( - guided_params.json, + guide_json, overrides={"whitespace_pattern": guided_params.whitespace_pattern}, defaults={ "whitespace_flexible": any_whitespace, diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 26c57b31aacd7..7e6b7ba470350 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -145,15 +145,7 @@ class Processor: if not params.guided_decoding or not self.decoding_config: return - supported_backends = [ - "xgrammar", "xgrammar:disable-any-whitespace", "guidance", - "guidance:disable-any-whitespace", "auto" - ] - engine_level_backend = self.decoding_config.guided_decoding_backend - if engine_level_backend not in supported_backends: - raise ValueError(f"Only {supported_backends} structured output is " - "supported in V1.") if params.guided_decoding.backend: # Request-level backend selection is not supported in V1. # The values may differ if `params` is reused and was set diff --git a/vllm/v1/structured_output/backend_guidance.py b/vllm/v1/structured_output/backend_guidance.py index 9150a28570bdd..0edb15558dce2 100644 --- a/vllm/v1/structured_output/backend_guidance.py +++ b/vllm/v1/structured_output/backend_guidance.py @@ -1,14 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 +import copy +import json import os from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Any, Optional, Union import torch from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import GuidedDecodingParams, SamplingParams from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.utils import LazyLoader from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, @@ -29,6 +31,29 @@ else: logger = init_logger(__name__) +def _walk_json_for_additional_properties(data: object): + if isinstance(data, dict): + for value in data.values(): + _walk_json_for_additional_properties(value) + if 'additionalProperties' not in data and \ + ('properties' in data or 'patternProperties' in data): + data['additionalProperties'] = False + elif isinstance(data, list): + for item in data: + _walk_json_for_additional_properties(item) + + +def process_for_additional_properties( + guide_json: Union[str, dict[str, Any]]) -> dict[str, Any]: + if isinstance(guide_json, str): + guide_json_obj = json.loads(guide_json) + else: + # copy for modifications + guide_json_obj = copy.deepcopy(guide_json) + _walk_json_for_additional_properties(guide_json_obj) + return guide_json_obj + + class GuidanceBackend(StructuredOutputBackend): def __init__(self, vllm_config: VllmConfig): @@ -41,9 +66,20 @@ class GuidanceBackend(StructuredOutputBackend): tokenizer_group.ping() self.vllm_config = vllm_config self.vocab_size = vllm_config.model_config.get_vocab_size() - self.disable_any_whitespace = ( - "disable-any-whitespace" - in vllm_config.decoding_config.guided_decoding_backend) + + self.disable_any_whitespace = False + self.no_additional_properties = False + backend_options = GuidedDecodingParams( + backend=vllm_config.decoding_config.guided_decoding_backend + ).backend_options() + for option in backend_options: + if option == "disable-any-whitespace": + self.disable_any_whitespace = True + elif option == "no-additional-properties": + self.no_additional_properties = True + else: + raise ValueError( + f"Unsupported option for the guidance backend: {option}") tokenizer = tokenizer_group.get_lora_tokenizer(None) self.ll_tokenizer = llguidance_hf.from_tokenizer( @@ -52,7 +88,8 @@ class GuidanceBackend(StructuredOutputBackend): def compile_grammar(self, request_type: StructuredOutputOptions, grammar_spec: str) -> StructuredOutputGrammar: self.serialized_grammar = serialize_guidance_grammar( - request_type, grammar_spec, self.disable_any_whitespace) + request_type, grammar_spec, self.disable_any_whitespace, + self.no_additional_properties) ll_matcher = llguidance.LLMatcher( self.ll_tokenizer, @@ -129,10 +166,15 @@ class GuidanceGrammar(StructuredOutputGrammar): self.ll_matcher.reset() -def serialize_guidance_grammar(request_type: StructuredOutputOptions, - grammar_spec: str, - disable_any_whitespace: bool = False) -> str: +def serialize_guidance_grammar( + request_type: StructuredOutputOptions, + grammar_spec: Union[str, dict[str, Any]], + disable_any_whitespace: bool = False, + no_additional_properties: bool = False, +) -> str: if request_type == StructuredOutputOptions.JSON: + if no_additional_properties: + grammar_spec = process_for_additional_properties(grammar_spec) return llguidance.LLMatcher.grammar_from_json_schema( grammar_spec, defaults={ diff --git a/vllm/v1/structured_output/backend_xgrammar.py b/vllm/v1/structured_output/backend_xgrammar.py index c9839bd7ddee0..1e4470153e306 100644 --- a/vllm/v1/structured_output/backend_xgrammar.py +++ b/vllm/v1/structured_output/backend_xgrammar.py @@ -9,7 +9,7 @@ import torch import vllm.envs from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import GuidedDecodingParams, SamplingParams from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer from vllm.utils import LazyLoader @@ -32,9 +32,6 @@ class XgrammarBackend(StructuredOutputBackend): def __init__(self, vllm_config: VllmConfig): self.vllm_config = vllm_config - self.disable_any_whitespace = ( - "disable-any-whitespace" - in vllm_config.decoding_config.guided_decoding_backend) tokenizer_group = init_tokenizer_from_configs( model_config=vllm_config.model_config, scheduler_config=vllm_config.scheduler_config, @@ -42,6 +39,17 @@ class XgrammarBackend(StructuredOutputBackend): lora_config=vllm_config.lora_config) # type: ignore[arg-type] tokenizer_group.ping() + self.disable_any_whitespace = False + backend_options = GuidedDecodingParams( + backend=vllm_config.decoding_config.guided_decoding_backend + ).backend_options() + for option in backend_options: + if option == "disable-any-whitespace": + self.disable_any_whitespace = True + else: + raise ValueError( + f"Unsupported option for the xgrammar backend: {option}") + tokenizer = tokenizer_group.get_lora_tokenizer(None) self.vocab_size = vllm_config.model_config.get_vocab_size() if isinstance(tokenizer, MistralTokenizer):