mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 06:55:01 +08:00
Add backward compatibility for GuidedDecodingParams (#25422)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
cc1dc7ed6d
commit
875d6def90
@ -5,6 +5,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
from dataclasses import fields
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
@ -21,7 +22,8 @@ from vllm.entrypoints.llm import LLM
|
|||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager
|
from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager
|
||||||
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
|
from vllm.sampling_params import (GuidedDecodingParams, SamplingParams,
|
||||||
|
StructuredOutputsParams)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.config import TokenizerMode
|
from vllm.config import TokenizerMode
|
||||||
@ -89,6 +91,26 @@ def _load_json(s: str, backend: str) -> str:
|
|||||||
return json.loads(s)
|
return json.loads(s)
|
||||||
|
|
||||||
|
|
||||||
|
def test_guided_decoding_deprecated():
|
||||||
|
with pytest.warns(DeprecationWarning,
|
||||||
|
match="GuidedDecodingParams is deprecated.*"):
|
||||||
|
guided_decoding = GuidedDecodingParams(json_object=True)
|
||||||
|
|
||||||
|
structured_outputs = StructuredOutputsParams(json_object=True)
|
||||||
|
assert fields(guided_decoding) == fields(structured_outputs)
|
||||||
|
|
||||||
|
with pytest.warns(DeprecationWarning,
|
||||||
|
match="guided_decoding is deprecated.*"):
|
||||||
|
sp1 = SamplingParams(guided_decoding=guided_decoding)
|
||||||
|
|
||||||
|
with pytest.warns(DeprecationWarning,
|
||||||
|
match="guided_decoding is deprecated.*"):
|
||||||
|
sp2 = SamplingParams.from_optional(guided_decoding=guided_decoding)
|
||||||
|
|
||||||
|
assert sp1 == sp2
|
||||||
|
assert sp1.structured_outputs == guided_decoding
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip_global_cleanup
|
@pytest.mark.skip_global_cleanup
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model_name, backend, tokenizer_mode, speculative_config",
|
"model_name, backend, tokenizer_mode, speculative_config",
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
"""Sampling parameters for text generation."""
|
"""Sampling parameters for text generation."""
|
||||||
import copy
|
import copy
|
||||||
|
import warnings
|
||||||
from dataclasses import field
|
from dataclasses import field
|
||||||
from enum import Enum, IntEnum
|
from enum import Enum, IntEnum
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
@ -59,6 +60,19 @@ class StructuredOutputsParams:
|
|||||||
f"but multiple are specified: {self.__dict__}")
|
f"but multiple are specified: {self.__dict__}")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GuidedDecodingParams(StructuredOutputsParams):
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
warnings.warn(
|
||||||
|
"GuidedDecodingParams is deprecated. This will be removed in "
|
||||||
|
"v0.12.0 or v1.0.0, which ever is soonest. Please use "
|
||||||
|
"StructuredOutputsParams instead.",
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=2)
|
||||||
|
return super().__post_init__()
|
||||||
|
|
||||||
|
|
||||||
class RequestOutputKind(Enum):
|
class RequestOutputKind(Enum):
|
||||||
# Return entire output so far in every RequestOutput
|
# Return entire output so far in every RequestOutput
|
||||||
CUMULATIVE = 0
|
CUMULATIVE = 0
|
||||||
@ -179,6 +193,8 @@ class SamplingParams(
|
|||||||
# Fields used to construct logits processors
|
# Fields used to construct logits processors
|
||||||
structured_outputs: Optional[StructuredOutputsParams] = None
|
structured_outputs: Optional[StructuredOutputsParams] = None
|
||||||
"""Parameters for configuring structured outputs."""
|
"""Parameters for configuring structured outputs."""
|
||||||
|
guided_decoding: Optional[GuidedDecodingParams] = None
|
||||||
|
"""Deprecated alias for structured_outputs."""
|
||||||
logit_bias: Optional[dict[int, float]] = None
|
logit_bias: Optional[dict[int, float]] = None
|
||||||
"""If provided, the engine will construct a logits processor that applies
|
"""If provided, the engine will construct a logits processor that applies
|
||||||
these logit biases."""
|
these logit biases."""
|
||||||
@ -227,6 +243,7 @@ class SamplingParams(
|
|||||||
ge=-1)]] = None,
|
ge=-1)]] = None,
|
||||||
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
|
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
|
||||||
structured_outputs: Optional[StructuredOutputsParams] = None,
|
structured_outputs: Optional[StructuredOutputsParams] = None,
|
||||||
|
guided_decoding: Optional[GuidedDecodingParams] = None,
|
||||||
logit_bias: Optional[Union[dict[int, float], dict[str, float]]] = None,
|
logit_bias: Optional[Union[dict[int, float], dict[str, float]]] = None,
|
||||||
allowed_token_ids: Optional[list[int]] = None,
|
allowed_token_ids: Optional[list[int]] = None,
|
||||||
extra_args: Optional[dict[str, Any]] = None,
|
extra_args: Optional[dict[str, Any]] = None,
|
||||||
@ -238,6 +255,15 @@ class SamplingParams(
|
|||||||
int(token): min(100.0, max(-100.0, bias))
|
int(token): min(100.0, max(-100.0, bias))
|
||||||
for token, bias in logit_bias.items()
|
for token, bias in logit_bias.items()
|
||||||
}
|
}
|
||||||
|
if guided_decoding is not None:
|
||||||
|
warnings.warn(
|
||||||
|
"guided_decoding is deprecated. This will be removed in "
|
||||||
|
"v0.12.0 or v1.0.0, which ever is soonest. Please use "
|
||||||
|
"structured_outputs instead.",
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=2)
|
||||||
|
structured_outputs = guided_decoding
|
||||||
|
guided_decoding = None
|
||||||
|
|
||||||
return SamplingParams(
|
return SamplingParams(
|
||||||
n=1 if n is None else n,
|
n=1 if n is None else n,
|
||||||
@ -334,6 +360,16 @@ class SamplingParams(
|
|||||||
# eos_token_id is added to this by the engine
|
# eos_token_id is added to this by the engine
|
||||||
self._all_stop_token_ids.update(self.stop_token_ids)
|
self._all_stop_token_ids.update(self.stop_token_ids)
|
||||||
|
|
||||||
|
if self.guided_decoding is not None:
|
||||||
|
warnings.warn(
|
||||||
|
"guided_decoding is deprecated. This will be removed in "
|
||||||
|
"v0.12.0 or v1.0.0, which ever is soonest. Please use "
|
||||||
|
"structured_outputs instead.",
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=2)
|
||||||
|
self.structured_outputs = self.guided_decoding
|
||||||
|
self.guided_decoding = None
|
||||||
|
|
||||||
def _verify_args(self) -> None:
|
def _verify_args(self) -> None:
|
||||||
if not isinstance(self.n, int):
|
if not isinstance(self.n, int):
|
||||||
raise ValueError(f"n must be an int, but is of "
|
raise ValueError(f"n must be an int, but is of "
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user