[Bugfix] Backend option to disable xgrammar any_whitespace (#12744)

Signed-off-by: Wallas Santos <wallashss@ibm.com>
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
Co-authored-by: Joe Runde <Joseph.Runde@ibm.com>
This commit is contained in:
Wallas Henrique 2025-02-26 15:52:34 -03:00 committed by GitHub
parent d08b285adf
commit 4cb6fa0a9c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 88 additions and 3 deletions

View File

@ -6,6 +6,7 @@ import weakref
import jsonschema
import pytest
from pydantic import BaseModel
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.entrypoints.llm import LLM
@ -322,3 +323,56 @@ def test_guided_json_object(llm, guided_decoding_backend: str):
# Parse to verify it is valid JSON
parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict)
@pytest.mark.skip_global_cleanup
def test_json_with_any_whitespace_disabled(llm):
class ResponseSchema(BaseModel):
clarifying_question: str
cost_per_serving: str
calories: str
type_dish_ids: str
type_meal_ids: str
product_ids: list[str]
exclude_product_ids: list[str]
allergen_ids: list[str]
total_cooking_time: str
kitchen_ids: str
holiday_ids: str
# Note: Without this setting, the response is sometimes full of `\n`
# for some models. This option prevents that.
guided_decoding_backend = 'xgrammar:disable-any-whitespace'
schema = ResponseSchema.model_json_schema()
guided_params = GuidedDecodingParams(json=schema,
backend=\
guided_decoding_backend)
sampling_params = SamplingParams(max_tokens=2000,
frequency_penalty=0,
presence_penalty=-1.1,
repetition_penalty=1.3,
guided_decoding=guided_params)
prompt = ("<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You"
"are a helpful assistant.<|im_end|>\n<|im_start|>user\nI want a "
"quick launch fast with $10.<|im_end|>\n<|im_start|>assistant\n")
outputs = llm.generate(prompts=prompt,
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
generated_text = output.outputs[0].text
assert generated_text is not None
assert "\n" not in generated_text
# Parse to verify it is valid JSON
parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict)
jsonschema.validate(instance=parsed_json, schema=schema)

View File

@ -385,6 +385,7 @@ class EngineArgs:
'Backend-specific options can be supplied in a comma-separated '
'list following a colon after the backend name. Valid backends and '
'all available options are: [xgrammar:no-fallback, '
'xgrammar:disable-any-whitespace, '
'outlines:no-fallback, lm-format-enforcer:no-fallback]')
parser.add_argument(
'--logits-processor-pattern',

View File

@ -19,6 +19,7 @@ except ImportError:
xgr_installed = False
pass
from vllm.logger import init_logger
from vllm.model_executor.guided_decoding.utils import (convert_lark_to_gbnf,
grammar_is_likely_lark)
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
@ -29,6 +30,8 @@ if TYPE_CHECKING:
from vllm.config import ModelConfig
from vllm.sampling_params import GuidedDecodingParams
logger = init_logger(__name__)
# TODO: passing batch size to max threads here
def get_local_xgrammar_guided_decoding_logits_processor(
@ -161,6 +164,7 @@ class GrammarConfig:
json_str: str | None = None
grammar_str: str | None = None
json_object: bool | None = None
any_whitespace: bool = True
max_threads: int = 8
tokenizer_data: TokenizerData | None = None
@ -180,11 +184,33 @@ class GrammarConfig:
else:
json_str = guided_params.json
any_whitespace = 'disable-any-whitespace' not in \
guided_params.backend_options()
# Check and log if model with xgrammar and whitespace have history
# of runaway generation of whitespaces.
# References:
# https://github.com/vllm-project/vllm/pull/12744
# https://github.com/mlc-ai/xgrammar/issues/212
model_with_warn = None
if 'Mistral' in model_config.model:
model_with_warn = 'Mistral'
elif 'Qwen' in model_config.model:
model_with_warn = 'Qwen'
if model_with_warn is not None and any_whitespace:
msg = (f"{model_with_warn} "
f"model detected, consider set "
f"`guided_backend=xgrammar:disable-any-whitespace` "
f"to prevent runaway generation of whitespaces.")
logger.info_once(msg)
# Validate the schema and raise ValueError here if it is invalid.
# This is to avoid exceptions in model execution, which will crash
# the engine worker process.
try:
xgr.Grammar.from_json_schema(json_str)
xgr.Grammar.from_json_schema(json_str,
any_whitespace=any_whitespace)
except RuntimeError as err:
raise ValueError(str(err)) from err
@ -192,7 +218,8 @@ class GrammarConfig:
vocab_size=model_config.hf_text_config.vocab_size,
tokenizer_hash=tokenizer_hash,
max_threads=max_threads,
tokenizer_data=tokenizer_data)
tokenizer_data=tokenizer_data,
any_whitespace=any_whitespace)
elif guided_params.grammar:
# XGrammar only supports GBNF grammars, so we must convert Lark
if grammar_is_likely_lark(guided_params.grammar):
@ -290,7 +317,10 @@ class XGrammarLogitsProcessor:
if self.ctx is None:
compiler = GrammarCompilerCache.get_compiler(self.config)
if self.config.json_str is not None:
self.ctx = compiler.compile_json_schema(self.config.json_str)
any_whitespace = self.config.any_whitespace
self.ctx = compiler\
.compile_json_schema(self.config.json_str,
any_whitespace=any_whitespace)
elif self.config.grammar_str is not None:
self.ctx = compiler.compile_grammar(self.config.grammar_str)
elif self.config.json_object: