mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-01 13:23:30 +08:00
[Bugfix] Fix CFGGuide and use outlines for grammars that can't convert to GBNF (#11389)
Signed-off-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
parent
e51719ae72
commit
5bfb30a529
@ -174,11 +174,6 @@ def test_guided_choice_completion(sample_guided_choice, llm,
|
|||||||
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
|
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
|
||||||
def test_guided_grammar(sample_sql_statements, llm,
|
def test_guided_grammar(sample_sql_statements, llm,
|
||||||
guided_decoding_backend: str):
|
guided_decoding_backend: str):
|
||||||
if guided_decoding_backend == "outlines":
|
|
||||||
pytest.skip("Outlines backend fails in this test case with:\n"
|
|
||||||
"AttributeError: Error in model execution: 'ParserConf' "
|
|
||||||
"object has no attribute 'deterministic'")
|
|
||||||
|
|
||||||
sampling_params = SamplingParams(temperature=0.8,
|
sampling_params = SamplingParams(temperature=0.8,
|
||||||
top_p=0.95,
|
top_p=0.95,
|
||||||
max_tokens=1000,
|
max_tokens=1000,
|
||||||
|
|||||||
@ -3,6 +3,9 @@ from __future__ import annotations
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.guided_decoding.utils import (
|
||||||
|
convert_lark_to_gbnf, grammar_is_likely_lark,
|
||||||
|
has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features)
|
||||||
from vllm.platforms import CpuArchEnum, current_platform
|
from vllm.platforms import CpuArchEnum, current_platform
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -15,76 +18,6 @@ if TYPE_CHECKING:
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def has_xgrammar_unsupported_json_features(schema: dict) -> bool:
|
|
||||||
"""Check if JSON schema contains features unsupported by xgrammar."""
|
|
||||||
|
|
||||||
def check_object(obj: dict) -> bool:
|
|
||||||
if not isinstance(obj, dict):
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Check for pattern restrictions
|
|
||||||
if "pattern" in obj:
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Check for numeric ranges
|
|
||||||
if obj.get("type") in ("integer", "number") and any(
|
|
||||||
key in obj for key in [
|
|
||||||
"minimum", "maximum", "exclusiveMinimum",
|
|
||||||
"exclusiveMaximum", "multipleOf"
|
|
||||||
]):
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Recursively check all nested objects and arrays
|
|
||||||
for value in obj.values():
|
|
||||||
if isinstance(value, dict):
|
|
||||||
if check_object(value):
|
|
||||||
return True
|
|
||||||
elif isinstance(value, list):
|
|
||||||
for item in value:
|
|
||||||
if isinstance(item, dict) and check_object(item):
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
return check_object(schema)
|
|
||||||
|
|
||||||
|
|
||||||
def has_lmf_unsupported_json_features(schema: dict) -> bool:
|
|
||||||
"""
|
|
||||||
Check if JSON schema contains features unsupported
|
|
||||||
by lm_format_enforcer.
|
|
||||||
|
|
||||||
Known issues:
|
|
||||||
- Regex patterns:
|
|
||||||
"grade": {
|
|
||||||
"type": "string",
|
|
||||||
"pattern": "^[A-D]$" # Regex pattern
|
|
||||||
},
|
|
||||||
"""
|
|
||||||
|
|
||||||
def check_object(obj: dict) -> bool:
|
|
||||||
if not isinstance(obj, dict):
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Check for pattern restrictions
|
|
||||||
if "pattern" in obj:
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Recursively check all nested objects and arrays
|
|
||||||
for value in obj.values():
|
|
||||||
if isinstance(value, dict):
|
|
||||||
if check_object(value):
|
|
||||||
return True
|
|
||||||
elif isinstance(value, list):
|
|
||||||
for item in value:
|
|
||||||
if isinstance(item, dict) and check_object(item):
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
return check_object(schema)
|
|
||||||
|
|
||||||
|
|
||||||
def maybe_backend_fallback(
|
def maybe_backend_fallback(
|
||||||
guided_params: GuidedDecodingParams) -> GuidedDecodingParams:
|
guided_params: GuidedDecodingParams) -> GuidedDecodingParams:
|
||||||
# lm-format-enforce doesn't support grammar, fallback to xgrammar
|
# lm-format-enforce doesn't support grammar, fallback to xgrammar
|
||||||
@ -127,6 +60,20 @@ def maybe_backend_fallback(
|
|||||||
"Falling back to use outlines instead.")
|
"Falling back to use outlines instead.")
|
||||||
guided_params.backend = "outlines"
|
guided_params.backend = "outlines"
|
||||||
|
|
||||||
|
# xgrammar only supports GBNF grammars, so we must convert Lark.
|
||||||
|
# We must check if the grammar is likely Lark and if that
|
||||||
|
# grammar is convertible to GBNF
|
||||||
|
elif (guided_params.grammar is not None
|
||||||
|
and grammar_is_likely_lark(guided_params.grammar)):
|
||||||
|
try:
|
||||||
|
convert_lark_to_gbnf(guided_params.grammar)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"xgrammar does not support Lark grammars and the "
|
||||||
|
"grammar failed to convert to GBNF. "
|
||||||
|
"Falling back to use outlines instead.")
|
||||||
|
guided_params.backend = "outlines"
|
||||||
|
|
||||||
if (guided_params.backend == "outlines"
|
if (guided_params.backend == "outlines"
|
||||||
and guided_params.json_object is not None):
|
and guided_params.json_object is not None):
|
||||||
# outlines doesn't support json_object, fallback to xgrammar
|
# outlines doesn't support json_object, fallback to xgrammar
|
||||||
|
|||||||
@ -21,10 +21,11 @@ from typing import Callable, DefaultDict, Dict, List, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from lark import Lark
|
|
||||||
from outlines import grammars
|
from outlines import grammars
|
||||||
from outlines.caching import cache
|
from outlines.caching import cache
|
||||||
from outlines.fsm.guide import CFGGuide, Generate, Guide, RegexGuide, Write
|
from outlines.fsm.guide import (CFGGuide, CFGState, Generate, Guide,
|
||||||
|
RegexGuide, Write)
|
||||||
|
from outlines.fsm.parsing import PartialLark
|
||||||
from outlines_core.fsm.json_schema import build_regex_from_schema
|
from outlines_core.fsm.json_schema import build_regex_from_schema
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
@ -34,7 +35,9 @@ class BaseLogitsProcessor:
|
|||||||
|
|
||||||
def __init__(self, guide: Guide):
|
def __init__(self, guide: Guide):
|
||||||
self._guide: Guide = guide
|
self._guide: Guide = guide
|
||||||
self._fsm_state: DefaultDict[int, int] = defaultdict(int)
|
# CFGState is used for the FSM state for CFGGuide
|
||||||
|
self._fsm_state: DefaultDict[int, Union[int,
|
||||||
|
CFGState]] = defaultdict(int)
|
||||||
|
|
||||||
def __call__(self, input_ids: List[int],
|
def __call__(self, input_ids: List[int],
|
||||||
scores: torch.Tensor) -> torch.Tensor:
|
scores: torch.Tensor) -> torch.Tensor:
|
||||||
@ -54,15 +57,13 @@ class BaseLogitsProcessor:
|
|||||||
# On the first time this is called, we simply re-create
|
# On the first time this is called, we simply re-create
|
||||||
# the Lark object.
|
# the Lark object.
|
||||||
if isinstance(self._guide, CFGGuide):
|
if isinstance(self._guide, CFGGuide):
|
||||||
self._guide.parser = Lark(
|
self._guide.parser = PartialLark(
|
||||||
self._guide.cfg_string,
|
self._guide.cfg_string,
|
||||||
parser="lalr",
|
parser="lalr",
|
||||||
lexer="contextual",
|
|
||||||
propagate_positions=False,
|
|
||||||
maybe_placeholders=False,
|
|
||||||
regex=True,
|
|
||||||
import_paths=[grammars.GRAMMAR_PATH],
|
import_paths=[grammars.GRAMMAR_PATH],
|
||||||
)
|
)
|
||||||
|
self._fsm_state[seq_id] = CFGState(
|
||||||
|
parser_state=self._guide.parser.parse(""), prev_token=None)
|
||||||
|
|
||||||
instruction = self._guide.get_next_instruction(
|
instruction = self._guide.get_next_instruction(
|
||||||
state=self._fsm_state[seq_id])
|
state=self._fsm_state[seq_id])
|
||||||
@ -200,7 +201,8 @@ def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase):
|
|||||||
string = tokenizer.convert_tokens_to_string([token])
|
string = tokenizer.convert_tokens_to_string([token])
|
||||||
|
|
||||||
# A hack to handle missing spaces to HF's Llama tokenizers
|
# A hack to handle missing spaces to HF's Llama tokenizers
|
||||||
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
|
if (type(token) is str and token.startswith(SPIECE_UNDERLINE)
|
||||||
|
or token == "<0x20>"):
|
||||||
return " " + string
|
return " " + string
|
||||||
|
|
||||||
return string
|
return string
|
||||||
@ -211,6 +213,9 @@ def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase):
|
|||||||
"""Sync vLLM's decoder with the outlines by returning list."""
|
"""Sync vLLM's decoder with the outlines by returning list."""
|
||||||
|
|
||||||
def new_decoder(inp_tokens: List[int]) -> List[str]:
|
def new_decoder(inp_tokens: List[int]) -> List[str]:
|
||||||
|
if (isinstance(inp_tokens, list) and len(inp_tokens) == 1
|
||||||
|
and isinstance(inp_tokens[0], list)):
|
||||||
|
inp_tokens = inp_tokens[0]
|
||||||
return [decoder(inp_tokens)]
|
return [decoder(inp_tokens)]
|
||||||
|
|
||||||
return new_decoder
|
return new_decoder
|
||||||
|
|||||||
@ -1,6 +1,76 @@
|
|||||||
import re
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
def has_xgrammar_unsupported_json_features(schema: dict) -> bool:
|
||||||
|
"""Check if JSON schema contains features unsupported by xgrammar."""
|
||||||
|
|
||||||
|
def check_object(obj: dict) -> bool:
|
||||||
|
if not isinstance(obj, dict):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check for pattern restrictions
|
||||||
|
if "pattern" in obj:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check for numeric ranges
|
||||||
|
if obj.get("type") in ("integer", "number") and any(
|
||||||
|
key in obj for key in [
|
||||||
|
"minimum", "maximum", "exclusiveMinimum",
|
||||||
|
"exclusiveMaximum", "multipleOf"
|
||||||
|
]):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Recursively check all nested objects and arrays
|
||||||
|
for value in obj.values():
|
||||||
|
if isinstance(value, dict):
|
||||||
|
if check_object(value):
|
||||||
|
return True
|
||||||
|
elif isinstance(value, list):
|
||||||
|
for item in value:
|
||||||
|
if isinstance(item, dict) and check_object(item):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
return check_object(schema)
|
||||||
|
|
||||||
|
|
||||||
|
def has_lmf_unsupported_json_features(schema: dict) -> bool:
|
||||||
|
"""
|
||||||
|
Check if JSON schema contains features unsupported
|
||||||
|
by lm_format_enforcer.
|
||||||
|
|
||||||
|
Known issues:
|
||||||
|
- Regex patterns:
|
||||||
|
"grade": {
|
||||||
|
"type": "string",
|
||||||
|
"pattern": "^[A-D]$" # Regex pattern
|
||||||
|
},
|
||||||
|
"""
|
||||||
|
|
||||||
|
def check_object(obj: dict) -> bool:
|
||||||
|
if not isinstance(obj, dict):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check for pattern restrictions
|
||||||
|
if "pattern" in obj:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Recursively check all nested objects and arrays
|
||||||
|
for value in obj.values():
|
||||||
|
if isinstance(value, dict):
|
||||||
|
if check_object(value):
|
||||||
|
return True
|
||||||
|
elif isinstance(value, list):
|
||||||
|
for item in value:
|
||||||
|
if isinstance(item, dict) and check_object(item):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
return check_object(schema)
|
||||||
|
|
||||||
|
|
||||||
def grammar_is_likely_lark(grammar_str: str) -> bool:
|
def grammar_is_likely_lark(grammar_str: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if grammar appears to use Lark syntax.
|
Check if grammar appears to use Lark syntax.
|
||||||
@ -14,8 +14,8 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
from vllm.model_executor.guided_decoding.xgrammar_utils import (
|
from vllm.model_executor.guided_decoding.utils import (convert_lark_to_gbnf,
|
||||||
convert_lark_to_gbnf, grammar_is_likely_lark)
|
grammar_is_likely_lark)
|
||||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user