[V1] Support disable_any_whtespace for guidance backend (#15584)

Signed-off-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
Russell Bryant 2025-03-28 11:46:45 -04:00 committed by GitHub
parent 541d1df486
commit 7329ff5468
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 44 additions and 117 deletions

View File

@ -6,7 +6,6 @@ import weakref
import jsonschema
import pytest
from pydantic import BaseModel
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.entrypoints.llm import LLM
@ -15,7 +14,10 @@ from vllm.sampling_params import GuidedDecodingParams, SamplingParams
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
GUIDED_DECODING_BACKENDS = [
"outlines", "lm-format-enforcer", "xgrammar", "guidance"
"outlines",
"lm-format-enforcer",
"xgrammar:disable-any-whitespace",
"guidance:disable-any-whitespace",
]
@ -322,59 +324,9 @@ def test_guided_json_object(llm, guided_decoding_backend: str):
print(generated_text)
assert generated_text is not None
if 'disable-any-whitespace' in guided_decoding_backend:
assert "\n" not in generated_text
# Parse to verify it is valid JSON
parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict)
@pytest.mark.skip_global_cleanup
def test_json_with_any_whitespace_disabled(llm):
class ResponseSchema(BaseModel):
clarifying_question: str
cost_per_serving: str
calories: str
type_dish_ids: str
type_meal_ids: str
product_ids: list[str]
exclude_product_ids: list[str]
allergen_ids: list[str]
total_cooking_time: str
kitchen_ids: str
holiday_ids: str
# Note: Without this setting, the response is sometimes full of `\n`
# for some models. This option prevents that.
guided_decoding_backend = 'xgrammar:disable-any-whitespace'
schema = ResponseSchema.model_json_schema()
guided_params = GuidedDecodingParams(json=schema,
backend=\
guided_decoding_backend)
sampling_params = SamplingParams(max_tokens=2000,
frequency_penalty=0,
presence_penalty=-1.1,
repetition_penalty=1.3,
guided_decoding=guided_params)
prompt = ("<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You"
"are a helpful assistant.<|im_end|>\n<|im_start|>user\nI want a "
"quick launch fast with $10.<|im_end|>\n<|im_start|>assistant\n")
outputs = llm.generate(prompts=prompt,
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
generated_text = output.outputs[0].text
assert generated_text is not None
assert "\n" not in generated_text
# Parse to verify it is valid JSON
parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict)
jsonschema.validate(instance=parsed_json, schema=schema)

View File

