From 4cb6fa0a9c7ceca688f6fa86a29ccd40abd3936d Mon Sep 17 00:00:00 2001 From: Wallas Henrique Date: Wed, 26 Feb 2025 15:52:34 -0300 Subject: [PATCH] [Bugfix] Backend option to disable xgrammar any_whitespace (#12744) Signed-off-by: Wallas Santos Signed-off-by: Joe Runde Co-authored-by: Joe Runde --- tests/entrypoints/llm/test_guided_generate.py | 54 +++++++++++++++++++ vllm/engine/arg_utils.py | 1 + .../guided_decoding/xgrammar_decoding.py | 36 +++++++++++-- 3 files changed, 88 insertions(+), 3 deletions(-) diff --git a/tests/entrypoints/llm/test_guided_generate.py b/tests/entrypoints/llm/test_guided_generate.py index 314dc59328cb..fce581c78288 100644 --- a/tests/entrypoints/llm/test_guided_generate.py +++ b/tests/entrypoints/llm/test_guided_generate.py @@ -6,6 +6,7 @@ import weakref import jsonschema import pytest +from pydantic import BaseModel from vllm.distributed import cleanup_dist_env_and_memory from vllm.entrypoints.llm import LLM @@ -322,3 +323,56 @@ def test_guided_json_object(llm, guided_decoding_backend: str): # Parse to verify it is valid JSON parsed_json = json.loads(generated_text) assert isinstance(parsed_json, dict) + + +@pytest.mark.skip_global_cleanup +def test_json_with_any_whitespace_disabled(llm): + + class ResponseSchema(BaseModel): + clarifying_question: str + cost_per_serving: str + calories: str + type_dish_ids: str + type_meal_ids: str + product_ids: list[str] + exclude_product_ids: list[str] + allergen_ids: list[str] + total_cooking_time: str + kitchen_ids: str + holiday_ids: str + + # Note: Without this setting, the response is sometimes full of `\n` + # for some models. This option prevents that. + guided_decoding_backend = 'xgrammar:disable-any-whitespace' + + schema = ResponseSchema.model_json_schema() + guided_params = GuidedDecodingParams(json=schema, + backend=\ + guided_decoding_backend) + sampling_params = SamplingParams(max_tokens=2000, + frequency_penalty=0, + presence_penalty=-1.1, + repetition_penalty=1.3, + guided_decoding=guided_params) + + prompt = ("<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You" + "are a helpful assistant.<|im_end|>\n<|im_start|>user\nI want a " + "quick launch fast with $10.<|im_end|>\n<|im_start|>assistant\n") + outputs = llm.generate(prompts=prompt, + sampling_params=sampling_params, + use_tqdm=True) + + assert outputs is not None + + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + + generated_text = output.outputs[0].text + assert generated_text is not None + assert "\n" not in generated_text + + # Parse to verify it is valid JSON + parsed_json = json.loads(generated_text) + assert isinstance(parsed_json, dict) + jsonschema.validate(instance=parsed_json, schema=schema) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 663ea1ef8afd..26d4a84b841c 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -385,6 +385,7 @@ class EngineArgs: 'Backend-specific options can be supplied in a comma-separated ' 'list following a colon after the backend name. Valid backends and ' 'all available options are: [xgrammar:no-fallback, ' + 'xgrammar:disable-any-whitespace, ' 'outlines:no-fallback, lm-format-enforcer:no-fallback]') parser.add_argument( '--logits-processor-pattern', diff --git a/vllm/model_executor/guided_decoding/xgrammar_decoding.py b/vllm/model_executor/guided_decoding/xgrammar_decoding.py index e6ba7f5ecc6e..eb9d83acb286 100644 --- a/vllm/model_executor/guided_decoding/xgrammar_decoding.py +++ b/vllm/model_executor/guided_decoding/xgrammar_decoding.py @@ -19,6 +19,7 @@ except ImportError: xgr_installed = False pass +from vllm.logger import init_logger from vllm.model_executor.guided_decoding.utils import (convert_lark_to_gbnf, grammar_is_likely_lark) from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer @@ -29,6 +30,8 @@ if TYPE_CHECKING: from vllm.config import ModelConfig from vllm.sampling_params import GuidedDecodingParams +logger = init_logger(__name__) + # TODO: passing batch size to max threads here def get_local_xgrammar_guided_decoding_logits_processor( @@ -161,6 +164,7 @@ class GrammarConfig: json_str: str | None = None grammar_str: str | None = None json_object: bool | None = None + any_whitespace: bool = True max_threads: int = 8 tokenizer_data: TokenizerData | None = None @@ -180,11 +184,33 @@ class GrammarConfig: else: json_str = guided_params.json + any_whitespace = 'disable-any-whitespace' not in \ + guided_params.backend_options() + + # Check and log if model with xgrammar and whitespace have history + # of runaway generation of whitespaces. + # References: + # https://github.com/vllm-project/vllm/pull/12744 + # https://github.com/mlc-ai/xgrammar/issues/212 + model_with_warn = None + + if 'Mistral' in model_config.model: + model_with_warn = 'Mistral' + elif 'Qwen' in model_config.model: + model_with_warn = 'Qwen' + + if model_with_warn is not None and any_whitespace: + msg = (f"{model_with_warn} " + f"model detected, consider set " + f"`guided_backend=xgrammar:disable-any-whitespace` " + f"to prevent runaway generation of whitespaces.") + logger.info_once(msg) # Validate the schema and raise ValueError here if it is invalid. # This is to avoid exceptions in model execution, which will crash # the engine worker process. try: - xgr.Grammar.from_json_schema(json_str) + xgr.Grammar.from_json_schema(json_str, + any_whitespace=any_whitespace) except RuntimeError as err: raise ValueError(str(err)) from err @@ -192,7 +218,8 @@ class GrammarConfig: vocab_size=model_config.hf_text_config.vocab_size, tokenizer_hash=tokenizer_hash, max_threads=max_threads, - tokenizer_data=tokenizer_data) + tokenizer_data=tokenizer_data, + any_whitespace=any_whitespace) elif guided_params.grammar: # XGrammar only supports GBNF grammars, so we must convert Lark if grammar_is_likely_lark(guided_params.grammar): @@ -290,7 +317,10 @@ class XGrammarLogitsProcessor: if self.ctx is None: compiler = GrammarCompilerCache.get_compiler(self.config) if self.config.json_str is not None: - self.ctx = compiler.compile_json_schema(self.config.json_str) + any_whitespace = self.config.any_whitespace + self.ctx = compiler\ + .compile_json_schema(self.config.json_str, + any_whitespace=any_whitespace) elif self.config.grammar_str is not None: self.ctx = compiler.compile_grammar(self.config.grammar_str) elif self.config.json_object: