[v0][structured output] Support reasoning output (#12955)

Signed-off-by: Ce Gao <cegao@tensorchord.ai>
This commit is contained in:
Ce Gao 2025-03-03 03:49:42 +08:00 committed by GitHub
parent bc6ccb9878
commit bf33700ecd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 400 additions and 76 deletions

View File

@ -76,7 +76,13 @@ Streaming chat completions are also supported for reasoning models. The `reasoni
}
```
Please note that it is not compatible with the OpenAI Python client library. You can use the `requests` library to make streaming requests.
Please note that it is not compatible with the OpenAI Python client library. You can use the `requests` library to make streaming requests. You could checkout the [example](https://github.com/vllm-project/vllm/blob/main/examples/online_serving/openai_chat_completion_with_reasoning_streaming.py).
## Limitations
- The reasoning content is only available for online serving's chat completion endpoint (`/v1/chat/completions`).
- It is not compatible with [`tool_calling`](#tool_calling).
- The reasoning content is not available for all models. Check the model's documentation to see if it supports reasoning.
## How to support a new reasoning model
@ -137,15 +143,36 @@ class ExampleParser(ReasoningParser):
"""
```
After defining the reasoning parser, you can use it by specifying the `--reasoning-parser` flag when making a request to the chat completion endpoint.
Additionally, to enable structured output, you'll need to create a new `Reasoner` similar to the one in `vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py`.
```python
@dataclass
class DeepSeekReasoner(Reasoner):
"""
Reasoner for DeepSeek R series models.
"""
start_token_id: int
end_token_id: int
start_token: str = "<think>"
end_token: str = "</think>"
@classmethod
def from_tokenizer(cls, tokenizer: PreTrainedTokenizer) -> Reasoner:
return cls(start_token_id=tokenizer.encode(
"<think>", add_special_tokens=False)[0],
end_token_id=tokenizer.encode("</think>",
add_special_tokens=False)[0])
def is_reasoning_end(self, input_ids: list[int]) -> bool:
return self.end_token_id in input_ids
```
The structured output engine like xgrammar will use `end_token_id` to check if the reasoning content is present in the model output and skip the structured output if it is the case.
Finally, you can enable reasoning for the model by using the `--enable-reasoning` and `--reasoning-parser` flags.
```bash
vllm serve <model_tag> \
--enable-reasoning --reasoning-parser example
```
## Limitations
- The reasoning content is only available for online serving's chat completion endpoint (`/v1/chat/completions`).
- It is not compatible with the [`structured_outputs`](#structured_outputs) and [`tool_calling`](#tool_calling) features.
- The reasoning content is not available for all models. Check the model's documentation to see if it supports reasoning.

View File

@ -0,0 +1,64 @@
# SPDX-License-Identifier: Apache-2.0
"""
An example shows how to generate structured outputs from reasoning models
like DeepSeekR1. The thinking process will not be guided by the JSON
schema provided by the user. Only the final output will be structured.
To run this example, you need to start the vLLM server with the reasoning
parser:
```bash
vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \
--enable-reasoning --reasoning-parser deepseek_r1
```
This example demonstrates how to generate chat completions from reasoning models
using the OpenAI Python client library.
"""
from enum import Enum
from openai import OpenAI
from pydantic import BaseModel
# Modify OpenAI's API key and API base to use vLLM's API server.
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"
client = OpenAI(
api_key=openai_api_key,
base_url=openai_api_base,
)
models = client.models.list()
model = models.data[0].id
# Guided decoding by JSON using Pydantic schema
class CarType(str, Enum):
sedan = "sedan"
suv = "SUV"
truck = "Truck"
coupe = "Coupe"
class CarDescription(BaseModel):
brand: str
model: str
car_type: CarType
json_schema = CarDescription.model_json_schema()
prompt = ("Generate a JSON with the brand, model and car_type of"
"the most iconic car from the 90's, think in 100 tokens")
completion = client.chat.completions.create(
model=model,
messages=[{
"role": "user",
"content": prompt,
}],
extra_body={"guided_json": json_schema},
)
print("content", completion.choices[0].message.content)
print("reasoning_content: ", completion.choices[0].message.reasoning_content)

View File

@ -16,17 +16,33 @@ from vllm.sampling_params import GuidedDecodingParams
MODEL_NAME = 'HuggingFaceH4/zephyr-7b-beta'
GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"]
GUIDED_DECODING_BACKENDS_WITH_REASONING_SUPPORT = ["outlines", "xgrammar"]
REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
def test_guided_logits_processors(sample_regex, sample_json_schema):
# Initialize the tokenizer for the model here to avoid repeated loading
@pytest.fixture(scope="module")
def zephyr_7B_tokenzer():
return AutoTokenizer.from_pretrained(MODEL_NAME)
@pytest.fixture(scope="module")
def deepseek_r1_qwen_tokenizer():
return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)
def test_guided_logits_processors(zephyr_7B_tokenzer, sample_regex,
sample_json_schema):
"""Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor."""
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
regex_LP = RegexLogitsProcessor(sample_regex, tokenizer)
regex_LP = RegexLogitsProcessor(sample_regex,
zephyr_7B_tokenzer,
reasoner=None)
json_LP = JSONLogitsProcessor(sample_json_schema,
tokenizer,
whitespace_pattern=None)
zephyr_7B_tokenzer,
whitespace_pattern=None,
reasoner=None)
token_ids = tokenizer.encode(
token_ids = zephyr_7B_tokenzer.encode(
f"Give an example IPv4 address with this regex: {sample_regex}")
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
@ -34,7 +50,7 @@ def test_guided_logits_processors(sample_regex, sample_json_schema):
assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor)
token_ids = tokenizer.encode(
token_ids = zephyr_7B_tokenzer.encode(
f"Give an employee profile that fits this schema: {sample_json_schema}"
)
tensor = torch.rand(32000)
@ -49,7 +65,8 @@ def test_guided_logits_processors(sample_regex, sample_json_schema):
@pytest.mark.parametrize("is_local", [True, False])
async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
sample_regex,
sample_json_schema):
sample_json_schema,
zephyr_7B_tokenzer):
config = ModelConfig(
MODEL_NAME,
@ -60,15 +77,14 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
seed=0,
dtype="bfloat16",
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
token_ids = tokenizer.encode(
token_ids = zephyr_7B_tokenzer.encode(
f"Give an example IPv4 address with this regex: {sample_regex}")
regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)
regex_lp = get_local_guided_decoding_logits_processor(
regex_request, tokenizer, config) if is_local else \
regex_request, zephyr_7B_tokenzer, config) if is_local else \
await get_guided_decoding_logits_processor(
regex_request, tokenizer, config)
regex_request, zephyr_7B_tokenzer, config)
assert regex_lp is not None
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
@ -76,13 +92,85 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor)
token_ids = tokenizer.encode(
token_ids = zephyr_7B_tokenzer.encode(
f"Give an employee profile that fits this schema: {sample_json_schema}"
)
json_request = GuidedDecodingParams(json=sample_json_schema,
backend=backend)
json_lp = await get_guided_decoding_logits_processor(
json_request, tokenizer, config)
json_request, zephyr_7B_tokenzer, config)
assert json_lp is not None
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
tensor = json_lp(token_ids, tensor)
assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor)
@pytest.mark.asyncio
@pytest.mark.parametrize("backend",
GUIDED_DECODING_BACKENDS_WITH_REASONING_SUPPORT)
@pytest.mark.parametrize("is_local", [True, False])
@pytest.mark.parametrize("reasoning_backend", ["deepseek_r1"])
async def test_guided_logits_processor_with_reasoning(
backend: str, is_local: bool, reasoning_backend: str, sample_regex,
sample_json_schema, deepseek_r1_qwen_tokenizer):
config = ModelConfig(
REASONING_MODEL_NAME,
task="generate",
tokenizer=REASONING_MODEL_NAME,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="bfloat16",
)
token_ids = deepseek_r1_qwen_tokenizer.encode(
f"Give an example IPv4 address with this regex: {sample_regex}."
"<think>here is the thinking process")
regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)
regex_lp = get_local_guided_decoding_logits_processor(regex_request,
deepseek_r1_qwen_tokenizer, config,
reasoning_backend) if is_local else \
await get_guided_decoding_logits_processor(
regex_request, deepseek_r1_qwen_tokenizer, config,
reasoning_backend)
assert regex_lp is not None
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
tensor = regex_lp(token_ids, tensor)
assert tensor.shape == original_tensor.shape
assert torch.allclose(tensor, original_tensor)
token_ids = deepseek_r1_qwen_tokenizer.encode(
f"Give an employee profile that fits this schema: {sample_json_schema}."
"<think>here is the thinking process")
json_request = GuidedDecodingParams(json=sample_json_schema,
backend=backend)
json_lp = get_local_guided_decoding_logits_processor(
json_request, deepseek_r1_qwen_tokenizer, config,
reasoning_backend) if is_local else \
await get_guided_decoding_logits_processor(
json_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend)
assert json_lp is not None
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
tensor = json_lp(token_ids, tensor)
assert tensor.shape == original_tensor.shape
assert torch.allclose(tensor, original_tensor)
# Thinking is over, so the tensor should change.
token_ids = deepseek_r1_qwen_tokenizer.encode(
f"Give an employee profile that fits this schema: {sample_json_schema}."
"<think>here is the thinking process</think> Then")
json_request = GuidedDecodingParams(json=sample_json_schema,
backend=backend)
json_lp = get_local_guided_decoding_logits_processor(
json_request, deepseek_r1_qwen_tokenizer, config,
reasoning_backend) if is_local else \
await get_guided_decoding_logits_processor(
json_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend)
assert json_lp is not None
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)

View File

@ -2715,6 +2715,8 @@ class DecodingConfig:
# 'outlines' / 'lm-format-enforcer' / 'xgrammar'
guided_decoding_backend: str = 'xgrammar'
reasoning_backend: Optional[str] = None
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,

View File

@ -213,6 +213,8 @@ class EngineArgs:
calculate_kv_scales: Optional[bool] = None
additional_config: Optional[Dict[str, Any]] = None
enable_reasoning: Optional[bool] = None
reasoning_parser: Optional[str] = None
def __post_init__(self):
if not self.tokenizer:
@ -1059,6 +1061,25 @@ class EngineArgs:
"Different platforms may support different configs. Make sure the "
"configs are valid for the platform you are using. The input format"
" is like '{\"config_key\":\"config_value\"}'")
parser.add_argument(
"--enable-reasoning",
action="store_true",
default=False,
help="Whether to enable reasoning_content for the model. "
"If enabled, the model will be able to generate reasoning content."
)
parser.add_argument(
"--reasoning-parser",
type=str,
choices=["deepseek_r1"],
default=None,
help=
"Select the reasoning parser depending on the model that you're "
"using. This is used to parse the reasoning content into OpenAI "
"API format. Required for ``--enable-reasoning``.")
return parser
@classmethod
@ -1332,7 +1353,10 @@ class EngineArgs:
if self.enable_prompt_adapter else None
decoding_config = DecodingConfig(
guided_decoding_backend=self.guided_decoding_backend)
guided_decoding_backend=self.guided_decoding_backend,
reasoning_backend=self.reasoning_parser
if self.enable_reasoning else None,
)
show_hidden_metrics = False
if self.show_hidden_metrics_for_version is not None:

View File

@ -509,6 +509,7 @@ class _AsyncLLMEngine(LLMEngine):
tokenizer=await self.get_tokenizer_async(lora_request),
default_guided_backend=self.decoding_config.
guided_decoding_backend,
reasoning_backend=self.decoding_config.reasoning_backend,
model_config=self.model_config)
self._add_processed_request(
@ -530,7 +531,7 @@ class _AsyncLLMEngine(LLMEngine):
async def build_guided_decoding_logits_processor_async(
sampling_params: SamplingParams, tokenizer: AnyTokenizer,
default_guided_backend: str,
default_guided_backend: str, reasoning_backend: Optional[str],
model_config: ModelConfig) -> SamplingParams:
"""Constructs logits processors based on the guided_decoding,
logits_bias, and allowed_token_ids fields in sampling_params. Deletes
@ -545,14 +546,18 @@ async def build_guided_decoding_logits_processor_async(
sampling_params = copy.copy(sampling_params)
guided_decoding = sampling_params.guided_decoding
logger.debug("Building guided decoding logits processor. "
"Params: %s", guided_decoding)
logger.info(
"Building guided decoding logits processor. "
"guided_decoding: %s%s", guided_decoding,
f", reasoning_backend: {reasoning_backend}"
if reasoning_backend is not None else "")
guided_decoding.backend = guided_decoding.backend or default_guided_backend
processor = await get_guided_decoding_logits_processor(
guided_params=guided_decoding,
tokenizer=tokenizer,
reasoning_backend=reasoning_backend,
model_config=model_config)
if processor:

View File

@ -2048,10 +2048,15 @@ class LLMEngine:
guided_decoding.backend = guided_decoding.backend or \
self.decoding_config.guided_decoding_backend
logger.debug("Reasoning backend: %s",
self.decoding_config.reasoning_backend)
processor = get_local_guided_decoding_logits_processor(
guided_params=guided_decoding,
tokenizer=tokenizer,
model_config=self.model_config)
model_config=self.model_config,
reasoning_backend=self.decoding_config.reasoning_backend,
)
if processor:
logits_processors.append(processor)

View File

@ -611,7 +611,8 @@ class MQLLMEngineClient(EngineClient):
default_guided_backend=(self.decoding_config.guided_decoding_backend
if self.decoding_config
else DecodingConfig.guided_decoding_backend),
model_config=self.model_config
model_config=self.model_config,
reasoning_backend=self.decoding_config.reasoning_backend,
)
# 1) Create output queue for this requests.

View File

@ -13,7 +13,6 @@ from typing import List, Optional, Sequence, Union, get_args
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
validate_chat_template)
from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager
from vllm.entrypoints.openai.serving_models import (LoRAModulePath,
PromptAdapterPath)
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
@ -215,23 +214,6 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
default=False,
help="Enable auto tool choice for supported models. Use "
"``--tool-call-parser`` to specify which parser to use.")
parser.add_argument(
"--enable-reasoning",
action="store_true",
default=False,
help="Whether to enable reasoning_content for the model. "
"If enabled, the model will be able to generate reasoning content.")
valid_reasoning_parsers = ReasoningParserManager.reasoning_parsers.keys()
parser.add_argument(
"--reasoning-parser",
type=str,
metavar="{" + ",".join(valid_reasoning_parsers) + "}",
default=None,
help=
"Select the reasoning parser depending on the model that you're using."
" This is used to parse the reasoning content into OpenAI API "
"format. Required for ``--enable-reasoning``.")
valid_tool_parsers = ToolParserManager.tool_parsers.keys()
parser.add_argument(

View File

@ -5,6 +5,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING
from vllm.logger import init_logger
from vllm.model_executor.guided_decoding.reasoner import get_reasoner
from vllm.model_executor.guided_decoding.utils import (
convert_lark_to_gbnf, grammar_is_likely_lark,
has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features)
@ -103,8 +104,13 @@ def maybe_backend_fallback(
async def get_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer,
model_config: ModelConfig) -> LogitsProcessor | None:
guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizer,
model_config: ModelConfig,
reasoning_backend: str | None = None) -> LogitsProcessor | None:
reasoner = get_reasoner(tokenizer, reasoning_backend)
guided_params = maybe_backend_fallback(guided_params)
# CFG grammar not supported by LMFE, so we use outlines instead
if guided_params.backend_name == 'outlines':
@ -112,8 +118,8 @@ async def get_guided_decoding_logits_processor(
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
get_outlines_guided_decoding_logits_processor)
return await get_outlines_guided_decoding_logits_processor(
guided_params, tokenizer)
if guided_params.backend_name == 'lm-format-enforcer':
guided_params, tokenizer, reasoner)
if guided_params.backend == 'lm-format-enforcer':
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
get_local_lm_format_enforcer_guided_decoding_logits_processor)
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
@ -122,7 +128,7 @@ async def get_guided_decoding_logits_processor(
from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa
get_local_xgrammar_guided_decoding_logits_processor)
return get_local_xgrammar_guided_decoding_logits_processor(
guided_params, tokenizer, model_config)
guided_params, tokenizer, model_config, reasoner)
raise ValueError(
f"Unknown guided decoding backend '{guided_params.backend}'. "
@ -130,16 +136,22 @@ async def get_guided_decoding_logits_processor(
def get_local_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer,
model_config: ModelConfig) -> LogitsProcessor | None:
guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizer,
model_config: ModelConfig,
reasoning_backend: str | None = None) -> LogitsProcessor | None:
guided_params = maybe_backend_fallback(guided_params)
# Get the reasoner if needed, it will be None if reasoning_
reasoner = get_reasoner(tokenizer, reasoning_backend)
# CFG grammar not supported by LMFE, so we use outlines instead
if guided_params.backend_name == 'outlines':
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
get_local_outlines_guided_decoding_logits_processor)
return get_local_outlines_guided_decoding_logits_processor(
guided_params, tokenizer)
guided_params, tokenizer, reasoner)
if guided_params.backend_name == 'lm-format-enforcer':
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
get_local_lm_format_enforcer_guided_decoding_logits_processor)
@ -149,7 +161,7 @@ def get_local_guided_decoding_logits_processor(
from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa
get_local_xgrammar_guided_decoding_logits_processor)
return get_local_xgrammar_guided_decoding_logits_processor(
guided_params, tokenizer, model_config)
guided_params, tokenizer, model_config, reasoner)
raise ValueError(
f"Unknown guided decoding backend '{guided_params.backend}'. "

View File

@ -6,12 +6,13 @@ import os
from enum import Enum
from json import dumps as json_dumps
from re import escape as regex_escape
from typing import Tuple, Union
from typing import Optional, Tuple, Union
from transformers import PreTrainedTokenizerBase
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor)
from vllm.model_executor.guided_decoding.reasoner import Reasoner
from vllm.sampling_params import GuidedDecodingParams
@ -58,7 +59,9 @@ _MAX_THREADPOOL_WORKERS = 16
async def get_outlines_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase
guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizerBase,
reasoner: Optional[Reasoner],
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
None]:
"""
@ -82,11 +85,14 @@ async def get_outlines_guided_decoding_logits_processor(
return await loop.run_in_executor(global_thread_pool,
_get_logits_processor, guide, tokenizer,
mode, guided_params.whitespace_pattern)
mode, guided_params.whitespace_pattern,
reasoner)
def get_local_outlines_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase
guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizerBase,
reasoner: Optional[Reasoner],
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
None]:
"""
@ -100,7 +106,7 @@ def get_local_outlines_guided_decoding_logits_processor(
return None
return _get_logits_processor(guide, tokenizer, mode,
guided_params.whitespace_pattern)
guided_params.whitespace_pattern, reasoner)
def _get_guide_and_mode(
@ -131,14 +137,18 @@ def _get_guide_and_mode(
def _get_logits_processor(
guide: str, tokenizer: PreTrainedTokenizerBase, mode: GuidedDecodingMode,
whitespace_pattern: Union[str, None]
guide: str,
tokenizer: PreTrainedTokenizerBase,
mode: GuidedDecodingMode,
whitespace_pattern: Union[str, None],
reasoner: Optional[Reasoner],
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor]:
if mode == GuidedDecodingMode.JSON:
return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern)
return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern,
reasoner)
elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE:
return RegexLogitsProcessor(guide, tokenizer)
return RegexLogitsProcessor(guide, tokenizer, reasoner)
elif mode == GuidedDecodingMode.GRAMMAR:
return CFGLogitsProcessor(guide, tokenizer)
return CFGLogitsProcessor(guide, tokenizer, reasoner)
else:
raise ValueError(f"Unknown guided decoding mode {mode}")

