[Frontend] Support guidance:no-additional-properties for compatibility with xgrammar (#15949)

Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
This commit is contained in:
Travis Johnson 2025-04-23 12:34:41 -06:00 committed by GitHub
parent bdb3660312
commit 3cde34a4a4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 201 additions and 34 deletions

View File

@ -383,4 +383,62 @@ def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str):
assert generated_text is not None
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json, schema=json_schema)
jsonschema.validate(instance=output_json, schema=json_schema)
@pytest.mark.skip_global_cleanup
def test_guidance_no_additional_properties(llm):
schema = {
'type': 'object',
'properties': {
'a1': {
'type': 'string'
},
'a2': {
'type': 'string'
},
'a3': {
'type': 'string'
}
},
'required': ['a1', 'a2', 'a3'],
}
prompt = (
"<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a "
"helpful assistant.<|im_end|>\n<|im_start|>user\nPlease generate a "
"large JSON object with key-value pairs a1=b1, a2=b2, ..., a20=b20"
"<|im_end|>\n<|im_start|>assistant\n")
def generate_with_backend(backend):
guided_params = GuidedDecodingParams(json=schema, backend=backend)
sampling_params = SamplingParams(temperature=0,
max_tokens=256,
guided_decoding=guided_params)
outputs = llm.generate(prompts=prompt, sampling_params=sampling_params)
assert outputs is not None
generated_text = outputs[0].outputs[0].text
assert generated_text is not None
parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict)
jsonschema.validate(instance=parsed_json, schema=schema)
return parsed_json
base_generated = generate_with_backend('guidance:disable-any-whitespace')
assert "a1" in base_generated
assert "a2" in base_generated
assert "a3" in base_generated
# by default additional keys are generated
assert "a4" in base_generated
assert "a5" in base_generated
assert "a6" in base_generated
generated = generate_with_backend(
'guidance:no-additional-properties,disable-any-whitespace')
assert "a1" in generated
assert "a2" in generated
assert "a3" in generated
assert "a4" not in generated
assert "a5" not in generated
assert "a6" not in generated

View File

@ -412,3 +412,59 @@ def test_structured_output_auto_mode(
# Parse to verify it is valid JSON
parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict)
@pytest.mark.skip_global_cleanup
def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_USE_V1", "1")
backend = 'guidance:no-additional-properties,disable-any-whitespace'
llm = LLM(model="Qwen/Qwen2.5-1.5B-Instruct",
max_model_len=1024,
guided_decoding_backend=backend)
schema = {
'type': 'object',
'properties': {
'a1': {
'type': 'string'
},
'a2': {
'type': 'string'
},
'a3': {
'type': 'string'
}
},
'required': ['a1', 'a2', 'a3'],
}
prompt = (
"<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a "
"helpful assistant.<|im_end|>\n<|im_start|>user\nPlease generate a "
"large JSON object with key-value pairs a1=b1, a2=b2, ..., a20=b20"
"<|im_end|>\n<|im_start|>assistant\n")
def generate_with_backend(backend):
guided_params = GuidedDecodingParams(json=schema, backend=backend)
sampling_params = SamplingParams(temperature=0,
max_tokens=256,
guided_decoding=guided_params)
outputs = llm.generate(prompts=prompt, sampling_params=sampling_params)
assert outputs is not None
generated_text = outputs[0].outputs[0].text
assert generated_text is not None
parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict)
jsonschema.validate(instance=parsed_json, schema=schema)
return parsed_json
generated = generate_with_backend(
'guidance:no-additional-properties,disable-any-whitespace')
assert "a1" in generated
assert "a2" in generated
assert "a3" in generated
assert "a4" not in generated
assert "a5" not in generated
assert "a6" not in generated

View File

@ -3107,7 +3107,7 @@ def get_served_model_name(model: str,
GuidedDecodingBackendV0 = Literal["auto", "outlines", "lm-format-enforcer",
"xgrammar"]
"xgrammar", "guidance"]
GuidedDecodingBackendV1 = Literal["auto", "xgrammar", "guidance"]

