mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 22:04:58 +08:00
[V1] Structured Outputs + Thinking compatibility (#16577)
Signed-off-by: Aaron Pham <contact@aarnphm.xyz> Co-authored-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
parent
d93c976a0d
commit
2fc9075b82
@ -141,10 +141,10 @@ Remember to check whether the `reasoning_content` exists in the response before
|
|||||||
The reasoning content is also available in the structured output. The structured output engine like `xgrammar` will use the reasoning content to generate structured output. It is only supported in v0 engine now.
|
The reasoning content is also available in the structured output. The structured output engine like `xgrammar` will use the reasoning content to generate structured output. It is only supported in v0 engine now.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
VLLM_USE_V1=0 vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B --reasoning-parser deepseek_r1
|
vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B --reasoning-parser deepseek_r1
|
||||||
```
|
```
|
||||||
|
|
||||||
Please note that the `VLLM_USE_V1` environment variable must be set to `0` to use the v0 engine.
|
The following is an example client:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
# ruff: noqa: E501
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@ -5,17 +6,22 @@ from __future__ import annotations
|
|||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
import jsonschema
|
import jsonschema
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from tests.reasoning.utils import run_reasoning_extraction
|
||||||
from vllm.entrypoints.llm import LLM
|
from vllm.entrypoints.llm import LLM
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager
|
||||||
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.config import TokenizerMode
|
||||||
|
|
||||||
NGRAM_SPEC_CONFIG = {
|
NGRAM_SPEC_CONFIG = {
|
||||||
"model": "[ngram]",
|
"model": "[ngram]",
|
||||||
"num_speculative_tokens": 5,
|
"num_speculative_tokens": 5,
|
||||||
@ -444,7 +450,7 @@ def test_structured_output(
|
|||||||
|
|
||||||
prompt = """
|
prompt = """
|
||||||
You have access to the following function to retrieve the weather in a city:
|
You have access to the following function to retrieve the weather in a city:
|
||||||
|
|
||||||
{
|
{
|
||||||
"name": "get_weather",
|
"name": "get_weather",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
@ -455,7 +461,7 @@ You have access to the following function to retrieve the weather in a city:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
If a you choose to call a function ONLY reply in the following format:
|
If a you choose to call a function ONLY reply in the following format:
|
||||||
<{start_tag}={function_name}>{parameters}{end_tag}
|
<{start_tag}={function_name}>{parameters}{end_tag}
|
||||||
where
|
where
|
||||||
@ -476,7 +482,7 @@ Reminder:
|
|||||||
- Always add your sources when using search results to answer the user query
|
- Always add your sources when using search results to answer the user query
|
||||||
|
|
||||||
You are a helpful assistant.
|
You are a helpful assistant.
|
||||||
|
|
||||||
Given the previous instructions, what is the weather in New York City? \
|
Given the previous instructions, what is the weather in New York City? \
|
||||||
Make the response as short as possible.
|
Make the response as short as possible.
|
||||||
"""
|
"""
|
||||||
@ -514,6 +520,88 @@ Make the response as short as possible.
|
|||||||
f"{generated_text!r}\nError: {str(e)}")
|
f"{generated_text!r}\nError: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip_global_cleanup
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_name, guided_decoding_backend, tokenizer_mode, reasoning_parser, speculative_config", # noqa: E501
|
||||||
|
[
|
||||||
|
("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", "xgrammar", "auto",
|
||||||
|
"deepseek_r1", NGRAM_SPEC_CONFIG),
|
||||||
|
("Qwen/Qwen3-1.7B", "xgrammar", "auto", "deepseek_r1", None),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_structured_output_with_reasoning_matrices(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
guided_decoding_backend: str,
|
||||||
|
tokenizer_mode: TokenizerMode,
|
||||||
|
reasoning_parser: str,
|
||||||
|
model_name: str,
|
||||||
|
speculative_config: dict[str, Any] | None,
|
||||||
|
):
|
||||||
|
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||||
|
|
||||||
|
if current_platform.is_tpu() and speculative_config:
|
||||||
|
pytest.skip("TPU does not support speculative decoding")
|
||||||
|
|
||||||
|
# Use a single LLM instance for several scenarios to
|
||||||
|
# speed up the test suite.
|
||||||
|
llm = LLM(
|
||||||
|
model=model_name,
|
||||||
|
# Don't use eager execution on TPUs because we want to test for no
|
||||||
|
# recompilation at runtime
|
||||||
|
enforce_eager=bool(not current_platform.is_tpu()),
|
||||||
|
max_model_len=1024,
|
||||||
|
max_num_seqs=16,
|
||||||
|
guided_decoding_backend=guided_decoding_backend,
|
||||||
|
guided_decoding_disable_any_whitespace=True,
|
||||||
|
tokenizer_mode=tokenizer_mode,
|
||||||
|
reasoning_parser=reasoning_parser,
|
||||||
|
speculative_config=speculative_config,
|
||||||
|
)
|
||||||
|
tokenizer = llm.get_tokenizer(None)
|
||||||
|
reasoner = ReasoningParserManager.get_reasoning_parser(reasoning_parser)(
|
||||||
|
tokenizer=tokenizer)
|
||||||
|
|
||||||
|
reasoning_prompt = "Solve the following math problem step-by-step, then provide the final answer as JSON object with a single key 'result'. Make sure to correct your reasoning if there are any issue should it arise.\nProblem: What is 5 * 8 + 2?" # noqa: E501
|
||||||
|
reasoning_schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"result": {
|
||||||
|
"type": "integer"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["result"],
|
||||||
|
"additionalProperties": False
|
||||||
|
}
|
||||||
|
if "Qwen3" in model_name:
|
||||||
|
reasoning_prompt += "<think>\n"
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=0.1,
|
||||||
|
max_tokens=8192,
|
||||||
|
guided_decoding=GuidedDecodingParams(json=reasoning_schema),
|
||||||
|
)
|
||||||
|
outputs = llm.generate(
|
||||||
|
[reasoning_prompt],
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
use_tqdm=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert outputs is not None
|
||||||
|
output = outputs[0]
|
||||||
|
assert output is not None and isinstance(output, RequestOutput)
|
||||||
|
prompt = output.prompt
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
reasoning_content, content = run_reasoning_extraction(
|
||||||
|
reasoner, [generated_text])
|
||||||
|
print(
|
||||||
|
f"Prompt: {prompt!r}\nReasoning: {reasoning_content!r}\nContent: {content!r}"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert content is not None and reasoning_content is not None
|
||||||
|
output_json = json.loads(content)
|
||||||
|
jsonschema.validate(instance=output_json, schema=reasoning_schema)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip_global_cleanup
|
@pytest.mark.skip_global_cleanup
|
||||||
@pytest.mark.parametrize("model_name, tokenizer_mode",
|
@pytest.mark.parametrize("model_name, tokenizer_mode",
|
||||||
PARAMS_MODELS_TOKENIZER_MODE)
|
PARAMS_MODELS_TOKENIZER_MODE)
|
||||||
|
|||||||
@ -2332,7 +2332,7 @@ class SpeculativeConfig:
|
|||||||
`TypicalAcceptanceSampler`."""
|
`TypicalAcceptanceSampler`."""
|
||||||
|
|
||||||
speculative_token_tree: Optional[str] = None
|
speculative_token_tree: Optional[str] = None
|
||||||
"""Specifies the tree structure for speculative token generation.
|
"""Specifies the tree structure for speculative token generation.
|
||||||
"""
|
"""
|
||||||
# required configuration params passed from engine
|
# required configuration params passed from engine
|
||||||
target_model_config: ModelConfig = field(default=None,
|
target_model_config: ModelConfig = field(default=None,
|
||||||
@ -4024,7 +4024,7 @@ class VllmConfig:
|
|||||||
"""LoRA configuration."""
|
"""LoRA configuration."""
|
||||||
speculative_config: Optional[SpeculativeConfig] = None
|
speculative_config: Optional[SpeculativeConfig] = None
|
||||||
"""Speculative decoding configuration."""
|
"""Speculative decoding configuration."""
|
||||||
decoding_config: Optional[DecodingConfig] = None
|
decoding_config: DecodingConfig = field(default_factory=DecodingConfig)
|
||||||
"""Decoding configuration."""
|
"""Decoding configuration."""
|
||||||
observability_config: Optional[ObservabilityConfig] = None
|
observability_config: Optional[ObservabilityConfig] = None
|
||||||
"""Observability configuration."""
|
"""Observability configuration."""
|
||||||
|
|||||||
@ -1,5 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
@ -33,7 +35,7 @@ class ReasoningParser:
|
|||||||
return self.model_tokenizer.get_vocab()
|
return self.model_tokenizer.get_vocab()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if the reasoning content ends in the input_ids.
|
Check if the reasoning content ends in the input_ids.
|
||||||
|
|
||||||
@ -106,7 +108,7 @@ class ReasoningParserManager:
|
|||||||
reasoning_parsers: dict[str, type] = {}
|
reasoning_parsers: dict[str, type] = {}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_reasoning_parser(cls, name) -> type:
|
def get_reasoning_parser(cls, name: str | None) -> type[ReasoningParser]:
|
||||||
"""
|
"""
|
||||||
Get reasoning parser by name which is registered by `register_module`.
|
Get reasoning parser by name which is registered by `register_module`.
|
||||||
|
|
||||||
|
|||||||
@ -758,7 +758,8 @@ class Scheduler(SchedulerInterface):
|
|||||||
# the outer lists can be of length > 1.
|
# the outer lists can be of length > 1.
|
||||||
new_logprobs = logprobs.slice(req_index, req_index + 1)
|
new_logprobs = logprobs.slice(req_index, req_index + 1)
|
||||||
|
|
||||||
if new_token_ids and request.use_structured_output:
|
if new_token_ids and self.structured_output_manager.should_advance(
|
||||||
|
request):
|
||||||
# NOTE: structured_output_request
|
# NOTE: structured_output_request
|
||||||
# should not be None if use_structured_output, we have
|
# should not be None if use_structured_output, we have
|
||||||
# check above, so safe to ignore type warning
|
# check above, so safe to ignore type warning
|
||||||
@ -767,11 +768,10 @@ class Scheduler(SchedulerInterface):
|
|||||||
|
|
||||||
# Add newly generated spec token ids to the request.
|
# Add newly generated spec token ids to the request.
|
||||||
if spec_token_ids is not None:
|
if spec_token_ids is not None:
|
||||||
if request.use_structured_output:
|
if self.structured_output_manager.should_advance(request):
|
||||||
metadata = request.structured_output_request
|
metadata = request.structured_output_request
|
||||||
assert metadata is not None and metadata.grammar is not None
|
|
||||||
# Needs to happen after new_token_ids are accepted.
|
# Needs to happen after new_token_ids are accepted.
|
||||||
request.spec_token_ids = metadata.grammar.validate_tokens(
|
request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr]
|
||||||
spec_token_ids[req_index])
|
spec_token_ids[req_index])
|
||||||
else:
|
else:
|
||||||
request.spec_token_ids = spec_token_ids[req_index]
|
request.spec_token_ids = spec_token_ids[req_index]
|
||||||
|
|||||||
@ -7,16 +7,23 @@ from typing import TYPE_CHECKING, Optional
|
|||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.reasoning import ReasoningParserManager
|
||||||
|
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||||
|
from vllm.utils import LazyLoader
|
||||||
from vllm.v1.structured_output.backend_guidance import GuidanceBackend
|
from vllm.v1.structured_output.backend_guidance import GuidanceBackend
|
||||||
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
|
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
|
||||||
StructuredOutputGrammar)
|
StructuredOutputGrammar)
|
||||||
|
from vllm.v1.structured_output.backend_xgrammar import XgrammarBackend
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.reasoning import ReasoningParser
|
||||||
from vllm.v1.request import Request
|
from vllm.v1.request import Request
|
||||||
|
else:
|
||||||
|
torch = LazyLoader("torch", globals(), "torch")
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -26,9 +33,11 @@ class StructuredOutputManager:
|
|||||||
|
|
||||||
def __init__(self, vllm_config: VllmConfig):
|
def __init__(self, vllm_config: VllmConfig):
|
||||||
self.backend: Optional[StructuredOutputBackend] = None
|
self.backend: Optional[StructuredOutputBackend] = None
|
||||||
|
self.reasoner: Optional[ReasoningParser] = None
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
|
|
||||||
self._grammar_bitmask: Optional[torch.Tensor] = None
|
self._grammar_bitmask: Optional[torch.Tensor] = None
|
||||||
|
self._full_mask = torch.tensor(-1, dtype=torch.int32)
|
||||||
|
|
||||||
# The default max_workers if not specified is the number of CPUs * 5,
|
# The default max_workers if not specified is the number of CPUs * 5,
|
||||||
# which is way too high since these tasks are CPU-bound, not I/O bound.
|
# which is way too high since these tasks are CPU-bound, not I/O bound.
|
||||||
@ -36,24 +45,43 @@ class StructuredOutputManager:
|
|||||||
# compilation, so we set it to half the number of CPUs.
|
# compilation, so we set it to half the number of CPUs.
|
||||||
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
|
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
|
||||||
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||||
|
self.tokenizer = init_tokenizer_from_configs(
|
||||||
|
model_config=self.vllm_config.model_config,
|
||||||
|
scheduler_config=self.vllm_config.scheduler_config,
|
||||||
|
lora_config=self.vllm_config.lora_config,
|
||||||
|
).get_lora_tokenizer(None)
|
||||||
|
reasoning_backend = vllm_config.decoding_config.reasoning_backend
|
||||||
|
if reasoning_backend:
|
||||||
|
reasoner_cls = ReasoningParserManager.get_reasoning_parser(
|
||||||
|
reasoning_backend)
|
||||||
|
self.reasoner = reasoner_cls(tokenizer=self.tokenizer)
|
||||||
|
|
||||||
def grammar_init(self, request: Request) -> None:
|
def grammar_init(self, request: Request) -> None:
|
||||||
if request.structured_output_request is None:
|
if request.structured_output_request is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
assert request.sampling_params.guided_decoding is not None
|
||||||
|
|
||||||
# Initialize the backend the first time it is needed.
|
# Initialize the backend the first time it is needed.
|
||||||
#
|
#
|
||||||
# NOTE: We only support a single backend. We do NOT support different
|
# NOTE: We only support a single backend. We do NOT support different
|
||||||
# backends on a per-request basis in V1 (for now, anyway...).
|
# backends on a per-request basis in V1 (for now, anyway...).
|
||||||
if self.backend is None:
|
if self.backend is None:
|
||||||
backend = request.sampling_params.guided_decoding.backend
|
backend = request.sampling_params.guided_decoding.backend
|
||||||
|
vocab_size = self.vllm_config.model_config.get_vocab_size()
|
||||||
if backend == "xgrammar":
|
if backend == "xgrammar":
|
||||||
from vllm.v1.structured_output.backend_xgrammar import (
|
self.backend = XgrammarBackend(
|
||||||
XgrammarBackend)
|
self.vllm_config,
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
self.backend = XgrammarBackend(self.vllm_config)
|
vocab_size=vocab_size,
|
||||||
|
)
|
||||||
elif backend == "guidance":
|
elif backend == "guidance":
|
||||||
self.backend = GuidanceBackend(self.vllm_config)
|
self.backend = GuidanceBackend(
|
||||||
|
self.vllm_config,
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported structured output backend: {backend}")
|
f"Unsupported structured output backend: {backend}")
|
||||||
@ -87,14 +115,14 @@ class StructuredOutputManager:
|
|||||||
if not structured_output_request_ids:
|
if not structured_output_request_ids:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
max_num_spec_tokens = 0
|
||||||
|
if self.vllm_config.speculative_config is not None:
|
||||||
|
max_num_spec_tokens = \
|
||||||
|
self.vllm_config.speculative_config.num_speculative_tokens
|
||||||
|
|
||||||
if self._grammar_bitmask is None:
|
if self._grammar_bitmask is None:
|
||||||
assert self.backend is not None
|
assert self.backend is not None
|
||||||
max_batch_size = self.vllm_config.scheduler_config.max_num_seqs
|
max_batch_size = self.vllm_config.scheduler_config.max_num_seqs
|
||||||
if self.vllm_config.speculative_config is not None:
|
|
||||||
max_num_spec_tokens = self.vllm_config.\
|
|
||||||
speculative_config.num_speculative_tokens
|
|
||||||
else:
|
|
||||||
max_num_spec_tokens = 0
|
|
||||||
|
|
||||||
# Allocate a bitmask for each token needing to be checked:
|
# Allocate a bitmask for each token needing to be checked:
|
||||||
# one for each speculative position, and one more for the
|
# one for each speculative position, and one more for the
|
||||||
@ -103,6 +131,7 @@ class StructuredOutputManager:
|
|||||||
self.backend.allocate_token_bitmask(
|
self.backend.allocate_token_bitmask(
|
||||||
max_batch_size * (1 + max_num_spec_tokens))
|
max_batch_size * (1 + max_num_spec_tokens))
|
||||||
|
|
||||||
|
bitmask_tensor = self._grammar_bitmask
|
||||||
# Generate a batched bitmask for all structured output requests.
|
# Generate a batched bitmask for all structured output requests.
|
||||||
# When speculative decoding is enabled, we need to include multiple
|
# When speculative decoding is enabled, we need to include multiple
|
||||||
# masks for each request, one for each possible bonus token position.
|
# masks for each request, one for each possible bonus token position.
|
||||||
@ -110,16 +139,30 @@ class StructuredOutputManager:
|
|||||||
cumulative_index = 0
|
cumulative_index = 0
|
||||||
ordered_seq = sorted(structured_output_request_ids.items(),
|
ordered_seq = sorted(structured_output_request_ids.items(),
|
||||||
key=lambda x: x[1])
|
key=lambda x: x[1])
|
||||||
|
|
||||||
|
# Note that for thinking support, we will need to
|
||||||
|
# reset the relevant part of the bitmask for consequent
|
||||||
|
# request here.
|
||||||
|
bitmask_tensor[:(len(ordered_seq) * (1 + max_num_spec_tokens))].fill_(
|
||||||
|
self._full_mask)
|
||||||
|
|
||||||
# NOTE: This outer loop can likely be parallelized to improve
|
# NOTE: This outer loop can likely be parallelized to improve
|
||||||
# performance of bitmask generation for large batches.
|
# performance of bitmask generation for large batches.
|
||||||
for req_id, _ in ordered_seq:
|
for req_id, _ in ordered_seq:
|
||||||
request = requests[req_id].structured_output_request
|
request = requests[req_id].structured_output_request
|
||||||
assert request is not None and request.grammar is not None
|
if TYPE_CHECKING:
|
||||||
|
assert request is not None
|
||||||
|
assert request.grammar is not None
|
||||||
|
|
||||||
|
apply_bitmask = (
|
||||||
|
request.reasoning_ended if self.reasoner is not None else True
|
||||||
|
) # noqa: E501
|
||||||
|
|
||||||
state_advancements = 0
|
state_advancements = 0
|
||||||
req_tokens = scheduled_spec_decode_tokens.get(req_id, []) + [None]
|
req_tokens = scheduled_spec_decode_tokens.get(req_id, []) + [None]
|
||||||
for i, token in enumerate(req_tokens):
|
for i, token in enumerate(req_tokens):
|
||||||
if not request.grammar.is_terminated():
|
if apply_bitmask and not request.grammar.is_terminated():
|
||||||
request.grammar.fill_bitmask(self._grammar_bitmask,
|
request.grammar.fill_bitmask(bitmask_tensor,
|
||||||
cumulative_index)
|
cumulative_index)
|
||||||
if token is not None:
|
if token is not None:
|
||||||
# In order to generate the correct bitmask for each
|
# In order to generate the correct bitmask for each
|
||||||
@ -132,15 +175,41 @@ class StructuredOutputManager:
|
|||||||
if state_advancements > 0:
|
if state_advancements > 0:
|
||||||
request.grammar.rollback(state_advancements)
|
request.grammar.rollback(state_advancements)
|
||||||
|
|
||||||
bitmask_tensor = self._grammar_bitmask
|
if cumulative_index < bitmask_tensor.shape[0]:
|
||||||
if cumulative_index < self._grammar_bitmask.shape[0]:
|
bitmask_tensor = bitmask_tensor[:cumulative_index]
|
||||||
bitmask_tensor = self._grammar_bitmask[:cumulative_index]
|
|
||||||
|
|
||||||
# After finishing with the xgrammar operations, we convert to
|
# After finishing with the xgrammar operations, we convert to
|
||||||
# np.ndarray, because that is much more efficient for serialization
|
# np.ndarray, because that is much more efficient for serialization
|
||||||
# and deserialization when sending this to the GPU workers.
|
# and deserialization when sending this to the GPU workers.
|
||||||
return bitmask_tensor.numpy()
|
return bitmask_tensor.numpy()
|
||||||
|
|
||||||
|
def should_advance(self, request: Request) -> bool:
|
||||||
|
if not request.use_structured_output:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# To determine whether we can advance the FSM.
|
||||||
|
# Supports thinking usage where we skip the reasoning components.
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
assert request.structured_output_request is not None
|
||||||
|
assert request.structured_output_request.grammar is not None
|
||||||
|
# by default, we should always advance
|
||||||
|
# for cases that doesn't uses thinking mode.
|
||||||
|
if self.reasoner is not None:
|
||||||
|
structured_req = request.structured_output_request
|
||||||
|
|
||||||
|
if structured_req.reasoning_ended:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check if reasoning ends in *this* step
|
||||||
|
if self.reasoner.is_reasoning_end(request.all_token_ids):
|
||||||
|
# Reasoning just ended, so we shouldn't advanced til
|
||||||
|
# next pass
|
||||||
|
structured_req.reasoning_ended = True
|
||||||
|
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
|
||||||
def clear_backend(self) -> None:
|
def clear_backend(self) -> None:
|
||||||
if self.backend is not None:
|
if self.backend is not None:
|
||||||
self.backend.destroy()
|
self.backend.destroy()
|
||||||
|
|||||||
@ -1,5 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
@ -8,10 +10,8 @@ from typing import TYPE_CHECKING, Any, Optional, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
|
||||||
from vllm.utils import LazyLoader
|
from vllm.utils import LazyLoader
|
||||||
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
|
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
|
||||||
StructuredOutputGrammar,
|
StructuredOutputGrammar,
|
||||||
@ -54,25 +54,17 @@ def process_for_additional_properties(
|
|||||||
return guide_json_obj
|
return guide_json_obj
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class GuidanceBackend(StructuredOutputBackend):
|
class GuidanceBackend(StructuredOutputBackend):
|
||||||
|
|
||||||
def __init__(self, vllm_config: VllmConfig):
|
def __post_init__(self):
|
||||||
self.vllm_config = vllm_config
|
|
||||||
tokenizer_group = init_tokenizer_from_configs(
|
|
||||||
model_config=vllm_config.model_config,
|
|
||||||
scheduler_config=vllm_config.scheduler_config,
|
|
||||||
lora_config=vllm_config.lora_config) # type: ignore[arg-type]
|
|
||||||
self.vllm_config = vllm_config
|
|
||||||
self.vocab_size = vllm_config.model_config.get_vocab_size()
|
|
||||||
|
|
||||||
self.disable_any_whitespace = \
|
self.disable_any_whitespace = \
|
||||||
vllm_config.decoding_config.disable_any_whitespace
|
self.vllm_config.decoding_config.disable_any_whitespace
|
||||||
self.disable_additional_properties = \
|
self.disable_additional_properties = \
|
||||||
vllm_config.decoding_config.disable_additional_properties
|
self.vllm_config.decoding_config.disable_additional_properties
|
||||||
|
|
||||||
tokenizer = tokenizer_group.get_lora_tokenizer(None)
|
|
||||||
self.ll_tokenizer = llguidance_hf.from_tokenizer(
|
self.ll_tokenizer = llguidance_hf.from_tokenizer(
|
||||||
tokenizer, self.vocab_size)
|
self.tokenizer, self.vocab_size)
|
||||||
|
|
||||||
def compile_grammar(self, request_type: StructuredOutputOptions,
|
def compile_grammar(self, request_type: StructuredOutputOptions,
|
||||||
grammar_spec: str) -> StructuredOutputGrammar:
|
grammar_spec: str) -> StructuredOutputGrammar:
|
||||||
|
|||||||
@ -1,9 +1,17 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import torch
|
if TYPE_CHECKING:
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
|
|
||||||
|
|
||||||
class StructuredOutputOptions(enum.Enum):
|
class StructuredOutputOptions(enum.Enum):
|
||||||
@ -85,9 +93,14 @@ class StructuredOutputGrammar(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class StructuredOutputBackend(ABC):
|
class StructuredOutputBackend(ABC):
|
||||||
"""Engine-level backend for structured output requests."""
|
"""Engine-level backend for structured output requests."""
|
||||||
|
|
||||||
|
vllm_config: VllmConfig
|
||||||
|
tokenizer: AnyTokenizer
|
||||||
|
vocab_size: int
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def compile_grammar(self, request_type: StructuredOutputOptions,
|
def compile_grammar(self, request_type: StructuredOutputOptions,
|
||||||
grammar_spec: str) -> StructuredOutputGrammar:
|
grammar_spec: str) -> StructuredOutputGrammar:
|
||||||
@ -104,7 +117,7 @@ class StructuredOutputBackend(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def allocate_token_bitmask(self, max_num_seqs: int):
|
def allocate_token_bitmask(self, max_num_seqs: int) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Allocates a token bitmask for the specified maximum number of sequences.
|
Allocates a token bitmask for the specified maximum number of sequences.
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
@ -7,10 +9,8 @@ from typing import TYPE_CHECKING, Any
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm.envs
|
import vllm.envs
|
||||||
from vllm.config import VllmConfig
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
|
||||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||||
from vllm.utils import LazyLoader
|
from vllm.utils import LazyLoader
|
||||||
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
|
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
|
||||||
@ -28,61 +28,49 @@ else:
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class XgrammarBackend(StructuredOutputBackend):
|
class XgrammarBackend(StructuredOutputBackend):
|
||||||
|
|
||||||
def __init__(self, vllm_config: VllmConfig):
|
def __post_init__(self):
|
||||||
self.vllm_config = vllm_config
|
|
||||||
tokenizer_group = init_tokenizer_from_configs(
|
|
||||||
model_config=vllm_config.model_config,
|
|
||||||
scheduler_config=vllm_config.scheduler_config,
|
|
||||||
lora_config=vllm_config.lora_config) # type: ignore[arg-type]
|
|
||||||
|
|
||||||
self.disable_any_whitespace = \
|
self.disable_any_whitespace = \
|
||||||
vllm_config.decoding_config.disable_any_whitespace
|
self.vllm_config.decoding_config.disable_any_whitespace
|
||||||
|
|
||||||
self.num_speculative_tokens = 0
|
if isinstance(self.tokenizer, MistralTokenizer):
|
||||||
if self.vllm_config.speculative_config is not None:
|
|
||||||
self.num_speculative_tokens = \
|
|
||||||
self.vllm_config.speculative_config.num_speculative_tokens
|
|
||||||
|
|
||||||
tokenizer = tokenizer_group.get_lora_tokenizer(None)
|
|
||||||
self.vocab_size = vllm_config.model_config.get_vocab_size()
|
|
||||||
if isinstance(tokenizer, MistralTokenizer):
|
|
||||||
# NOTE: ideally, xgrammar should handle this accordingly.
|
# NOTE: ideally, xgrammar should handle this accordingly.
|
||||||
# refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98
|
# refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98
|
||||||
try:
|
try:
|
||||||
if tokenizer.is_tekken:
|
if self.tokenizer.is_tekken:
|
||||||
encoded_vocab = tokenizer._vocab
|
encoded_vocab = self.tokenizer._vocab
|
||||||
else:
|
else:
|
||||||
encoded_vocab = [
|
encoded_vocab = [
|
||||||
token for token, _ in sorted(
|
token for token, _ in sorted(
|
||||||
tokenizer.get_vocab().items(),
|
self.tokenizer.get_vocab().items(),
|
||||||
key=lambda x: x[1],
|
key=lambda x: x[1],
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
stop_token_ids = None
|
stop_token_ids = None
|
||||||
if hasattr(
|
if (hasattr(
|
||||||
tokenizer,
|
self.tokenizer,
|
||||||
"eos_token_id",
|
"eos_token_id",
|
||||||
) and tokenizer.eos_token_id is not None:
|
) and self.tokenizer.eos_token_id is not None):
|
||||||
stop_token_ids = [tokenizer.eos_token_id]
|
stop_token_ids = [self.tokenizer.eos_token_id]
|
||||||
except AttributeError as e:
|
except AttributeError as e:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Cannot get the vocabulary of the tokenizer "
|
f"Cannot get the vocabulary of the tokenizer "
|
||||||
f"{type(tokenizer)}. The tokenizer should have a "
|
f"{type(self.tokenizer)}. The tokenizer should have a "
|
||||||
"get_vocab method.") from e
|
"get_vocab method.") from e
|
||||||
tokenizer_info = xgr.TokenizerInfo( # type: ignore
|
tokenizer_info = xgr.TokenizerInfo( # type: ignore
|
||||||
encoded_vocab=encoded_vocab,
|
encoded_vocab=encoded_vocab,
|
||||||
# NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501
|
# NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501
|
||||||
vocab_type=xgr.VocabType.RAW
|
vocab_type=xgr.VocabType.RAW
|
||||||
if tokenizer.is_tekken else xgr.VocabType.BYTE_FALLBACK,
|
if self.tokenizer.is_tekken else xgr.VocabType.BYTE_FALLBACK,
|
||||||
vocab_size=self.vocab_size,
|
vocab_size=self.vocab_size,
|
||||||
stop_token_ids=stop_token_ids,
|
stop_token_ids=stop_token_ids,
|
||||||
add_prefix_space=True,
|
add_prefix_space=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
tokenizer_info = xgr.TokenizerInfo.from_huggingface(
|
tokenizer_info = xgr.TokenizerInfo.from_huggingface(
|
||||||
tokenizer,
|
self.tokenizer,
|
||||||
vocab_size=self.vocab_size,
|
vocab_size=self.vocab_size,
|
||||||
)
|
)
|
||||||
self.compiler = xgr.GrammarCompiler(
|
self.compiler = xgr.GrammarCompiler(
|
||||||
@ -92,6 +80,11 @@ class XgrammarBackend(StructuredOutputBackend):
|
|||||||
cache_limit_bytes=vllm.envs.VLLM_XGRAMMAR_CACHE_MB * 1024 * 1024,
|
cache_limit_bytes=vllm.envs.VLLM_XGRAMMAR_CACHE_MB * 1024 * 1024,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.num_speculative_tokens = 0
|
||||||
|
if self.vllm_config.speculative_config is not None:
|
||||||
|
self.num_speculative_tokens = \
|
||||||
|
self.vllm_config.speculative_config.num_speculative_tokens
|
||||||
|
|
||||||
def compile_grammar(self, request_type: StructuredOutputOptions,
|
def compile_grammar(self, request_type: StructuredOutputOptions,
|
||||||
grammar_spec: str) -> StructuredOutputGrammar:
|
grammar_spec: str) -> StructuredOutputGrammar:
|
||||||
if request_type == StructuredOutputOptions.JSON:
|
if request_type == StructuredOutputOptions.JSON:
|
||||||
|
|||||||
@ -20,6 +20,7 @@ class StructuredOutputRequest:
|
|||||||
sampling_params: SamplingParams
|
sampling_params: SamplingParams
|
||||||
_grammar: Optional[Union[Future[StructuredOutputGrammar],
|
_grammar: Optional[Union[Future[StructuredOutputGrammar],
|
||||||
StructuredOutputGrammar]] = None
|
StructuredOutputGrammar]] = None
|
||||||
|
reasoning_ended: bool = False
|
||||||
|
|
||||||
def _check_grammar_completion(self) -> bool:
|
def _check_grammar_completion(self) -> bool:
|
||||||
# NOTE: We have to lazy import to gate circular imports
|
# NOTE: We have to lazy import to gate circular imports
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user