[Core] choice-based structured output with xgrammar (#12632)

This commit is contained in:
Russell Bryant 2025-02-14 07:36:05 -05:00 committed by GitHub
parent 6224a9f620
commit 7734e9a291
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 35 additions and 7 deletions

View File

@ -20,7 +20,7 @@ tiktoken >= 0.6.0 # Required for DBRX tokenizer
lm-format-enforcer >= 0.10.9, < 0.11
outlines == 0.1.11
lark == 1.2.2
xgrammar >= 0.1.6; platform_machine == "x86_64"
xgrammar >= 0.1.11; platform_machine == "x86_64"
typing_extensions >= 4.10
filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317
partial-json-parser # used for parsing partial JSON outputs

View File

@ -49,11 +49,10 @@ def maybe_backend_fallback(
"Falling back to use outlines instead.")
guided_params.backend = "outlines"
# xgrammar doesn't support regex or choice, fallback to outlines
if guided_params.regex is not None or guided_params.choice is not None:
logger.warning(
"xgrammar only supports json or grammar guided decoding. "
"Falling back to use outlines instead.")
# xgrammar doesn't support regex, fallback to outlines
if guided_params.regex is not None:
logger.warning("xgrammar does not support regex guided decoding. "
"Falling back to use outlines instead.")
guided_params.backend = "outlines"
# xgrammar doesn't support some JSON schema features

View File

@ -5,8 +5,9 @@ from __future__ import annotations
import copy
import json
import re
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, List
import torch
from transformers import PreTrainedTokenizerFast
@ -228,11 +229,39 @@ class GrammarConfig:
max_threads=max_threads,
tokenizer_data=tokenizer_data,
)
elif guided_params.choice:
choice_str = GrammarConfig.choice_as_grammar(guided_params.choice)
try:
xgr.Grammar.from_ebnf(choice_str)
except RuntimeError as err:
raise ValueError(str(err)) from err
return cls(
grammar_str=choice_str,
vocab_size=model_config.hf_text_config.vocab_size,
tokenizer_hash=tokenizer_hash,
max_threads=max_threads,
tokenizer_data=tokenizer_data,
)
else:
raise ValueError(
"Currently only support JSON and EBNF grammar mode for xgrammar"
)
@staticmethod
def escape_ebnf_string(s: str) -> str:
"""Escape special characters in a EBNF string."""
# Escape double quotes and backslashes
return re.sub(r'(["\\])', r'\\\1', s)
@staticmethod
def choice_as_grammar(choice: List[str] | None) -> str:
if choice is None:
raise ValueError("Choice is not set")
escaped_choices = (GrammarConfig.escape_ebnf_string(c) for c in choice)
grammar = ('root ::= ' + ' | '.join(f'"{c}"' for c in escaped_choices))
return grammar
@dataclass
class XGrammarLogitsProcessor: