From 1f16b7fe746f840bec2d3df30da5e5e2d31ca2a8 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Thu, 20 Mar 2025 00:33:51 -0400 Subject: [PATCH] [Core][V0] Add guidance backend for structured output (#14589) Signed-off-by: Russell Bryant Co-authored-by: Loc Huynh Co-authored-by: Michal Moskal Co-authored-by: Aaron Pham --- .../benchmark_serving_structured_output.py | 11 +-- requirements/common.txt | 1 + tests/entrypoints/llm/test_guided_generate.py | 4 +- .../model_executor/test_guided_processors.py | 4 +- vllm/config.py | 4 +- .../guided_decoding/__init__.py | 27 ++++-- .../guided_decoding/guidance_decoding.py | 44 ++++++++++ .../guidance_logits_processors.py | 85 +++++++++++++++++++ 8 files changed, 167 insertions(+), 13 deletions(-) create mode 100644 vllm/model_executor/guided_decoding/guidance_decoding.py create mode 100644 vllm/model_executor/guided_decoding/guidance_logits_processors.py diff --git a/benchmarks/benchmark_serving_structured_output.py b/benchmarks/benchmark_serving_structured_output.py index 444bda2ad26ba..c79a93faff197 100644 --- a/benchmarks/benchmark_serving_structured_output.py +++ b/benchmarks/benchmark_serving_structured_output.py @@ -999,11 +999,12 @@ if __name__ == "__main__": type=float, default=1.0, help="Ratio of Structured Outputs requests") - parser.add_argument("--structured-output-backend", - type=str, - choices=["outlines", "lm-format-enforcer", "xgrammar"], - default="xgrammar", - help="Backend to use for structured outputs") + parser.add_argument( + "--structured-output-backend", + type=str, + choices=["outlines", "lm-format-enforcer", "xgrammar", "guidance"], + default="xgrammar", + help="Backend to use for structured outputs") args = parser.parse_args() main(args) diff --git a/requirements/common.txt b/requirements/common.txt index d08ef253828b1..2d52858ad9e1b 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -18,6 +18,7 @@ pillow # Required for image processing prometheus-fastapi-instrumentator >= 7.0.0 tiktoken >= 0.6.0 # Required for DBRX tokenizer lm-format-enforcer >= 0.10.11, < 0.11 +llguidance >= 0.7.2, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64" outlines == 0.1.11 lark == 1.2.2 xgrammar == 0.1.16; platform_machine == "x86_64" or platform_machine == "aarch64" diff --git a/tests/entrypoints/llm/test_guided_generate.py b/tests/entrypoints/llm/test_guided_generate.py index 97ee027bde3bf..5f1a91cb2b19f 100644 --- a/tests/entrypoints/llm/test_guided_generate.py +++ b/tests/entrypoints/llm/test_guided_generate.py @@ -14,7 +14,9 @@ from vllm.outputs import RequestOutput from vllm.sampling_params import GuidedDecodingParams, SamplingParams MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct" -GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"] +GUIDED_DECODING_BACKENDS = [ + "outlines", "lm-format-enforcer", "xgrammar", "guidance" +] @pytest.fixture(scope="module") diff --git a/tests/model_executor/test_guided_processors.py b/tests/model_executor/test_guided_processors.py index 85a53a178ca75..59da575e37b18 100644 --- a/tests/model_executor/test_guided_processors.py +++ b/tests/model_executor/test_guided_processors.py @@ -16,7 +16,9 @@ from vllm.model_executor.guided_decoding.outlines_logits_processors import ( from vllm.sampling_params import GuidedDecodingParams MODEL_NAME = 'HuggingFaceH4/zephyr-7b-beta' -GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"] +GUIDED_DECODING_BACKENDS = [ + "outlines", "lm-format-enforcer", "xgrammar", "guidance" +] GUIDED_DECODING_BACKENDS_WITH_REASONING_SUPPORT = ["outlines", "xgrammar"] REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" diff --git a/vllm/config.py b/vllm/config.py index 2d8f1ba483e12..ffff3b7c8a8ef 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2785,7 +2785,9 @@ class DecodingConfig: return hash_str def __post_init__(self): - valid_guided_backends = ['outlines', 'lm-format-enforcer', 'xgrammar'] + valid_guided_backends = [ + 'outlines', 'lm-format-enforcer', 'xgrammar', 'guidance' + ] backend = GuidedDecodingParams( backend=self.guided_decoding_backend).backend_name diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index c21df044d48f6..0c26a60588c88 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -79,6 +79,12 @@ def maybe_backend_fallback( "xgrammar does not support Lark grammars and the " "grammar failed to convert to GBNF.", "outlines") + elif guided_params.json_object: + # https://github.com/mlc-ai/xgrammar/issues/256 + fallback_or_error(guided_params, + "xgrammar does not support json_object.", + "guidance") + # If the xgrammar module cannot be imported successfully, # we should still allow users to use guided decoding with a fallback. elif not xgr_installed: @@ -88,9 +94,9 @@ def maybe_backend_fallback( if (guided_params.backend_name == "outlines" and guided_params.json_object is not None): - # outlines doesn't support json_object, fallback to xgrammar + # outlines doesn't support json_object, fallback to guidance fallback_or_error(guided_params, - "outlines does not support json_object.", "xgrammar") + "outlines does not support json_object.", "guidance") return guided_params @@ -122,10 +128,15 @@ async def get_guided_decoding_logits_processor( get_local_xgrammar_guided_decoding_logits_processor) return get_local_xgrammar_guided_decoding_logits_processor( guided_params, tokenizer, model_config, reasoner) - + if guided_params.backend_name == 'guidance': + from vllm.model_executor.guided_decoding.guidance_decoding import ( + get_local_guidance_guided_decoding_logits_processor) + return get_local_guidance_guided_decoding_logits_processor( + guided_params, tokenizer) raise ValueError( f"Unknown guided decoding backend '{guided_params.backend}'. " - "Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar'") + "Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar', 'guidance'" + ) def get_local_guided_decoding_logits_processor( @@ -155,7 +166,13 @@ def get_local_guided_decoding_logits_processor( get_local_xgrammar_guided_decoding_logits_processor) return get_local_xgrammar_guided_decoding_logits_processor( guided_params, tokenizer, model_config, reasoner) + if guided_params.backend_name == 'guidance': + from vllm.model_executor.guided_decoding.guidance_decoding import ( + get_local_guidance_guided_decoding_logits_processor) + return get_local_guidance_guided_decoding_logits_processor( + guided_params, tokenizer) raise ValueError( f"Unknown guided decoding backend '{guided_params.backend}'. " - "Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar'") + "Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar', 'guidance'" + ) diff --git a/vllm/model_executor/guided_decoding/guidance_decoding.py b/vllm/model_executor/guided_decoding/guidance_decoding.py new file mode 100644 index 0000000000000..d8675a14030de --- /dev/null +++ b/vllm/model_executor/guided_decoding/guidance_decoding.py @@ -0,0 +1,44 @@ +# SPDX-License-Identifier: Apache-2.0 +from re import escape as regex_escape + +import llguidance +from transformers import PreTrainedTokenizerBase + +from vllm.model_executor.guided_decoding.guidance_logits_processors import ( + GuidanceLogitsProcessor) +from vllm.sampling_params import GuidedDecodingParams + + +def get_local_guidance_guided_decoding_logits_processor( + guided_params: GuidedDecodingParams, + tokenizer: PreTrainedTokenizerBase) -> GuidanceLogitsProcessor: + """ + Given an OpenAI-compatible request, check for guided decoding parameters + and get the necessary logits processor for the given guide. + """ + + grm = "" + if guided_params.json: + grm = llguidance.LLMatcher.grammar_from_json_schema( + guided_params.json, + overrides={"whitespace_pattern": guided_params.whitespace_pattern}) + elif guided_params.json_object: + grm = llguidance.LLMatcher.grammar_from_json_schema( + '{"type": "object"}', + overrides={"whitespace_pattern": guided_params.whitespace_pattern}) + elif guided_params.regex: + grm = llguidance.grammar_from("regex", guided_params.regex) + elif guided_params.choice: + # choice just uses regex + choices = (regex_escape(str(choice)) + for choice in guided_params.choice) + choices_regex = "(" + "|".join(choices) + ")" + grm = llguidance.grammar_from("regex", choices_regex) + elif guided_params.grammar: + # this supports Lark and GBNF + grm = llguidance.grammar_from("grammar", guided_params.grammar) + + if grm: + return GuidanceLogitsProcessor(grm, tokenizer) + + raise ValueError("Unknown guided decoding mode") diff --git a/vllm/model_executor/guided_decoding/guidance_logits_processors.py b/vllm/model_executor/guided_decoding/guidance_logits_processors.py new file mode 100644 index 0000000000000..26fcafe31c765 --- /dev/null +++ b/vllm/model_executor/guided_decoding/guidance_logits_processors.py @@ -0,0 +1,85 @@ +# SPDX-License-Identifier: Apache-2.0 +import os +from typing import Any, List + +import llguidance +import llguidance.hf +import llguidance.torch +import torch +from transformers import PreTrainedTokenizerBase + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class GuidanceLogitsProcessor: + """Base Guidance Logits Processor""" + + cached_tokenizers: dict[str, Any] = {} + + def __init__( + self, + grammar: str, + tokenizer: PreTrainedTokenizerBase, + ) -> None: + """Base Guidance Logits Processor + + Args: + grammar (str) + grammar to guide the generation + tokenizer (PreTrainedTokenizerBase) + model's tokenizer + """ + self.grammar = grammar + self.tokenizer = tokenizer + self.tokenizer_name = tokenizer.name_or_path + self.new_sampling = False + self.initialized = False + + def _initialize(self): + if self.initialized: + return + + ll_tokenizer = self.cached_tokenizers.get(self.tokenizer.name_or_path, + None) + if ll_tokenizer is None: + ll_tokenizer = llguidance.hf.from_tokenizer(self.tokenizer, None) + self.cached_tokenizers[self.tokenizer.name_or_path] = ll_tokenizer + + self.ll_tokenizer = ll_tokenizer + self.ll_matcher = llguidance.LLMatcher( + self.ll_tokenizer, + self.grammar, + log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")), + ) + + # create reusable bitmask + self.bitmask = llguidance.torch.allocate_token_bitmask( + 1, self.ll_tokenizer.vocab_size) + + self.initialized = True + + def __call__( + self, + input_ids: List[int], + scores: torch.Tensor, + ) -> torch.Tensor: + # we initialize the guidance model here + # to avoid pickling ll_tokenizer and ll_interpreter + self._initialize() + + if self.new_sampling and len(input_ids) > 0: + self.ll_matcher.consume_token(input_ids[-1]) + err = self.ll_matcher.get_error() + if err: + logger.warning("Error in LLMatcher: %s", err) + + llguidance.torch.fill_next_token_bitmask(self.ll_matcher, self.bitmask, + 0) + llguidance.torch.apply_token_bitmask_inplace( + scores, self.bitmask.to(scores.device)) + + self.new_sampling = True + + return scores