[V1] Structured Outputs + Thinking compatibility (#16577)

Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
Co-authored-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
Aaron Pham 2025-05-14 18:45:24 -04:00 committed by GitHub
parent d93c976a0d
commit 2fc9075b82
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 233 additions and 75 deletions

View File

@ -141,10 +141,10 @@ Remember to check whether the `reasoning_content` exists in the response before
The reasoning content is also available in the structured output. The structured output engine like `xgrammar` will use the reasoning content to generate structured output. It is only supported in v0 engine now. The reasoning content is also available in the structured output. The structured output engine like `xgrammar` will use the reasoning content to generate structured output. It is only supported in v0 engine now.
```bash ```bash
VLLM_USE_V1=0 vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B --reasoning-parser deepseek_r1 vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B --reasoning-parser deepseek_r1
``` ```
Please note that the `VLLM_USE_V1` environment variable must be set to `0` to use the v0 engine. The following is an example client:
```python ```python
from openai import OpenAI from openai import OpenAI

View File

@ -1,3 +1,4 @@
# ruff: noqa: E501
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from __future__ import annotations from __future__ import annotations
@ -5,17 +6,22 @@ from __future__ import annotations
import json import json
import re import re
from enum import Enum from enum import Enum
from typing import Any from typing import TYPE_CHECKING, Any
import jsonschema import jsonschema
import pytest import pytest
from pydantic import BaseModel from pydantic import BaseModel
from tests.reasoning.utils import run_reasoning_extraction
from vllm.entrypoints.llm import LLM from vllm.entrypoints.llm import LLM
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager
from vllm.sampling_params import GuidedDecodingParams, SamplingParams from vllm.sampling_params import GuidedDecodingParams, SamplingParams
if TYPE_CHECKING:
from vllm.config import TokenizerMode
NGRAM_SPEC_CONFIG = { NGRAM_SPEC_CONFIG = {
"model": "[ngram]", "model": "[ngram]",
"num_speculative_tokens": 5, "num_speculative_tokens": 5,
@ -444,7 +450,7 @@ def test_structured_output(
prompt = """ prompt = """
You have access to the following function to retrieve the weather in a city: You have access to the following function to retrieve the weather in a city:
{ {
"name": "get_weather", "name": "get_weather",
"parameters": { "parameters": {
@ -455,7 +461,7 @@ You have access to the following function to retrieve the weather in a city:
} }
} }
} }
If a you choose to call a function ONLY reply in the following format: If a you choose to call a function ONLY reply in the following format:
<{start_tag}={function_name}>{parameters}{end_tag} <{start_tag}={function_name}>{parameters}{end_tag}
where where
@ -476,7 +482,7 @@ Reminder:
- Always add your sources when using search results to answer the user query - Always add your sources when using search results to answer the user query
You are a helpful assistant. You are a helpful assistant.
Given the previous instructions, what is the weather in New York City? \ Given the previous instructions, what is the weather in New York City? \
Make the response as short as possible. Make the response as short as possible.
""" """
@ -514,6 +520,88 @@ Make the response as short as possible.
f"{generated_text!r}\nError: {str(e)}") f"{generated_text!r}\nError: {str(e)}")
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize(
"model_name, guided_decoding_backend, tokenizer_mode, reasoning_parser, speculative_config", # noqa: E501
[
("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", "xgrammar", "auto",
"deepseek_r1", NGRAM_SPEC_CONFIG),
("Qwen/Qwen3-1.7B", "xgrammar", "auto", "deepseek_r1", None),
],
)
def test_structured_output_with_reasoning_matrices(
monkeypatch: pytest.MonkeyPatch,
guided_decoding_backend: str,
tokenizer_mode: TokenizerMode,
reasoning_parser: str,
model_name: str,
speculative_config: dict[str, Any] | None,
):
monkeypatch.setenv("VLLM_USE_V1", "1")
if current_platform.is_tpu() and speculative_config:
pytest.skip("TPU does not support speculative decoding")
# Use a single LLM instance for several scenarios to
# speed up the test suite.
llm = LLM(
model=model_name,
# Don't use eager execution on TPUs because we want to test for no
# recompilation at runtime
enforce_eager=bool(not current_platform.is_tpu()),
max_model_len=1024,
max_num_seqs=16,
guided_decoding_backend=guided_decoding_backend,
guided_decoding_disable_any_whitespace=True,
tokenizer_mode=tokenizer_mode,
reasoning_parser=reasoning_parser,
speculative_config=speculative_config,
)
tokenizer = llm.get_tokenizer(None)
reasoner = ReasoningParserManager.get_reasoning_parser(reasoning_parser)(
tokenizer=tokenizer)
reasoning_prompt = "Solve the following math problem step-by-step, then provide the final answer as JSON object with a single key 'result'. Make sure to correct your reasoning if there are any issue should it arise.\nProblem: What is 5 * 8 + 2?" # noqa: E501
reasoning_schema = {
"type": "object",
"properties": {
"result": {
"type": "integer"
}
},
"required": ["result"],
"additionalProperties": False
}
if "Qwen3" in model_name:
reasoning_prompt += "<think>\n"
sampling_params = SamplingParams(
temperature=0.1,
max_tokens=8192,
guided_decoding=GuidedDecodingParams(json=reasoning_schema),
)
outputs = llm.generate(
[reasoning_prompt],
sampling_params=sampling_params,
use_tqdm=True,
)
assert outputs is not None
output = outputs[0]
assert output is not None and isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
reasoning_content, content = run_reasoning_extraction(
reasoner, [generated_text])
print(
f"Prompt: {prompt!r}\nReasoning: {reasoning_content!r}\nContent: {content!r}"
)
assert content is not None and reasoning_content is not None
output_json = json.loads(content)
jsonschema.validate(instance=output_json, schema=reasoning_schema)
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("model_name, tokenizer_mode", @pytest.mark.parametrize("model_name, tokenizer_mode",
PARAMS_MODELS_TOKENIZER_MODE) PARAMS_MODELS_TOKENIZER_MODE)

View File

@ -2332,7 +2332,7 @@ class SpeculativeConfig:
`TypicalAcceptanceSampler`.""" `TypicalAcceptanceSampler`."""
speculative_token_tree: Optional[str] = None speculative_token_tree: Optional[str] = None
"""Specifies the tree structure for speculative token generation. """Specifies the tree structure for speculative token generation.
""" """
# required configuration params passed from engine # required configuration params passed from engine
target_model_config: ModelConfig = field(default=None, target_model_config: ModelConfig = field(default=None,
@ -4024,7 +4024,7 @@ class VllmConfig:
"""LoRA configuration.""" """LoRA configuration."""
speculative_config: Optional[SpeculativeConfig] = None speculative_config: Optional[SpeculativeConfig] = None
"""Speculative decoding configuration.""" """Speculative decoding configuration."""
decoding_config: Optional[DecodingConfig] = None decoding_config: DecodingConfig = field(default_factory=DecodingConfig)
"""Decoding configuration.""" """Decoding configuration."""
observability_config: Optional[ObservabilityConfig] = None observability_config: Optional[ObservabilityConfig] = None
"""Observability configuration.""" """Observability configuration."""

View File

@ -1,5 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import os import os
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Sequence from collections.abc import Sequence
@ -33,7 +35,7 @@ class ReasoningParser:
return self.model_tokenizer.get_vocab() return self.model_tokenizer.get_vocab()
@abstractmethod @abstractmethod
def is_reasoning_end(self, input_ids: list[int]) -> bool: def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
""" """
Check if the reasoning content ends in the input_ids. Check if the reasoning content ends in the input_ids.
@ -106,7 +108,7 @@ class ReasoningParserManager:
reasoning_parsers: dict[str, type] = {} reasoning_parsers: dict[str, type] = {}
@classmethod @classmethod
def get_reasoning_parser(cls, name) -> type: def get_reasoning_parser(cls, name: str | None) -> type[ReasoningParser]:
""" """
Get reasoning parser by name which is registered by `register_module`. Get reasoning parser by name which is registered by `register_module`.

View File

@ -758,7 +758,8 @@ class Scheduler(SchedulerInterface):
# the outer lists can be of length > 1. # the outer lists can be of length > 1.
new_logprobs = logprobs.slice(req_index, req_index + 1) new_logprobs = logprobs.slice(req_index, req_index + 1)
if new_token_ids and request.use_structured_output: if new_token_ids and self.structured_output_manager.should_advance(
request):
# NOTE: structured_output_request # NOTE: structured_output_request
# should not be None if use_structured_output, we have # should not be None if use_structured_output, we have
# check above, so safe to ignore type warning # check above, so safe to ignore type warning
@ -767,11 +768,10 @@ class Scheduler(SchedulerInterface):
# Add newly generated spec token ids to the request. # Add newly generated spec token ids to the request.
if spec_token_ids is not None: if spec_token_ids is not None:
if request.use_structured_output: if self.structured_output_manager.should_advance(request):
metadata = request.structured_output_request metadata = request.structured_output_request
assert metadata is not None and metadata.grammar is not None
# Needs to happen after new_token_ids are accepted. # Needs to happen after new_token_ids are accepted.
request.spec_token_ids = metadata.grammar.validate_tokens( request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr]
spec_token_ids[req_index]) spec_token_ids[req_index])
else: else:
request.spec_token_ids = spec_token_ids[req_index] request.spec_token_ids = spec_token_ids[req_index]