View File

@ -18,7 +18,8 @@ import vllm.envs as envs
from vllm import version
from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
ConfigFormat, ConfigType, DecodingConfig, Device,
DeviceConfig, DistributedExecutorBackend, HfOverrides,
DeviceConfig, DistributedExecutorBackend,
GuidedDecodingBackendV1, HfOverrides,
KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
ModelConfig, ModelImpl, MultiModalConfig,
ObservabilityConfig, ParallelConfig, PoolerConfig,
@ -1370,14 +1371,13 @@ class EngineArgs:
recommend_to_remove=True)
return False
# Xgrammar and Guidance are supported.
SUPPORTED_GUIDED_DECODING = [
"xgrammar", "xgrammar:disable-any-whitespace", "guidance",
"guidance:disable-any-whitespace", "auto"
]
if self.guided_decoding_backend not in SUPPORTED_GUIDED_DECODING:
_raise_or_fallback(feature_name="--guided-decoding-backend",
recommend_to_remove=False)
# remove backend options when doing this check
if self.guided_decoding_backend.split(':')[0] \
not in get_args(GuidedDecodingBackendV1):
_raise_or_fallback(
feature_name=
f"--guided-decoding-backend={self.guided_decoding_backend}",
recommend_to_remove=False)
return False
# Need at least Ampere for now (FA support required).

View File