View File

@ -19,7 +19,7 @@ import copy
import json
from collections import defaultdict
from functools import lru_cache
from typing import Callable, DefaultDict, Dict, List, Union
from typing import Callable, DefaultDict, Dict, List, Optional, Union
import numpy as np
import torch
@ -32,13 +32,18 @@ from outlines_core.fsm.json_schema import build_regex_from_schema
from pydantic import BaseModel
from transformers import PreTrainedTokenizerBase
from vllm.logger import init_logger
from vllm.model_executor.guided_decoding.reasoner import Reasoner
from vllm.platforms import current_platform
logger = init_logger(__name__)
class BaseLogitsProcessor:
def __init__(self, guide: Guide):
def __init__(self, guide: Guide, reasoner: Optional[Reasoner]):
self._guide: Guide = guide
self._reasoner = reasoner
# CFGState is used for the FSM state for CFGGuide
self._fsm_state: DefaultDict[int, Union[int,
CFGState]] = defaultdict(int)
@ -46,6 +51,14 @@ class BaseLogitsProcessor:
def __call__(self, input_ids: List[int],
scores: torch.Tensor) -> torch.Tensor:
"""Use the FSM to bias the logits before sampling the next token."""
# Skip the structured logits processing if reasoning is not finished.
# reasoner is not None only when `--enable-reasoning` is set.
if self._reasoner is not None and \
not self._reasoner.is_reasoning_end(
input_ids):
return scores
seq_id = hash(tuple(input_ids))
if len(input_ids) > 0:
@ -113,7 +126,12 @@ class RegexLogitsProcessor(BaseLogitsProcessor):
tokenizer = _adapt_tokenizer(tokenizer)
return RegexGuide.from_regex(regex_string, tokenizer)
def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase):
def __init__(
self,
regex_string: str,
tokenizer: PreTrainedTokenizerBase,
reasoner: Optional[Reasoner],
):
"""Compile the FSM that drives the regex-structured generation.
Parameters
@ -125,14 +143,15 @@ class RegexLogitsProcessor(BaseLogitsProcessor):
"""
super().__init__(
RegexLogitsProcessor._get_guide(regex_string, tokenizer))
RegexLogitsProcessor._get_guide(regex_string, tokenizer), reasoner)
class JSONLogitsProcessor(RegexLogitsProcessor):
def __init__(self, schema: Union[str, Dict, BaseModel],
tokenizer: PreTrainedTokenizerBase,
whitespace_pattern: Union[str, None]):
whitespace_pattern: Union[str, None],
reasoner: Optional[Reasoner]):
"""Compile the FSM that drives the JSON-guided generation.
Parameters
@ -160,7 +179,7 @@ class JSONLogitsProcessor(RegexLogitsProcessor):
f"a Pydantic object, a dictionary or a string that contains "
f"the JSON Schema specification")
regex_string = build_regex_from_schema(schema_str, whitespace_pattern)
super().__init__(regex_string, tokenizer)
super().__init__(regex_string, tokenizer, reasoner)
class CFGLogitsProcessor(BaseLogitsProcessor):
@ -171,7 +190,8 @@ class CFGLogitsProcessor(BaseLogitsProcessor):
tokenizer = _adapt_tokenizer(tokenizer)
return CFGGuide(cfg, tokenizer)
def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase):
def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase,
reasoner: Optional[Reasoner]):
"""Compile the FSM that drives the context free grammar generation.
Parameters
@ -182,7 +202,8 @@ class CFGLogitsProcessor(BaseLogitsProcessor):
The model's tokenizer
"""
super().__init__(CFGLogitsProcessor._get_guide(cfg, tokenizer))
super().__init__(CFGLogitsProcessor._get_guide(cfg, tokenizer),
reasoner)
self._guide = self._guide.copy()

View File

@ -0,0 +1,23 @@
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
from transformers import PreTrainedTokenizer
from vllm.model_executor.guided_decoding.reasoner.deepseek_reasoner import ( # noqa: E501
DeepSeekReasoner)
from vllm.model_executor.guided_decoding.reasoner.reasoner import Reasoner
def get_reasoner(tokenizer: PreTrainedTokenizer,
reasoning_backend: str | None) -> Reasoner | None:
if reasoning_backend is None:
# No reasoning backend specified
return None
elif reasoning_backend == "deepseek_r1":
return DeepSeekReasoner.from_tokenizer(tokenizer)
else:
raise ValueError(f"Unknown reasoning backend '{reasoning_backend}'")
__all__ = ["Reasoner", "get_reasoner"]

View File

@ -0,0 +1,28 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from transformers import PreTrainedTokenizer
from vllm.model_executor.guided_decoding.reasoner.reasoner import Reasoner
@dataclass
class DeepSeekReasoner(Reasoner):
"""
Reasoner for DeepSeek R series models.
"""
start_token_id: int
end_token_id: int
start_token: str = "<think>"
end_token: str = "</think>"
@classmethod
def from_tokenizer(cls, tokenizer: PreTrainedTokenizer) -> Reasoner:
return cls(start_token_id=tokenizer.encode(
"<think>", add_special_tokens=False)[0],
end_token_id=tokenizer.encode("</think>",
add_special_tokens=False)[0])
def is_reasoning_end(self, input_ids: list[int]) -> bool:
return self.end_token_id in input_ids

View File

@ -0,0 +1,19 @@
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from transformers import PreTrainedTokenizer
@dataclass
class Reasoner(ABC):
@abstractmethod
def from_tokenizer(cls, tokenizer: PreTrainedTokenizer) -> Reasoner:
pass
@abstractmethod
def is_reasoning_end(self, input_ids: list[int]) -> bool:
pass

View File

@ -11,6 +11,8 @@ from typing import TYPE_CHECKING, Any, List
import torch
from transformers import PreTrainedTokenizerFast
from vllm.logger import init_logger
try:
import xgrammar as xgr
from xgrammar.base import _core as xgr_core
@ -19,7 +21,6 @@ except ImportError:
xgr_installed = False
pass
from vllm.logger import init_logger
from vllm.model_executor.guided_decoding.utils import (convert_lark_to_gbnf,
grammar_is_likely_lark)
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
@ -28,6 +29,7 @@ if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
from vllm.config import ModelConfig
from vllm.model_executor.guided_decoding.reasoner import Reasoner
from vllm.sampling_params import GuidedDecodingParams
logger = init_logger(__name__)
@ -38,12 +40,13 @@ def get_local_xgrammar_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizer,
model_config: ModelConfig,
reasoner: Reasoner | None,
max_threads: int = 8):
config = GrammarConfig.from_guided_params(guided_params=guided_params,
model_config=model_config,
tokenizer=tokenizer,
max_threads=max_threads)
return XGrammarLogitsProcessor(config)
return XGrammarLogitsProcessor(config, reasoner)
@dataclass(frozen=True)
@ -293,6 +296,7 @@ class GrammarConfig:
class XGrammarLogitsProcessor:
"""Wrapper class to support pickle protocol"""
config: GrammarConfig
reasoner: Reasoner | None = None
ctx: xgr.CompiledGrammar | None = None
token_bitmask: torch.Tensor = None # type: ignore[assignment]
@ -301,10 +305,11 @@ class XGrammarLogitsProcessor:
prefilled: bool = field(default=False)
def __getstate__(self) -> dict[str, Any]:
return {'config': self.config}
return {'config': self.config, 'reasoner': self.reasoner}
def __setstate__(self, state: dict[str, Any]):
self.config = state['config']
self.reasoner = state['reasoner']
self.ctx = None
self.matchers = []
@ -331,6 +336,14 @@ class XGrammarLogitsProcessor:
def __call__(self, input_ids: list[int],
scores: torch.Tensor) -> torch.Tensor:
# Skip the structured logits processing if reasoning is not finished.
# reasoner is not None only when `--enable-reasoning` is set.
if self.reasoner is not None and \
not self.reasoner.is_reasoning_end(
input_ids):
return scores
if self.ctx is None:
self._ensure_ctx()