mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 02:15:01 +08:00
Scheduled removal of guided_* config fields (#29326)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
2d9ee28cab
commit
316c8492bf
@ -7,7 +7,7 @@ This document shows you some examples of the different options that are
|
||||
available to generate structured outputs.
|
||||
|
||||
!!! warning
|
||||
If you are still using the following deprecated API fields, please update your code to use `structured_outputs` as demonstrated in the rest of this document:
|
||||
If you are still using the following deprecated API fields which were removed in v0.12.0, please update your code to use `structured_outputs` as demonstrated in the rest of this document:
|
||||
|
||||
- `guided_json` -> `{"structured_outputs": {"json": ...}}` or `StructuredOutputsParams(json=...)`
|
||||
- `guided_regex` -> `{"structured_outputs": {"regex": ...}}` or `StructuredOutputsParams(regex=...)`
|
||||
|
||||
@ -3,7 +3,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from dataclasses import fields
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
@ -21,7 +20,6 @@ from vllm.outputs import RequestOutput
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager
|
||||
from vllm.sampling_params import (
|
||||
GuidedDecodingParams,
|
||||
SamplingParams,
|
||||
StructuredOutputsParams,
|
||||
)
|
||||
@ -108,23 +106,6 @@ class CarDescription(BaseModel):
|
||||
car_type: CarType
|
||||
|
||||
|
||||
def test_guided_decoding_deprecated():
|
||||
with pytest.warns(DeprecationWarning, match="GuidedDecodingParams is deprecated.*"):
|
||||
guided_decoding = GuidedDecodingParams(json_object=True)
|
||||
|
||||
structured_outputs = StructuredOutputsParams(json_object=True)
|
||||
assert fields(guided_decoding) == fields(structured_outputs)
|
||||
|
||||
with pytest.warns(DeprecationWarning, match="guided_decoding is deprecated.*"):
|
||||
sp1 = SamplingParams(guided_decoding=guided_decoding)
|
||||
|
||||
with pytest.warns(DeprecationWarning, match="guided_decoding is deprecated.*"):
|
||||
sp2 = SamplingParams.from_optional(guided_decoding=guided_decoding)
|
||||
|
||||
assert sp1 == sp2
|
||||
assert sp1.structured_outputs == guided_decoding
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, backend, tokenizer_mode, speculative_config",
|
||||
PARAMS_MODELS_BACKENDS_TOKENIZER_MODE,
|
||||
@ -899,13 +880,11 @@ def test_structured_output_batched_with_non_structured_outputs_requests(
|
||||
output_json = json.loads(generated_text)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("guided_decoding_backend", ["xgrammar"])
|
||||
def test_structured_output_with_structural_tag(
|
||||
guided_decoding_backend: str,
|
||||
):
|
||||
@pytest.mark.parametrize("backend", ["xgrammar"])
|
||||
def test_structured_output_with_structural_tag(backend: str):
|
||||
llm = LLM(
|
||||
model="Qwen/Qwen2.5-1.5B-Instruct",
|
||||
guided_decoding_backend=guided_decoding_backend,
|
||||
structured_outputs_config=StructuredOutputsConfig(backend=backend),
|
||||
)
|
||||
|
||||
structural_tag_config = {
|
||||
@ -923,7 +902,7 @@ def test_structured_output_with_structural_tag(
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.0,
|
||||
max_tokens=500,
|
||||
guided_decoding=StructuredOutputsParams(
|
||||
structured_outputs=StructuredOutputsParams(
|
||||
structural_tag=json.dumps(structural_tag_config)
|
||||
),
|
||||
)
|
||||
|
||||
@ -502,11 +502,6 @@ class EngineArgs:
|
||||
)
|
||||
reasoning_parser: str = StructuredOutputsConfig.reasoning_parser
|
||||
reasoning_parser_plugin: str | None = None
|
||||
# Deprecated guided decoding fields
|
||||
guided_decoding_backend: str | None = None
|
||||
guided_decoding_disable_fallback: bool | None = None
|
||||
guided_decoding_disable_any_whitespace: bool | None = None
|
||||
guided_decoding_disable_additional_properties: bool | None = None
|
||||
|
||||
logits_processor_pattern: str | None = ModelConfig.logits_processor_pattern
|
||||
|
||||
@ -725,19 +720,6 @@ class EngineArgs:
|
||||
"--reasoning-parser-plugin",
|
||||
**structured_outputs_kwargs["reasoning_parser_plugin"],
|
||||
)
|
||||
# Deprecated guided decoding arguments
|
||||
for arg, type in [
|
||||
("--guided-decoding-backend", str),
|
||||
("--guided-decoding-disable-fallback", bool),
|
||||
("--guided-decoding-disable-any-whitespace", bool),
|
||||
("--guided-decoding-disable-additional-properties", bool),
|
||||
]:
|
||||
structured_outputs_group.add_argument(
|
||||
arg,
|
||||
type=type,
|
||||
help=(f"[DEPRECATED] {arg} will be removed in v0.12.0."),
|
||||
deprecated=True,
|
||||
)
|
||||
|
||||
# Parallel arguments
|
||||
parallel_kwargs = get_kwargs(ParallelConfig)
|
||||
@ -1712,21 +1694,6 @@ class EngineArgs:
|
||||
self.reasoning_parser_plugin
|
||||
)
|
||||
|
||||
# Forward the deprecated CLI args to the StructuredOutputsConfig
|
||||
so_config = self.structured_outputs_config
|
||||
if self.guided_decoding_backend is not None:
|
||||
so_config.guided_decoding_backend = self.guided_decoding_backend
|
||||
if self.guided_decoding_disable_fallback is not None:
|
||||
so_config.disable_fallback = self.guided_decoding_disable_fallback
|
||||
if self.guided_decoding_disable_any_whitespace is not None:
|
||||
so_config.disable_any_whitespace = (
|
||||
self.guided_decoding_disable_any_whitespace
|
||||
)
|
||||
if self.guided_decoding_disable_additional_properties is not None:
|
||||
so_config.disable_additional_properties = (
|
||||
self.guided_decoding_disable_additional_properties
|
||||
)
|
||||
|
||||
observability_config = ObservabilityConfig(
|
||||
show_hidden_metrics_for_version=self.show_hidden_metrics_for_version,
|
||||
otlp_traces_endpoint=self.otlp_traces_endpoint,
|
||||
|
||||
@ -652,62 +652,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
default=None,
|
||||
description="Additional kwargs for structured outputs",
|
||||
)
|
||||
guided_json: str | dict | BaseModel | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"`guided_json` is deprecated. "
|
||||
"This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
|
||||
"Please pass `json` to `structured_outputs` instead."
|
||||
),
|
||||
)
|
||||
guided_regex: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"`guided_regex` is deprecated. "
|
||||
"This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
|
||||
"Please pass `regex` to `structured_outputs` instead."
|
||||
),
|
||||
)
|
||||
guided_choice: list[str] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"`guided_choice` is deprecated. "
|
||||
"This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
|
||||
"Please pass `choice` to `structured_outputs` instead."
|
||||
),
|
||||
)
|
||||
guided_grammar: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"`guided_grammar` is deprecated. "
|
||||
"This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
|
||||
"Please pass `grammar` to `structured_outputs` instead."
|
||||
),
|
||||
)
|
||||
structural_tag: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"`structural_tag` is deprecated. "
|
||||
"This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
|
||||
"Please pass `structural_tag` to `structured_outputs` instead."
|
||||
),
|
||||
)
|
||||
guided_decoding_backend: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"`guided_decoding_backend` is deprecated. "
|
||||
"This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
|
||||
"Please remove it from your request."
|
||||
),
|
||||
)
|
||||
guided_whitespace_pattern: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"`guided_whitespace_pattern` is deprecated. "
|
||||
"This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
|
||||
"Please pass `whitespace_pattern` to `structured_outputs` instead."
|
||||
),
|
||||
)
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
@ -841,20 +785,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
if prompt_logprobs is None and self.echo:
|
||||
prompt_logprobs = self.top_logprobs
|
||||
|
||||
# Forward deprecated guided_* parameters to structured_outputs
|
||||
if self.structured_outputs is None:
|
||||
kwargs = dict[str, Any](
|
||||
json=self.guided_json,
|
||||
regex=self.guided_regex,
|
||||
choice=self.guided_choice,
|
||||
grammar=self.guided_grammar,
|
||||
whitespace_pattern=self.guided_whitespace_pattern,
|
||||
structural_tag=self.structural_tag,
|
||||
)
|
||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
if len(kwargs) > 0:
|
||||
self.structured_outputs = StructuredOutputsParams(**kwargs)
|
||||
|
||||
response_format = self.response_format
|
||||
if response_format is not None:
|
||||
# If structured outputs wasn't already enabled,
|
||||
@ -863,24 +793,23 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
self.structured_outputs = StructuredOutputsParams()
|
||||
|
||||
# Set structured output params for response format
|
||||
if response_format is not None:
|
||||
if response_format.type == "json_object":
|
||||
self.structured_outputs.json_object = True
|
||||
elif response_format.type == "json_schema":
|
||||
json_schema = response_format.json_schema
|
||||
assert json_schema is not None
|
||||
self.structured_outputs.json = json_schema.json_schema
|
||||
elif response_format.type == "structural_tag":
|
||||
structural_tag = response_format
|
||||
assert structural_tag is not None and isinstance(
|
||||
structural_tag,
|
||||
(
|
||||
LegacyStructuralTagResponseFormat,
|
||||
StructuralTagResponseFormat,
|
||||
),
|
||||
)
|
||||
s_tag_obj = structural_tag.model_dump(by_alias=True)
|
||||
self.structured_outputs.structural_tag = json.dumps(s_tag_obj)
|
||||
if response_format.type == "json_object":
|
||||
self.structured_outputs.json_object = True
|
||||
elif response_format.type == "json_schema":
|
||||
json_schema = response_format.json_schema
|
||||
assert json_schema is not None
|
||||
self.structured_outputs.json = json_schema.json_schema
|
||||
elif response_format.type == "structural_tag":
|
||||
structural_tag = response_format
|
||||
assert structural_tag is not None and isinstance(
|
||||
structural_tag,
|
||||
(
|
||||
LegacyStructuralTagResponseFormat,
|
||||
StructuralTagResponseFormat,
|
||||
),
|
||||
)
|
||||
s_tag_obj = structural_tag.model_dump(by_alias=True)
|
||||
self.structured_outputs.structural_tag = json.dumps(s_tag_obj)
|
||||
|
||||
extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
|
||||
if self.kv_transfer_params:
|
||||
@ -1140,58 +1069,6 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
default=None,
|
||||
description="Additional kwargs for structured outputs",
|
||||
)
|
||||
guided_json: str | dict | BaseModel | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"`guided_json` is deprecated. "
|
||||
"This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
|
||||
"Please pass `json` to `structured_outputs` instead."
|
||||
),
|
||||
)
|
||||
guided_regex: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"`guided_regex` is deprecated. "
|
||||
"This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
|
||||
"Please pass `regex` to `structured_outputs` instead."
|
||||
),
|
||||
)
|
||||
guided_choice: list[str] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"`guided_choice` is deprecated. "
|
||||
"This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
|
||||
"Please pass `choice` to `structured_outputs` instead."
|
||||
),
|
||||
)
|
||||
guided_grammar: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"`guided_grammar` is deprecated. "
|
||||
"This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
|
||||
"Please pass `grammar` to `structured_outputs` instead."
|
||||
),
|
||||
)
|
||||
structural_tag: str | None = Field(
|
||||
default=None,
|
||||
description=("If specified, the output will follow the structural tag schema."),
|
||||
)
|
||||
guided_decoding_backend: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"`guided_decoding_backend` is deprecated. "
|
||||
"This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
|
||||
"Please remove it from your request."
|
||||
),
|
||||
)
|
||||
guided_whitespace_pattern: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"`guided_whitespace_pattern` is deprecated. "
|
||||
"This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
|
||||
"Please pass `whitespace_pattern` to `structured_outputs` instead."
|
||||
),
|
||||
)
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
@ -1336,35 +1213,31 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
|
||||
echo_without_generation = self.echo and self.max_tokens == 0
|
||||
|
||||
guided_json_object = None
|
||||
if self.response_format is not None:
|
||||
if self.response_format.type == "json_object":
|
||||
guided_json_object = True
|
||||
elif self.response_format.type == "json_schema":
|
||||
json_schema = self.response_format.json_schema
|
||||
response_format = self.response_format
|
||||
if response_format is not None:
|
||||
# If structured outputs wasn't already enabled,
|
||||
# we must enable it for these features to work
|
||||
if self.structured_outputs is None:
|
||||
self.structured_outputs = StructuredOutputsParams()
|
||||
|
||||
# Set structured output params for response format
|
||||
if response_format.type == "json_object":
|
||||
self.structured_outputs.json_object = True
|
||||
elif response_format.type == "json_schema":
|
||||
json_schema = response_format.json_schema
|
||||
assert json_schema is not None
|
||||
self.guided_json = json_schema.json_schema
|
||||
elif self.response_format.type == "structural_tag":
|
||||
structural_tag = self.response_format
|
||||
self.structured_outputs.json = json_schema.json_schema
|
||||
elif response_format.type == "structural_tag":
|
||||
structural_tag = response_format
|
||||
assert structural_tag is not None and isinstance(
|
||||
structural_tag, StructuralTagResponseFormat
|
||||
structural_tag,
|
||||
(
|
||||
LegacyStructuralTagResponseFormat,
|
||||
StructuralTagResponseFormat,
|
||||
),
|
||||
)
|
||||
s_tag_obj = structural_tag.model_dump(by_alias=True)
|
||||
self.structural_tag = json.dumps(s_tag_obj)
|
||||
|
||||
# Forward deprecated guided_* parameters to structured_outputs
|
||||
if self.structured_outputs is None:
|
||||
kwargs = dict[str, Any](
|
||||
json=self.guided_json,
|
||||
json_object=guided_json_object,
|
||||
regex=self.guided_regex,
|
||||
choice=self.guided_choice,
|
||||
grammar=self.guided_grammar,
|
||||
whitespace_pattern=self.guided_whitespace_pattern,
|
||||
)
|
||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
if len(kwargs) > 0:
|
||||
self.structured_outputs = StructuredOutputsParams(**kwargs)
|
||||
self.structured_outputs.structural_tag = json.dumps(s_tag_obj)
|
||||
|
||||
extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
|
||||
if self.kv_transfer_params:
|
||||
|
||||
@ -3,7 +3,6 @@
|
||||
"""Sampling parameters for text generation."""
|
||||
|
||||
import copy
|
||||
import warnings
|
||||
from dataclasses import field
|
||||
from enum import Enum, IntEnum
|
||||
from functools import cached_property
|
||||
@ -100,19 +99,6 @@ class StructuredOutputsParams:
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GuidedDecodingParams(StructuredOutputsParams):
|
||||
def __post_init__(self):
|
||||
warnings.warn(
|
||||
"GuidedDecodingParams is deprecated. This will be removed in "
|
||||
"v0.12.0 or v1.0.0, which ever is soonest. Please use "
|
||||
"StructuredOutputsParams instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return super().__post_init__()
|
||||
|
||||
|
||||
class RequestOutputKind(Enum):
|
||||
# Return entire output so far in every RequestOutput
|
||||
CUMULATIVE = 0
|
||||
@ -234,8 +220,6 @@ class SamplingParams(
|
||||
# Fields used to construct logits processors
|
||||
structured_outputs: StructuredOutputsParams | None = None
|
||||
"""Parameters for configuring structured outputs."""
|
||||
guided_decoding: GuidedDecodingParams | None = None
|
||||
"""Deprecated alias for structured_outputs."""
|
||||
logit_bias: dict[int, float] | None = None
|
||||
"""If provided, the engine will construct a logits processor that applies
|
||||
these logit biases."""
|
||||
@ -283,7 +267,6 @@ class SamplingParams(
|
||||
truncate_prompt_tokens: Annotated[int, msgspec.Meta(ge=-1)] | None = None,
|
||||
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
|
||||
structured_outputs: StructuredOutputsParams | None = None,
|
||||
guided_decoding: GuidedDecodingParams | None = None,
|
||||
logit_bias: dict[int, float] | dict[str, float] | None = None,
|
||||
allowed_token_ids: list[int] | None = None,
|
||||
extra_args: dict[str, Any] | None = None,
|
||||
@ -295,16 +278,6 @@ class SamplingParams(
|
||||
int(token): min(100.0, max(-100.0, bias))
|
||||
for token, bias in logit_bias.items()
|
||||
}
|
||||
if guided_decoding is not None:
|
||||
warnings.warn(
|
||||
"guided_decoding is deprecated. This will be removed in "
|
||||
"v0.12.0 or v1.0.0, which ever is soonest. Please use "
|
||||
"structured_outputs instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
structured_outputs = guided_decoding
|
||||
guided_decoding = None
|
||||
|
||||
return SamplingParams(
|
||||
n=1 if n is None else n,
|
||||
@ -387,17 +360,6 @@ class SamplingParams(
|
||||
# eos_token_id is added to this by the engine
|
||||
self._all_stop_token_ids.update(self.stop_token_ids)
|
||||
|
||||
if self.guided_decoding is not None:
|
||||
warnings.warn(
|
||||
"guided_decoding is deprecated. This will be removed in "
|
||||
"v0.12.0 or v1.0.0, which ever is soonest. Please use "
|
||||
"structured_outputs instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
self.structured_outputs = self.guided_decoding
|
||||
self.guided_decoding = None
|
||||
|
||||
if self.skip_reading_prefix_cache is None:
|
||||
# If prefix caching is enabled,
|
||||
# the output of prompt logprobs may less than n_prompt_tokens,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user