mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-29 07:17:04 +08:00
[v0][structured output] Support reasoning output (#12955)
Signed-off-by: Ce Gao <cegao@tensorchord.ai>
This commit is contained in:
parent
bc6ccb9878
commit
bf33700ecd
@ -76,7 +76,13 @@ Streaming chat completions are also supported for reasoning models. The `reasoni
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
Please note that it is not compatible with the OpenAI Python client library. You can use the `requests` library to make streaming requests.
|
Please note that it is not compatible with the OpenAI Python client library. You can use the `requests` library to make streaming requests. You could checkout the [example](https://github.com/vllm-project/vllm/blob/main/examples/online_serving/openai_chat_completion_with_reasoning_streaming.py).
|
||||||
|
|
||||||
|
## Limitations
|
||||||
|
|
||||||
|
- The reasoning content is only available for online serving's chat completion endpoint (`/v1/chat/completions`).
|
||||||
|
- It is not compatible with [`tool_calling`](#tool_calling).
|
||||||
|
- The reasoning content is not available for all models. Check the model's documentation to see if it supports reasoning.
|
||||||
|
|
||||||
## How to support a new reasoning model
|
## How to support a new reasoning model
|
||||||
|
|
||||||
@ -137,15 +143,36 @@ class ExampleParser(ReasoningParser):
|
|||||||
"""
|
"""
|
||||||
```
|
```
|
||||||
|
|
||||||
After defining the reasoning parser, you can use it by specifying the `--reasoning-parser` flag when making a request to the chat completion endpoint.
|
Additionally, to enable structured output, you'll need to create a new `Reasoner` similar to the one in `vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py`.
|
||||||
|
|
||||||
|
```python
|
||||||
|
@dataclass
|
||||||
|
class DeepSeekReasoner(Reasoner):
|
||||||
|
"""
|
||||||
|
Reasoner for DeepSeek R series models.
|
||||||
|
"""
|
||||||
|
start_token_id: int
|
||||||
|
end_token_id: int
|
||||||
|
|
||||||
|
start_token: str = "<think>"
|
||||||
|
end_token: str = "</think>"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_tokenizer(cls, tokenizer: PreTrainedTokenizer) -> Reasoner:
|
||||||
|
return cls(start_token_id=tokenizer.encode(
|
||||||
|
"<think>", add_special_tokens=False)[0],
|
||||||
|
end_token_id=tokenizer.encode("</think>",
|
||||||
|
add_special_tokens=False)[0])
|
||||||
|
|
||||||
|
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||||
|
return self.end_token_id in input_ids
|
||||||
|
```
|
||||||
|
|
||||||
|
The structured output engine like xgrammar will use `end_token_id` to check if the reasoning content is present in the model output and skip the structured output if it is the case.
|
||||||
|
|
||||||
|
Finally, you can enable reasoning for the model by using the `--enable-reasoning` and `--reasoning-parser` flags.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
vllm serve <model_tag> \
|
vllm serve <model_tag> \
|
||||||
--enable-reasoning --reasoning-parser example
|
--enable-reasoning --reasoning-parser example
|
||||||
```
|
```
|
||||||
|
|
||||||
## Limitations
|
|
||||||
|
|
||||||
- The reasoning content is only available for online serving's chat completion endpoint (`/v1/chat/completions`).
|
|
||||||
- It is not compatible with the [`structured_outputs`](#structured_outputs) and [`tool_calling`](#tool_calling) features.
|
|
||||||
- The reasoning content is not available for all models. Check the model's documentation to see if it supports reasoning.
|
|
||||||
|
|||||||
@ -0,0 +1,64 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
"""
|
||||||
|
An example shows how to generate structured outputs from reasoning models
|
||||||
|
like DeepSeekR1. The thinking process will not be guided by the JSON
|
||||||
|
schema provided by the user. Only the final output will be structured.
|
||||||
|
|
||||||
|
To run this example, you need to start the vLLM server with the reasoning
|
||||||
|
parser:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \
|
||||||
|
--enable-reasoning --reasoning-parser deepseek_r1
|
||||||
|
```
|
||||||
|
|
||||||
|
This example demonstrates how to generate chat completions from reasoning models
|
||||||
|
using the OpenAI Python client library.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from openai import OpenAI
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
# Modify OpenAI's API key and API base to use vLLM's API server.
|
||||||
|
openai_api_key = "EMPTY"
|
||||||
|
openai_api_base = "http://localhost:8000/v1"
|
||||||
|
|
||||||
|
client = OpenAI(
|
||||||
|
api_key=openai_api_key,
|
||||||
|
base_url=openai_api_base,
|
||||||
|
)
|
||||||
|
|
||||||
|
models = client.models.list()
|
||||||
|
model = models.data[0].id
|
||||||
|
|
||||||
|
|
||||||
|
# Guided decoding by JSON using Pydantic schema
|
||||||
|
class CarType(str, Enum):
|
||||||
|
sedan = "sedan"
|
||||||
|
suv = "SUV"
|
||||||
|
truck = "Truck"
|
||||||
|
coupe = "Coupe"
|
||||||
|
|
||||||
|
|
||||||
|
class CarDescription(BaseModel):
|
||||||
|
brand: str
|
||||||
|
model: str
|
||||||
|
car_type: CarType
|
||||||
|
|
||||||
|
|
||||||
|
json_schema = CarDescription.model_json_schema()
|
||||||
|
|
||||||
|
prompt = ("Generate a JSON with the brand, model and car_type of"
|
||||||
|
"the most iconic car from the 90's, think in 100 tokens")
|
||||||
|
completion = client.chat.completions.create(
|
||||||
|
model=model,
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": prompt,
|
||||||
|
}],
|
||||||
|
extra_body={"guided_json": json_schema},
|
||||||
|
)
|
||||||
|
print("content", completion.choices[0].message.content)
|
||||||
|
print("reasoning_content: ", completion.choices[0].message.reasoning_content)
|
||||||
@ -16,17 +16,33 @@ from vllm.sampling_params import GuidedDecodingParams
|
|||||||
|
|
||||||
MODEL_NAME = 'HuggingFaceH4/zephyr-7b-beta'
|
MODEL_NAME = 'HuggingFaceH4/zephyr-7b-beta'
|
||||||
GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"]
|
GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"]
|
||||||
|
GUIDED_DECODING_BACKENDS_WITH_REASONING_SUPPORT = ["outlines", "xgrammar"]
|
||||||
|
REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
|
||||||
|
|
||||||
|
|
||||||
def test_guided_logits_processors(sample_regex, sample_json_schema):
|
# Initialize the tokenizer for the model here to avoid repeated loading
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def zephyr_7B_tokenzer():
|
||||||
|
return AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def deepseek_r1_qwen_tokenizer():
|
||||||
|
return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)
|
||||||
|
|
||||||
|
|
||||||
|
def test_guided_logits_processors(zephyr_7B_tokenzer, sample_regex,
|
||||||
|
sample_json_schema):
|
||||||
"""Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor."""
|
"""Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor."""
|
||||||
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
|
regex_LP = RegexLogitsProcessor(sample_regex,
|
||||||
regex_LP = RegexLogitsProcessor(sample_regex, tokenizer)
|
zephyr_7B_tokenzer,
|
||||||
|
reasoner=None)
|
||||||
json_LP = JSONLogitsProcessor(sample_json_schema,
|
json_LP = JSONLogitsProcessor(sample_json_schema,
|
||||||
tokenizer,
|
zephyr_7B_tokenzer,
|
||||||
whitespace_pattern=None)
|
whitespace_pattern=None,
|
||||||
|
reasoner=None)
|
||||||
|
|
||||||
token_ids = tokenizer.encode(
|
token_ids = zephyr_7B_tokenzer.encode(
|
||||||
f"Give an example IPv4 address with this regex: {sample_regex}")
|
f"Give an example IPv4 address with this regex: {sample_regex}")
|
||||||
tensor = torch.rand(32000)
|
tensor = torch.rand(32000)
|
||||||
original_tensor = torch.clone(tensor)
|
original_tensor = torch.clone(tensor)
|
||||||
@ -34,7 +50,7 @@ def test_guided_logits_processors(sample_regex, sample_json_schema):
|
|||||||
assert tensor.shape == original_tensor.shape
|
assert tensor.shape == original_tensor.shape
|
||||||
assert not torch.allclose(tensor, original_tensor)
|
assert not torch.allclose(tensor, original_tensor)
|
||||||
|
|
||||||
token_ids = tokenizer.encode(
|
token_ids = zephyr_7B_tokenzer.encode(
|
||||||
f"Give an employee profile that fits this schema: {sample_json_schema}"
|
f"Give an employee profile that fits this schema: {sample_json_schema}"
|
||||||
)
|
)
|
||||||
tensor = torch.rand(32000)
|
tensor = torch.rand(32000)
|
||||||
@ -49,7 +65,8 @@ def test_guided_logits_processors(sample_regex, sample_json_schema):
|
|||||||
@pytest.mark.parametrize("is_local", [True, False])
|
@pytest.mark.parametrize("is_local", [True, False])
|
||||||
async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
|
async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
|
||||||
sample_regex,
|
sample_regex,
|
||||||
sample_json_schema):
|
sample_json_schema,
|
||||||
|
zephyr_7B_tokenzer):
|
||||||
|
|
||||||
config = ModelConfig(
|
config = ModelConfig(
|
||||||
MODEL_NAME,
|
MODEL_NAME,
|
||||||
@ -60,15 +77,14 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
|
|||||||
seed=0,
|
seed=0,
|
||||||
dtype="bfloat16",
|
dtype="bfloat16",
|
||||||
)
|
)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
token_ids = zephyr_7B_tokenzer.encode(
|
||||||
token_ids = tokenizer.encode(
|
|
||||||
f"Give an example IPv4 address with this regex: {sample_regex}")
|
f"Give an example IPv4 address with this regex: {sample_regex}")
|
||||||
regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)
|
regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)
|
||||||
|
|
||||||
regex_lp = get_local_guided_decoding_logits_processor(
|
regex_lp = get_local_guided_decoding_logits_processor(
|
||||||
regex_request, tokenizer, config) if is_local else \
|
regex_request, zephyr_7B_tokenzer, config) if is_local else \
|
||||||
await get_guided_decoding_logits_processor(
|
await get_guided_decoding_logits_processor(
|
||||||
regex_request, tokenizer, config)
|
regex_request, zephyr_7B_tokenzer, config)
|
||||||
assert regex_lp is not None
|
assert regex_lp is not None
|
||||||
tensor = torch.rand(32000)
|
tensor = torch.rand(32000)
|
||||||
original_tensor = torch.clone(tensor)
|
original_tensor = torch.clone(tensor)
|
||||||
@ -76,13 +92,85 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
|
|||||||
assert tensor.shape == original_tensor.shape
|
assert tensor.shape == original_tensor.shape
|
||||||
assert not torch.allclose(tensor, original_tensor)
|
assert not torch.allclose(tensor, original_tensor)
|
||||||
|
|
||||||
token_ids = tokenizer.encode(
|
token_ids = zephyr_7B_tokenzer.encode(
|
||||||
f"Give an employee profile that fits this schema: {sample_json_schema}"
|
f"Give an employee profile that fits this schema: {sample_json_schema}"
|
||||||
)
|
)
|
||||||
json_request = GuidedDecodingParams(json=sample_json_schema,
|
json_request = GuidedDecodingParams(json=sample_json_schema,
|
||||||
backend=backend)
|
backend=backend)
|
||||||
json_lp = await get_guided_decoding_logits_processor(
|
json_lp = await get_guided_decoding_logits_processor(
|
||||||
json_request, tokenizer, config)
|
json_request, zephyr_7B_tokenzer, config)
|
||||||
|
assert json_lp is not None
|
||||||
|
tensor = torch.rand(32000)
|
||||||
|
original_tensor = torch.clone(tensor)
|
||||||
|
tensor = json_lp(token_ids, tensor)
|
||||||
|
assert tensor.shape == original_tensor.shape
|
||||||
|
assert not torch.allclose(tensor, original_tensor)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("backend",
|
||||||
|
GUIDED_DECODING_BACKENDS_WITH_REASONING_SUPPORT)
|
||||||
|
@pytest.mark.parametrize("is_local", [True, False])
|
||||||
|
@pytest.mark.parametrize("reasoning_backend", ["deepseek_r1"])
|
||||||
|
async def test_guided_logits_processor_with_reasoning(
|
||||||
|
backend: str, is_local: bool, reasoning_backend: str, sample_regex,
|
||||||
|
sample_json_schema, deepseek_r1_qwen_tokenizer):
|
||||||
|
|
||||||
|
config = ModelConfig(
|
||||||
|
REASONING_MODEL_NAME,
|
||||||
|
task="generate",
|
||||||
|
tokenizer=REASONING_MODEL_NAME,
|
||||||
|
tokenizer_mode="auto",
|
||||||
|
trust_remote_code=False,
|
||||||
|
seed=0,
|
||||||
|
dtype="bfloat16",
|
||||||
|
)
|
||||||
|
token_ids = deepseek_r1_qwen_tokenizer.encode(
|
||||||
|
f"Give an example IPv4 address with this regex: {sample_regex}."
|
||||||
|
"<think>here is the thinking process")
|
||||||
|
regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)
|
||||||
|
|
||||||
|
regex_lp = get_local_guided_decoding_logits_processor(regex_request,
|
||||||
|
deepseek_r1_qwen_tokenizer, config,
|
||||||
|
reasoning_backend) if is_local else \
|
||||||
|
await get_guided_decoding_logits_processor(
|
||||||
|
regex_request, deepseek_r1_qwen_tokenizer, config,
|
||||||
|
reasoning_backend)
|
||||||
|
assert regex_lp is not None
|
||||||
|
tensor = torch.rand(32000)
|
||||||
|
original_tensor = torch.clone(tensor)
|
||||||
|
tensor = regex_lp(token_ids, tensor)
|
||||||
|
assert tensor.shape == original_tensor.shape
|
||||||
|
assert torch.allclose(tensor, original_tensor)
|
||||||
|
|
||||||
|
token_ids = deepseek_r1_qwen_tokenizer.encode(
|
||||||
|
f"Give an employee profile that fits this schema: {sample_json_schema}."
|
||||||
|
"<think>here is the thinking process")
|
||||||
|
json_request = GuidedDecodingParams(json=sample_json_schema,
|
||||||
|
backend=backend)
|
||||||
|
json_lp = get_local_guided_decoding_logits_processor(
|
||||||
|
json_request, deepseek_r1_qwen_tokenizer, config,
|
||||||
|
reasoning_backend) if is_local else \
|
||||||
|
await get_guided_decoding_logits_processor(
|
||||||
|
json_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend)
|
||||||
|
assert json_lp is not None
|
||||||
|
tensor = torch.rand(32000)
|
||||||
|
original_tensor = torch.clone(tensor)
|
||||||
|
tensor = json_lp(token_ids, tensor)
|
||||||
|
assert tensor.shape == original_tensor.shape
|
||||||
|
assert torch.allclose(tensor, original_tensor)
|
||||||
|
|
||||||
|
# Thinking is over, so the tensor should change.
|
||||||
|
token_ids = deepseek_r1_qwen_tokenizer.encode(
|
||||||
|
f"Give an employee profile that fits this schema: {sample_json_schema}."
|
||||||
|
"<think>here is the thinking process</think> Then")
|
||||||
|
json_request = GuidedDecodingParams(json=sample_json_schema,
|
||||||
|
backend=backend)
|
||||||
|
json_lp = get_local_guided_decoding_logits_processor(
|
||||||
|
json_request, deepseek_r1_qwen_tokenizer, config,
|
||||||
|
reasoning_backend) if is_local else \
|
||||||
|
await get_guided_decoding_logits_processor(
|
||||||
|
json_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend)
|
||||||
assert json_lp is not None
|
assert json_lp is not None
|
||||||
tensor = torch.rand(32000)
|
tensor = torch.rand(32000)
|
||||||
original_tensor = torch.clone(tensor)
|
original_tensor = torch.clone(tensor)
|
||||||
|
|||||||
@ -2715,6 +2715,8 @@ class DecodingConfig:
|
|||||||
# 'outlines' / 'lm-format-enforcer' / 'xgrammar'
|
# 'outlines' / 'lm-format-enforcer' / 'xgrammar'
|
||||||
guided_decoding_backend: str = 'xgrammar'
|
guided_decoding_backend: str = 'xgrammar'
|
||||||
|
|
||||||
|
reasoning_backend: Optional[str] = None
|
||||||
|
|
||||||
def compute_hash(self) -> str:
|
def compute_hash(self) -> str:
|
||||||
"""
|
"""
|
||||||
WARNING: Whenever a new field is added to this config,
|
WARNING: Whenever a new field is added to this config,
|
||||||
|
|||||||
@ -213,6 +213,8 @@ class EngineArgs:
|
|||||||
calculate_kv_scales: Optional[bool] = None
|
calculate_kv_scales: Optional[bool] = None
|
||||||
|
|
||||||
additional_config: Optional[Dict[str, Any]] = None
|
additional_config: Optional[Dict[str, Any]] = None
|
||||||
|
enable_reasoning: Optional[bool] = None
|
||||||
|
reasoning_parser: Optional[str] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if not self.tokenizer:
|
if not self.tokenizer:
|
||||||
@ -1059,6 +1061,25 @@ class EngineArgs:
|
|||||||
"Different platforms may support different configs. Make sure the "
|
"Different platforms may support different configs. Make sure the "
|
||||||
"configs are valid for the platform you are using. The input format"
|
"configs are valid for the platform you are using. The input format"
|
||||||
" is like '{\"config_key\":\"config_value\"}'")
|
" is like '{\"config_key\":\"config_value\"}'")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-reasoning",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Whether to enable reasoning_content for the model. "
|
||||||
|
"If enabled, the model will be able to generate reasoning content."
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--reasoning-parser",
|
||||||
|
type=str,
|
||||||
|
choices=["deepseek_r1"],
|
||||||
|
default=None,
|
||||||
|
help=
|
||||||
|
"Select the reasoning parser depending on the model that you're "
|
||||||
|
"using. This is used to parse the reasoning content into OpenAI "
|
||||||
|
"API format. Required for ``--enable-reasoning``.")
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -1332,7 +1353,10 @@ class EngineArgs:
|
|||||||
if self.enable_prompt_adapter else None
|
if self.enable_prompt_adapter else None
|
||||||
|
|
||||||
decoding_config = DecodingConfig(
|
decoding_config = DecodingConfig(
|
||||||
guided_decoding_backend=self.guided_decoding_backend)
|
guided_decoding_backend=self.guided_decoding_backend,
|
||||||
|
reasoning_backend=self.reasoning_parser
|
||||||
|
if self.enable_reasoning else None,
|
||||||
|
)
|
||||||
|
|
||||||
show_hidden_metrics = False
|
show_hidden_metrics = False
|
||||||
if self.show_hidden_metrics_for_version is not None:
|
if self.show_hidden_metrics_for_version is not None:
|
||||||
|
|||||||
@ -509,6 +509,7 @@ class _AsyncLLMEngine(LLMEngine):
|
|||||||
tokenizer=await self.get_tokenizer_async(lora_request),
|
tokenizer=await self.get_tokenizer_async(lora_request),
|
||||||
default_guided_backend=self.decoding_config.
|
default_guided_backend=self.decoding_config.
|
||||||
guided_decoding_backend,
|
guided_decoding_backend,
|
||||||
|
reasoning_backend=self.decoding_config.reasoning_backend,
|
||||||
model_config=self.model_config)
|
model_config=self.model_config)
|
||||||
|
|
||||||
self._add_processed_request(
|
self._add_processed_request(
|
||||||
@ -530,7 +531,7 @@ class _AsyncLLMEngine(LLMEngine):
|
|||||||
|
|
||||||
async def build_guided_decoding_logits_processor_async(
|
async def build_guided_decoding_logits_processor_async(
|
||||||
sampling_params: SamplingParams, tokenizer: AnyTokenizer,
|
sampling_params: SamplingParams, tokenizer: AnyTokenizer,
|
||||||
default_guided_backend: str,
|
default_guided_backend: str, reasoning_backend: Optional[str],
|
||||||
model_config: ModelConfig) -> SamplingParams:
|
model_config: ModelConfig) -> SamplingParams:
|
||||||
"""Constructs logits processors based on the guided_decoding,
|
"""Constructs logits processors based on the guided_decoding,
|
||||||
logits_bias, and allowed_token_ids fields in sampling_params. Deletes
|
logits_bias, and allowed_token_ids fields in sampling_params. Deletes
|
||||||
@ -545,14 +546,18 @@ async def build_guided_decoding_logits_processor_async(
|
|||||||
sampling_params = copy.copy(sampling_params)
|
sampling_params = copy.copy(sampling_params)
|
||||||
guided_decoding = sampling_params.guided_decoding
|
guided_decoding = sampling_params.guided_decoding
|
||||||
|
|
||||||
logger.debug("Building guided decoding logits processor. "
|
logger.info(
|
||||||
"Params: %s", guided_decoding)
|
"Building guided decoding logits processor. "
|
||||||
|
"guided_decoding: %s%s", guided_decoding,
|
||||||
|
f", reasoning_backend: {reasoning_backend}"
|
||||||
|
if reasoning_backend is not None else "")
|
||||||
|
|
||||||
guided_decoding.backend = guided_decoding.backend or default_guided_backend
|
guided_decoding.backend = guided_decoding.backend or default_guided_backend
|
||||||
|
|
||||||
processor = await get_guided_decoding_logits_processor(
|
processor = await get_guided_decoding_logits_processor(
|
||||||
guided_params=guided_decoding,
|
guided_params=guided_decoding,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
reasoning_backend=reasoning_backend,
|
||||||
model_config=model_config)
|
model_config=model_config)
|
||||||
|
|
||||||
if processor:
|
if processor:
|
||||||
|
|||||||
@ -2048,10 +2048,15 @@ class LLMEngine:
|
|||||||
guided_decoding.backend = guided_decoding.backend or \
|
guided_decoding.backend = guided_decoding.backend or \
|
||||||
self.decoding_config.guided_decoding_backend
|
self.decoding_config.guided_decoding_backend
|
||||||
|
|
||||||
|
logger.debug("Reasoning backend: %s",
|
||||||
|
self.decoding_config.reasoning_backend)
|
||||||
|
|
||||||
processor = get_local_guided_decoding_logits_processor(
|
processor = get_local_guided_decoding_logits_processor(
|
||||||
guided_params=guided_decoding,
|
guided_params=guided_decoding,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
model_config=self.model_config)
|
model_config=self.model_config,
|
||||||
|
reasoning_backend=self.decoding_config.reasoning_backend,
|
||||||
|
)
|
||||||
if processor:
|
if processor:
|
||||||
logits_processors.append(processor)
|
logits_processors.append(processor)
|
||||||
|
|
||||||
|
|||||||
@ -611,7 +611,8 @@ class MQLLMEngineClient(EngineClient):
|
|||||||
default_guided_backend=(self.decoding_config.guided_decoding_backend
|
default_guided_backend=(self.decoding_config.guided_decoding_backend
|
||||||
if self.decoding_config
|
if self.decoding_config
|
||||||
else DecodingConfig.guided_decoding_backend),
|
else DecodingConfig.guided_decoding_backend),
|
||||||
model_config=self.model_config
|
model_config=self.model_config,
|
||||||
|
reasoning_backend=self.decoding_config.reasoning_backend,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 1) Create output queue for this requests.
|
# 1) Create output queue for this requests.
|
||||||
|
|||||||
@ -13,7 +13,6 @@ from typing import List, Optional, Sequence, Union, get_args
|
|||||||
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
|
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
|
||||||
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
|
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
|
||||||
validate_chat_template)
|
validate_chat_template)
|
||||||
from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager
|
|
||||||
from vllm.entrypoints.openai.serving_models import (LoRAModulePath,
|
from vllm.entrypoints.openai.serving_models import (LoRAModulePath,
|
||||||
PromptAdapterPath)
|
PromptAdapterPath)
|
||||||
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
||||||
@ -215,23 +214,6 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
|||||||
default=False,
|
default=False,
|
||||||
help="Enable auto tool choice for supported models. Use "
|
help="Enable auto tool choice for supported models. Use "
|
||||||
"``--tool-call-parser`` to specify which parser to use.")
|
"``--tool-call-parser`` to specify which parser to use.")
|
||||||
parser.add_argument(
|
|
||||||
"--enable-reasoning",
|
|
||||||
action="store_true",
|
|
||||||
default=False,
|
|
||||||
help="Whether to enable reasoning_content for the model. "
|
|
||||||
"If enabled, the model will be able to generate reasoning content.")
|
|
||||||
|
|
||||||
valid_reasoning_parsers = ReasoningParserManager.reasoning_parsers.keys()
|
|
||||||
parser.add_argument(
|
|
||||||
"--reasoning-parser",
|
|
||||||
type=str,
|
|
||||||
metavar="{" + ",".join(valid_reasoning_parsers) + "}",
|
|
||||||
default=None,
|
|
||||||
help=
|
|
||||||
"Select the reasoning parser depending on the model that you're using."
|
|
||||||
" This is used to parse the reasoning content into OpenAI API "
|
|
||||||
"format. Required for ``--enable-reasoning``.")
|
|
||||||
|
|
||||||
valid_tool_parsers = ToolParserManager.tool_parsers.keys()
|
valid_tool_parsers = ToolParserManager.tool_parsers.keys()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
@ -5,6 +5,7 @@ from __future__ import annotations
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.guided_decoding.reasoner import get_reasoner
|
||||||
from vllm.model_executor.guided_decoding.utils import (
|
from vllm.model_executor.guided_decoding.utils import (
|
||||||
convert_lark_to_gbnf, grammar_is_likely_lark,
|
convert_lark_to_gbnf, grammar_is_likely_lark,
|
||||||
has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features)
|
has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features)
|
||||||
@ -103,8 +104,13 @@ def maybe_backend_fallback(
|
|||||||
|
|
||||||
|
|
||||||
async def get_guided_decoding_logits_processor(
|
async def get_guided_decoding_logits_processor(
|
||||||
guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer,
|
guided_params: GuidedDecodingParams,
|
||||||
model_config: ModelConfig) -> LogitsProcessor | None:
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
reasoning_backend: str | None = None) -> LogitsProcessor | None:
|
||||||
|
|
||||||
|
reasoner = get_reasoner(tokenizer, reasoning_backend)
|
||||||
|
|
||||||
guided_params = maybe_backend_fallback(guided_params)
|
guided_params = maybe_backend_fallback(guided_params)
|
||||||
# CFG grammar not supported by LMFE, so we use outlines instead
|
# CFG grammar not supported by LMFE, so we use outlines instead
|
||||||
if guided_params.backend_name == 'outlines':
|
if guided_params.backend_name == 'outlines':
|
||||||
@ -112,8 +118,8 @@ async def get_guided_decoding_logits_processor(
|
|||||||
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
|
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
|
||||||
get_outlines_guided_decoding_logits_processor)
|
get_outlines_guided_decoding_logits_processor)
|
||||||
return await get_outlines_guided_decoding_logits_processor(
|
return await get_outlines_guided_decoding_logits_processor(
|
||||||
guided_params, tokenizer)
|
guided_params, tokenizer, reasoner)
|
||||||
if guided_params.backend_name == 'lm-format-enforcer':
|
if guided_params.backend == 'lm-format-enforcer':
|
||||||
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
|
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
|
||||||
get_local_lm_format_enforcer_guided_decoding_logits_processor)
|
get_local_lm_format_enforcer_guided_decoding_logits_processor)
|
||||||
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
|
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
|
||||||
@ -122,7 +128,7 @@ async def get_guided_decoding_logits_processor(
|
|||||||
from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa
|
from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa
|
||||||
get_local_xgrammar_guided_decoding_logits_processor)
|
get_local_xgrammar_guided_decoding_logits_processor)
|
||||||
return get_local_xgrammar_guided_decoding_logits_processor(
|
return get_local_xgrammar_guided_decoding_logits_processor(
|
||||||
guided_params, tokenizer, model_config)
|
guided_params, tokenizer, model_config, reasoner)
|
||||||
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown guided decoding backend '{guided_params.backend}'. "
|
f"Unknown guided decoding backend '{guided_params.backend}'. "
|
||||||
@ -130,16 +136,22 @@ async def get_guided_decoding_logits_processor(
|
|||||||
|
|
||||||
|
|
||||||
def get_local_guided_decoding_logits_processor(
|
def get_local_guided_decoding_logits_processor(
|
||||||
guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer,
|
guided_params: GuidedDecodingParams,
|
||||||
model_config: ModelConfig) -> LogitsProcessor | None:
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
reasoning_backend: str | None = None) -> LogitsProcessor | None:
|
||||||
guided_params = maybe_backend_fallback(guided_params)
|
guided_params = maybe_backend_fallback(guided_params)
|
||||||
|
|
||||||
|
# Get the reasoner if needed, it will be None if reasoning_
|
||||||
|
reasoner = get_reasoner(tokenizer, reasoning_backend)
|
||||||
|
|
||||||
# CFG grammar not supported by LMFE, so we use outlines instead
|
# CFG grammar not supported by LMFE, so we use outlines instead
|
||||||
if guided_params.backend_name == 'outlines':
|
if guided_params.backend_name == 'outlines':
|
||||||
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
|
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
|
||||||
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
|
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
|
||||||
get_local_outlines_guided_decoding_logits_processor)
|
get_local_outlines_guided_decoding_logits_processor)
|
||||||
return get_local_outlines_guided_decoding_logits_processor(
|
return get_local_outlines_guided_decoding_logits_processor(
|
||||||
guided_params, tokenizer)
|
guided_params, tokenizer, reasoner)
|
||||||
if guided_params.backend_name == 'lm-format-enforcer':
|
if guided_params.backend_name == 'lm-format-enforcer':
|
||||||
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
|
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
|
||||||
get_local_lm_format_enforcer_guided_decoding_logits_processor)
|
get_local_lm_format_enforcer_guided_decoding_logits_processor)
|
||||||
@ -149,7 +161,7 @@ def get_local_guided_decoding_logits_processor(
|
|||||||
from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa
|
from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa
|
||||||
get_local_xgrammar_guided_decoding_logits_processor)
|
get_local_xgrammar_guided_decoding_logits_processor)
|
||||||
return get_local_xgrammar_guided_decoding_logits_processor(
|
return get_local_xgrammar_guided_decoding_logits_processor(
|
||||||
guided_params, tokenizer, model_config)
|
guided_params, tokenizer, model_config, reasoner)
|
||||||
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown guided decoding backend '{guided_params.backend}'. "
|
f"Unknown guided decoding backend '{guided_params.backend}'. "
|
||||||
|
|||||||
@ -6,12 +6,13 @@ import os
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from json import dumps as json_dumps
|
from json import dumps as json_dumps
|
||||||
from re import escape as regex_escape
|
from re import escape as regex_escape
|
||||||
from typing import Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
|
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
|
||||||
CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor)
|
CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor)
|
||||||
|
from vllm.model_executor.guided_decoding.reasoner import Reasoner
|
||||||
from vllm.sampling_params import GuidedDecodingParams
|
from vllm.sampling_params import GuidedDecodingParams
|
||||||
|
|
||||||
|
|
||||||
@ -58,7 +59,9 @@ _MAX_THREADPOOL_WORKERS = 16
|
|||||||
|
|
||||||
|
|
||||||
async def get_outlines_guided_decoding_logits_processor(
|
async def get_outlines_guided_decoding_logits_processor(
|
||||||
guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase
|
guided_params: GuidedDecodingParams,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
reasoner: Optional[Reasoner],
|
||||||
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
|
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
|
||||||
None]:
|
None]:
|
||||||
"""
|
"""
|
||||||
@ -82,11 +85,14 @@ async def get_outlines_guided_decoding_logits_processor(
|
|||||||
|
|
||||||
return await loop.run_in_executor(global_thread_pool,
|
return await loop.run_in_executor(global_thread_pool,
|
||||||
_get_logits_processor, guide, tokenizer,
|
_get_logits_processor, guide, tokenizer,
|
||||||
mode, guided_params.whitespace_pattern)
|
mode, guided_params.whitespace_pattern,
|
||||||
|
reasoner)
|
||||||
|
|
||||||
|
|
||||||
def get_local_outlines_guided_decoding_logits_processor(
|
def get_local_outlines_guided_decoding_logits_processor(
|
||||||
guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase
|
guided_params: GuidedDecodingParams,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
reasoner: Optional[Reasoner],
|
||||||
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
|
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
|
||||||
None]:
|
None]:
|
||||||
"""
|
"""
|
||||||
@ -100,7 +106,7 @@ def get_local_outlines_guided_decoding_logits_processor(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
return _get_logits_processor(guide, tokenizer, mode,
|
return _get_logits_processor(guide, tokenizer, mode,
|
||||||
guided_params.whitespace_pattern)
|
guided_params.whitespace_pattern, reasoner)
|
||||||
|
|
||||||
|
|
||||||
def _get_guide_and_mode(
|
def _get_guide_and_mode(
|
||||||
@ -131,14 +137,18 @@ def _get_guide_and_mode(
|
|||||||
|
|
||||||
|
|
||||||
def _get_logits_processor(
|
def _get_logits_processor(
|
||||||
guide: str, tokenizer: PreTrainedTokenizerBase, mode: GuidedDecodingMode,
|
guide: str,
|
||||||
whitespace_pattern: Union[str, None]
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
mode: GuidedDecodingMode,
|
||||||
|
whitespace_pattern: Union[str, None],
|
||||||
|
reasoner: Optional[Reasoner],
|
||||||
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor]:
|
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor]:
|
||||||
if mode == GuidedDecodingMode.JSON:
|
if mode == GuidedDecodingMode.JSON:
|
||||||
return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern)
|
return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern,
|
||||||
|
reasoner)
|
||||||
elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE:
|
elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE:
|
||||||
return RegexLogitsProcessor(guide, tokenizer)
|
return RegexLogitsProcessor(guide, tokenizer, reasoner)
|
||||||
elif mode == GuidedDecodingMode.GRAMMAR:
|
elif mode == GuidedDecodingMode.GRAMMAR:
|
||||||
return CFGLogitsProcessor(guide, tokenizer)
|
return CFGLogitsProcessor(guide, tokenizer, reasoner)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown guided decoding mode {mode}")
|
raise ValueError(f"Unknown guided decoding mode {mode}")
|
||||||
|
|||||||
@ -19,7 +19,7 @@ import copy
|
|||||||
import json
|
import json
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Callable, DefaultDict, Dict, List, Union
|
from typing import Callable, DefaultDict, Dict, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -32,13 +32,18 @@ from outlines_core.fsm.json_schema import build_regex_from_schema
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.guided_decoding.reasoner import Reasoner
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class BaseLogitsProcessor:
|
class BaseLogitsProcessor:
|
||||||
|
|
||||||
def __init__(self, guide: Guide):
|
def __init__(self, guide: Guide, reasoner: Optional[Reasoner]):
|
||||||
self._guide: Guide = guide
|
self._guide: Guide = guide
|
||||||
|
self._reasoner = reasoner
|
||||||
# CFGState is used for the FSM state for CFGGuide
|
# CFGState is used for the FSM state for CFGGuide
|
||||||
self._fsm_state: DefaultDict[int, Union[int,
|
self._fsm_state: DefaultDict[int, Union[int,
|
||||||
CFGState]] = defaultdict(int)
|
CFGState]] = defaultdict(int)
|
||||||
@ -46,6 +51,14 @@ class BaseLogitsProcessor:
|
|||||||
def __call__(self, input_ids: List[int],
|
def __call__(self, input_ids: List[int],
|
||||||
scores: torch.Tensor) -> torch.Tensor:
|
scores: torch.Tensor) -> torch.Tensor:
|
||||||
"""Use the FSM to bias the logits before sampling the next token."""
|
"""Use the FSM to bias the logits before sampling the next token."""
|
||||||
|
|
||||||
|
# Skip the structured logits processing if reasoning is not finished.
|
||||||
|
# reasoner is not None only when `--enable-reasoning` is set.
|
||||||
|
if self._reasoner is not None and \
|
||||||
|
not self._reasoner.is_reasoning_end(
|
||||||
|
input_ids):
|
||||||
|
return scores
|
||||||
|
|
||||||
seq_id = hash(tuple(input_ids))
|
seq_id = hash(tuple(input_ids))
|
||||||
|
|
||||||
if len(input_ids) > 0:
|
if len(input_ids) > 0:
|
||||||
@ -113,7 +126,12 @@ class RegexLogitsProcessor(BaseLogitsProcessor):
|
|||||||
tokenizer = _adapt_tokenizer(tokenizer)
|
tokenizer = _adapt_tokenizer(tokenizer)
|
||||||
return RegexGuide.from_regex(regex_string, tokenizer)
|
return RegexGuide.from_regex(regex_string, tokenizer)
|
||||||
|
|
||||||
def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase):
|
def __init__(
|
||||||
|
self,
|
||||||
|
regex_string: str,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
reasoner: Optional[Reasoner],
|
||||||
|
):
|
||||||
"""Compile the FSM that drives the regex-structured generation.
|
"""Compile the FSM that drives the regex-structured generation.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -125,14 +143,15 @@ class RegexLogitsProcessor(BaseLogitsProcessor):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
super().__init__(
|
super().__init__(
|
||||||
RegexLogitsProcessor._get_guide(regex_string, tokenizer))
|
RegexLogitsProcessor._get_guide(regex_string, tokenizer), reasoner)
|
||||||
|
|
||||||
|
|
||||||
class JSONLogitsProcessor(RegexLogitsProcessor):
|
class JSONLogitsProcessor(RegexLogitsProcessor):
|
||||||
|
|
||||||
def __init__(self, schema: Union[str, Dict, BaseModel],
|
def __init__(self, schema: Union[str, Dict, BaseModel],
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
whitespace_pattern: Union[str, None]):
|
whitespace_pattern: Union[str, None],
|
||||||
|
reasoner: Optional[Reasoner]):
|
||||||
"""Compile the FSM that drives the JSON-guided generation.
|
"""Compile the FSM that drives the JSON-guided generation.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -160,7 +179,7 @@ class JSONLogitsProcessor(RegexLogitsProcessor):
|
|||||||
f"a Pydantic object, a dictionary or a string that contains "
|
f"a Pydantic object, a dictionary or a string that contains "
|
||||||
f"the JSON Schema specification")
|
f"the JSON Schema specification")
|
||||||
regex_string = build_regex_from_schema(schema_str, whitespace_pattern)
|
regex_string = build_regex_from_schema(schema_str, whitespace_pattern)
|
||||||
super().__init__(regex_string, tokenizer)
|
super().__init__(regex_string, tokenizer, reasoner)
|
||||||
|
|
||||||
|
|
||||||
class CFGLogitsProcessor(BaseLogitsProcessor):
|
class CFGLogitsProcessor(BaseLogitsProcessor):
|
||||||
@ -171,7 +190,8 @@ class CFGLogitsProcessor(BaseLogitsProcessor):
|
|||||||
tokenizer = _adapt_tokenizer(tokenizer)
|
tokenizer = _adapt_tokenizer(tokenizer)
|
||||||
return CFGGuide(cfg, tokenizer)
|
return CFGGuide(cfg, tokenizer)
|
||||||
|
|
||||||
def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase):
|
def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase,
|
||||||
|
reasoner: Optional[Reasoner]):
|
||||||
"""Compile the FSM that drives the context free grammar generation.
|
"""Compile the FSM that drives the context free grammar generation.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -182,7 +202,8 @@ class CFGLogitsProcessor(BaseLogitsProcessor):
|
|||||||
The model's tokenizer
|
The model's tokenizer
|
||||||
|
|
||||||
"""
|
"""
|
||||||
super().__init__(CFGLogitsProcessor._get_guide(cfg, tokenizer))
|
super().__init__(CFGLogitsProcessor._get_guide(cfg, tokenizer),
|
||||||
|
reasoner)
|
||||||
self._guide = self._guide.copy()
|
self._guide = self._guide.copy()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
23
vllm/model_executor/guided_decoding/reasoner/__init__.py
Normal file
23
vllm/model_executor/guided_decoding/reasoner/__init__.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
|
from vllm.model_executor.guided_decoding.reasoner.deepseek_reasoner import ( # noqa: E501
|
||||||
|
DeepSeekReasoner)
|
||||||
|
from vllm.model_executor.guided_decoding.reasoner.reasoner import Reasoner
|
||||||
|
|
||||||
|
|
||||||
|
def get_reasoner(tokenizer: PreTrainedTokenizer,
|
||||||
|
reasoning_backend: str | None) -> Reasoner | None:
|
||||||
|
if reasoning_backend is None:
|
||||||
|
# No reasoning backend specified
|
||||||
|
return None
|
||||||
|
elif reasoning_backend == "deepseek_r1":
|
||||||
|
return DeepSeekReasoner.from_tokenizer(tokenizer)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown reasoning backend '{reasoning_backend}'")
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["Reasoner", "get_reasoner"]
|
||||||
@ -0,0 +1,28 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
|
from vllm.model_executor.guided_decoding.reasoner.reasoner import Reasoner
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DeepSeekReasoner(Reasoner):
|
||||||
|
"""
|
||||||
|
Reasoner for DeepSeek R series models.
|
||||||
|
"""
|
||||||
|
start_token_id: int
|
||||||
|
end_token_id: int
|
||||||
|
|
||||||
|
start_token: str = "<think>"
|
||||||
|
end_token: str = "</think>"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_tokenizer(cls, tokenizer: PreTrainedTokenizer) -> Reasoner:
|
||||||
|
return cls(start_token_id=tokenizer.encode(
|
||||||
|
"<think>", add_special_tokens=False)[0],
|
||||||
|
end_token_id=tokenizer.encode("</think>",
|
||||||
|
add_special_tokens=False)[0])
|
||||||
|
|
||||||
|
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||||
|
return self.end_token_id in input_ids
|
||||||
19
vllm/model_executor/guided_decoding/reasoner/reasoner.py
Normal file
19
vllm/model_executor/guided_decoding/reasoner/reasoner.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Reasoner(ABC):
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def from_tokenizer(cls, tokenizer: PreTrainedTokenizer) -> Reasoner:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||||
|
pass
|
||||||
@ -11,6 +11,8 @@ from typing import TYPE_CHECKING, Any, List
|
|||||||
import torch
|
import torch
|
||||||
from transformers import PreTrainedTokenizerFast
|
from transformers import PreTrainedTokenizerFast
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import xgrammar as xgr
|
import xgrammar as xgr
|
||||||
from xgrammar.base import _core as xgr_core
|
from xgrammar.base import _core as xgr_core
|
||||||
@ -19,7 +21,6 @@ except ImportError:
|
|||||||
xgr_installed = False
|
xgr_installed = False
|
||||||
pass
|
pass
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
|
||||||
from vllm.model_executor.guided_decoding.utils import (convert_lark_to_gbnf,
|
from vllm.model_executor.guided_decoding.utils import (convert_lark_to_gbnf,
|
||||||
grammar_is_likely_lark)
|
grammar_is_likely_lark)
|
||||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||||
@ -28,6 +29,7 @@ if TYPE_CHECKING:
|
|||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
|
from vllm.model_executor.guided_decoding.reasoner import Reasoner
|
||||||
from vllm.sampling_params import GuidedDecodingParams
|
from vllm.sampling_params import GuidedDecodingParams
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -38,12 +40,13 @@ def get_local_xgrammar_guided_decoding_logits_processor(
|
|||||||
guided_params: GuidedDecodingParams,
|
guided_params: GuidedDecodingParams,
|
||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
|
reasoner: Reasoner | None,
|
||||||
max_threads: int = 8):
|
max_threads: int = 8):
|
||||||
config = GrammarConfig.from_guided_params(guided_params=guided_params,
|
config = GrammarConfig.from_guided_params(guided_params=guided_params,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
max_threads=max_threads)
|
max_threads=max_threads)
|
||||||
return XGrammarLogitsProcessor(config)
|
return XGrammarLogitsProcessor(config, reasoner)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@ -293,6 +296,7 @@ class GrammarConfig:
|
|||||||
class XGrammarLogitsProcessor:
|
class XGrammarLogitsProcessor:
|
||||||
"""Wrapper class to support pickle protocol"""
|
"""Wrapper class to support pickle protocol"""
|
||||||
config: GrammarConfig
|
config: GrammarConfig
|
||||||
|
reasoner: Reasoner | None = None
|
||||||
|
|
||||||
ctx: xgr.CompiledGrammar | None = None
|
ctx: xgr.CompiledGrammar | None = None
|
||||||
token_bitmask: torch.Tensor = None # type: ignore[assignment]
|
token_bitmask: torch.Tensor = None # type: ignore[assignment]
|
||||||
@ -301,10 +305,11 @@ class XGrammarLogitsProcessor:
|
|||||||
prefilled: bool = field(default=False)
|
prefilled: bool = field(default=False)
|
||||||
|
|
||||||
def __getstate__(self) -> dict[str, Any]:
|
def __getstate__(self) -> dict[str, Any]:
|
||||||
return {'config': self.config}
|
return {'config': self.config, 'reasoner': self.reasoner}
|
||||||
|
|
||||||
def __setstate__(self, state: dict[str, Any]):
|
def __setstate__(self, state: dict[str, Any]):
|
||||||
self.config = state['config']
|
self.config = state['config']
|
||||||
|
self.reasoner = state['reasoner']
|
||||||
|
|
||||||
self.ctx = None
|
self.ctx = None
|
||||||
self.matchers = []
|
self.matchers = []
|
||||||
@ -331,6 +336,14 @@ class XGrammarLogitsProcessor:
|
|||||||
|
|
||||||
def __call__(self, input_ids: list[int],
|
def __call__(self, input_ids: list[int],
|
||||||
scores: torch.Tensor) -> torch.Tensor:
|
scores: torch.Tensor) -> torch.Tensor:
|
||||||
|
|
||||||
|
# Skip the structured logits processing if reasoning is not finished.
|
||||||
|
# reasoner is not None only when `--enable-reasoning` is set.
|
||||||
|
if self.reasoner is not None and \
|
||||||
|
not self.reasoner.is_reasoning_end(
|
||||||
|
input_ids):
|
||||||
|
return scores
|
||||||
|
|
||||||
if self.ctx is None:
|
if self.ctx is None:
|
||||||
self._ensure_ctx()
|
self._ensure_ctx()
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user