View File

@ -7,16 +7,23 @@ from typing import TYPE_CHECKING, Optional
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.reasoning import ReasoningParserManager
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.utils import LazyLoader
from vllm.v1.structured_output.backend_guidance import GuidanceBackend from vllm.v1.structured_output.backend_guidance import GuidanceBackend
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
StructuredOutputGrammar) StructuredOutputGrammar)
from vllm.v1.structured_output.backend_xgrammar import XgrammarBackend
if TYPE_CHECKING: if TYPE_CHECKING:
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
import torch import torch
from vllm.reasoning import ReasoningParser
from vllm.v1.request import Request from vllm.v1.request import Request
else:
torch = LazyLoader("torch", globals(), "torch")
logger = init_logger(__name__) logger = init_logger(__name__)
@ -26,9 +33,11 @@ class StructuredOutputManager:
def __init__(self, vllm_config: VllmConfig): def __init__(self, vllm_config: VllmConfig):
self.backend: Optional[StructuredOutputBackend] = None self.backend: Optional[StructuredOutputBackend] = None
self.reasoner: Optional[ReasoningParser] = None
self.vllm_config = vllm_config self.vllm_config = vllm_config
self._grammar_bitmask: Optional[torch.Tensor] = None self._grammar_bitmask: Optional[torch.Tensor] = None
self._full_mask = torch.tensor(-1, dtype=torch.int32)
# The default max_workers if not specified is the number of CPUs * 5, # The default max_workers if not specified is the number of CPUs * 5,
# which is way too high since these tasks are CPU-bound, not I/O bound. # which is way too high since these tasks are CPU-bound, not I/O bound.
@ -36,24 +45,43 @@ class StructuredOutputManager:
# compilation, so we set it to half the number of CPUs. # compilation, so we set it to half the number of CPUs.
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2) max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
self.executor = ThreadPoolExecutor(max_workers=max_workers) self.executor = ThreadPoolExecutor(max_workers=max_workers)
self.tokenizer = init_tokenizer_from_configs(
model_config=self.vllm_config.model_config,
scheduler_config=self.vllm_config.scheduler_config,
lora_config=self.vllm_config.lora_config,
).get_lora_tokenizer(None)
reasoning_backend = vllm_config.decoding_config.reasoning_backend
if reasoning_backend:
reasoner_cls = ReasoningParserManager.get_reasoning_parser(
reasoning_backend)
self.reasoner = reasoner_cls(tokenizer=self.tokenizer)
def grammar_init(self, request: Request) -> None: def grammar_init(self, request: Request) -> None:
if request.structured_output_request is None: if request.structured_output_request is None:
return return
if TYPE_CHECKING:
assert request.sampling_params.guided_decoding is not None
# Initialize the backend the first time it is needed. # Initialize the backend the first time it is needed.
# #
# NOTE: We only support a single backend. We do NOT support different # NOTE: We only support a single backend. We do NOT support different
# backends on a per-request basis in V1 (for now, anyway...). # backends on a per-request basis in V1 (for now, anyway...).
if self.backend is None: if self.backend is None:
backend = request.sampling_params.guided_decoding.backend backend = request.sampling_params.guided_decoding.backend
vocab_size = self.vllm_config.model_config.get_vocab_size()
if backend == "xgrammar": if backend == "xgrammar":
from vllm.v1.structured_output.backend_xgrammar import ( self.backend = XgrammarBackend(
XgrammarBackend) self.vllm_config,
tokenizer=self.tokenizer,
self.backend = XgrammarBackend(self.vllm_config) vocab_size=vocab_size,
)
elif backend == "guidance": elif backend == "guidance":
self.backend = GuidanceBackend(self.vllm_config) self.backend = GuidanceBackend(
self.vllm_config,
tokenizer=self.tokenizer,
vocab_size=vocab_size,
)
else: else:
raise ValueError( raise ValueError(
f"Unsupported structured output backend: {backend}") f"Unsupported structured output backend: {backend}")
@ -87,14 +115,14 @@ class StructuredOutputManager:
if not structured_output_request_ids: if not structured_output_request_ids:
return None return None
max_num_spec_tokens = 0
if self.vllm_config.speculative_config is not None:
max_num_spec_tokens = \
self.vllm_config.speculative_config.num_speculative_tokens
if self._grammar_bitmask is None: if self._grammar_bitmask is None:
assert self.backend is not None assert self.backend is not None
max_batch_size = self.vllm_config.scheduler_config.max_num_seqs max_batch_size = self.vllm_config.scheduler_config.max_num_seqs
if self.vllm_config.speculative_config is not None:
max_num_spec_tokens = self.vllm_config.\
speculative_config.num_speculative_tokens
else:
max_num_spec_tokens = 0
# Allocate a bitmask for each token needing to be checked: # Allocate a bitmask for each token needing to be checked:
# one for each speculative position, and one more for the # one for each speculative position, and one more for the
@ -103,6 +131,7 @@ class StructuredOutputManager:
self.backend.allocate_token_bitmask( self.backend.allocate_token_bitmask(
max_batch_size * (1 + max_num_spec_tokens)) max_batch_size * (1 + max_num_spec_tokens))
bitmask_tensor = self._grammar_bitmask
# Generate a batched bitmask for all structured output requests. # Generate a batched bitmask for all structured output requests.
# When speculative decoding is enabled, we need to include multiple # When speculative decoding is enabled, we need to include multiple
# masks for each request, one for each possible bonus token position. # masks for each request, one for each possible bonus token position.
@ -110,16 +139,30 @@ class StructuredOutputManager:
cumulative_index = 0 cumulative_index = 0
ordered_seq = sorted(structured_output_request_ids.items(), ordered_seq = sorted(structured_output_request_ids.items(),
key=lambda x: x[1]) key=lambda x: x[1])
# Note that for thinking support, we will need to
# reset the relevant part of the bitmask for consequent
# request here.
bitmask_tensor[:(len(ordered_seq) * (1 + max_num_spec_tokens))].fill_(
self._full_mask)
# NOTE: This outer loop can likely be parallelized to improve # NOTE: This outer loop can likely be parallelized to improve
# performance of bitmask generation for large batches. # performance of bitmask generation for large batches.
for req_id, _ in ordered_seq: for req_id, _ in ordered_seq:
request = requests[req_id].structured_output_request request = requests[req_id].structured_output_request
assert request is not None and request.grammar is not None if TYPE_CHECKING:
assert request is not None
assert request.grammar is not None
apply_bitmask = (
request.reasoning_ended if self.reasoner is not None else True
) # noqa: E501
state_advancements = 0 state_advancements = 0
req_tokens = scheduled_spec_decode_tokens.get(req_id, []) + [None] req_tokens = scheduled_spec_decode_tokens.get(req_id, []) + [None]
for i, token in enumerate(req_tokens): for i, token in enumerate(req_tokens):
if not request.grammar.is_terminated(): if apply_bitmask and not request.grammar.is_terminated():
request.grammar.fill_bitmask(self._grammar_bitmask, request.grammar.fill_bitmask(bitmask_tensor,
cumulative_index) cumulative_index)
if token is not None: if token is not None:
# In order to generate the correct bitmask for each # In order to generate the correct bitmask for each
@ -132,15 +175,41 @@ class StructuredOutputManager:
if state_advancements > 0: if state_advancements > 0:
request.grammar.rollback(state_advancements) request.grammar.rollback(state_advancements)
bitmask_tensor = self._grammar_bitmask if cumulative_index < bitmask_tensor.shape[0]:
if cumulative_index < self._grammar_bitmask.shape[0]: bitmask_tensor = bitmask_tensor[:cumulative_index]
bitmask_tensor = self._grammar_bitmask[:cumulative_index]
# After finishing with the xgrammar operations, we convert to # After finishing with the xgrammar operations, we convert to
# np.ndarray, because that is much more efficient for serialization # np.ndarray, because that is much more efficient for serialization
# and deserialization when sending this to the GPU workers. # and deserialization when sending this to the GPU workers.
return bitmask_tensor.numpy() return bitmask_tensor.numpy()
def should_advance(self, request: Request) -> bool:
if not request.use_structured_output:
return False
# To determine whether we can advance the FSM.
# Supports thinking usage where we skip the reasoning components.
if TYPE_CHECKING:
assert request.structured_output_request is not None
assert request.structured_output_request.grammar is not None
# by default, we should always advance
# for cases that doesn't uses thinking mode.
if self.reasoner is not None:
structured_req = request.structured_output_request
if structured_req.reasoning_ended:
return True
# Check if reasoning ends in *this* step
if self.reasoner.is_reasoning_end(request.all_token_ids):
# Reasoning just ended, so we shouldn't advanced til
# next pass
structured_req.reasoning_ended = True
return False
else:
return True
def clear_backend(self) -> None: def clear_backend(self) -> None:
if self.backend is not None: if self.backend is not None:
self.backend.destroy() self.backend.destroy()

