Scheduled removal of guided_* config fields (#29326)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-11-25 05:24:05 +00:00 committed by GitHub
parent 2d9ee28cab
commit 316c8492bf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 43 additions and 262 deletions

View File

@ -7,7 +7,7 @@ This document shows you some examples of the different options that are
available to generate structured outputs.
!!! warning
If you are still using the following deprecated API fields, please update your code to use `structured_outputs` as demonstrated in the rest of this document:
If you are still using the following deprecated API fields which were removed in v0.12.0, please update your code to use `structured_outputs` as demonstrated in the rest of this document:
- `guided_json` -> `{"structured_outputs": {"json": ...}}` or `StructuredOutputsParams(json=...)`
- `guided_regex` -> `{"structured_outputs": {"regex": ...}}` or `StructuredOutputsParams(regex=...)`

View File

@ -3,7 +3,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from dataclasses import fields
from enum import Enum
from typing import TYPE_CHECKING, Any
@ -21,7 +20,6 @@ from vllm.outputs import RequestOutput
from vllm.platforms import current_platform
from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager
from vllm.sampling_params import (
GuidedDecodingParams,
SamplingParams,
StructuredOutputsParams,
)
@ -108,23 +106,6 @@ class CarDescription(BaseModel):
car_type: CarType
def test_guided_decoding_deprecated():
with pytest.warns(DeprecationWarning, match="GuidedDecodingParams is deprecated.*"):
guided_decoding = GuidedDecodingParams(json_object=True)
structured_outputs = StructuredOutputsParams(json_object=True)
assert fields(guided_decoding) == fields(structured_outputs)
with pytest.warns(DeprecationWarning, match="guided_decoding is deprecated.*"):
sp1 = SamplingParams(guided_decoding=guided_decoding)
with pytest.warns(DeprecationWarning, match="guided_decoding is deprecated.*"):
sp2 = SamplingParams.from_optional(guided_decoding=guided_decoding)
assert sp1 == sp2
assert sp1.structured_outputs == guided_decoding
@pytest.mark.parametrize(
"model_name, backend, tokenizer_mode, speculative_config",
PARAMS_MODELS_BACKENDS_TOKENIZER_MODE,
@ -899,13 +880,11 @@ def test_structured_output_batched_with_non_structured_outputs_requests(
output_json = json.loads(generated_text)
@pytest.mark.parametrize("guided_decoding_backend", ["xgrammar"])
def test_structured_output_with_structural_tag(
guided_decoding_backend: str,
):
@pytest.mark.parametrize("backend", ["xgrammar"])
def test_structured_output_with_structural_tag(backend: str):
llm = LLM(
model="Qwen/Qwen2.5-1.5B-Instruct",
guided_decoding_backend=guided_decoding_backend,
structured_outputs_config=StructuredOutputsConfig(backend=backend),
)
structural_tag_config = {
@ -923,7 +902,7 @@ def test_structured_output_with_structural_tag(
sampling_params = SamplingParams(
temperature=0.0,
max_tokens=500,
guided_decoding=StructuredOutputsParams(
structured_outputs=StructuredOutputsParams(
structural_tag=json.dumps(structural_tag_config)
),
)

View File

@ -502,11 +502,6 @@ class EngineArgs:
)
reasoning_parser: str = StructuredOutputsConfig.reasoning_parser
reasoning_parser_plugin: str | None = None
# Deprecated guided decoding fields
guided_decoding_backend: str | None = None
guided_decoding_disable_fallback: bool | None = None
guided_decoding_disable_any_whitespace: bool | None = None
guided_decoding_disable_additional_properties: bool | None = None
logits_processor_pattern: str | None = ModelConfig.logits_processor_pattern
@ -725,19 +720,6 @@ class EngineArgs:
"--reasoning-parser-plugin",
**structured_outputs_kwargs["reasoning_parser_plugin"],
)
# Deprecated guided decoding arguments
for arg, type in [
("--guided-decoding-backend", str),
("--guided-decoding-disable-fallback", bool),
("--guided-decoding-disable-any-whitespace", bool),
("--guided-decoding-disable-additional-properties", bool),
]:
structured_outputs_group.add_argument(
arg,
type=type,
help=(f"[DEPRECATED] {arg} will be removed in v0.12.0."),
deprecated=True,
)
# Parallel arguments
parallel_kwargs = get_kwargs(ParallelConfig)
@ -1712,21 +1694,6 @@ class EngineArgs:
self.reasoning_parser_plugin
)
# Forward the deprecated CLI args to the StructuredOutputsConfig
so_config = self.structured_outputs_config
if self.guided_decoding_backend is not None:
so_config.guided_decoding_backend = self.guided_decoding_backend
if self.guided_decoding_disable_fallback is not None:
so_config.disable_fallback = self.guided_decoding_disable_fallback
if self.guided_decoding_disable_any_whitespace is not None:
so_config.disable_any_whitespace = (
self.guided_decoding_disable_any_whitespace
)
if self.guided_decoding_disable_additional_properties is not None:
so_config.disable_additional_properties = (
self.guided_decoding_disable_additional_properties
)
observability_config = ObservabilityConfig(
show_hidden_metrics_for_version=self.show_hidden_metrics_for_version,
otlp_traces_endpoint=self.otlp_traces_endpoint,

View File

@ -652,62 +652,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
default=None,
description="Additional kwargs for structured outputs",
)
guided_json: str | dict | BaseModel | None = Field(
default=None,
description=(
"`guided_json` is deprecated. "
"This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
"Please pass `json` to `structured_outputs` instead."
),
)
guided_regex: str | None = Field(
default=None,
description=(
"`guided_regex` is deprecated. "
"This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
"Please pass `regex` to `structured_outputs` instead."
),
)
guided_choice: list[str] | None = Field(
default=None,
description=(
"`guided_choice` is deprecated. "
"This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
"Please pass `choice` to `structured_outputs` instead."
),
)
guided_grammar: str | None = Field(
default=None,
description=(
"`guided_grammar` is deprecated. "
"This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
"Please pass `grammar` to `structured_outputs` instead."
),
)
structural_tag: str | None = Field(
default=None,
description=(
"`structural_tag` is deprecated. "
"This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
"Please pass `structural_tag` to `structured_outputs` instead."
),
)
guided_decoding_backend: str | None = Field(
default=None,
description=(
"`guided_decoding_backend` is deprecated. "
"This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
"Please remove it from your request."
),
)
guided_whitespace_pattern: str | None = Field(
default=None,
description=(
"`guided_whitespace_pattern` is deprecated. "
"This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
"Please pass `whitespace_pattern` to `structured_outputs` instead."
),
)
priority: int = Field(
default=0,
description=(
@ -841,20 +785,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
if prompt_logprobs is None and self.echo:
prompt_logprobs = self.top_logprobs
# Forward deprecated guided_* parameters to structured_outputs
if self.structured_outputs is None:
kwargs = dict[str, Any](
json=self.guided_json,
regex=self.guided_regex,
choice=self.guided_choice,
grammar=self.guided_grammar,
whitespace_pattern=self.guided_whitespace_pattern,
structural_tag=self.structural_tag,
)
kwargs = {k: v for k, v in kwargs.items() if v is not None}
if len(kwargs) > 0:
self.structured_outputs = StructuredOutputsParams(**kwargs)
response_format = self.response_format
if response_format is not None:
# If structured outputs wasn't already enabled,
@ -863,24 +793,23 @@ class ChatCompletionRequest(OpenAIBaseModel):
self.structured_outputs = StructuredOutputsParams()
# Set structured output params for response format
if response_format is not None:
if response_format.type == "json_object":
self.structured_outputs.json_object = True
elif response_format.type == "json_schema":
json_schema = response_format.json_schema
assert json_schema is not None
self.structured_outputs.json = json_schema.json_schema
elif response_format.type == "structural_tag":
structural_tag = response_format
assert structural_tag is not None and isinstance(
structural_tag,
(
LegacyStructuralTagResponseFormat,
StructuralTagResponseFormat,
),
)
s_tag_obj = structural_tag.model_dump(by_alias=True)
self.structured_outputs.structural_tag = json.dumps(s_tag_obj)
if response_format.type == "json_object":
self.structured_outputs.json_object = True
elif response_format.type == "json_schema":
json_schema = response_format.json_schema
assert json_schema is not None
self.structured_outputs.json = json_schema.json_schema
elif response_format.type == "structural_tag":
structural_tag = response_format
assert structural_tag is not None and isinstance(
structural_tag,
(
LegacyStructuralTagResponseFormat,
StructuralTagResponseFormat,
),
)
s_tag_obj = structural_tag.model_dump(by_alias=True)
self.structured_outputs.structural_tag = json.dumps(s_tag_obj)
extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
if self.kv_transfer_params:
@ -1140,58 +1069,6 @@ class CompletionRequest(OpenAIBaseModel):
default=None,
description="Additional kwargs for structured outputs",
)
guided_json: str | dict | BaseModel | None = Field(
default=None,
description=(
"`guided_json` is deprecated. "
"This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
"Please pass `json` to `structured_outputs` instead."
),
)
guided_regex: str | None = Field(
default=None,
description=(
"`guided_regex` is deprecated. "
"This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
"Please pass `regex` to `structured_outputs` instead."
),
)
guided_choice: list[str] | None = Field(
default=None,
description=(
"`guided_choice` is deprecated. "
"This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
"Please pass `choice` to `structured_outputs` instead."
),
)
guided_grammar: str | None = Field(
default=None,
description=(
"`guided_grammar` is deprecated. "
"This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
"Please pass `grammar` to `structured_outputs` instead."
),
)
structural_tag: str | None = Field(
default=None,
description=("If specified, the output will follow the structural tag schema."),
)
guided_decoding_backend: str | None = Field(
default=None,
description=(
"`guided_decoding_backend` is deprecated. "
"This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
"Please remove it from your request."
),
)
guided_whitespace_pattern: str | None = Field(
default=None,
description=(
"`guided_whitespace_pattern` is deprecated. "
"This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
"Please pass `whitespace_pattern` to `structured_outputs` instead."
),
)
priority: int = Field(
default=0,
description=(
@ -1336,35 +1213,31 @@ class CompletionRequest(OpenAIBaseModel):
echo_without_generation = self.echo and self.max_tokens == 0
guided_json_object = None
if self.response_format is not None:
if self.response_format.type == "json_object":
guided_json_object = True
elif self.response_format.type == "json_schema":
json_schema = self.response_format.json_schema
response_format = self.response_format
if response_format is not None:
# If structured outputs wasn't already enabled,
# we must enable it for these features to work
if self.structured_outputs is None:
self.structured_outputs = StructuredOutputsParams()
# Set structured output params for response format
if response_format.type == "json_object":
self.structured_outputs.json_object = True
elif response_format.type == "json_schema":
json_schema = response_format.json_schema
assert json_schema is not None
self.guided_json = json_schema.json_schema
elif self.response_format.type == "structural_tag":
structural_tag = self.response_format
self.structured_outputs.json = json_schema.json_schema
elif response_format.type == "structural_tag":
structural_tag = response_format
assert structural_tag is not None and isinstance(
structural_tag, StructuralTagResponseFormat
structural_tag,
(
LegacyStructuralTagResponseFormat,
StructuralTagResponseFormat,
),
)
s_tag_obj = structural_tag.model_dump(by_alias=True)
self.structural_tag = json.dumps(s_tag_obj)
# Forward deprecated guided_* parameters to structured_outputs
if self.structured_outputs is None:
kwargs = dict[str, Any](
json=self.guided_json,
json_object=guided_json_object,
regex=self.guided_regex,
choice=self.guided_choice,
grammar=self.guided_grammar,
whitespace_pattern=self.guided_whitespace_pattern,
)
kwargs = {k: v for k, v in kwargs.items() if v is not None}
if len(kwargs) > 0:
self.structured_outputs = StructuredOutputsParams(**kwargs)
self.structured_outputs.structural_tag = json.dumps(s_tag_obj)
extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
if self.kv_transfer_params:

View File

@ -3,7 +3,6 @@
"""Sampling parameters for text generation."""
import copy
import warnings
from dataclasses import field
from enum import Enum, IntEnum
from functools import cached_property
@ -100,19 +99,6 @@ class StructuredOutputsParams:
)
@dataclass
class GuidedDecodingParams(StructuredOutputsParams):
def __post_init__(self):
warnings.warn(
"GuidedDecodingParams is deprecated. This will be removed in "
"v0.12.0 or v1.0.0, which ever is soonest. Please use "
"StructuredOutputsParams instead.",
DeprecationWarning,
stacklevel=2,
)
return super().__post_init__()
class RequestOutputKind(Enum):
# Return entire output so far in every RequestOutput
CUMULATIVE = 0
@ -234,8 +220,6 @@ class SamplingParams(
# Fields used to construct logits processors
structured_outputs: StructuredOutputsParams | None = None
"""Parameters for configuring structured outputs."""
guided_decoding: GuidedDecodingParams | None = None
"""Deprecated alias for structured_outputs."""
logit_bias: dict[int, float] | None = None
"""If provided, the engine will construct a logits processor that applies
these logit biases."""
@ -283,7 +267,6 @@ class SamplingParams(
truncate_prompt_tokens: Annotated[int, msgspec.Meta(ge=-1)] | None = None,
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
structured_outputs: StructuredOutputsParams | None = None,
guided_decoding: GuidedDecodingParams | None = None,
logit_bias: dict[int, float] | dict[str, float] | None = None,
allowed_token_ids: list[int] | None = None,
extra_args: dict[str, Any] | None = None,
@ -295,16 +278,6 @@ class SamplingParams(
int(token): min(100.0, max(-100.0, bias))
for token, bias in logit_bias.items()
}
if guided_decoding is not None:
warnings.warn(
"guided_decoding is deprecated. This will be removed in "
"v0.12.0 or v1.0.0, which ever is soonest. Please use "
"structured_outputs instead.",
DeprecationWarning,
stacklevel=2,
)
structured_outputs = guided_decoding
guided_decoding = None
return SamplingParams(
n=1 if n is None else n,
@ -387,17 +360,6 @@ class SamplingParams(
# eos_token_id is added to this by the engine
self._all_stop_token_ids.update(self.stop_token_ids)
if self.guided_decoding is not None:
warnings.warn(
"guided_decoding is deprecated. This will be removed in "
"v0.12.0 or v1.0.0, which ever is soonest. Please use "
"structured_outputs instead.",
DeprecationWarning,
stacklevel=2,
)
self.structured_outputs = self.guided_decoding
self.guided_decoding = None
if self.skip_reading_prefix_cache is None:
# If prefix caching is enabled,
# the output of prompt logprobs may less than n_prompt_tokens,