mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-25 13:54:28 +08:00
[v0][structured output] Support reasoning output (#12955)
Signed-off-by: Ce Gao <cegao@tensorchord.ai>
This commit is contained in:
parent
bc6ccb9878
commit
bf33700ecd
@ -76,7 +76,13 @@ Streaming chat completions are also supported for reasoning models. The `reasoni
|
||||
}
|
||||
```
|
||||
|
||||
Please note that it is not compatible with the OpenAI Python client library. You can use the `requests` library to make streaming requests.
|
||||
Please note that it is not compatible with the OpenAI Python client library. You can use the `requests` library to make streaming requests. You could checkout the [example](https://github.com/vllm-project/vllm/blob/main/examples/online_serving/openai_chat_completion_with_reasoning_streaming.py).
|
||||
|
||||
## Limitations
|
||||
|
||||
- The reasoning content is only available for online serving's chat completion endpoint (`/v1/chat/completions`).
|
||||
- It is not compatible with [`tool_calling`](#tool_calling).
|
||||
- The reasoning content is not available for all models. Check the model's documentation to see if it supports reasoning.
|
||||
|
||||
## How to support a new reasoning model
|
||||
|
||||
@ -137,15 +143,36 @@ class ExampleParser(ReasoningParser):
|
||||
"""
|
||||
```
|
||||
|
||||
After defining the reasoning parser, you can use it by specifying the `--reasoning-parser` flag when making a request to the chat completion endpoint.
|
||||
Additionally, to enable structured output, you'll need to create a new `Reasoner` similar to the one in `vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py`.
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class DeepSeekReasoner(Reasoner):
|
||||
"""
|
||||
Reasoner for DeepSeek R series models.
|
||||
"""
|
||||
start_token_id: int
|
||||
end_token_id: int
|
||||
|
||||
start_token: str = "<think>"
|
||||
end_token: str = "</think>"
|
||||
|
||||
@classmethod
|
||||
def from_tokenizer(cls, tokenizer: PreTrainedTokenizer) -> Reasoner:
|
||||
return cls(start_token_id=tokenizer.encode(
|
||||
"<think>", add_special_tokens=False)[0],
|
||||
end_token_id=tokenizer.encode("</think>",
|
||||
add_special_tokens=False)[0])
|
||||
|
||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||
return self.end_token_id in input_ids
|
||||
```
|
||||
|
||||
The structured output engine like xgrammar will use `end_token_id` to check if the reasoning content is present in the model output and skip the structured output if it is the case.
|
||||
|
||||
Finally, you can enable reasoning for the model by using the `--enable-reasoning` and `--reasoning-parser` flags.
|
||||
|
||||
```bash
|
||||
vllm serve <model_tag> \
|
||||
--enable-reasoning --reasoning-parser example
|
||||
```
|
||||
|
||||
## Limitations
|
||||
|
||||
- The reasoning content is only available for online serving's chat completion endpoint (`/v1/chat/completions`).
|
||||
- It is not compatible with the [`structured_outputs`](#structured_outputs) and [`tool_calling`](#tool_calling) features.
|
||||
- The reasoning content is not available for all models. Check the model's documentation to see if it supports reasoning.
|
||||
|
||||
@ -0,0 +1,64 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
An example shows how to generate structured outputs from reasoning models
|
||||
like DeepSeekR1. The thinking process will not be guided by the JSON
|
||||
schema provided by the user. Only the final output will be structured.
|
||||
|
||||
To run this example, you need to start the vLLM server with the reasoning
|
||||
parser:
|
||||
|
||||
```bash
|
||||
vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \
|
||||
--enable-reasoning --reasoning-parser deepseek_r1
|
||||
```
|
||||
|
||||
This example demonstrates how to generate chat completions from reasoning models
|
||||
using the OpenAI Python client library.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from openai import OpenAI
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Modify OpenAI's API key and API base to use vLLM's API server.
|
||||
openai_api_key = "EMPTY"
|
||||
openai_api_base = "http://localhost:8000/v1"
|
||||
|
||||
client = OpenAI(
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_api_base,
|
||||
)
|
||||
|
||||
models = client.models.list()
|
||||
model = models.data[0].id
|
||||
|
||||
|
||||
# Guided decoding by JSON using Pydantic schema
|
||||
class CarType(str, Enum):
|
||||
sedan = "sedan"
|
||||
suv = "SUV"
|
||||
truck = "Truck"
|
||||
coupe = "Coupe"
|
||||
|
||||
|
||||
class CarDescription(BaseModel):
|
||||
brand: str
|
||||
model: str
|
||||
car_type: CarType
|
||||
|
||||
|
||||
json_schema = CarDescription.model_json_schema()
|
||||
|
||||
prompt = ("Generate a JSON with the brand, model and car_type of"
|
||||
"the most iconic car from the 90's, think in 100 tokens")
|
||||
completion = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
}],
|
||||
extra_body={"guided_json": json_schema},
|
||||
)
|
||||
print("content", completion.choices[0].message.content)
|
||||
print("reasoning_content: ", completion.choices[0].message.reasoning_content)
|
||||
@ -16,17 +16,33 @@ from vllm.sampling_params import GuidedDecodingParams
|
||||
|
||||
MODEL_NAME = 'HuggingFaceH4/zephyr-7b-beta'
|
||||
GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"]
|
||||
GUIDED_DECODING_BACKENDS_WITH_REASONING_SUPPORT = ["outlines", "xgrammar"]
|
||||
REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
|
||||
|
||||
|
||||
def test_guided_logits_processors(sample_regex, sample_json_schema):
|
||||
# Initialize the tokenizer for the model here to avoid repeated loading
|
||||
@pytest.fixture(scope="module")
|
||||
def zephyr_7B_tokenzer():
|
||||
return AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def deepseek_r1_qwen_tokenizer():
|
||||
return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)
|
||||
|
||||
|
||||
def test_guided_logits_processors(zephyr_7B_tokenzer, sample_regex,
|
||||
sample_json_schema):
|
||||
"""Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor."""
|
||||
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
|
||||
regex_LP = RegexLogitsProcessor(sample_regex, tokenizer)
|
||||
regex_LP = RegexLogitsProcessor(sample_regex,
|
||||
zephyr_7B_tokenzer,
|
||||
reasoner=None)
|
||||
json_LP = JSONLogitsProcessor(sample_json_schema,
|
||||
tokenizer,
|
||||
whitespace_pattern=None)
|
||||
zephyr_7B_tokenzer,
|
||||
whitespace_pattern=None,
|
||||
reasoner=None)
|
||||
|
||||
token_ids = tokenizer.encode(
|
||||
token_ids = zephyr_7B_tokenzer.encode(
|
||||
f"Give an example IPv4 address with this regex: {sample_regex}")
|
||||
tensor = torch.rand(32000)
|
||||
original_tensor = torch.clone(tensor)
|
||||
@ -34,7 +50,7 @@ def test_guided_logits_processors(sample_regex, sample_json_schema):
|
||||
assert tensor.shape == original_tensor.shape
|
||||
assert not torch.allclose(tensor, original_tensor)
|
||||
|
||||
token_ids = tokenizer.encode(
|
||||
token_ids = zephyr_7B_tokenzer.encode(
|
||||
f"Give an employee profile that fits this schema: {sample_json_schema}"
|
||||
)
|
||||
tensor = torch.rand(32000)
|
||||
@ -49,7 +65,8 @@ def test_guided_logits_processors(sample_regex, sample_json_schema):
|
||||
@pytest.mark.parametrize("is_local", [True, False])
|
||||
async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
|
||||
sample_regex,
|
||||
sample_json_schema):
|
||||
sample_json_schema,
|
||||
zephyr_7B_tokenzer):
|
||||
|
||||
config = ModelConfig(
|
||||
MODEL_NAME,
|
||||
@ -60,15 +77,14 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
|
||||
seed=0,
|
||||
dtype="bfloat16",
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
token_ids = tokenizer.encode(
|
||||
token_ids = zephyr_7B_tokenzer.encode(
|
||||
f"Give an example IPv4 address with this regex: {sample_regex}")
|
||||
regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)
|
||||
|
||||
regex_lp = get_local_guided_decoding_logits_processor(
|
||||
regex_request, tokenizer, config) if is_local else \
|
||||
regex_request, zephyr_7B_tokenzer, config) if is_local else \
|
||||
await get_guided_decoding_logits_processor(
|
||||
regex_request, tokenizer, config)
|
||||
regex_request, zephyr_7B_tokenzer, config)
|
||||
assert regex_lp is not None
|
||||
tensor = torch.rand(32000)
|
||||
original_tensor = torch.clone(tensor)
|
||||
@ -76,13 +92,85 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
|
||||
assert tensor.shape == original_tensor.shape
|
||||
assert not torch.allclose(tensor, original_tensor)
|
||||
|
||||
token_ids = tokenizer.encode(
|
||||
token_ids = zephyr_7B_tokenzer.encode(
|
||||
f"Give an employee profile that fits this schema: {sample_json_schema}"
|
||||
)
|
||||
json_request = GuidedDecodingParams(json=sample_json_schema,
|
||||
backend=backend)
|
||||
json_lp = await get_guided_decoding_logits_processor(
|
||||
json_request, tokenizer, config)
|
||||
json_request, zephyr_7B_tokenzer, config)
|
||||
assert json_lp is not None
|
||||
tensor = torch.rand(32000)
|
||||
original_tensor = torch.clone(tensor)
|
||||
tensor = json_lp(token_ids, tensor)
|
||||
assert tensor.shape == original_tensor.shape
|
||||
assert not torch.allclose(tensor, original_tensor)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("backend",
|
||||
GUIDED_DECODING_BACKENDS_WITH_REASONING_SUPPORT)
|
||||
@pytest.mark.parametrize("is_local", [True, False])
|
||||
@pytest.mark.parametrize("reasoning_backend", ["deepseek_r1"])
|
||||
async def test_guided_logits_processor_with_reasoning(
|
||||
backend: str, is_local: bool, reasoning_backend: str, sample_regex,
|
||||
sample_json_schema, deepseek_r1_qwen_tokenizer):
|
||||
|
||||
config = ModelConfig(
|
||||
REASONING_MODEL_NAME,
|
||||
task="generate",
|
||||
tokenizer=REASONING_MODEL_NAME,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="bfloat16",
|
||||
)
|
||||
token_ids = deepseek_r1_qwen_tokenizer.encode(
|
||||
f"Give an example IPv4 address with this regex: {sample_regex}."
|
||||
"<think>here is the thinking process")
|
||||
regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)
|
||||
|
||||
regex_lp = get_local_guided_decoding_logits_processor(regex_request,
|
||||
deepseek_r1_qwen_tokenizer, config,
|
||||
reasoning_backend) if is_local else \
|
||||
await get_guided_decoding_logits_processor(
|
||||
regex_request, deepseek_r1_qwen_tokenizer, config,
|
||||
reasoning_backend)
|
||||
assert regex_lp is not None
|
||||
tensor = torch.rand(32000)
|
||||
original_tensor = torch.clone(tensor)
|
||||
tensor = regex_lp(token_ids, tensor)
|
||||
assert tensor.shape == original_tensor.shape
|
||||
assert torch.allclose(tensor, original_tensor)
|
||||
|
||||
token_ids = deepseek_r1_qwen_tokenizer.encode(
|
||||
f"Give an employee profile that fits this schema: {sample_json_schema}."
|
||||
"<think>here is the thinking process")
|
||||
json_request = GuidedDecodingParams(json=sample_json_schema,
|
||||
backend=backend)
|
||||
json_lp = get_local_guided_decoding_logits_processor(
|
||||
json_request, deepseek_r1_qwen_tokenizer, config,
|
||||
reasoning_backend) if is_local else \
|
||||
await get_guided_decoding_logits_processor(
|
||||
json_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend)
|
||||
assert json_lp is not None
|
||||
tensor = torch.rand(32000)
|
||||
original_tensor = torch.clone(tensor)
|
||||
tensor = json_lp(token_ids, tensor)
|
||||
assert tensor.shape == original_tensor.shape
|
||||
assert torch.allclose(tensor, original_tensor)
|
||||
|
||||
# Thinking is over, so the tensor should change.
|
||||
token_ids = deepseek_r1_qwen_tokenizer.encode(
|
||||
f"Give an employee profile that fits this schema: {sample_json_schema}."
|
||||
"<think>here is the thinking process</think> Then")
|
||||
json_request = GuidedDecodingParams(json=sample_json_schema,
|
||||
backend=backend)
|
||||
json_lp = get_local_guided_decoding_logits_processor(
|
||||
json_request, deepseek_r1_qwen_tokenizer, config,
|
||||
reasoning_backend) if is_local else \
|
||||
await get_guided_decoding_logits_processor(
|
||||
json_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend)
|
||||
assert json_lp is not None
|
||||
tensor = torch.rand(32000)
|
||||
original_tensor = torch.clone(tensor)
|
||||
|
||||
@ -2715,6 +2715,8 @@ class DecodingConfig:
|
||||
# 'outlines' / 'lm-format-enforcer' / 'xgrammar'
|
||||
guided_decoding_backend: str = 'xgrammar'
|
||||
|
||||
reasoning_backend: Optional[str] = None
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
|
||||
@ -213,6 +213,8 @@ class EngineArgs:
|
||||
calculate_kv_scales: Optional[bool] = None
|
||||
|
||||
additional_config: Optional[Dict[str, Any]] = None
|
||||
enable_reasoning: Optional[bool] = None
|
||||
reasoning_parser: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.tokenizer:
|
||||
@ -1059,6 +1061,25 @@ class EngineArgs:
|
||||
"Different platforms may support different configs. Make sure the "
|
||||
"configs are valid for the platform you are using. The input format"
|
||||
" is like '{\"config_key\":\"config_value\"}'")
|
||||
|
||||
parser.add_argument(
|
||||
"--enable-reasoning",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to enable reasoning_content for the model. "
|
||||
"If enabled, the model will be able to generate reasoning content."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--reasoning-parser",
|
||||
type=str,
|
||||
choices=["deepseek_r1"],
|
||||
default=None,
|
||||
help=
|
||||
"Select the reasoning parser depending on the model that you're "
|
||||
"using. This is used to parse the reasoning content into OpenAI "
|
||||
"API format. Required for ``--enable-reasoning``.")
|
||||
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
@ -1332,7 +1353,10 @@ class EngineArgs:
|
||||
if self.enable_prompt_adapter else None
|
||||
|
||||
decoding_config = DecodingConfig(
|
||||
guided_decoding_backend=self.guided_decoding_backend)
|
||||
guided_decoding_backend=self.guided_decoding_backend,
|
||||
reasoning_backend=self.reasoning_parser
|
||||
if self.enable_reasoning else None,
|
||||
)
|
||||
|
||||
show_hidden_metrics = False
|
||||
if self.show_hidden_metrics_for_version is not None:
|
||||
|
||||
@ -509,6 +509,7 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
tokenizer=await self.get_tokenizer_async(lora_request),
|
||||
default_guided_backend=self.decoding_config.
|
||||
guided_decoding_backend,
|
||||
reasoning_backend=self.decoding_config.reasoning_backend,
|
||||
model_config=self.model_config)
|
||||
|
||||
self._add_processed_request(
|
||||
@ -530,7 +531,7 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
|
||||
async def build_guided_decoding_logits_processor_async(
|
||||
sampling_params: SamplingParams, tokenizer: AnyTokenizer,
|
||||
default_guided_backend: str,
|
||||
default_guided_backend: str, reasoning_backend: Optional[str],
|
||||
model_config: ModelConfig) -> SamplingParams:
|
||||
"""Constructs logits processors based on the guided_decoding,
|
||||
logits_bias, and allowed_token_ids fields in sampling_params. Deletes
|
||||
@ -545,14 +546,18 @@ async def build_guided_decoding_logits_processor_async(
|
||||
sampling_params = copy.copy(sampling_params)
|
||||
guided_decoding = sampling_params.guided_decoding
|
||||
|
||||
logger.debug("Building guided decoding logits processor. "
|
||||
"Params: %s", guided_decoding)
|
||||
logger.info(
|
||||
"Building guided decoding logits processor. "
|
||||
"guided_decoding: %s%s", guided_decoding,
|
||||
f", reasoning_backend: {reasoning_backend}"
|
||||
if reasoning_backend is not None else "")
|
||||
|
||||
guided_decoding.backend = guided_decoding.backend or default_guided_backend
|
||||
|
||||
processor = await get_guided_decoding_logits_processor(
|
||||
guided_params=guided_decoding,
|
||||
tokenizer=tokenizer,
|
||||
reasoning_backend=reasoning_backend,
|
||||
model_config=model_config)
|
||||
|
||||
if processor:
|
||||
|
||||
@ -2048,10 +2048,15 @@ class LLMEngine:
|
||||
guided_decoding.backend = guided_decoding.backend or \
|
||||
self.decoding_config.guided_decoding_backend
|
||||
|
||||
logger.debug("Reasoning backend: %s",
|
||||
self.decoding_config.reasoning_backend)
|
||||
|
||||
processor = get_local_guided_decoding_logits_processor(
|
||||
guided_params=guided_decoding,
|
||||
tokenizer=tokenizer,
|
||||
model_config=self.model_config)
|
||||
model_config=self.model_config,
|
||||
reasoning_backend=self.decoding_config.reasoning_backend,
|
||||
)
|
||||
if processor:
|
||||
logits_processors.append(processor)
|
||||
|
||||
|
||||
@ -611,7 +611,8 @@ class MQLLMEngineClient(EngineClient):
|
||||
default_guided_backend=(self.decoding_config.guided_decoding_backend
|
||||
if self.decoding_config
|
||||
else DecodingConfig.guided_decoding_backend),
|
||||
model_config=self.model_config
|
||||
model_config=self.model_config,
|
||||
reasoning_backend=self.decoding_config.reasoning_backend,
|
||||
)
|
||||
|
||||
# 1) Create output queue for this requests.
|
||||
|
||||
@ -13,7 +13,6 @@ from typing import List, Optional, Sequence, Union, get_args
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
|
||||
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
|
||||
validate_chat_template)
|
||||
from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager
|
||||
from vllm.entrypoints.openai.serving_models import (LoRAModulePath,
|
||||
PromptAdapterPath)
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
||||
@ -215,23 +214,6 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
default=False,
|
||||
help="Enable auto tool choice for supported models. Use "
|
||||
"``--tool-call-parser`` to specify which parser to use.")
|
||||
parser.add_argument(
|
||||
"--enable-reasoning",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to enable reasoning_content for the model. "
|
||||
"If enabled, the model will be able to generate reasoning content.")
|
||||
|
||||
valid_reasoning_parsers = ReasoningParserManager.reasoning_parsers.keys()
|
||||
parser.add_argument(
|
||||
"--reasoning-parser",
|
||||
type=str,
|
||||
metavar="{" + ",".join(valid_reasoning_parsers) + "}",
|
||||
default=None,
|
||||
help=
|
||||
"Select the reasoning parser depending on the model that you're using."
|
||||
" This is used to parse the reasoning content into OpenAI API "
|
||||
"format. Required for ``--enable-reasoning``.")
|
||||
|
||||
valid_tool_parsers = ToolParserManager.tool_parsers.keys()
|
||||
parser.add_argument(
|
||||
|
||||
@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.guided_decoding.reasoner import get_reasoner
|
||||
from vllm.model_executor.guided_decoding.utils import (
|
||||
convert_lark_to_gbnf, grammar_is_likely_lark,
|
||||
has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features)
|
||||
@ -103,8 +104,13 @@ def maybe_backend_fallback(
|
||||
|
||||
|
||||
async def get_guided_decoding_logits_processor(
|
||||
guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer,
|
||||
model_config: ModelConfig) -> LogitsProcessor | None:
|
||||
guided_params: GuidedDecodingParams,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
model_config: ModelConfig,
|
||||
reasoning_backend: str | None = None) -> LogitsProcessor | None:
|
||||
|
||||
reasoner = get_reasoner(tokenizer, reasoning_backend)
|
||||
|
||||
guided_params = maybe_backend_fallback(guided_params)
|
||||
# CFG grammar not supported by LMFE, so we use outlines instead
|
||||
if guided_params.backend_name == 'outlines':
|
||||
@ -112,8 +118,8 @@ async def get_guided_decoding_logits_processor(
|
||||
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
|
||||
get_outlines_guided_decoding_logits_processor)
|
||||
return await get_outlines_guided_decoding_logits_processor(
|
||||
guided_params, tokenizer)
|
||||
if guided_params.backend_name == 'lm-format-enforcer':
|
||||
guided_params, tokenizer, reasoner)
|
||||
if guided_params.backend == 'lm-format-enforcer':
|
||||
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
|
||||
get_local_lm_format_enforcer_guided_decoding_logits_processor)
|
||||
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
|
||||
@ -122,7 +128,7 @@ async def get_guided_decoding_logits_processor(
|
||||
from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa
|
||||
get_local_xgrammar_guided_decoding_logits_processor)
|
||||
return get_local_xgrammar_guided_decoding_logits_processor(
|
||||
guided_params, tokenizer, model_config)
|
||||
guided_params, tokenizer, model_config, reasoner)
|
||||
|
||||
raise ValueError(
|
||||
f"Unknown guided decoding backend '{guided_params.backend}'. "
|
||||
@ -130,16 +136,22 @@ async def get_guided_decoding_logits_processor(
|
||||
|
||||
|
||||
def get_local_guided_decoding_logits_processor(
|
||||
guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer,
|
||||
model_config: ModelConfig) -> LogitsProcessor | None:
|
||||
guided_params: GuidedDecodingParams,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
model_config: ModelConfig,
|
||||
reasoning_backend: str | None = None) -> LogitsProcessor | None:
|
||||
guided_params = maybe_backend_fallback(guided_params)
|
||||
|
||||
# Get the reasoner if needed, it will be None if reasoning_
|
||||
reasoner = get_reasoner(tokenizer, reasoning_backend)
|
||||
|
||||
# CFG grammar not supported by LMFE, so we use outlines instead
|
||||
if guided_params.backend_name == 'outlines':
|
||||
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
|
||||
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
|
||||
get_local_outlines_guided_decoding_logits_processor)
|
||||
return get_local_outlines_guided_decoding_logits_processor(
|
||||
guided_params, tokenizer)
|
||||
guided_params, tokenizer, reasoner)
|
||||
if guided_params.backend_name == 'lm-format-enforcer':
|
||||
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
|
||||
get_local_lm_format_enforcer_guided_decoding_logits_processor)
|
||||
@ -149,7 +161,7 @@ def get_local_guided_decoding_logits_processor(
|
||||
from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa
|
||||
get_local_xgrammar_guided_decoding_logits_processor)
|
||||
return get_local_xgrammar_guided_decoding_logits_processor(
|
||||
guided_params, tokenizer, model_config)
|
||||
guided_params, tokenizer, model_config, reasoner)
|
||||
|
||||
raise ValueError(
|
||||
f"Unknown guided decoding backend '{guided_params.backend}'. "
|
||||
|
||||
@ -6,12 +6,13 @@ import os
|
||||
from enum import Enum
|
||||
from json import dumps as json_dumps
|
||||
from re import escape as regex_escape
|
||||
from typing import Tuple, Union
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
|
||||
CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor)
|
||||
from vllm.model_executor.guided_decoding.reasoner import Reasoner
|
||||
from vllm.sampling_params import GuidedDecodingParams
|
||||
|
||||
|
||||
@ -58,7 +59,9 @@ _MAX_THREADPOOL_WORKERS = 16
|
||||
|
||||
|
||||
async def get_outlines_guided_decoding_logits_processor(
|
||||
guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase
|
||||
guided_params: GuidedDecodingParams,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
reasoner: Optional[Reasoner],
|
||||
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
|
||||
None]:
|
||||
"""
|
||||
@ -82,11 +85,14 @@ async def get_outlines_guided_decoding_logits_processor(
|
||||
|
||||
return await loop.run_in_executor(global_thread_pool,
|
||||
_get_logits_processor, guide, tokenizer,
|
||||
mode, guided_params.whitespace_pattern)
|
||||
mode, guided_params.whitespace_pattern,
|
||||
reasoner)
|
||||
|
||||
|
||||
def get_local_outlines_guided_decoding_logits_processor(
|
||||
guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase
|
||||
guided_params: GuidedDecodingParams,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
reasoner: Optional[Reasoner],
|
||||
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
|
||||
None]:
|
||||
"""
|
||||
@ -100,7 +106,7 @@ def get_local_outlines_guided_decoding_logits_processor(
|
||||
return None
|
||||
|
||||
return _get_logits_processor(guide, tokenizer, mode,
|
||||
guided_params.whitespace_pattern)
|
||||
guided_params.whitespace_pattern, reasoner)
|
||||
|
||||
|
||||
def _get_guide_and_mode(
|
||||
@ -131,14 +137,18 @@ def _get_guide_and_mode(
|
||||
|
||||
|
||||
def _get_logits_processor(
|
||||
guide: str, tokenizer: PreTrainedTokenizerBase, mode: GuidedDecodingMode,
|
||||
whitespace_pattern: Union[str, None]
|
||||
guide: str,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
mode: GuidedDecodingMode,
|
||||
whitespace_pattern: Union[str, None],
|
||||
reasoner: Optional[Reasoner],
|
||||
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor]:
|
||||
if mode == GuidedDecodingMode.JSON:
|
||||
return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern)
|
||||
return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern,
|
||||
reasoner)
|
||||
elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE:
|
||||
return RegexLogitsProcessor(guide, tokenizer)
|
||||
return RegexLogitsProcessor(guide, tokenizer, reasoner)
|
||||
elif mode == GuidedDecodingMode.GRAMMAR:
|
||||
return CFGLogitsProcessor(guide, tokenizer)
|
||||
return CFGLogitsProcessor(guide, tokenizer, reasoner)
|
||||
else:
|
||||
raise ValueError(f"Unknown guided decoding mode {mode}")
|
||||
|
||||
@ -19,7 +19,7 @@ import copy
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from functools import lru_cache
|
||||
from typing import Callable, DefaultDict, Dict, List, Union
|
||||
from typing import Callable, DefaultDict, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -32,13 +32,18 @@ from outlines_core.fsm.json_schema import build_regex_from_schema
|
||||
from pydantic import BaseModel
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.guided_decoding.reasoner import Reasoner
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class BaseLogitsProcessor:
|
||||
|
||||
def __init__(self, guide: Guide):
|
||||
def __init__(self, guide: Guide, reasoner: Optional[Reasoner]):
|
||||
self._guide: Guide = guide
|
||||
self._reasoner = reasoner
|
||||
# CFGState is used for the FSM state for CFGGuide
|
||||
self._fsm_state: DefaultDict[int, Union[int,
|
||||
CFGState]] = defaultdict(int)
|
||||
@ -46,6 +51,14 @@ class BaseLogitsProcessor:
|
||||
def __call__(self, input_ids: List[int],
|
||||
scores: torch.Tensor) -> torch.Tensor:
|
||||
"""Use the FSM to bias the logits before sampling the next token."""
|
||||
|
||||
# Skip the structured logits processing if reasoning is not finished.
|
||||
# reasoner is not None only when `--enable-reasoning` is set.
|
||||
if self._reasoner is not None and \
|
||||
not self._reasoner.is_reasoning_end(
|
||||
input_ids):
|
||||
return scores
|
||||
|
||||
seq_id = hash(tuple(input_ids))
|
||||
|
||||
if len(input_ids) > 0:
|
||||
@ -113,7 +126,12 @@ class RegexLogitsProcessor(BaseLogitsProcessor):
|
||||
tokenizer = _adapt_tokenizer(tokenizer)
|
||||
return RegexGuide.from_regex(regex_string, tokenizer)
|
||||
|
||||
def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase):
|
||||
def __init__(
|
||||
self,
|
||||
regex_string: str,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
reasoner: Optional[Reasoner],
|
||||
):
|
||||
"""Compile the FSM that drives the regex-structured generation.
|
||||
|
||||
Parameters
|
||||
@ -125,14 +143,15 @@ class RegexLogitsProcessor(BaseLogitsProcessor):
|
||||
|
||||
"""
|
||||
super().__init__(
|
||||
RegexLogitsProcessor._get_guide(regex_string, tokenizer))
|
||||
RegexLogitsProcessor._get_guide(regex_string, tokenizer), reasoner)
|
||||
|
||||
|
||||
class JSONLogitsProcessor(RegexLogitsProcessor):
|
||||
|
||||
def __init__(self, schema: Union[str, Dict, BaseModel],
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
whitespace_pattern: Union[str, None]):
|
||||
whitespace_pattern: Union[str, None],
|
||||
reasoner: Optional[Reasoner]):
|
||||
"""Compile the FSM that drives the JSON-guided generation.
|
||||
|
||||
Parameters
|
||||
@ -160,7 +179,7 @@ class JSONLogitsProcessor(RegexLogitsProcessor):
|
||||
f"a Pydantic object, a dictionary or a string that contains "
|
||||
f"the JSON Schema specification")
|
||||
regex_string = build_regex_from_schema(schema_str, whitespace_pattern)
|
||||
super().__init__(regex_string, tokenizer)
|
||||
super().__init__(regex_string, tokenizer, reasoner)
|
||||
|
||||
|
||||
class CFGLogitsProcessor(BaseLogitsProcessor):
|
||||
@ -171,7 +190,8 @@ class CFGLogitsProcessor(BaseLogitsProcessor):
|
||||
tokenizer = _adapt_tokenizer(tokenizer)
|
||||
return CFGGuide(cfg, tokenizer)
|
||||
|
||||
def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase):
|
||||
def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase,
|
||||
reasoner: Optional[Reasoner]):
|
||||
"""Compile the FSM that drives the context free grammar generation.
|
||||
|
||||
Parameters
|
||||
@ -182,7 +202,8 @@ class CFGLogitsProcessor(BaseLogitsProcessor):
|
||||
The model's tokenizer
|
||||
|
||||
"""
|
||||
super().__init__(CFGLogitsProcessor._get_guide(cfg, tokenizer))
|
||||
super().__init__(CFGLogitsProcessor._get_guide(cfg, tokenizer),
|
||||
reasoner)
|
||||
self._guide = self._guide.copy()
|
||||
|
||||
|
||||
|
||||
23
vllm/model_executor/guided_decoding/reasoner/__init__.py
Normal file
23
vllm/model_executor/guided_decoding/reasoner/__init__.py
Normal file
@ -0,0 +1,23 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from vllm.model_executor.guided_decoding.reasoner.deepseek_reasoner import ( # noqa: E501
|
||||
DeepSeekReasoner)
|
||||
from vllm.model_executor.guided_decoding.reasoner.reasoner import Reasoner
|
||||
|
||||
|
||||
def get_reasoner(tokenizer: PreTrainedTokenizer,
|
||||
reasoning_backend: str | None) -> Reasoner | None:
|
||||
if reasoning_backend is None:
|
||||
# No reasoning backend specified
|
||||
return None
|
||||
elif reasoning_backend == "deepseek_r1":
|
||||
return DeepSeekReasoner.from_tokenizer(tokenizer)
|
||||
else:
|
||||
raise ValueError(f"Unknown reasoning backend '{reasoning_backend}'")
|
||||
|
||||
|
||||
__all__ = ["Reasoner", "get_reasoner"]
|
||||
@ -0,0 +1,28 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from dataclasses import dataclass
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from vllm.model_executor.guided_decoding.reasoner.reasoner import Reasoner
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeepSeekReasoner(Reasoner):
|
||||
"""
|
||||
Reasoner for DeepSeek R series models.
|
||||
"""
|
||||
start_token_id: int
|
||||
end_token_id: int
|
||||
|
||||
start_token: str = "<think>"
|
||||
end_token: str = "</think>"
|
||||
|
||||
@classmethod
|
||||
def from_tokenizer(cls, tokenizer: PreTrainedTokenizer) -> Reasoner:
|
||||
return cls(start_token_id=tokenizer.encode(
|
||||
"<think>", add_special_tokens=False)[0],
|
||||
end_token_id=tokenizer.encode("</think>",
|
||||
add_special_tokens=False)[0])
|
||||
|
||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||
return self.end_token_id in input_ids
|
||||
19
vllm/model_executor/guided_decoding/reasoner/reasoner.py
Normal file
19
vllm/model_executor/guided_decoding/reasoner/reasoner.py
Normal file
@ -0,0 +1,19 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
|
||||
@dataclass
|
||||
class Reasoner(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def from_tokenizer(cls, tokenizer: PreTrainedTokenizer) -> Reasoner:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||
pass
|
||||
@ -11,6 +11,8 @@ from typing import TYPE_CHECKING, Any, List
|
||||
import torch
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
try:
|
||||
import xgrammar as xgr
|
||||
from xgrammar.base import _core as xgr_core
|
||||
@ -19,7 +21,6 @@ except ImportError:
|
||||
xgr_installed = False
|
||||
pass
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.guided_decoding.utils import (convert_lark_to_gbnf,
|
||||
grammar_is_likely_lark)
|
||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||
@ -28,6 +29,7 @@ if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.model_executor.guided_decoding.reasoner import Reasoner
|
||||
from vllm.sampling_params import GuidedDecodingParams
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -38,12 +40,13 @@ def get_local_xgrammar_guided_decoding_logits_processor(
|
||||
guided_params: GuidedDecodingParams,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
model_config: ModelConfig,
|
||||
reasoner: Reasoner | None,
|
||||
max_threads: int = 8):
|
||||
config = GrammarConfig.from_guided_params(guided_params=guided_params,
|
||||
model_config=model_config,
|
||||
tokenizer=tokenizer,
|
||||
max_threads=max_threads)
|
||||
return XGrammarLogitsProcessor(config)
|
||||
return XGrammarLogitsProcessor(config, reasoner)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@ -293,6 +296,7 @@ class GrammarConfig:
|
||||
class XGrammarLogitsProcessor:
|
||||
"""Wrapper class to support pickle protocol"""
|
||||
config: GrammarConfig
|
||||
reasoner: Reasoner | None = None
|
||||
|
||||
ctx: xgr.CompiledGrammar | None = None
|
||||
token_bitmask: torch.Tensor = None # type: ignore[assignment]
|
||||
@ -301,10 +305,11 @@ class XGrammarLogitsProcessor:
|
||||
prefilled: bool = field(default=False)
|
||||
|
||||
def __getstate__(self) -> dict[str, Any]:
|
||||
return {'config': self.config}
|
||||
return {'config': self.config, 'reasoner': self.reasoner}
|
||||
|
||||
def __setstate__(self, state: dict[str, Any]):
|
||||
self.config = state['config']
|
||||
self.reasoner = state['reasoner']
|
||||
|
||||
self.ctx = None
|
||||
self.matchers = []
|
||||
@ -331,6 +336,14 @@ class XGrammarLogitsProcessor:
|
||||
|
||||
def __call__(self, input_ids: list[int],
|
||||
scores: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
# Skip the structured logits processing if reasoning is not finished.
|
||||
# reasoner is not None only when `--enable-reasoning` is set.
|
||||
if self.reasoner is not None and \
|
||||
not self.reasoner.is_reasoning_end(
|
||||
input_ids):
|
||||
return scores
|
||||
|
||||
if self.ctx is None:
|
||||
self._ensure_ctx()
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user