mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-10 04:27:02 +08:00
[Core] Support Lark grammars for XGrammar (#10870)
Signed-off-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
parent
a1887f2c96
commit
8b59631855
@ -73,14 +73,6 @@ def maybe_backend_fallback(
|
||||
"Falling back to use outlines instead.")
|
||||
guided_params.backend = "outlines"
|
||||
|
||||
# xgrammar only supports EBNF grammars and uses the GBNF format
|
||||
# https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md
|
||||
elif (guided_params.grammar is not None
|
||||
and "::=" not in guided_params.grammar):
|
||||
logger.warning("xgrammar only supports EBNF grammars. "
|
||||
"Falling back to use outlines instead.")
|
||||
guided_params.backend = "outlines"
|
||||
|
||||
# xgrammar doesn't support some JSON schema features
|
||||
elif (guided_params.json is not None
|
||||
and has_xgrammar_unsupported_json_features(guided_params.json)):
|
||||
|
||||
@ -14,6 +14,9 @@ try:
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
from vllm.model_executor.guided_decoding.xgrammar_utils import (
|
||||
convert_lark_to_gbnf, grammar_is_likely_lark)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
@ -152,7 +155,19 @@ class GrammarConfig:
|
||||
tokenizer_hash=tokenizer_hash,
|
||||
max_threads=max_threads)
|
||||
elif guided_params.grammar:
|
||||
return cls(grammar_str=guided_params.grammar,
|
||||
# XGrammar only supports GBNF grammars, so we must convert Lark
|
||||
if grammar_is_likely_lark(guided_params.grammar):
|
||||
try:
|
||||
grammar_str = convert_lark_to_gbnf(guided_params.grammar)
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
"Failed to convert the grammar from Lark to GBNF. "
|
||||
"Please either use GBNF grammar directly or specify"
|
||||
" --guided-decoding-backend=outlines.\n"
|
||||
f"Conversion error: {str(e)}") from e
|
||||
else:
|
||||
grammar_str = guided_params.grammar
|
||||
return cls(grammar_str=grammar_str,
|
||||
vocab_size=model_config.hf_config.vocab_size,
|
||||
encoded_vocab=encoded_vocab,
|
||||
stop_token_ids=stop_token_ids,
|
||||
|
||||
162
vllm/model_executor/guided_decoding/xgrammar_utils.py
Normal file
162
vllm/model_executor/guided_decoding/xgrammar_utils.py
Normal file
@ -0,0 +1,162 @@
|
||||
import re
|
||||
|
||||
|
||||
def grammar_is_likely_lark(grammar_str: str) -> bool:
|
||||
"""
|
||||
Check if grammar appears to use Lark syntax.
|
||||
|
||||
Args:
|
||||
grammar_str: Input grammar string
|
||||
|
||||
Returns:
|
||||
bool: True if grammar appears to be in Lark format, False otherwise
|
||||
|
||||
Examples:
|
||||
>>> grammar_is_likely_lark("rule: 'abc'")
|
||||
True
|
||||
>>> grammar_is_likely_lark("rule ::= 'abc'")
|
||||
False
|
||||
"""
|
||||
if not grammar_str or not isinstance(grammar_str, str):
|
||||
return False
|
||||
|
||||
for line in grammar_str.split('\n'):
|
||||
# Remove both comment styles
|
||||
line = re.sub(r'(#|//).*$', '', line).strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# Look for Lark-style rule definitions
|
||||
if ':' in line and '::=' not in line:
|
||||
return True
|
||||
|
||||
# Look for Lark-specific features
|
||||
if any(pattern in line for pattern in ['?start:', '|', '~']):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def convert_lark_to_gbnf(grammar_str: str) -> str:
|
||||
"""
|
||||
Convert a Lark grammar string to GBNF format.
|
||||
|
||||
GBNF reference:
|
||||
https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md
|
||||
Lark grammar reference:
|
||||
https://lark-parser.readthedocs.io/en/latest/grammar.html
|
||||
|
||||
Args:
|
||||
grammar_str: Input grammar in Lark format
|
||||
|
||||
Returns:
|
||||
str: Converted grammar in GBNF format
|
||||
|
||||
Examples:
|
||||
>>> print(convert_lark_to_gbnf("rule: 'hello'"))
|
||||
root ::= rule
|
||||
rule ::= "hello"
|
||||
"""
|
||||
if not isinstance(grammar_str, str):
|
||||
raise ValueError(f"Grammar must be a string, got {type(grammar_str)}")
|
||||
if not grammar_str.strip():
|
||||
raise ValueError("Grammar string cannot be empty")
|
||||
|
||||
defined_rules = set()
|
||||
referenced_rules = set()
|
||||
output_lines = []
|
||||
|
||||
def clean_line(line: str) -> str:
|
||||
"""Remove comments and whitespace from line."""
|
||||
return re.sub(r'(#|//).*$', '', line).strip()
|
||||
|
||||
def check_quotes(text: str, rule_name: str, line_num: int) -> None:
|
||||
"""Validate quote matching in text."""
|
||||
if text.count("'") % 2 != 0 or text.count('"') % 2 != 0:
|
||||
raise ValueError(
|
||||
f"Mismatched quotes in {rule_name} on line {line_num}")
|
||||
|
||||
def extract_references(text: str) -> set:
|
||||
"""Extract rule references from text."""
|
||||
# Remove quoted strings and special characters
|
||||
text = re.sub(r'"[^"]*"', '', text)
|
||||
text = re.sub(r'[+*?()|\[\]{}]', ' ', text)
|
||||
return set(re.findall(r'\b[a-zA-Z_][a-zA-Z0-9_]*\b', text))
|
||||
|
||||
# First pass: Find root rule and validate rule definitions
|
||||
lines = [clean_line(line) for line in grammar_str.split('\n')]
|
||||
first_rule = None
|
||||
|
||||
for line_num, line in enumerate(lines, 1):
|
||||
if not line or line.startswith('|'):
|
||||
continue
|
||||
|
||||
if ':' in line:
|
||||
try:
|
||||
name = line.split(':', 1)[0].strip().strip('?')
|
||||
defined_rules.add(name)
|
||||
if first_rule is None:
|
||||
first_rule = name
|
||||
if name == 'start':
|
||||
first_rule = 'start'
|
||||
except IndexError as e:
|
||||
raise ValueError(f"Invalid rule format on line {line_num}. "
|
||||
"Expected 'rule_name: definition'") from e
|
||||
|
||||
if not defined_rules:
|
||||
raise ValueError("No valid rules found in grammar")
|
||||
|
||||
# Add root rule
|
||||
output_lines.append(f"root ::= {first_rule}")
|
||||
|
||||
# Second pass: Process rule definitions and alternatives
|
||||
current_rule = None
|
||||
current_definition = []
|
||||
|
||||
for line_num, line in enumerate(lines, 1):
|
||||
if not line:
|
||||
continue
|
||||
|
||||
try:
|
||||
if ':' in line and not line.startswith('|'):
|
||||
# Save previous rule if exists
|
||||
if current_rule:
|
||||
output_lines.append(
|
||||
f"{current_rule} ::= {' | '.join(current_definition)}")
|
||||
|
||||
# Process new rule
|
||||
name, definition = line.split(':', 1)
|
||||
current_rule = name.strip().strip('?')
|
||||
|
||||
check_quotes(definition, f"rule '{current_rule}'", line_num)
|
||||
definition = re.sub(r"'([^']*)'", r'"\1"', definition)
|
||||
referenced_rules.update(extract_references(definition))
|
||||
current_definition = [definition.strip()]
|
||||
|
||||
elif line.startswith('|'):
|
||||
if not current_rule:
|
||||
raise ValueError(f"Alternative '|' on line {line_num} "
|
||||
"without a preceding rule definition")
|
||||
|
||||
alt_def = line[1:].strip()
|
||||
check_quotes(alt_def, f"alternative for rule '{current_rule}'",
|
||||
line_num)
|
||||
alt_def = re.sub(r"'([^']*)'", r'"\1"', alt_def)
|
||||
referenced_rules.update(extract_references(alt_def))
|
||||
current_definition.append(alt_def)
|
||||
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Error on line {line_num}: {str(e)}") from e
|
||||
|
||||
# Add final rule if exists
|
||||
if current_rule:
|
||||
output_lines.append(
|
||||
f"{current_rule} ::= {' | '.join(current_definition)}")
|
||||
|
||||
# Validate all rules are defined
|
||||
undefined_rules = referenced_rules - defined_rules - {'root'}
|
||||
if undefined_rules:
|
||||
raise ValueError("Referenced rules are not defined: "
|
||||
f"{', '.join(sorted(undefined_rules))}")
|
||||
|
||||
return '\n'.join(output_lines)
|
||||
Loading…
x
Reference in New Issue
Block a user