From 6d54336ae550528408e0c84cffb7857c426509f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Delacourt?= <54138269+Flechman@users.noreply.github.com> Date: Mon, 10 Nov 2025 20:53:32 +0100 Subject: [PATCH] [Bugfix] Fix llguidance backend, rollback when EOS was encountered (#25905) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Rémi Delacourt Signed-off-by: remi Co-authored-by: Russell Bryant --- .../test_backend_guidance.py | 118 ++++++++++++++++++ vllm/v1/structured_output/backend_guidance.py | 10 +- 2 files changed, 126 insertions(+), 2 deletions(-) create mode 100644 tests/v1/structured_output/test_backend_guidance.py diff --git a/tests/v1/structured_output/test_backend_guidance.py b/tests/v1/structured_output/test_backend_guidance.py new file mode 100644 index 0000000000000..771076186a3b4 --- /dev/null +++ b/tests/v1/structured_output/test_backend_guidance.py @@ -0,0 +1,118 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from transformers import AutoTokenizer + +from vllm.config import StructuredOutputsConfig, VllmConfig +from vllm.config.model import ModelConfig +from vllm.config.speculative import SpeculativeConfig +from vllm.sampling_params import SamplingParams, StructuredOutputsParams +from vllm.v1.request import Request +from vllm.v1.structured_output import StructuredOutputManager +from vllm.v1.structured_output.backend_guidance import GuidanceBackend +from vllm.v1.structured_output.backend_types import StructuredOutputOptions + +TOKENIZER = "gpt2" + + +def test_backend_guidance_rollback_terminated(): + # Test that the backend guidance successfully rollbacks from a + # terminated state. This can happen with speculative decoding, + # where the draft model proposes EOS and it is verified by the + # guidance backend. In that case we are in a stopped state, but + # it should be reverted in case EOS is not accepted by the target + # model. + vllm_config = VllmConfig( + decoding_config=StructuredOutputsConfig( + backend="guidance", + ) + ) + tokenizer = AutoTokenizer.from_pretrained(TOKENIZER) + + backend = GuidanceBackend( + vllm_config, + tokenizer=tokenizer, + vocab_size=50257, + ) + + grammar = backend.compile_grammar( + StructuredOutputOptions.JSON, '{"type": "object"}' + ) + + prompt = tokenizer.encode('{"a": "b"}') + assert len(prompt) > 1 + dummy_wrong = tokenizer.encode('{"a"}') + for token in prompt: + assert grammar.accept_tokens("", [token]) + assert not grammar.is_terminated() + assert grammar.accept_tokens("", [tokenizer.eos_token_id]) + assert grammar.is_terminated() + # Giving any other token should also be accepted + assert grammar.accept_tokens("", dummy_wrong) + # Rollback is done from where state was terminated, so from '}' not EOS + grammar.rollback(len(prompt) - 1) + assert not grammar.is_terminated() + assert grammar.validate_tokens([tokenizer.eos_token_id]) == [] + assert grammar.validate_tokens(dummy_wrong) != dummy_wrong + assert grammar.accept_tokens("", prompt[1:]) + assert not grammar.is_terminated() + assert grammar.accept_tokens("", [tokenizer.eos_token_id]) + assert grammar.is_terminated() + # Rollback of <= 0 should not change the terminated state + grammar.rollback(0) + assert grammar.is_terminated() + grammar.rollback(-1) + assert grammar.is_terminated() + + +def test_grammar_bitmask_with_specdec(): + tokenizer = AutoTokenizer.from_pretrained(TOKENIZER) + prompt = tokenizer.encode('{"a": "b"}') + vllm_config = VllmConfig( + model_config=ModelConfig(tokenizer=TOKENIZER), + structured_outputs_config=StructuredOutputsConfig(backend="guidance"), + speculative_config=SpeculativeConfig(model="[ngram]", num_speculative_tokens=3), + ) + structured_output_manager = StructuredOutputManager(vllm_config) + + for i in range(1, 2): + sampling_params = SamplingParams( + structured_outputs=StructuredOutputsParams( + json='{"type": "object"}', + ), + ) + sampling_params.structured_outputs._backend = "guidance" + + my_req_id = f"my_req_id_{i}" + request = Request( + my_req_id, + prompt_token_ids=prompt[:i], + sampling_params=sampling_params, + pooling_params=None, + eos_token_id=tokenizer.eos_token_id, + ) + + structured_output_manager.grammar_init(request) + + def grammar_bitmask(req: Request, tokens: list[int]) -> None: + structured_output_manager.grammar_bitmask( + requests={req.request_id: req}, + structured_output_request_ids={req.request_id: 0}, + scheduled_spec_decode_tokens={req.request_id: tokens}, + ) + # At this point, we rolled-back, so should not be terminated + assert not req.structured_output_request.grammar.is_terminated() + + # The grammar might not yet be compiled, so we wait for it + while not request.structured_output_request._check_grammar_completion(): + continue + + assert request.structured_output_request.grammar.accept_tokens( + request.request_id, prompt[:i] + ) + + grammar_bitmask(request, prompt[i:] + [tokenizer.eos_token_id]) + grammar_bitmask( + request, prompt[i:] + [tokenizer.eos_token_id] + prompt + ) # EOS not the final token + grammar_bitmask(request, prompt[i:]) # EOS not present + grammar_bitmask(request, prompt[i:] + [tokenizer.eos_token_id]) diff --git a/vllm/v1/structured_output/backend_guidance.py b/vllm/v1/structured_output/backend_guidance.py index 00a625e103bd3..2962a439dcb3e 100644 --- a/vllm/v1/structured_output/backend_guidance.py +++ b/vllm/v1/structured_output/backend_guidance.py @@ -111,6 +111,7 @@ class GuidanceGrammar(StructuredOutputGrammar): vocab_size: int printed_error: bool = False terminated: bool = False + rollback_lag: int = 0 def check_error(self): if not self.printed_error: @@ -127,6 +128,8 @@ class GuidanceGrammar(StructuredOutputGrammar): """ if self.ll_tokenizer.eos_token in tokens: + if self.ll_matcher.is_stopped() and not self.terminated: + self.rollback_lag = 1 self.terminated = True if self.ll_matcher.is_stopped(): @@ -163,8 +166,11 @@ class GuidanceGrammar(StructuredOutputGrammar): return tokens[:num_tokens] def rollback(self, num_tokens: int) -> None: - self.ll_matcher.rollback(num_tokens) - self.check_error() + if num_tokens > 0: + self.ll_matcher.rollback(num_tokens - self.rollback_lag) + self.terminated = False + self.rollback_lag = 0 + self.check_error() def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None: # this will automatically return [EOS] mask if the matcher is stopped