View File

@ -1,5 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import copy import copy
import json import json
import os import os
@ -8,10 +10,8 @@ from typing import TYPE_CHECKING, Any, Optional, Union
import torch import torch
from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.utils import LazyLoader from vllm.utils import LazyLoader
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
StructuredOutputGrammar, StructuredOutputGrammar,
@ -54,25 +54,17 @@ def process_for_additional_properties(
return guide_json_obj return guide_json_obj
@dataclass
class GuidanceBackend(StructuredOutputBackend): class GuidanceBackend(StructuredOutputBackend):
def __init__(self, vllm_config: VllmConfig): def __post_init__(self):
self.vllm_config = vllm_config
tokenizer_group = init_tokenizer_from_configs(
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
lora_config=vllm_config.lora_config) # type: ignore[arg-type]
self.vllm_config = vllm_config
self.vocab_size = vllm_config.model_config.get_vocab_size()
self.disable_any_whitespace = \ self.disable_any_whitespace = \
vllm_config.decoding_config.disable_any_whitespace self.vllm_config.decoding_config.disable_any_whitespace
self.disable_additional_properties = \ self.disable_additional_properties = \
vllm_config.decoding_config.disable_additional_properties self.vllm_config.decoding_config.disable_additional_properties
tokenizer = tokenizer_group.get_lora_tokenizer(None)
self.ll_tokenizer = llguidance_hf.from_tokenizer( self.ll_tokenizer = llguidance_hf.from_tokenizer(
tokenizer, self.vocab_size) self.tokenizer, self.vocab_size)
def compile_grammar(self, request_type: StructuredOutputOptions, def compile_grammar(self, request_type: StructuredOutputOptions,
grammar_spec: str) -> StructuredOutputGrammar: grammar_spec: str) -> StructuredOutputGrammar:

View File

@ -1,9 +1,17 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import enum import enum
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING
import torch if TYPE_CHECKING:
import torch
from vllm.config import VllmConfig
from vllm.transformers_utils.tokenizer import AnyTokenizer
class StructuredOutputOptions(enum.Enum): class StructuredOutputOptions(enum.Enum):
@ -85,9 +93,14 @@ class StructuredOutputGrammar(ABC):
""" """
@dataclass
class StructuredOutputBackend(ABC): class StructuredOutputBackend(ABC):
"""Engine-level backend for structured output requests.""" """Engine-level backend for structured output requests."""
vllm_config: VllmConfig
tokenizer: AnyTokenizer
vocab_size: int
@abstractmethod @abstractmethod
def compile_grammar(self, request_type: StructuredOutputOptions, def compile_grammar(self, request_type: StructuredOutputOptions,
grammar_spec: str) -> StructuredOutputGrammar: grammar_spec: str) -> StructuredOutputGrammar:
@ -104,7 +117,7 @@ class StructuredOutputBackend(ABC):
""" """
@abstractmethod @abstractmethod
def allocate_token_bitmask(self, max_num_seqs: int): def allocate_token_bitmask(self, max_num_seqs: int) -> torch.Tensor:
""" """
Allocates a token bitmask for the specified maximum number of sequences. Allocates a token bitmask for the specified maximum number of sequences.

View File

@ -1,5 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import json import json
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
@ -7,10 +9,8 @@ from typing import TYPE_CHECKING, Any
import torch import torch
import vllm.envs import vllm.envs
from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.utils import LazyLoader from vllm.utils import LazyLoader
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
@ -28,61 +28,49 @@ else:
logger = init_logger(__name__) logger = init_logger(__name__)
@dataclass
class XgrammarBackend(StructuredOutputBackend): class XgrammarBackend(StructuredOutputBackend):
def __init__(self, vllm_config: VllmConfig): def __post_init__(self):
self.vllm_config = vllm_config
tokenizer_group = init_tokenizer_from_configs(
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
lora_config=vllm_config.lora_config) # type: ignore[arg-type]
self.disable_any_whitespace = \ self.disable_any_whitespace = \
vllm_config.decoding_config.disable_any_whitespace self.vllm_config.decoding_config.disable_any_whitespace
self.num_speculative_tokens = 0 if isinstance(self.tokenizer, MistralTokenizer):
if self.vllm_config.speculative_config is not None:
self.num_speculative_tokens = \
self.vllm_config.speculative_config.num_speculative_tokens
tokenizer = tokenizer_group.get_lora_tokenizer(None)
self.vocab_size = vllm_config.model_config.get_vocab_size()
if isinstance(tokenizer, MistralTokenizer):
# NOTE: ideally, xgrammar should handle this accordingly. # NOTE: ideally, xgrammar should handle this accordingly.
# refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98 # refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98
try: try:
if tokenizer.is_tekken: if self.tokenizer.is_tekken:
encoded_vocab = tokenizer._vocab encoded_vocab = self.tokenizer._vocab
else: else:
encoded_vocab = [ encoded_vocab = [
token for token, _ in sorted( token for token, _ in sorted(
tokenizer.get_vocab().items(), self.tokenizer.get_vocab().items(),
key=lambda x: x[1], key=lambda x: x[1],
) )
] ]
stop_token_ids = None stop_token_ids = None
if hasattr( if (hasattr(
tokenizer, self.tokenizer,
"eos_token_id", "eos_token_id",
) and tokenizer.eos_token_id is not None: ) and self.tokenizer.eos_token_id is not None):
stop_token_ids = [tokenizer.eos_token_id] stop_token_ids = [self.tokenizer.eos_token_id]
except AttributeError as e: except AttributeError as e:
raise ValueError( raise ValueError(
f"Cannot get the vocabulary of the tokenizer " f"Cannot get the vocabulary of the tokenizer "
f"{type(tokenizer)}. The tokenizer should have a " f"{type(self.tokenizer)}. The tokenizer should have a "
"get_vocab method.") from e "get_vocab method.") from e
tokenizer_info = xgr.TokenizerInfo( # type: ignore tokenizer_info = xgr.TokenizerInfo( # type: ignore
encoded_vocab=encoded_vocab, encoded_vocab=encoded_vocab,
# NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501 # NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501
vocab_type=xgr.VocabType.RAW vocab_type=xgr.VocabType.RAW
if tokenizer.is_tekken else xgr.VocabType.BYTE_FALLBACK, if self.tokenizer.is_tekken else xgr.VocabType.BYTE_FALLBACK,
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
stop_token_ids=stop_token_ids, stop_token_ids=stop_token_ids,
add_prefix_space=True, add_prefix_space=True,
) )
else: else:
tokenizer_info = xgr.TokenizerInfo.from_huggingface( tokenizer_info = xgr.TokenizerInfo.from_huggingface(
tokenizer, self.tokenizer,
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
) )
self.compiler = xgr.GrammarCompiler( self.compiler = xgr.GrammarCompiler(
@ -92,6 +80,11 @@ class XgrammarBackend(StructuredOutputBackend):
cache_limit_bytes=vllm.envs.VLLM_XGRAMMAR_CACHE_MB * 1024 * 1024, cache_limit_bytes=vllm.envs.VLLM_XGRAMMAR_CACHE_MB * 1024 * 1024,
) )
self.num_speculative_tokens = 0
if self.vllm_config.speculative_config is not None:
self.num_speculative_tokens = \
self.vllm_config.speculative_config.num_speculative_tokens
def compile_grammar(self, request_type: StructuredOutputOptions, def compile_grammar(self, request_type: StructuredOutputOptions,
grammar_spec: str) -> StructuredOutputGrammar: grammar_spec: str) -> StructuredOutputGrammar:
if request_type == StructuredOutputOptions.JSON: if request_type == StructuredOutputOptions.JSON:

View File

@ -20,6 +20,7 @@ class StructuredOutputRequest:
sampling_params: SamplingParams sampling_params: SamplingParams
_grammar: Optional[Union[Future[StructuredOutputGrammar], _grammar: Optional[Union[Future[StructuredOutputGrammar],
StructuredOutputGrammar]] = None StructuredOutputGrammar]] = None
reasoning_ended: bool = False
def _check_grammar_completion(self) -> bool: def _check_grammar_completion(self) -> bool:
# NOTE: We have to lazy import to gate circular imports # NOTE: We have to lazy import to gate circular imports