Add backward compatibility for GuidedDecodingParams (#25422)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-09-23 17:07:30 +01:00 committed by GitHub
parent cc1dc7ed6d
commit 875d6def90
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 59 additions and 1 deletions

View File

@ -5,6 +5,7 @@
from __future__ import annotations from __future__ import annotations
import json import json
from dataclasses import fields
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
@ -21,7 +22,8 @@ from vllm.entrypoints.llm import LLM
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager
from vllm.sampling_params import SamplingParams, StructuredOutputsParams from vllm.sampling_params import (GuidedDecodingParams, SamplingParams,
StructuredOutputsParams)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import TokenizerMode from vllm.config import TokenizerMode
@ -89,6 +91,26 @@ def _load_json(s: str, backend: str) -> str:
return json.loads(s) return json.loads(s)
def test_guided_decoding_deprecated():
with pytest.warns(DeprecationWarning,
match="GuidedDecodingParams is deprecated.*"):
guided_decoding = GuidedDecodingParams(json_object=True)
structured_outputs = StructuredOutputsParams(json_object=True)
assert fields(guided_decoding) == fields(structured_outputs)
with pytest.warns(DeprecationWarning,
match="guided_decoding is deprecated.*"):
sp1 = SamplingParams(guided_decoding=guided_decoding)
with pytest.warns(DeprecationWarning,
match="guided_decoding is deprecated.*"):
sp2 = SamplingParams.from_optional(guided_decoding=guided_decoding)
assert sp1 == sp2
assert sp1.structured_outputs == guided_decoding
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_name, backend, tokenizer_mode, speculative_config", "model_name, backend, tokenizer_mode, speculative_config",

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Sampling parameters for text generation.""" """Sampling parameters for text generation."""
import copy import copy
import warnings
from dataclasses import field from dataclasses import field
from enum import Enum, IntEnum from enum import Enum, IntEnum
from functools import cached_property from functools import cached_property
@ -59,6 +60,19 @@ class StructuredOutputsParams:
f"but multiple are specified: {self.__dict__}") f"but multiple are specified: {self.__dict__}")
@dataclass
class GuidedDecodingParams(StructuredOutputsParams):
def __post_init__(self):
warnings.warn(
"GuidedDecodingParams is deprecated. This will be removed in "
"v0.12.0 or v1.0.0, which ever is soonest. Please use "
"StructuredOutputsParams instead.",
DeprecationWarning,
stacklevel=2)
return super().__post_init__()
class RequestOutputKind(Enum): class RequestOutputKind(Enum):
# Return entire output so far in every RequestOutput # Return entire output so far in every RequestOutput
CUMULATIVE = 0 CUMULATIVE = 0
@ -179,6 +193,8 @@ class SamplingParams(
# Fields used to construct logits processors # Fields used to construct logits processors
structured_outputs: Optional[StructuredOutputsParams] = None structured_outputs: Optional[StructuredOutputsParams] = None
"""Parameters for configuring structured outputs.""" """Parameters for configuring structured outputs."""
guided_decoding: Optional[GuidedDecodingParams] = None
"""Deprecated alias for structured_outputs."""
logit_bias: Optional[dict[int, float]] = None logit_bias: Optional[dict[int, float]] = None
"""If provided, the engine will construct a logits processor that applies """If provided, the engine will construct a logits processor that applies
these logit biases.""" these logit biases."""
@ -227,6 +243,7 @@ class SamplingParams(
ge=-1)]] = None, ge=-1)]] = None,
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE, output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
structured_outputs: Optional[StructuredOutputsParams] = None, structured_outputs: Optional[StructuredOutputsParams] = None,
guided_decoding: Optional[GuidedDecodingParams] = None,
logit_bias: Optional[Union[dict[int, float], dict[str, float]]] = None, logit_bias: Optional[Union[dict[int, float], dict[str, float]]] = None,
allowed_token_ids: Optional[list[int]] = None, allowed_token_ids: Optional[list[int]] = None,
extra_args: Optional[dict[str, Any]] = None, extra_args: Optional[dict[str, Any]] = None,
@ -238,6 +255,15 @@ class SamplingParams(
int(token): min(100.0, max(-100.0, bias)) int(token): min(100.0, max(-100.0, bias))
for token, bias in logit_bias.items() for token, bias in logit_bias.items()
} }
if guided_decoding is not None:
warnings.warn(
"guided_decoding is deprecated. This will be removed in "
"v0.12.0 or v1.0.0, which ever is soonest. Please use "
"structured_outputs instead.",
DeprecationWarning,
stacklevel=2)
structured_outputs = guided_decoding
guided_decoding = None
return SamplingParams( return SamplingParams(
n=1 if n is None else n, n=1 if n is None else n,
@ -334,6 +360,16 @@ class SamplingParams(
# eos_token_id is added to this by the engine # eos_token_id is added to this by the engine
self._all_stop_token_ids.update(self.stop_token_ids) self._all_stop_token_ids.update(self.stop_token_ids)
if self.guided_decoding is not None:
warnings.warn(
"guided_decoding is deprecated. This will be removed in "
"v0.12.0 or v1.0.0, which ever is soonest. Please use "
"structured_outputs instead.",
DeprecationWarning,
stacklevel=2)
self.structured_outputs = self.guided_decoding
self.guided_decoding = None
def _verify_args(self) -> None: def _verify_args(self) -> None:
if not isinstance(self.n, int): if not isinstance(self.n, int):
raise ValueError(f"n must be an int, but is of " raise ValueError(f"n must be an int, but is of "