[Bugfix] Fix the issue where llm.generate cannot be called repeatedly after setting GuidedDecodingParams (#16767)

Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
Signed-off-by: Russell Bryant <rbryant@redhat.com>
Co-authored-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
Chauncey 2025-04-22 14:02:20 +08:00 committed by GitHub
parent a114bf20a3
commit acba33a0f1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 32 additions and 4 deletions

View File

@ -386,13 +386,21 @@ def test_structured_output_auto_mode(
max_tokens=1000,
guided_decoding=GuidedDecodingParams(json=unsupported_json_schema))
prompts = ("Give an example JSON object for a grade "
"that fits this schema: "
f"{unsupported_json_schema}")
# This would fail with the default of "xgrammar", but in "auto"
# we will handle fallback automatically.
outputs = llm.generate(prompts=("Give an example JSON object for a grade "
"that fits this schema: "
f"{unsupported_json_schema}"),
outputs = llm.generate(prompts=prompts,
sampling_params=sampling_params,
use_tqdm=True)
# Make sure `auto` backend handling doesn't mess up sampling_params
# and that we can reuse it without error.
outputs.extend(
llm.generate(prompts=prompts,
sampling_params=sampling_params,
use_tqdm=True))
assert outputs is not None
for output in outputs:
assert output is not None

View File

@ -79,6 +79,17 @@ class GuidedDecodingParams:
return []
return self.backend.split(":")[1].split(",")
def add_option(self, opt_name: str) -> None:
"""Adds an option to the backend options."""
if not self.backend:
self.backend = f":{opt_name}"
elif ":" not in self.backend:
self.backend += f":{opt_name}"
else:
options = set(self.backend_options())
options.add(opt_name)
self.backend = f"{self.backend_name}:{','.join(sorted(options))}"
def no_fallback(self) -> bool:
"""Returns True if the "no-fallback" option is supplied for the guided
decoding backend"""

View File

@ -155,7 +155,14 @@ class Processor:
raise ValueError(f"Only {supported_backends} structured output is "
"supported in V1.")
if params.guided_decoding.backend:
if params.guided_decoding.backend != engine_level_backend:
# Request-level backend selection is not supported in V1.
# The values may differ if `params` is reused and was set
# to a specific backend based on `auto` behavior in a previous
# request. We remember that it was set as a result of `auto`
# using the `_auto` option set on the backend in the params.
if (params.guided_decoding.backend != engine_level_backend
and not (engine_level_backend == "auto" and "_auto"
in params.guided_decoding.backend_options())):
raise ValueError(
"Request-level structured output backend selection is no "
"longer supported. The request specified "
@ -190,6 +197,8 @@ class Processor:
# The request includes some jsonschema feature(s) that
# are not supported in xgrammar. Fall back to guidance.
params.guided_decoding.backend = "guidance"
# Remember that this backend was set automatically
params.guided_decoding.add_option("_auto")
def process_inputs(
self,