[Bugfix] Fix llguidance backend, rollback when EOS was encountered (#25905)

Signed-off-by: Rémi Delacourt <remi@mistral.ai>
Signed-off-by: remi <remi@mistral.ai>
Co-authored-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
Rémi Delacourt 2025-11-10 20:53:32 +01:00 committed by GitHub
parent 34553b9d27
commit 6d54336ae5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 126 additions and 2 deletions

View File

@ -0,0 +1,118 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from transformers import AutoTokenizer
from vllm.config import StructuredOutputsConfig, VllmConfig
from vllm.config.model import ModelConfig
from vllm.config.speculative import SpeculativeConfig
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.v1.request import Request
from vllm.v1.structured_output import StructuredOutputManager
from vllm.v1.structured_output.backend_guidance import GuidanceBackend
from vllm.v1.structured_output.backend_types import StructuredOutputOptions
TOKENIZER = "gpt2"
def test_backend_guidance_rollback_terminated():
# Test that the backend guidance successfully rollbacks from a
# terminated state. This can happen with speculative decoding,
# where the draft model proposes EOS and it is verified by the
# guidance backend. In that case we are in a stopped state, but
# it should be reverted in case EOS is not accepted by the target
# model.
vllm_config = VllmConfig(
decoding_config=StructuredOutputsConfig(
backend="guidance",
)
)
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)
backend = GuidanceBackend(
vllm_config,
tokenizer=tokenizer,
vocab_size=50257,
)
grammar = backend.compile_grammar(
StructuredOutputOptions.JSON, '{"type": "object"}'
)
prompt = tokenizer.encode('{"a": "b"}')
assert len(prompt) > 1
dummy_wrong = tokenizer.encode('{"a"}')
for token in prompt:
assert grammar.accept_tokens("", [token])
assert not grammar.is_terminated()
assert grammar.accept_tokens("", [tokenizer.eos_token_id])
assert grammar.is_terminated()
# Giving any other token should also be accepted
assert grammar.accept_tokens("", dummy_wrong)
# Rollback is done from where state was terminated, so from '}' not EOS
grammar.rollback(len(prompt) - 1)
assert not grammar.is_terminated()
assert grammar.validate_tokens([tokenizer.eos_token_id]) == []
assert grammar.validate_tokens(dummy_wrong) != dummy_wrong
assert grammar.accept_tokens("", prompt[1:])
assert not grammar.is_terminated()
assert grammar.accept_tokens("", [tokenizer.eos_token_id])
assert grammar.is_terminated()
# Rollback of <= 0 should not change the terminated state
grammar.rollback(0)
assert grammar.is_terminated()
grammar.rollback(-1)
assert grammar.is_terminated()
def test_grammar_bitmask_with_specdec():
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)
prompt = tokenizer.encode('{"a": "b"}')
vllm_config = VllmConfig(
model_config=ModelConfig(tokenizer=TOKENIZER),
structured_outputs_config=StructuredOutputsConfig(backend="guidance"),
speculative_config=SpeculativeConfig(model="[ngram]", num_speculative_tokens=3),
)
structured_output_manager = StructuredOutputManager(vllm_config)
for i in range(1, 2):
sampling_params = SamplingParams(
structured_outputs=StructuredOutputsParams(
json='{"type": "object"}',
),
)
sampling_params.structured_outputs._backend = "guidance"
my_req_id = f"my_req_id_{i}"
request = Request(
my_req_id,
prompt_token_ids=prompt[:i],
sampling_params=sampling_params,
pooling_params=None,
eos_token_id=tokenizer.eos_token_id,
)
structured_output_manager.grammar_init(request)
def grammar_bitmask(req: Request, tokens: list[int]) -> None:
structured_output_manager.grammar_bitmask(
requests={req.request_id: req},
structured_output_request_ids={req.request_id: 0},
scheduled_spec_decode_tokens={req.request_id: tokens},
)
# At this point, we rolled-back, so should not be terminated
assert not req.structured_output_request.grammar.is_terminated()
# The grammar might not yet be compiled, so we wait for it
while not request.structured_output_request._check_grammar_completion():
continue
assert request.structured_output_request.grammar.accept_tokens(
request.request_id, prompt[:i]
)
grammar_bitmask(request, prompt[i:] + [tokenizer.eos_token_id])
grammar_bitmask(
request, prompt[i:] + [tokenizer.eos_token_id] + prompt
) # EOS not the final token
grammar_bitmask(request, prompt[i:]) # EOS not present
grammar_bitmask(request, prompt[i:] + [tokenizer.eos_token_id])

View File

@ -111,6 +111,7 @@ class GuidanceGrammar(StructuredOutputGrammar):
vocab_size: int
printed_error: bool = False
terminated: bool = False
rollback_lag: int = 0
def check_error(self):
if not self.printed_error:
@ -127,6 +128,8 @@ class GuidanceGrammar(StructuredOutputGrammar):
"""
if self.ll_tokenizer.eos_token in tokens:
if self.ll_matcher.is_stopped() and not self.terminated:
self.rollback_lag = 1
self.terminated = True
if self.ll_matcher.is_stopped():
@ -163,8 +166,11 @@ class GuidanceGrammar(StructuredOutputGrammar):
return tokens[:num_tokens]
def rollback(self, num_tokens: int) -> None:
self.ll_matcher.rollback(num_tokens)
self.check_error()
if num_tokens > 0:
self.ll_matcher.rollback(num_tokens - self.rollback_lag)
self.terminated = False
self.rollback_lag = 0
self.check_error()
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
# this will automatically return [EOS] mask if the matcher is stopped