mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 12:25:01 +08:00
[Frontend] Support guidance:no-additional-properties for compatibility with xgrammar (#15949)
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
This commit is contained in:
parent
bdb3660312
commit
3cde34a4a4
@ -383,4 +383,62 @@ def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str):
|
||||
assert generated_text is not None
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
output_json = json.loads(generated_text)
|
||||
jsonschema.validate(instance=output_json, schema=json_schema)
|
||||
jsonschema.validate(instance=output_json, schema=json_schema)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_guidance_no_additional_properties(llm):
|
||||
schema = {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'a1': {
|
||||
'type': 'string'
|
||||
},
|
||||
'a2': {
|
||||
'type': 'string'
|
||||
},
|
||||
'a3': {
|
||||
'type': 'string'
|
||||
}
|
||||
},
|
||||
'required': ['a1', 'a2', 'a3'],
|
||||
}
|
||||
|
||||
prompt = (
|
||||
"<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a "
|
||||
"helpful assistant.<|im_end|>\n<|im_start|>user\nPlease generate a "
|
||||
"large JSON object with key-value pairs a1=b1, a2=b2, ..., a20=b20"
|
||||
"<|im_end|>\n<|im_start|>assistant\n")
|
||||
|
||||
def generate_with_backend(backend):
|
||||
guided_params = GuidedDecodingParams(json=schema, backend=backend)
|
||||
sampling_params = SamplingParams(temperature=0,
|
||||
max_tokens=256,
|
||||
guided_decoding=guided_params)
|
||||
|
||||
outputs = llm.generate(prompts=prompt, sampling_params=sampling_params)
|
||||
assert outputs is not None
|
||||
generated_text = outputs[0].outputs[0].text
|
||||
assert generated_text is not None
|
||||
parsed_json = json.loads(generated_text)
|
||||
assert isinstance(parsed_json, dict)
|
||||
jsonschema.validate(instance=parsed_json, schema=schema)
|
||||
return parsed_json
|
||||
|
||||
base_generated = generate_with_backend('guidance:disable-any-whitespace')
|
||||
assert "a1" in base_generated
|
||||
assert "a2" in base_generated
|
||||
assert "a3" in base_generated
|
||||
# by default additional keys are generated
|
||||
assert "a4" in base_generated
|
||||
assert "a5" in base_generated
|
||||
assert "a6" in base_generated
|
||||
|
||||
generated = generate_with_backend(
|
||||
'guidance:no-additional-properties,disable-any-whitespace')
|
||||
assert "a1" in generated
|
||||
assert "a2" in generated
|
||||
assert "a3" in generated
|
||||
assert "a4" not in generated
|
||||
assert "a5" not in generated
|
||||
assert "a6" not in generated
|
||||
|
||||
@ -412,3 +412,59 @@ def test_structured_output_auto_mode(
|
||||
# Parse to verify it is valid JSON
|
||||
parsed_json = json.loads(generated_text)
|
||||
assert isinstance(parsed_json, dict)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
backend = 'guidance:no-additional-properties,disable-any-whitespace'
|
||||
llm = LLM(model="Qwen/Qwen2.5-1.5B-Instruct",
|
||||
max_model_len=1024,
|
||||
guided_decoding_backend=backend)
|
||||
|
||||
schema = {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'a1': {
|
||||
'type': 'string'
|
||||
},
|
||||
'a2': {
|
||||
'type': 'string'
|
||||
},
|
||||
'a3': {
|
||||
'type': 'string'
|
||||
}
|
||||
},
|
||||
'required': ['a1', 'a2', 'a3'],
|
||||
}
|
||||
|
||||
prompt = (
|
||||
"<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a "
|
||||
"helpful assistant.<|im_end|>\n<|im_start|>user\nPlease generate a "
|
||||
"large JSON object with key-value pairs a1=b1, a2=b2, ..., a20=b20"
|
||||
"<|im_end|>\n<|im_start|>assistant\n")
|
||||
|
||||
def generate_with_backend(backend):
|
||||
guided_params = GuidedDecodingParams(json=schema, backend=backend)
|
||||
sampling_params = SamplingParams(temperature=0,
|
||||
max_tokens=256,
|
||||
guided_decoding=guided_params)
|
||||
|
||||
outputs = llm.generate(prompts=prompt, sampling_params=sampling_params)
|
||||
assert outputs is not None
|
||||
generated_text = outputs[0].outputs[0].text
|
||||
assert generated_text is not None
|
||||
parsed_json = json.loads(generated_text)
|
||||
assert isinstance(parsed_json, dict)
|
||||
jsonschema.validate(instance=parsed_json, schema=schema)
|
||||
return parsed_json
|
||||
|
||||
generated = generate_with_backend(
|
||||
'guidance:no-additional-properties,disable-any-whitespace')
|
||||
assert "a1" in generated
|
||||
assert "a2" in generated
|
||||
assert "a3" in generated
|
||||
assert "a4" not in generated
|
||||
assert "a5" not in generated
|
||||
assert "a6" not in generated
|
||||
|
||||
@ -3107,7 +3107,7 @@ def get_served_model_name(model: str,
|
||||
|
||||
|
||||
GuidedDecodingBackendV0 = Literal["auto", "outlines", "lm-format-enforcer",
|
||||
"xgrammar"]
|
||||
"xgrammar", "guidance"]
|
||||
GuidedDecodingBackendV1 = Literal["auto", "xgrammar", "guidance"]
|
||||
|
||||
|
||||
|
||||
@ -18,7 +18,8 @@ import vllm.envs as envs
|
||||
from vllm import version
|
||||
from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
|
||||
ConfigFormat, ConfigType, DecodingConfig, Device,
|
||||
DeviceConfig, DistributedExecutorBackend, HfOverrides,
|
||||
DeviceConfig, DistributedExecutorBackend,
|
||||
GuidedDecodingBackendV1, HfOverrides,
|
||||
KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
|
||||
ModelConfig, ModelImpl, MultiModalConfig,
|
||||
ObservabilityConfig, ParallelConfig, PoolerConfig,
|
||||
@ -1370,14 +1371,13 @@ class EngineArgs:
|
||||
recommend_to_remove=True)
|
||||
return False
|
||||
|
||||
# Xgrammar and Guidance are supported.
|
||||
SUPPORTED_GUIDED_DECODING = [
|
||||
"xgrammar", "xgrammar:disable-any-whitespace", "guidance",
|
||||
"guidance:disable-any-whitespace", "auto"
|
||||
]
|
||||
if self.guided_decoding_backend not in SUPPORTED_GUIDED_DECODING:
|
||||
_raise_or_fallback(feature_name="--guided-decoding-backend",
|
||||
recommend_to_remove=False)
|
||||
# remove backend options when doing this check
|
||||
if self.guided_decoding_backend.split(':')[0] \
|
||||
not in get_args(GuidedDecodingBackendV1):
|
||||
_raise_or_fallback(
|
||||
feature_name=
|
||||
f"--guided-decoding-backend={self.guided_decoding_backend}",
|
||||
recommend_to_remove=False)
|
||||
return False
|
||||
|
||||
# Need at least Ampere for now (FA support required).
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import json
|
||||
from re import escape as regex_escape
|
||||
|
||||
import llguidance
|
||||
@ -7,6 +8,8 @@ from transformers import PreTrainedTokenizerBase
|
||||
from vllm.model_executor.guided_decoding.guidance_logits_processors import (
|
||||
GuidanceLogitsProcessor)
|
||||
from vllm.sampling_params import GuidedDecodingParams
|
||||
from vllm.v1.structured_output.backend_guidance import (
|
||||
process_for_additional_properties)
|
||||
|
||||
|
||||
def get_local_guidance_guided_decoding_logits_processor(
|
||||
@ -20,9 +23,17 @@ def get_local_guidance_guided_decoding_logits_processor(
|
||||
grm = ""
|
||||
any_whitespace = 'disable-any-whitespace' not in \
|
||||
guided_params.backend_options()
|
||||
if guided_params.json:
|
||||
if (guide_json := guided_params.json) is not None:
|
||||
# Optionally set additionalProperties to False at the top-level
|
||||
# By default, other backends do not allow additional top-level
|
||||
# properties, so this makes guidance more similar to other backends
|
||||
if 'no-additional-properties' in guided_params.backend_options():
|
||||
if not isinstance(guide_json, str):
|
||||
guide_json = json.dumps(guide_json)
|
||||
guide_json = process_for_additional_properties(guide_json)
|
||||
|
||||
grm = llguidance.LLMatcher.grammar_from_json_schema(
|
||||
guided_params.json,
|
||||
guide_json,
|
||||
overrides={"whitespace_pattern": guided_params.whitespace_pattern},
|
||||
defaults={
|
||||
"whitespace_flexible": any_whitespace,
|
||||
|
||||
@ -145,15 +145,7 @@ class Processor:
|
||||
if not params.guided_decoding or not self.decoding_config:
|
||||
return
|
||||
|
||||
supported_backends = [
|
||||
"xgrammar", "xgrammar:disable-any-whitespace", "guidance",
|
||||
"guidance:disable-any-whitespace", "auto"
|
||||
]
|
||||
|
||||
engine_level_backend = self.decoding_config.guided_decoding_backend
|
||||
if engine_level_backend not in supported_backends:
|
||||
raise ValueError(f"Only {supported_backends} structured output is "
|
||||
"supported in V1.")
|
||||
if params.guided_decoding.backend:
|
||||
# Request-level backend selection is not supported in V1.
|
||||
# The values may differ if `params` is reused and was set
|
||||
|
||||
@ -1,14 +1,16 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
||||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||
from vllm.utils import LazyLoader
|
||||
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
|
||||
@ -29,6 +31,29 @@ else:
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _walk_json_for_additional_properties(data: object):
|
||||
if isinstance(data, dict):
|
||||
for value in data.values():
|
||||
_walk_json_for_additional_properties(value)
|
||||
if 'additionalProperties' not in data and \
|
||||
('properties' in data or 'patternProperties' in data):
|
||||
data['additionalProperties'] = False
|
||||
elif isinstance(data, list):
|
||||
for item in data:
|
||||
_walk_json_for_additional_properties(item)
|
||||
|
||||
|
||||
def process_for_additional_properties(
|
||||
guide_json: Union[str, dict[str, Any]]) -> dict[str, Any]:
|
||||
if isinstance(guide_json, str):
|
||||
guide_json_obj = json.loads(guide_json)
|
||||
else:
|
||||
# copy for modifications
|
||||
guide_json_obj = copy.deepcopy(guide_json)
|
||||
_walk_json_for_additional_properties(guide_json_obj)
|
||||
return guide_json_obj
|
||||
|
||||
|
||||
class GuidanceBackend(StructuredOutputBackend):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
@ -41,9 +66,20 @@ class GuidanceBackend(StructuredOutputBackend):
|
||||
tokenizer_group.ping()
|
||||
self.vllm_config = vllm_config
|
||||
self.vocab_size = vllm_config.model_config.get_vocab_size()
|
||||
self.disable_any_whitespace = (
|
||||
"disable-any-whitespace"
|
||||
in vllm_config.decoding_config.guided_decoding_backend)
|
||||
|
||||
self.disable_any_whitespace = False
|
||||
self.no_additional_properties = False
|
||||
backend_options = GuidedDecodingParams(
|
||||
backend=vllm_config.decoding_config.guided_decoding_backend
|
||||
).backend_options()
|
||||
for option in backend_options:
|
||||
if option == "disable-any-whitespace":
|
||||
self.disable_any_whitespace = True
|
||||
elif option == "no-additional-properties":
|
||||
self.no_additional_properties = True
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported option for the guidance backend: {option}")
|
||||
|
||||
tokenizer = tokenizer_group.get_lora_tokenizer(None)
|
||||
self.ll_tokenizer = llguidance_hf.from_tokenizer(
|
||||
@ -52,7 +88,8 @@ class GuidanceBackend(StructuredOutputBackend):
|
||||
def compile_grammar(self, request_type: StructuredOutputOptions,
|
||||
grammar_spec: str) -> StructuredOutputGrammar:
|
||||
self.serialized_grammar = serialize_guidance_grammar(
|
||||
request_type, grammar_spec, self.disable_any_whitespace)
|
||||
request_type, grammar_spec, self.disable_any_whitespace,
|
||||
self.no_additional_properties)
|
||||
|
||||
ll_matcher = llguidance.LLMatcher(
|
||||
self.ll_tokenizer,
|
||||
@ -129,10 +166,15 @@ class GuidanceGrammar(StructuredOutputGrammar):
|
||||
self.ll_matcher.reset()
|
||||
|
||||
|
||||
def serialize_guidance_grammar(request_type: StructuredOutputOptions,
|
||||
grammar_spec: str,
|
||||
disable_any_whitespace: bool = False) -> str:
|
||||
def serialize_guidance_grammar(
|
||||
request_type: StructuredOutputOptions,
|
||||
grammar_spec: Union[str, dict[str, Any]],
|
||||
disable_any_whitespace: bool = False,
|
||||
no_additional_properties: bool = False,
|
||||
) -> str:
|
||||
if request_type == StructuredOutputOptions.JSON:
|
||||
if no_additional_properties:
|
||||
grammar_spec = process_for_additional_properties(grammar_spec)
|
||||
return llguidance.LLMatcher.grammar_from_json_schema(
|
||||
grammar_spec,
|
||||
defaults={
|
||||
|
||||
@ -9,7 +9,7 @@ import torch
|
||||
import vllm.envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
||||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.utils import LazyLoader
|
||||
@ -32,9 +32,6 @@ class XgrammarBackend(StructuredOutputBackend):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
self.vllm_config = vllm_config
|
||||
self.disable_any_whitespace = (
|
||||
"disable-any-whitespace"
|
||||
in vllm_config.decoding_config.guided_decoding_backend)
|
||||
tokenizer_group = init_tokenizer_from_configs(
|
||||
model_config=vllm_config.model_config,
|
||||
scheduler_config=vllm_config.scheduler_config,
|
||||
@ -42,6 +39,17 @@ class XgrammarBackend(StructuredOutputBackend):
|
||||
lora_config=vllm_config.lora_config) # type: ignore[arg-type]
|
||||
tokenizer_group.ping()
|
||||
|
||||
self.disable_any_whitespace = False
|
||||
backend_options = GuidedDecodingParams(
|
||||
backend=vllm_config.decoding_config.guided_decoding_backend
|
||||
).backend_options()
|
||||
for option in backend_options:
|
||||
if option == "disable-any-whitespace":
|
||||
self.disable_any_whitespace = True
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported option for the xgrammar backend: {option}")
|
||||
|
||||
tokenizer = tokenizer_group.get_lora_tokenizer(None)
|
||||
self.vocab_size = vllm_config.model_config.get_vocab_size()
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user