@ -1,4 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
import json
from re import escape as regex_escape
import llguidance
@ -7,6 +8,8 @@ from transformers import PreTrainedTokenizerBase
from vllm.model_executor.guided_decoding.guidance_logits_processors import (
GuidanceLogitsProcessor)
from vllm.sampling_params import GuidedDecodingParams
from vllm.v1.structured_output.backend_guidance import (
process_for_additional_properties)
def get_local_guidance_guided_decoding_logits_processor(
@ -20,9 +23,17 @@ def get_local_guidance_guided_decoding_logits_processor(
grm = ""
any_whitespace = 'disable-any-whitespace' not in \
guided_params.backend_options()
if guided_params.json:
if (guide_json := guided_params.json) is not None:
# Optionally set additionalProperties to False at the top-level
# By default, other backends do not allow additional top-level
# properties, so this makes guidance more similar to other backends
if 'no-additional-properties' in guided_params.backend_options():
if not isinstance(guide_json, str):
guide_json = json.dumps(guide_json)
guide_json = process_for_additional_properties(guide_json)
grm = llguidance.LLMatcher.grammar_from_json_schema(
guided_params.json,
guide_json,
overrides={"whitespace_pattern": guided_params.whitespace_pattern},
defaults={
"whitespace_flexible": any_whitespace,

View File

@ -145,15 +145,7 @@ class Processor:
if not params.guided_decoding or not self.decoding_config:
return
supported_backends = [
"xgrammar", "xgrammar:disable-any-whitespace", "guidance",
"guidance:disable-any-whitespace", "auto"
]
engine_level_backend = self.decoding_config.guided_decoding_backend
if engine_level_backend not in supported_backends:
raise ValueError(f"Only {supported_backends} structured output is "
"supported in V1.")
if params.guided_decoding.backend:
# Request-level backend selection is not supported in V1.
# The values may differ if `params` is reused and was set

View File

@ -1,14 +1,16 @@
# SPDX-License-Identifier: Apache-2.0
import copy
import json
import os
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Any, Optional, Union
import torch
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.utils import LazyLoader
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
@ -29,6 +31,29 @@ else:
logger = init_logger(__name__)
def _walk_json_for_additional_properties(data: object):
if isinstance(data, dict):
for value in data.values():
_walk_json_for_additional_properties(value)
if 'additionalProperties' not in data and \
('properties' in data or 'patternProperties' in data):
data['additionalProperties'] = False
elif isinstance(data, list):
for item in data:
_walk_json_for_additional_properties(item)
def process_for_additional_properties(
guide_json: Union[str, dict[str, Any]]) -> dict[str, Any]:
if isinstance(guide_json, str):
guide_json_obj = json.loads(guide_json)
else:
# copy for modifications
guide_json_obj = copy.deepcopy(guide_json)
_walk_json_for_additional_properties(guide_json_obj)
return guide_json_obj
class GuidanceBackend(StructuredOutputBackend):
def __init__(self, vllm_config: VllmConfig):
@ -41,9 +66,20 @@ class GuidanceBackend(StructuredOutputBackend):
tokenizer_group.ping()
self.vllm_config = vllm_config
self.vocab_size = vllm_config.model_config.get_vocab_size()
self.disable_any_whitespace = (
"disable-any-whitespace"
in vllm_config.decoding_config.guided_decoding_backend)
self.disable_any_whitespace = False
self.no_additional_properties = False
backend_options = GuidedDecodingParams(
backend=vllm_config.decoding_config.guided_decoding_backend
).backend_options()
for option in backend_options:
if option == "disable-any-whitespace":
self.disable_any_whitespace = True
elif option == "no-additional-properties":
self.no_additional_properties = True
else:
raise ValueError(
f"Unsupported option for the guidance backend: {option}")
tokenizer = tokenizer_group.get_lora_tokenizer(None)
self.ll_tokenizer = llguidance_hf.from_tokenizer(
@ -52,7 +88,8 @@ class GuidanceBackend(StructuredOutputBackend):
def compile_grammar(self, request_type: StructuredOutputOptions,
grammar_spec: str) -> StructuredOutputGrammar:
self.serialized_grammar = serialize_guidance_grammar(
request_type, grammar_spec, self.disable_any_whitespace)
request_type, grammar_spec, self.disable_any_whitespace,
self.no_additional_properties)
ll_matcher = llguidance.LLMatcher(
self.ll_tokenizer,
@ -129,10 +166,15 @@ class GuidanceGrammar(StructuredOutputGrammar):
self.ll_matcher.reset()
def serialize_guidance_grammar(request_type: StructuredOutputOptions,
grammar_spec: str,
disable_any_whitespace: bool = False) -> str:
def serialize_guidance_grammar(
request_type: StructuredOutputOptions,
grammar_spec: Union[str, dict[str, Any]],
disable_any_whitespace: bool = False,
no_additional_properties: bool = False,
) -> str:
if request_type == StructuredOutputOptions.JSON:
if no_additional_properties:
grammar_spec = process_for_additional_properties(grammar_spec)
return llguidance.LLMatcher.grammar_from_json_schema(
grammar_spec,
defaults={

View File

@ -9,7 +9,7 @@ import torch
import vllm.envs
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.utils import LazyLoader
@ -32,9 +32,6 @@ class XgrammarBackend(StructuredOutputBackend):
def __init__(self, vllm_config: VllmConfig):
self.vllm_config = vllm_config
self.disable_any_whitespace = (
"disable-any-whitespace"
in vllm_config.decoding_config.guided_decoding_backend)
tokenizer_group = init_tokenizer_from_configs(
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
@ -42,6 +39,17 @@ class XgrammarBackend(StructuredOutputBackend):
lora_config=vllm_config.lora_config) # type: ignore[arg-type]
tokenizer_group.ping()
self.disable_any_whitespace = False
backend_options = GuidedDecodingParams(
backend=vllm_config.decoding_config.guided_decoding_backend
).backend_options()
for option in backend_options:
if option == "disable-any-whitespace":
self.disable_any_whitespace = True
else:
raise ValueError(
f"Unsupported option for the xgrammar backend: {option}")
tokenizer = tokenizer_group.get_lora_tokenizer(None)
self.vocab_size = vllm_config.model_config.get_vocab_size()
if isinstance(tokenizer, MistralTokenizer):