@ -15,7 +15,9 @@ from vllm.entrypoints.llm import LLM
from vllm.outputs import RequestOutput
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
GUIDED_DECODING_BACKENDS_V1 = ["xgrammar", "guidance"]
GUIDED_DECODING_BACKENDS_V1 = [
"xgrammar:disable-any-whitespace", "guidance:disable-any-whitespace"
]
MODELS_TO_TEST = [
"Qwen/Qwen2.5-1.5B-Instruct", "mistralai/Ministral-8B-Instruct-2410"
]
@ -55,50 +57,8 @@ def test_guided_json_completion(
generated_text = output.outputs[0].text
assert generated_text is not None
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json, schema=sample_json_schema)
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend",
GUIDED_DECODING_BACKENDS_V1)
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
def test_guided_json_completion_disable_any_whitespace(
monkeypatch: pytest.MonkeyPatch,
sample_json_schema: dict[str, Any],
guided_decoding_backend: str,
model_name: str,
):
if guided_decoding_backend != "xgrammar":
pytest.skip("disable-any-whitespace is only supported for xgrammar.")
guided_decoding_backend = 'xgrammar:disable-any-whitespace'
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=model_name,
max_model_len=1024,
guided_decoding_backend=guided_decoding_backend)
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(json=sample_json_schema))
outputs = llm.generate(prompts=[
f"Give an example JSON for an employee profile "
f"that fits this schema: {sample_json_schema}"
] * 2,
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
assert generated_text is not None
assert "\n" not in generated_text
if 'disable-any-whitespace' in guided_decoding_backend:
assert "\n" not in generated_text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json, schema=sample_json_schema)
@ -142,7 +102,7 @@ def test_guided_json_object(
# Parse to verify it is valid JSON
parsed_json = json.loads(generated_text)
allowed_types: tuple[type, ...] = (dict, )
if guided_decoding_backend == "xgrammar":
if guided_decoding_backend.startswith("xgrammar"):
# TODO - we are currently too permissive with xgrammar and
# allow # any valid json (typically comes back as a list or
# object). We can fix this by specifying a jsonschema of
@ -170,7 +130,7 @@ def test_guided_json_unsupported_schema(
temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(json=unsupported_json_schema))
if guided_decoding_backend == "xgrammar":
if guided_decoding_backend.startswith("xgrammar"):
with pytest.raises(ValueError,
match="The provided JSON schema contains features "
"not supported by xgrammar."):

View File

@ -1561,7 +1561,8 @@ class EngineArgs:
# Xgrammar and Guidance are supported.
SUPPORTED_GUIDED_DECODING = [
"xgrammar", "xgrammar:disable-any-whitespace", "guidance", "auto"
"xgrammar", "xgrammar:disable-any-whitespace", "guidance",
"guidance:disable-any-whitespace", "auto"
]
if self.guided_decoding_backend not in SUPPORTED_GUIDED_DECODING:
_raise_or_fallback(feature_name="--guided-decoding-backend",

View File

@ -18,14 +18,22 @@ def get_local_guidance_guided_decoding_logits_processor(
"""
grm = ""
any_whitespace = 'disable-any-whitespace' not in \
guided_params.backend_options()
if guided_params.json:
grm = llguidance.LLMatcher.grammar_from_json_schema(
guided_params.json,
overrides={"whitespace_pattern": guided_params.whitespace_pattern})
overrides={"whitespace_pattern": guided_params.whitespace_pattern},
defaults={
"whitespace_flexible": any_whitespace,
})
elif guided_params.json_object:
grm = llguidance.LLMatcher.grammar_from_json_schema(
'{"type": "object"}',
overrides={"whitespace_pattern": guided_params.whitespace_pattern})
overrides={"whitespace_pattern": guided_params.whitespace_pattern},
defaults={
"whitespace_flexible": any_whitespace,
})
elif guided_params.regex:
grm = llguidance.grammar_from("regex", guided_params.regex)
elif guided_params.choice:

View File

@ -121,7 +121,8 @@ class Processor:
return
supported_backends = [
"xgrammar", "xgrammar:disable-any-whitespace", "guidance", "auto"
"xgrammar", "xgrammar:disable-any-whitespace", "guidance",
"guidance:disable-any-whitespace", "auto"
]
engine_level_backend = self.decoding_config.guided_decoding_backend
if engine_level_backend not in supported_backends:
@ -140,11 +141,10 @@ class Processor:
raise ValueError("Structured output is not supported on TPU.")
# Request content validation
if engine_level_backend == "xgrammar":
if engine_level_backend.startswith("xgrammar"):
# xgrammar with no fallback
validate_structured_output_request_xgrammar(params)
params.guided_decoding.backend = "xgrammar"
params.guided_decoding.backend = engine_level_backend
elif engine_level_backend == "auto":
# "auto" is an opt-in to opinionated behavior where we try to
# choose a backend based on request contents. This is not the
@ -158,12 +158,13 @@ class Processor:
# are not supported in xgrammar. Fall back to guidance.
params.guided_decoding.backend = "guidance"
if params.guided_decoding.backend == "guidance":
if engine_level_backend.startswith("guidance"):
# TODO ideally we would have the LLTokenizer here as Lark syntax
# allows <|special_token|> and similar, see
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
# Without tokenizer these are disallowed in grammars.
validate_guidance_grammar(params, tokenizer=None)
params.guided_decoding.backend = engine_level_backend
def process_inputs(
self,

View File

@ -41,6 +41,9 @@ class GuidanceBackend(StructuredOutputBackend):
tokenizer_group.ping()
self.vllm_config = vllm_config
self.vocab_size = vllm_config.model_config.get_vocab_size()
self.disable_any_whitespace = (
"disable-any-whitespace"
in vllm_config.decoding_config.guided_decoding_backend)
tokenizer = tokenizer_group.get_lora_tokenizer(None)
self.ll_tokenizer = llguidance_hf.from_tokenizer(tokenizer, None)
@ -48,7 +51,7 @@ class GuidanceBackend(StructuredOutputBackend):
def compile_grammar(self, request_type: StructuredOutputOptions,
grammar_spec: str) -> StructuredOutputGrammar:
self.serialized_grammar = serialize_guidance_grammar(
request_type, grammar_spec)
request_type, grammar_spec, self.disable_any_whitespace)
ll_matcher = llguidance.LLMatcher(
self.ll_tokenizer,
@ -126,17 +129,19 @@ class GuidanceGrammar(StructuredOutputGrammar):
def serialize_guidance_grammar(request_type: StructuredOutputOptions,
grammar_spec: str) -> str:
grammar_spec: str,
disable_any_whitespace: bool = False) -> str:
if request_type == StructuredOutputOptions.JSON:
# TODO: make whitespace_flexible configurable
return llguidance.LLMatcher.grammar_from_json_schema(
grammar_spec, defaults={
"whitespace_flexible": True,
grammar_spec,
defaults={
"whitespace_flexible": not disable_any_whitespace,
})
elif request_type == StructuredOutputOptions.JSON_OBJECT:
return llguidance.LLMatcher.grammar_from_json_schema(
'{"type": "object"}', defaults={
"whitespace_flexible": True,
'{"type": "object"}',
defaults={
"whitespace_flexible": not disable_any_whitespace,
})
else:
if request_type == StructuredOutputOptions.REGEX: