mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 01:05:28 +08:00
[Frontend] Support reasoning content for deepseek r1 (#12473)
Signed-off-by: Ce Gao <cegao@tensorchord.ai> Co-authored-by: Rafael Vasquez <rafvasq21@gmail.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: Michael Goin <mgoin@redhat.com>
This commit is contained in:
parent
fbb5bd4cef
commit
a7e3eba66f
151
docs/source/features/reasoning_outputs.md
Normal file
151
docs/source/features/reasoning_outputs.md
Normal file
@ -0,0 +1,151 @@
|
|||||||
|
(reasoning-outputs)=
|
||||||
|
|
||||||
|
# Reasoning Outputs
|
||||||
|
|
||||||
|
vLLM offers support for reasoning models like [DeepSeek R1](https://huggingface.co/deepseek-ai/DeepSeek-R1), which are designed to generate outputs containing both reasoning steps and final conclusions.
|
||||||
|
|
||||||
|
Reasoning models return a additional `reasoning_content` field in their outputs, which contains the reasoning steps that led to the final conclusion. This field is not present in the outputs of other models.
|
||||||
|
|
||||||
|
## Supported Models
|
||||||
|
|
||||||
|
vLLM currently supports the following reasoning models:
|
||||||
|
|
||||||
|
- [DeepSeek R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d) (`deepseek_r1`, which looks for `<think> ... </think>`)
|
||||||
|
|
||||||
|
## Quickstart
|
||||||
|
|
||||||
|
To use reasoning models, you need to specify the `--enable-reasoning` and `--reasoning-parser` flags when making a request to the chat completion endpoint. The `--reasoning-parser` flag specifies the reasoning parser to use for extracting reasoning content from the model output.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \
|
||||||
|
--enable-reasoning --reasoning-parser deepseek_r1
|
||||||
|
```
|
||||||
|
|
||||||
|
Next, make a request to the model that should return the reasoning content in the response.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
# Modify OpenAI's API key and API base to use vLLM's API server.
|
||||||
|
openai_api_key = "EMPTY"
|
||||||
|
openai_api_base = "http://localhost:8000/v1"
|
||||||
|
|
||||||
|
client = OpenAI(
|
||||||
|
api_key=openai_api_key,
|
||||||
|
base_url=openai_api_base,
|
||||||
|
)
|
||||||
|
|
||||||
|
models = client.models.list()
|
||||||
|
model = models.data[0].id
|
||||||
|
|
||||||
|
# Round 1
|
||||||
|
messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}]
|
||||||
|
response = client.chat.completions.create(model=model, messages=messages)
|
||||||
|
|
||||||
|
reasoning_content = response.choices[0].message.reasoning_content
|
||||||
|
content = response.choices[0].message.content
|
||||||
|
|
||||||
|
print("reasoning_content:", reasoning_content)
|
||||||
|
print("content:", content)
|
||||||
|
```
|
||||||
|
|
||||||
|
The `reasoning_content` field contains the reasoning steps that led to the final conclusion, while the `content` field contains the final conclusion.
|
||||||
|
|
||||||
|
## Streaming chat completions
|
||||||
|
|
||||||
|
Streaming chat completions are also supported for reasoning models. The `reasoning_content` field is available in the `delta` field in [chat completion response chunks](https://platform.openai.com/docs/api-reference/chat/streaming).
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"id": "chatcmpl-123",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": 1694268190,
|
||||||
|
"model": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
|
||||||
|
"system_fingerprint": "fp_44709d6fcb",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {
|
||||||
|
"role": "assistant",
|
||||||
|
"reasoning_content": "is",
|
||||||
|
},
|
||||||
|
"logprobs": null,
|
||||||
|
"finish_reason": null
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Please note that it is not compatible with the OpenAI Python client library. You can use the `requests` library to make streaming requests.
|
||||||
|
|
||||||
|
## How to support a new reasoning model
|
||||||
|
|
||||||
|
You can add a new `ReasoningParser` similar to `vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py`.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# import the required packages
|
||||||
|
|
||||||
|
from vllm.entrypoints.openai.reasoning_parsers.abs_reasoning_parsers import (
|
||||||
|
ReasoningParser, ReasoningParserManager)
|
||||||
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||||
|
DeltaMessage)
|
||||||
|
|
||||||
|
# define a reasoning parser and register it to vllm
|
||||||
|
# the name list in register_module can be used
|
||||||
|
# in --reasoning-parser.
|
||||||
|
@ReasoningParserManager.register_module(["example"])
|
||||||
|
class ExampleParser(ReasoningParser):
|
||||||
|
def __init__(self, tokenizer: AnyTokenizer):
|
||||||
|
super().__init__(tokenizer)
|
||||||
|
|
||||||
|
def extract_reasoning_content_streaming(
|
||||||
|
self,
|
||||||
|
previous_text: str,
|
||||||
|
current_text: str,
|
||||||
|
delta_text: str,
|
||||||
|
previous_token_ids: Sequence[int],
|
||||||
|
current_token_ids: Sequence[int],
|
||||||
|
delta_token_ids: Sequence[int],
|
||||||
|
) -> Union[DeltaMessage, None]:
|
||||||
|
"""
|
||||||
|
Instance method that should be implemented for extracting reasoning
|
||||||
|
from an incomplete response; for use when handling reasoning calls and
|
||||||
|
streaming. Has to be an instance method because it requires state -
|
||||||
|
the current tokens/diffs, but also the information about what has
|
||||||
|
previously been parsed and extracted (see constructor)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def extract_reasoning_content(
|
||||||
|
self, model_output: str, request: ChatCompletionRequest
|
||||||
|
) -> Tuple[Optional[str], Optional[str]]:
|
||||||
|
"""
|
||||||
|
Extract reasoning content from a complete model-generated string.
|
||||||
|
|
||||||
|
Used for non-streaming responses where we have the entire model response
|
||||||
|
available before sending to the client.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
model_output: str
|
||||||
|
The model-generated string to extract reasoning content from.
|
||||||
|
|
||||||
|
request: ChatCompletionRequest
|
||||||
|
The request object that was used to generate the model_output.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[Optional[str], Optional[str]]
|
||||||
|
A tuple containing the reasoning content and the content.
|
||||||
|
"""
|
||||||
|
```
|
||||||
|
|
||||||
|
After defining the reasoning parser, you can use it by specifying the `--reasoning-parser` flag when making a request to the chat completion endpoint.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
vllm serve <model_tag> \
|
||||||
|
--enable-reasoning --reasoning-parser example
|
||||||
|
```
|
||||||
|
|
||||||
|
## Limitations
|
||||||
|
|
||||||
|
- The reasoning content is only available for online serving's chat completion endpoint (`/v1/chat/completions`).
|
||||||
|
- It is not compatible with the [`structured_outputs`](#structured_outputs) and [`tool_calling`](#tool_calling) features.
|
||||||
|
- The reasoning content is not available for all models. Check the model's documentation to see if it supports reasoning.
|
||||||
@ -90,6 +90,7 @@ models/extensions/index
|
|||||||
features/quantization/index
|
features/quantization/index
|
||||||
features/lora
|
features/lora
|
||||||
features/tool_calling
|
features/tool_calling
|
||||||
|
features/reasoning_outputs
|
||||||
features/structured_outputs
|
features/structured_outputs
|
||||||
features/automatic_prefix_caching
|
features/automatic_prefix_caching
|
||||||
features/disagg_prefill
|
features/disagg_prefill
|
||||||
|
|||||||
@ -0,0 +1,53 @@
|
|||||||
|
"""
|
||||||
|
An example shows how to generate chat completions from reasoning models
|
||||||
|
like DeepSeekR1.
|
||||||
|
|
||||||
|
To run this example, you need to start the vLLM server with the reasoning
|
||||||
|
parser:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \
|
||||||
|
--enable-reasoning --reasoning-parser deepseek_r1
|
||||||
|
```
|
||||||
|
|
||||||
|
This example demonstrates how to generate chat completions from reasoning models
|
||||||
|
using the OpenAI Python client library.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
# Modify OpenAI's API key and API base to use vLLM's API server.
|
||||||
|
openai_api_key = "EMPTY"
|
||||||
|
openai_api_base = "http://localhost:8000/v1"
|
||||||
|
|
||||||
|
client = OpenAI(
|
||||||
|
api_key=openai_api_key,
|
||||||
|
base_url=openai_api_base,
|
||||||
|
)
|
||||||
|
|
||||||
|
models = client.models.list()
|
||||||
|
model = models.data[0].id
|
||||||
|
|
||||||
|
# Round 1
|
||||||
|
messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}]
|
||||||
|
response = client.chat.completions.create(model=model, messages=messages)
|
||||||
|
|
||||||
|
reasoning_content = response.choices[0].message.reasoning_content
|
||||||
|
content = response.choices[0].message.content
|
||||||
|
|
||||||
|
print("reasoning_content:", reasoning_content)
|
||||||
|
print("content:", content)
|
||||||
|
|
||||||
|
# Round 2
|
||||||
|
messages.append({"role": "assistant", "content": content})
|
||||||
|
messages.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": "How many Rs are there in the word 'strawberry'?",
|
||||||
|
})
|
||||||
|
response = client.chat.completions.create(model=model, messages=messages)
|
||||||
|
|
||||||
|
reasoning_content = response.choices[0].message.reasoning_content
|
||||||
|
content = response.choices[0].message.content
|
||||||
|
|
||||||
|
print("reasoning_content:", reasoning_content)
|
||||||
|
print("content:", content)
|
||||||
@ -0,0 +1,90 @@
|
|||||||
|
"""
|
||||||
|
An example shows how to generate chat completions from reasoning models
|
||||||
|
like DeepSeekR1.
|
||||||
|
|
||||||
|
To run this example, you need to start the vLLM server with the reasoning
|
||||||
|
parser:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \
|
||||||
|
--enable-reasoning --reasoning-parser deepseek_r1
|
||||||
|
```
|
||||||
|
|
||||||
|
Unlike openai_chat_completion_with_reasoning.py, this example demonstrates the
|
||||||
|
streaming chat completions feature.
|
||||||
|
|
||||||
|
The streaming chat completions feature allows you to receive chat completions
|
||||||
|
in real-time as they are generated by the model. This is useful for scenarios
|
||||||
|
where you want to display chat completions to the user as they are generated
|
||||||
|
by the model.
|
||||||
|
|
||||||
|
Here we do not use the OpenAI Python client library, because it does not support
|
||||||
|
`reasoning_content` fields in the response.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
# Modify OpenAI's API key and API base to use vLLM's API server.
|
||||||
|
openai_api_key = "EMPTY"
|
||||||
|
openai_api_base = "http://localhost:8000/v1"
|
||||||
|
|
||||||
|
models = requests.get(
|
||||||
|
f"{openai_api_base}/models",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {openai_api_key}"
|
||||||
|
},
|
||||||
|
).json()
|
||||||
|
model = models["data"][0]["id"]
|
||||||
|
|
||||||
|
# Streaming chat completions
|
||||||
|
messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}]
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
f"{openai_api_base}/chat/completions",
|
||||||
|
headers={"Authorization": f"Bearer {openai_api_key}"},
|
||||||
|
json={
|
||||||
|
"model": model,
|
||||||
|
"messages": messages,
|
||||||
|
"stream": True
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
print("client: Start streaming chat completions...")
|
||||||
|
printed_reasoning_content = False
|
||||||
|
printed_content = False
|
||||||
|
# Make the streaming request
|
||||||
|
if response.status_code == 200:
|
||||||
|
# Process the streaming response
|
||||||
|
for line in response.iter_lines():
|
||||||
|
if line: # Filter out keep-alive new lines
|
||||||
|
# Decode the line and parse the JSON
|
||||||
|
decoded_line = line.decode("utf-8")
|
||||||
|
if decoded_line.startswith("data:"):
|
||||||
|
data = decoded_line[5:].strip() # Remove "data:" prefix
|
||||||
|
if data == "[DONE]": # End of stream
|
||||||
|
print("\nclient: Stream completed.")
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
# Parse the JSON data
|
||||||
|
chunk = json.loads(data)
|
||||||
|
reasoning_content = chunk["choices"][0]["delta"].get(
|
||||||
|
"reasoning_content", "")
|
||||||
|
content = chunk["choices"][0]["delta"].get("content", "")
|
||||||
|
|
||||||
|
if reasoning_content:
|
||||||
|
if not printed_reasoning_content:
|
||||||
|
printed_reasoning_content = True
|
||||||
|
print("reasoning_content:", end="", flush=True)
|
||||||
|
print(reasoning_content, end="", flush=True)
|
||||||
|
elif content:
|
||||||
|
if not printed_content:
|
||||||
|
printed_content = True
|
||||||
|
print("\ncontent:", end="", flush=True)
|
||||||
|
# Extract and print the content
|
||||||
|
print(content, end="", flush=True)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
print("Error decoding JSON:", decoded_line)
|
||||||
|
else:
|
||||||
|
print(f"Error: {response.status_code} - {response.text}")
|
||||||
@ -0,0 +1,120 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from tests.entrypoints.openai.reasoning_parsers.utils import (
|
||||||
|
run_reasoning_extraction)
|
||||||
|
from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser,
|
||||||
|
ReasoningParserManager)
|
||||||
|
|
||||||
|
parser_name = "deepseek_r1"
|
||||||
|
start_token = "<think>"
|
||||||
|
end_token = "</think>"
|
||||||
|
|
||||||
|
SIMPLE_REASONING = {
|
||||||
|
"output": "<think>This is a reasoning section</think>This is the rest",
|
||||||
|
"reasoning_content": "This is a reasoning section",
|
||||||
|
"content": "This is the rest",
|
||||||
|
}
|
||||||
|
COMPLETE_REASONING = {
|
||||||
|
"output": "<think>This is a reasoning section</think>",
|
||||||
|
"reasoning_content": "This is a reasoning section",
|
||||||
|
"content": None,
|
||||||
|
}
|
||||||
|
NO_REASONING = {
|
||||||
|
"output": "This is a reasoning section",
|
||||||
|
"reasoning_content": None,
|
||||||
|
"content": "This is a reasoning section",
|
||||||
|
}
|
||||||
|
MULTIPLE_LINES = {
|
||||||
|
"output": "<think>This\nThat</think>This is the rest\nThat",
|
||||||
|
"reasoning_content": "This\nThat",
|
||||||
|
"content": "This is the rest\nThat",
|
||||||
|
}
|
||||||
|
SHORTEST_REASONING_NO_STREAMING = {
|
||||||
|
"output": "<think></think>This is the rest",
|
||||||
|
"reasoning_content": "",
|
||||||
|
"content": "This is the rest",
|
||||||
|
}
|
||||||
|
SHORTEST_REASONING = {
|
||||||
|
"output": "<think></think>This is the rest",
|
||||||
|
"reasoning_content": None,
|
||||||
|
"content": "This is the rest",
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASES = [
|
||||||
|
pytest.param(
|
||||||
|
False,
|
||||||
|
SIMPLE_REASONING,
|
||||||
|
id="simple_streaming",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
True,
|
||||||
|
SIMPLE_REASONING,
|
||||||
|
id="simple_streaming",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
False,
|
||||||
|
COMPLETE_REASONING,
|
||||||
|
id="complete_streaming",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
True,
|
||||||
|
COMPLETE_REASONING,
|
||||||
|
id="complete_streaming",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
False,
|
||||||
|
NO_REASONING,
|
||||||
|
id="no_streaming",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
True,
|
||||||
|
NO_REASONING,
|
||||||
|
id="no_streaming",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
False,
|
||||||
|
MULTIPLE_LINES,
|
||||||
|
id="multiple_lines_streaming",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
True,
|
||||||
|
MULTIPLE_LINES,
|
||||||
|
id="multiple_lines_streaming",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
True,
|
||||||
|
SHORTEST_REASONING,
|
||||||
|
id="shortest_streaming",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
False,
|
||||||
|
SHORTEST_REASONING_NO_STREAMING,
|
||||||
|
id="shortest_streaming",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("streaming, param_dict", TEST_CASES)
|
||||||
|
def test_reasoning(
|
||||||
|
streaming: bool,
|
||||||
|
param_dict: dict,
|
||||||
|
):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
|
||||||
|
tokenizer.add_tokens([start_token, end_token])
|
||||||
|
output = tokenizer.tokenize(param_dict["output"])
|
||||||
|
# decode everything to tokens
|
||||||
|
output_tokens: List[str] = [
|
||||||
|
tokenizer.convert_tokens_to_string([token]) for token in output
|
||||||
|
]
|
||||||
|
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(
|
||||||
|
parser_name)(tokenizer)
|
||||||
|
|
||||||
|
reasoning, content = run_reasoning_extraction(parser,
|
||||||
|
output_tokens,
|
||||||
|
streaming=streaming)
|
||||||
|
|
||||||
|
assert reasoning == param_dict["reasoning_content"]
|
||||||
|
assert content == param_dict["content"]
|
||||||
93
tests/entrypoints/openai/reasoning_parsers/utils.py
Normal file
93
tests/entrypoints/openai/reasoning_parsers/utils.py
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||||
|
DeltaMessage)
|
||||||
|
from vllm.entrypoints.openai.reasoning_parsers import ReasoningParser
|
||||||
|
|
||||||
|
|
||||||
|
class StreamingReasoningReconstructor:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.reasoning_content = None
|
||||||
|
self.other_content = None
|
||||||
|
|
||||||
|
def append_delta(self, delta: DeltaMessage):
|
||||||
|
# content and the reasoning content should not be present
|
||||||
|
# at the same time
|
||||||
|
assert delta.content is None or delta.reasoning_content is None, (
|
||||||
|
"Both content and reasoning content are present in the "
|
||||||
|
"delta message")
|
||||||
|
if delta.content is not None:
|
||||||
|
if self.other_content is None:
|
||||||
|
self.other_content = delta.content
|
||||||
|
else:
|
||||||
|
self.other_content += delta.content
|
||||||
|
else:
|
||||||
|
if self.reasoning_content is None:
|
||||||
|
self.reasoning_content = delta.reasoning_content
|
||||||
|
else:
|
||||||
|
self.reasoning_content += delta.reasoning_content
|
||||||
|
|
||||||
|
|
||||||
|
def run_reasoning_extraction(
|
||||||
|
reasoning_parser: ReasoningParser,
|
||||||
|
model_output: List[str],
|
||||||
|
request: Union[ChatCompletionRequest, None] = None,
|
||||||
|
streaming: bool = False,
|
||||||
|
) -> Tuple[Optional[str], Optional[str]]:
|
||||||
|
if streaming:
|
||||||
|
reconstructor = run_reasoning_extraction_streaming(
|
||||||
|
reasoning_parser,
|
||||||
|
model_output,
|
||||||
|
request,
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
reconstructor.reasoning_content,
|
||||||
|
reconstructor.other_content or None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
reasoning, content = run_reasoning_extraction_nonstreaming(
|
||||||
|
reasoning_parser, model_output, request)
|
||||||
|
return reasoning, content
|
||||||
|
|
||||||
|
|
||||||
|
def run_reasoning_extraction_nonstreaming(
|
||||||
|
reasoning_parser: ReasoningParser,
|
||||||
|
model_output: List[str],
|
||||||
|
request: Union[ChatCompletionRequest, None] = None,
|
||||||
|
) -> Tuple[Optional[str], Optional[str]]:
|
||||||
|
request = request or ChatCompletionRequest(messages=[], model="test-model")
|
||||||
|
return reasoning_parser.extract_reasoning_content(
|
||||||
|
model_output=''.join(model_output), request=request)
|
||||||
|
|
||||||
|
|
||||||
|
def run_reasoning_extraction_streaming(
|
||||||
|
reasoning_parser: ReasoningParser,
|
||||||
|
model_deltas: List[str],
|
||||||
|
request: Union[ChatCompletionRequest, None] = None,
|
||||||
|
) -> StreamingReasoningReconstructor:
|
||||||
|
request = request or ChatCompletionRequest(messages=[], model="test-model")
|
||||||
|
reconstructor = StreamingReasoningReconstructor()
|
||||||
|
previous_text = ""
|
||||||
|
previous_tokens: List[int] = []
|
||||||
|
for delta in model_deltas:
|
||||||
|
token_delta = [
|
||||||
|
reasoning_parser.vocab.get(token)
|
||||||
|
for token in reasoning_parser.model_tokenizer.tokenize(delta)
|
||||||
|
if token in reasoning_parser.vocab
|
||||||
|
]
|
||||||
|
current_text = previous_text + delta
|
||||||
|
current_tokens = previous_tokens + token_delta
|
||||||
|
delta_message = reasoning_parser.extract_reasoning_content_streaming(
|
||||||
|
previous_text,
|
||||||
|
current_text,
|
||||||
|
delta,
|
||||||
|
previous_tokens,
|
||||||
|
current_tokens,
|
||||||
|
token_delta,
|
||||||
|
)
|
||||||
|
if delta_message is not None:
|
||||||
|
reconstructor.append_delta(delta_message)
|
||||||
|
previous_text = current_text
|
||||||
|
previous_tokens = current_tokens
|
||||||
|
return reconstructor
|
||||||
@ -116,6 +116,35 @@ def test_enable_auto_choice_passes_with_tool_call_parser(serve_parser):
|
|||||||
validate_parsed_serve_args(args)
|
validate_parsed_serve_args(args)
|
||||||
|
|
||||||
|
|
||||||
|
def test_enable_auto_choice_fails_with_enable_reasoning(serve_parser):
|
||||||
|
"""Ensure validation fails if reasoning is enabled with auto tool choice"""
|
||||||
|
args = serve_parser.parse_args(args=[
|
||||||
|
"--enable-auto-tool-choice",
|
||||||
|
"--enable-reasoning",
|
||||||
|
])
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
validate_parsed_serve_args(args)
|
||||||
|
|
||||||
|
|
||||||
|
def test_enable_reasoning_passes_with_reasoning_parser(serve_parser):
|
||||||
|
"""Ensure validation passes if reasoning is enabled
|
||||||
|
with a reasoning parser"""
|
||||||
|
args = serve_parser.parse_args(args=[
|
||||||
|
"--enable-reasoning",
|
||||||
|
"--reasoning-parser",
|
||||||
|
"deepseek_r1",
|
||||||
|
])
|
||||||
|
validate_parsed_serve_args(args)
|
||||||
|
|
||||||
|
|
||||||
|
def test_enable_reasoning_fails_without_reasoning_parser(serve_parser):
|
||||||
|
"""Ensure validation fails if reasoning is enabled
|
||||||
|
without a reasoning parser"""
|
||||||
|
args = serve_parser.parse_args(args=["--enable-reasoning"])
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
validate_parsed_serve_args(args)
|
||||||
|
|
||||||
|
|
||||||
def test_chat_template_validation_for_happy_paths(serve_parser):
|
def test_chat_template_validation_for_happy_paths(serve_parser):
|
||||||
"""Ensure validation passes if the chat template exists"""
|
"""Ensure validation passes if the chat template exists"""
|
||||||
args = serve_parser.parse_args(
|
args = serve_parser.parse_args(
|
||||||
|
|||||||
@ -61,6 +61,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
|||||||
TokenizeRequest,
|
TokenizeRequest,
|
||||||
TokenizeResponse,
|
TokenizeResponse,
|
||||||
UnloadLoraAdapterRequest)
|
UnloadLoraAdapterRequest)
|
||||||
|
from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||||
@ -771,6 +772,8 @@ async def init_app_state(
|
|||||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||||
enable_auto_tools=args.enable_auto_tool_choice,
|
enable_auto_tools=args.enable_auto_tool_choice,
|
||||||
tool_parser=args.tool_call_parser,
|
tool_parser=args.tool_call_parser,
|
||||||
|
enable_reasoning=args.enable_reasoning,
|
||||||
|
reasoning_parser=args.reasoning_parser,
|
||||||
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
||||||
) if model_config.runner_type == "generate" else None
|
) if model_config.runner_type == "generate" else None
|
||||||
state.openai_serving_completion = OpenAIServingCompletion(
|
state.openai_serving_completion = OpenAIServingCompletion(
|
||||||
@ -844,6 +847,13 @@ async def run_server(args, **uvicorn_kwargs) -> None:
|
|||||||
raise KeyError(f"invalid tool call parser: {args.tool_call_parser} "
|
raise KeyError(f"invalid tool call parser: {args.tool_call_parser} "
|
||||||
f"(chose from {{ {','.join(valid_tool_parses)} }})")
|
f"(chose from {{ {','.join(valid_tool_parses)} }})")
|
||||||
|
|
||||||
|
valid_reasoning_parses = ReasoningParserManager.reasoning_parsers.keys()
|
||||||
|
if args.enable_reasoning \
|
||||||
|
and args.reasoning_parser not in valid_reasoning_parses:
|
||||||
|
raise KeyError(
|
||||||
|
f"invalid reasoning parser: {args.reasoning_parser} "
|
||||||
|
f"(chose from {{ {','.join(valid_reasoning_parses)} }})")
|
||||||
|
|
||||||
# workaround to make sure that we bind the port before the engine is set up.
|
# workaround to make sure that we bind the port before the engine is set up.
|
||||||
# This avoids race conditions with ray.
|
# This avoids race conditions with ray.
|
||||||
# see https://github.com/vllm-project/vllm/issues/8204
|
# see https://github.com/vllm-project/vllm/issues/8204
|
||||||
|
|||||||
@ -12,6 +12,7 @@ from typing import List, Optional, Sequence, Union, get_args
|
|||||||
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
|
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
|
||||||
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
|
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
|
||||||
validate_chat_template)
|
validate_chat_template)
|
||||||
|
from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager
|
||||||
from vllm.entrypoints.openai.serving_models import (LoRAModulePath,
|
from vllm.entrypoints.openai.serving_models import (LoRAModulePath,
|
||||||
PromptAdapterPath)
|
PromptAdapterPath)
|
||||||
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
||||||
@ -208,6 +209,23 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
|||||||
default=False,
|
default=False,
|
||||||
help="Enable auto tool choice for supported models. Use "
|
help="Enable auto tool choice for supported models. Use "
|
||||||
"``--tool-call-parser`` to specify which parser to use.")
|
"``--tool-call-parser`` to specify which parser to use.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-reasoning",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Whether to enable reasoning_content for the model. "
|
||||||
|
"If enabled, the model will be able to generate reasoning content.")
|
||||||
|
|
||||||
|
valid_reasoning_parsers = ReasoningParserManager.reasoning_parsers.keys()
|
||||||
|
parser.add_argument(
|
||||||
|
"--reasoning-parser",
|
||||||
|
type=str,
|
||||||
|
metavar="{" + ",".join(valid_reasoning_parsers) + "}",
|
||||||
|
default=None,
|
||||||
|
help=
|
||||||
|
"Select the reasoning parser depending on the model that you're using."
|
||||||
|
" This is used to parse the reasoning content into OpenAI API "
|
||||||
|
"format. Required for ``--enable-reasoning``.")
|
||||||
|
|
||||||
valid_tool_parsers = ToolParserManager.tool_parsers.keys()
|
valid_tool_parsers = ToolParserManager.tool_parsers.keys()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -267,6 +285,18 @@ def validate_parsed_serve_args(args: argparse.Namespace):
|
|||||||
raise TypeError("Error: --enable-auto-tool-choice requires "
|
raise TypeError("Error: --enable-auto-tool-choice requires "
|
||||||
"--tool-call-parser")
|
"--tool-call-parser")
|
||||||
|
|
||||||
|
# Enable reasoning needs a reasoning parser to be valid
|
||||||
|
if args.enable_reasoning and not args.reasoning_parser:
|
||||||
|
raise TypeError("Error: --enable-reasoning requires "
|
||||||
|
"--reasoning-parser")
|
||||||
|
|
||||||
|
# Ref https://api-docs.deepseek.com/guides/reasoning_model
|
||||||
|
# tool call and reasoning cannot be enabled at the same time.
|
||||||
|
if args.enable_auto_tool_choice and args.enable_reasoning:
|
||||||
|
raise TypeError(
|
||||||
|
"Error: --enable-auto-tool-choice and "
|
||||||
|
"--enable-reasoning cannot be enabled at the same time")
|
||||||
|
|
||||||
|
|
||||||
def create_parser_for_docs() -> FlexibleArgumentParser:
|
def create_parser_for_docs() -> FlexibleArgumentParser:
|
||||||
parser_for_docs = FlexibleArgumentParser(
|
parser_for_docs = FlexibleArgumentParser(
|
||||||
|
|||||||
@ -1202,6 +1202,7 @@ class ExtractedToolCallInformation(BaseModel):
|
|||||||
|
|
||||||
class ChatMessage(OpenAIBaseModel):
|
class ChatMessage(OpenAIBaseModel):
|
||||||
role: str
|
role: str
|
||||||
|
reasoning_content: Optional[str] = None
|
||||||
content: Optional[str] = None
|
content: Optional[str] = None
|
||||||
tool_calls: List[ToolCall] = Field(default_factory=list)
|
tool_calls: List[ToolCall] = Field(default_factory=list)
|
||||||
|
|
||||||
@ -1243,6 +1244,7 @@ class ChatCompletionResponse(OpenAIBaseModel):
|
|||||||
class DeltaMessage(OpenAIBaseModel):
|
class DeltaMessage(OpenAIBaseModel):
|
||||||
role: Optional[str] = None
|
role: Optional[str] = None
|
||||||
content: Optional[str] = None
|
content: Optional[str] = None
|
||||||
|
reasoning_content: Optional[str] = None
|
||||||
tool_calls: List[DeltaToolCall] = Field(default_factory=list)
|
tool_calls: List[DeltaToolCall] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
6
vllm/entrypoints/openai/reasoning_parsers/__init__.py
Normal file
6
vllm/entrypoints/openai/reasoning_parsers/__init__.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
from .abs_reasoning_parsers import ReasoningParser, ReasoningParserManager
|
||||||
|
from .deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ReasoningParser", "ReasoningParserManager", "DeepSeekR1ReasoningParser"
|
||||||
|
]
|
||||||
@ -0,0 +1,158 @@
|
|||||||
|
import os
|
||||||
|
from functools import cached_property
|
||||||
|
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Type, Union
|
||||||
|
|
||||||
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||||
|
DeltaMessage)
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
|
from vllm.utils import import_from_path, is_list_of
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ReasoningParser:
|
||||||
|
"""
|
||||||
|
Abstract reasoning parser class that should not be used directly.
|
||||||
|
Provided and methods should be used in derived classes.
|
||||||
|
|
||||||
|
It is used to extract reasoning content from the model output.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, tokenizer: AnyTokenizer):
|
||||||
|
self.model_tokenizer = tokenizer
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def vocab(self) -> Dict[str, int]:
|
||||||
|
# NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab
|
||||||
|
# whereas all tokenizers have .get_vocab()
|
||||||
|
return self.model_tokenizer.get_vocab()
|
||||||
|
|
||||||
|
def extract_reasoning_content(
|
||||||
|
self, model_output: str, request: ChatCompletionRequest
|
||||||
|
) -> Tuple[Optional[str], Optional[str]]:
|
||||||
|
"""
|
||||||
|
Extract reasoning content from a complete model-generated string.
|
||||||
|
|
||||||
|
Used for non-streaming responses where we have the entire model response
|
||||||
|
available before sending to the client.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
model_output: str
|
||||||
|
The model-generated string to extract reasoning content from.
|
||||||
|
|
||||||
|
request: ChatCompletionRequest
|
||||||
|
The request object that was used to generate the model_output.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[Optional[str], Optional[str]]
|
||||||
|
A tuple containing the reasoning content and the content.
|
||||||
|
"""
|
||||||
|
|
||||||
|
raise NotImplementedError(
|
||||||
|
"AbstractReasoningParser.extract_reasoning_calls "
|
||||||
|
"has not been implemented!")
|
||||||
|
|
||||||
|
def extract_reasoning_content_streaming(
|
||||||
|
self,
|
||||||
|
previous_text: str,
|
||||||
|
current_text: str,
|
||||||
|
delta_text: str,
|
||||||
|
previous_token_ids: Sequence[int],
|
||||||
|
current_token_ids: Sequence[int],
|
||||||
|
delta_token_ids: Sequence[int],
|
||||||
|
) -> Union[DeltaMessage, None]:
|
||||||
|
"""
|
||||||
|
Instance method that should be implemented for extracting reasoning
|
||||||
|
from an incomplete response; for use when handling reasoning calls and
|
||||||
|
streaming. Has to be an instance method because it requires state -
|
||||||
|
the current tokens/diffs, but also the information about what has
|
||||||
|
previously been parsed and extracted (see constructor)
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
"AbstractReasoningParser.extract_reasoning_content_streaming "
|
||||||
|
"has not been implemented!")
|
||||||
|
|
||||||
|
|
||||||
|
class ReasoningParserManager:
|
||||||
|
reasoning_parsers: Dict[str, Type] = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_reasoning_parser(cls, name) -> Type:
|
||||||
|
"""
|
||||||
|
Get reasoning parser by name which is registered by `register_module`.
|
||||||
|
|
||||||
|
Raise a KeyError exception if the name is not registered.
|
||||||
|
"""
|
||||||
|
if name in cls.reasoning_parsers:
|
||||||
|
return cls.reasoning_parsers[name]
|
||||||
|
|
||||||
|
raise KeyError(f"reasoning helper: '{name}' not found in "
|
||||||
|
"reasoning_parsers")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _register_module(cls,
|
||||||
|
module: Type,
|
||||||
|
module_name: Optional[Union[str, List[str]]] = None,
|
||||||
|
force: bool = True) -> None:
|
||||||
|
if not issubclass(module, ReasoningParser):
|
||||||
|
raise TypeError("module must be subclass of ReasoningParser, "
|
||||||
|
f"but got {type(module)}")
|
||||||
|
if module_name is None:
|
||||||
|
module_name = module.__name__
|
||||||
|
if isinstance(module_name, str):
|
||||||
|
module_name = [module_name]
|
||||||
|
for name in module_name:
|
||||||
|
if not force and name in cls.reasoning_parsers:
|
||||||
|
existed_module = cls.reasoning_parsers[name]
|
||||||
|
raise KeyError(f"{name} is already registered "
|
||||||
|
f"at {existed_module.__module__}")
|
||||||
|
cls.reasoning_parsers[name] = module
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_module(
|
||||||
|
cls,
|
||||||
|
name: Optional[Union[str, List[str]]] = None,
|
||||||
|
force: bool = True,
|
||||||
|
module: Union[Type, None] = None) -> Union[type, Callable]:
|
||||||
|
"""
|
||||||
|
Register module with the given name or name list. it can be used as a
|
||||||
|
decoder(with module as None) or normal function(with module as not
|
||||||
|
None).
|
||||||
|
"""
|
||||||
|
if not isinstance(force, bool):
|
||||||
|
raise TypeError(f"force must be a boolean, but got {type(force)}")
|
||||||
|
|
||||||
|
# raise the error ahead of time
|
||||||
|
if not (name is None or isinstance(name, str)
|
||||||
|
or is_list_of(name, str)):
|
||||||
|
raise TypeError(
|
||||||
|
"name must be None, an instance of str, or a sequence of str, "
|
||||||
|
f"but got {type(name)}")
|
||||||
|
|
||||||
|
# use it as a normal method: x.register_module(module=SomeClass)
|
||||||
|
if module is not None:
|
||||||
|
cls._register_module(module=module, module_name=name, force=force)
|
||||||
|
return module
|
||||||
|
|
||||||
|
# use it as a decorator: @x.register_module()
|
||||||
|
def _register(module):
|
||||||
|
cls._register_module(module=module, module_name=name, force=force)
|
||||||
|
return module
|
||||||
|
|
||||||
|
return _register
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def import_reasoning_parser(cls, plugin_path: str) -> None:
|
||||||
|
"""
|
||||||
|
Import a user-defined reasoning parser by the path
|
||||||
|
of the reasoning parser define file.
|
||||||
|
"""
|
||||||
|
module_name = os.path.splitext(os.path.basename(plugin_path))[0]
|
||||||
|
|
||||||
|
try:
|
||||||
|
import_from_path(module_name, plugin_path)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to load module '%s' from %s.",
|
||||||
|
module_name, plugin_path)
|
||||||
|
return
|
||||||
@ -0,0 +1,133 @@
|
|||||||
|
import re
|
||||||
|
from typing import Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||||
|
DeltaMessage)
|
||||||
|
from vllm.entrypoints.openai.reasoning_parsers.abs_reasoning_parsers import (
|
||||||
|
ReasoningParser, ReasoningParserManager)
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ReasoningParserManager.register_module("deepseek_r1")
|
||||||
|
class DeepSeekR1ReasoningParser(ReasoningParser):
|
||||||
|
"""
|
||||||
|
Reasoning parser for DeepSeek R1 model.
|
||||||
|
|
||||||
|
The DeepSeek R1 model uses <think>...</think> tokens to denote reasoning
|
||||||
|
text. This parser extracts the reasoning content from the model output.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, tokenizer: PreTrainedTokenizerBase):
|
||||||
|
super().__init__(tokenizer)
|
||||||
|
self.think_start_token = "<think>"
|
||||||
|
self.think_end_token = "</think>"
|
||||||
|
|
||||||
|
self.reasoning_regex = re.compile(
|
||||||
|
rf"{self.think_start_token}(.*?){self.think_end_token}", re.DOTALL)
|
||||||
|
|
||||||
|
if not self.model_tokenizer:
|
||||||
|
raise ValueError(
|
||||||
|
"The model tokenizer must be passed to the ReasoningParser "
|
||||||
|
"constructor during construction.")
|
||||||
|
|
||||||
|
self.think_start_token_id = self.vocab.get(self.think_start_token)
|
||||||
|
self.think_end_token_id = self.vocab.get(self.think_end_token)
|
||||||
|
if (self.think_start_token_id is None
|
||||||
|
or self.think_end_token_id is None):
|
||||||
|
raise RuntimeError(
|
||||||
|
"DeepSeek R1 reasoning parser could not locate think start/end "
|
||||||
|
"tokens in the tokenizer!")
|
||||||
|
|
||||||
|
def extract_reasoning_content_streaming(
|
||||||
|
self,
|
||||||
|
previous_text: str,
|
||||||
|
current_text: str,
|
||||||
|
delta_text: str,
|
||||||
|
previous_token_ids: Sequence[int],
|
||||||
|
current_token_ids: Sequence[int],
|
||||||
|
delta_token_ids: Sequence[int],
|
||||||
|
) -> Union[DeltaMessage, None]:
|
||||||
|
"""
|
||||||
|
Extract reasoning content from a delta message.
|
||||||
|
Handles streaming output where previous + delta = current.
|
||||||
|
Uses token IDs for faster processing.
|
||||||
|
For text <think>abc</think>xyz:
|
||||||
|
- 'abc' goes to reasoning_content
|
||||||
|
- 'xyz' goes to content
|
||||||
|
"""
|
||||||
|
# Skip single special tokens
|
||||||
|
if len(delta_token_ids) == 1 and (delta_token_ids[0] in [
|
||||||
|
self.think_start_token_id, self.think_end_token_id
|
||||||
|
]):
|
||||||
|
return None
|
||||||
|
|
||||||
|
if self.think_start_token_id in previous_token_ids:
|
||||||
|
if self.think_end_token_id in delta_token_ids:
|
||||||
|
# <think> in previous, </think> in delta,
|
||||||
|
# extract reasoning content
|
||||||
|
end_index = delta_text.find(self.think_end_token)
|
||||||
|
reasoning_content = delta_text[:end_index]
|
||||||
|
content = delta_text[end_index + len(self.think_end_token):]
|
||||||
|
return DeltaMessage(reasoning_content=reasoning_content,
|
||||||
|
content=content if content else None)
|
||||||
|
elif self.think_end_token_id in previous_token_ids:
|
||||||
|
# <think> in previous, </think> in previous,
|
||||||
|
# reasoning content continues
|
||||||
|
return DeltaMessage(content=delta_text)
|
||||||
|
else:
|
||||||
|
# <think> in previous, no </think> in previous or delta,
|
||||||
|
# reasoning content continues
|
||||||
|
return DeltaMessage(reasoning_content=delta_text)
|
||||||
|
elif self.think_start_token_id in delta_token_ids:
|
||||||
|
logger.info(delta_text)
|
||||||
|
if self.think_end_token_id in delta_token_ids:
|
||||||
|
# <think> in delta, </think> in delta, extract reasoning content
|
||||||
|
start_index = delta_text.find(self.think_start_token)
|
||||||
|
end_index = delta_text.find(self.think_end_token)
|
||||||
|
reasoning_content = delta_text[start_index +
|
||||||
|
len(self.think_start_token
|
||||||
|
):end_index]
|
||||||
|
content = delta_text[end_index + len(self.think_end_token):]
|
||||||
|
return DeltaMessage(reasoning_content=reasoning_content,
|
||||||
|
content=content if content else None)
|
||||||
|
else:
|
||||||
|
# <think> in delta, no </think> in delta,
|
||||||
|
# reasoning content continues
|
||||||
|
return DeltaMessage(reasoning_content=delta_text)
|
||||||
|
else:
|
||||||
|
# No <think> in previous or delta, reasoning content continues.
|
||||||
|
return DeltaMessage(content=delta_text)
|
||||||
|
|
||||||
|
def extract_reasoning_content(
|
||||||
|
self, model_output: str, request: ChatCompletionRequest
|
||||||
|
) -> Tuple[Optional[str], Optional[str]]:
|
||||||
|
|
||||||
|
# Check if the model output contains the <think> tokens.
|
||||||
|
if (self.think_start_token not in model_output
|
||||||
|
or self.think_end_token not in model_output):
|
||||||
|
return None, model_output
|
||||||
|
else:
|
||||||
|
# Use a regex to find the reasoning content
|
||||||
|
reasoning_content = self.reasoning_regex.findall(model_output)[0]
|
||||||
|
|
||||||
|
# Remove the reasoning content from the model output
|
||||||
|
# Although deepseek's <think> token is always at the
|
||||||
|
# beginning of the line, we cannot guarantee that the
|
||||||
|
# other models will follow this convention.
|
||||||
|
# Therefore, we need to add :start_index.
|
||||||
|
start_index = model_output.find(self.think_start_token)
|
||||||
|
if start_index != -1:
|
||||||
|
end_index = start_index + len(
|
||||||
|
f"{self.think_start_token}{reasoning_content}{self.think_end_token}"
|
||||||
|
)
|
||||||
|
model_output = model_output[:start_index] + \
|
||||||
|
model_output[end_index:]
|
||||||
|
|
||||||
|
if len(model_output) == 0:
|
||||||
|
return reasoning_content, None
|
||||||
|
|
||||||
|
return reasoning_content, model_output
|
||||||
@ -21,6 +21,8 @@ from vllm.entrypoints.openai.protocol import (
|
|||||||
ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
|
ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
|
||||||
DeltaToolCall, ErrorResponse, FunctionCall, PromptTokenUsageInfo,
|
DeltaToolCall, ErrorResponse, FunctionCall, PromptTokenUsageInfo,
|
||||||
RequestResponseMetadata, ToolCall, UsageInfo)
|
RequestResponseMetadata, ToolCall, UsageInfo)
|
||||||
|
from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser,
|
||||||
|
ReasoningParserManager)
|
||||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||||
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
||||||
@ -47,6 +49,8 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
chat_template: Optional[str],
|
chat_template: Optional[str],
|
||||||
chat_template_content_format: ChatTemplateContentFormatOption,
|
chat_template_content_format: ChatTemplateContentFormatOption,
|
||||||
return_tokens_as_token_ids: bool = False,
|
return_tokens_as_token_ids: bool = False,
|
||||||
|
enable_reasoning: bool = False,
|
||||||
|
reasoning_parser: Optional[str] = None,
|
||||||
enable_auto_tools: bool = False,
|
enable_auto_tools: bool = False,
|
||||||
tool_parser: Optional[str] = None,
|
tool_parser: Optional[str] = None,
|
||||||
enable_prompt_tokens_details: bool = False,
|
enable_prompt_tokens_details: bool = False,
|
||||||
@ -69,6 +73,18 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
" the parallel_tool_calls client option is preset for "
|
" the parallel_tool_calls client option is preset for "
|
||||||
"compatibility reasons, it will be ignored.")
|
"compatibility reasons, it will be ignored.")
|
||||||
|
|
||||||
|
self.enable_reasoning: bool = enable_reasoning
|
||||||
|
self.reasoning_parser: Optional[Callable[[AnyTokenizer],
|
||||||
|
ReasoningParser]] = None
|
||||||
|
if self.enable_reasoning:
|
||||||
|
try:
|
||||||
|
self.reasoning_parser = (
|
||||||
|
ReasoningParserManager.get_reasoning_parser(
|
||||||
|
reasoning_parser))
|
||||||
|
except Exception as e:
|
||||||
|
raise TypeError("Error: --enable-reasoning requires "
|
||||||
|
f"reasoning_parser:'{reasoning_parser}' "
|
||||||
|
"which has not been registered") from e
|
||||||
self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None
|
self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None
|
||||||
if self.enable_auto_tools:
|
if self.enable_auto_tools:
|
||||||
try:
|
try:
|
||||||
@ -285,14 +301,35 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
not tool_choice_function_name
|
not tool_choice_function_name
|
||||||
and self._should_stream_with_auto_tool_parsing(request))
|
and self._should_stream_with_auto_tool_parsing(request))
|
||||||
|
|
||||||
|
should_stream_with_reasoning_parsing = (
|
||||||
|
self._should_stream_with_reasoning_parsing(request))
|
||||||
|
|
||||||
all_previous_token_ids: Optional[List[List[int]]]
|
all_previous_token_ids: Optional[List[List[int]]]
|
||||||
if tool_choice_auto:
|
|
||||||
|
# Only one of these will be used, thus previous_texts and
|
||||||
|
# all_previous_token_ids will not be used twice in the same iteration.
|
||||||
|
if tool_choice_auto or should_stream_with_reasoning_parsing:
|
||||||
# These are only required in "auto" tool choice case
|
# These are only required in "auto" tool choice case
|
||||||
previous_texts = [""] * num_choices
|
previous_texts = [""] * num_choices
|
||||||
all_previous_token_ids = [[]] * num_choices
|
all_previous_token_ids = [[]] * num_choices
|
||||||
else:
|
else:
|
||||||
previous_texts, all_previous_token_ids = None, None
|
previous_texts, all_previous_token_ids = None, None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# There is no need to check if the reasoning_parser is None
|
||||||
|
# because the should_stream_with_reasoning_parsing check
|
||||||
|
# already ensures that the reasoning_parser is not None.
|
||||||
|
# but the pre-commit hook requires it.
|
||||||
|
if should_stream_with_reasoning_parsing and \
|
||||||
|
self.reasoning_parser is not None:
|
||||||
|
reasoning_parser = self.reasoning_parser(tokenizer)
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.exception("Error in reasoning parser creation.")
|
||||||
|
data = self.create_streaming_error_response(str(e))
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
return
|
||||||
|
|
||||||
# Prepare the tool parser if it's needed
|
# Prepare the tool parser if it's needed
|
||||||
try:
|
try:
|
||||||
if tool_choice_auto and self.tool_parser:
|
if tool_choice_auto and self.tool_parser:
|
||||||
@ -456,6 +493,32 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
# update the previous values for the next iteration
|
# update the previous values for the next iteration
|
||||||
previous_texts[i] = current_text
|
previous_texts[i] = current_text
|
||||||
all_previous_token_ids[i] = current_token_ids
|
all_previous_token_ids[i] = current_token_ids
|
||||||
|
# reasoning_content cannot be enabled with tool_choice.
|
||||||
|
# If it is, the tool_choice will be used instead.
|
||||||
|
elif self.enable_reasoning:
|
||||||
|
# handle reasoning_content delta
|
||||||
|
assert reasoning_parser is not None
|
||||||
|
assert previous_texts is not None
|
||||||
|
assert all_previous_token_ids is not None
|
||||||
|
previous_text = previous_texts[i]
|
||||||
|
previous_token_ids = all_previous_token_ids[i]
|
||||||
|
current_text = previous_text + delta_text
|
||||||
|
current_token_ids = previous_token_ids + list(
|
||||||
|
output.token_ids)
|
||||||
|
|
||||||
|
delta_message = (reasoning_parser.
|
||||||
|
extract_reasoning_content_streaming(
|
||||||
|
previous_text,
|
||||||
|
current_text,
|
||||||
|
delta_text,
|
||||||
|
previous_token_ids,
|
||||||
|
current_token_ids,
|
||||||
|
output.token_ids,
|
||||||
|
))
|
||||||
|
|
||||||
|
# update the previous values for the next iteration
|
||||||
|
previous_texts[i] = current_text
|
||||||
|
all_previous_token_ids[i] = current_token_ids
|
||||||
|
|
||||||
# handle streaming just a content delta
|
# handle streaming just a content delta
|
||||||
else:
|
else:
|
||||||
@ -642,17 +705,38 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
else:
|
else:
|
||||||
logprobs = None
|
logprobs = None
|
||||||
|
|
||||||
|
should_stream_with_reasoning_parsing = (
|
||||||
|
self._should_stream_with_reasoning_parsing(request))
|
||||||
|
|
||||||
# In the OpenAI API the finish_reason is "tools_called"
|
# In the OpenAI API the finish_reason is "tools_called"
|
||||||
# if the tool choice is auto and the model produced a tool
|
# if the tool choice is auto and the model produced a tool
|
||||||
# call. The same is not true for named function calls
|
# call. The same is not true for named function calls
|
||||||
auto_tools_called = False
|
auto_tools_called = False
|
||||||
|
|
||||||
|
if should_stream_with_reasoning_parsing and \
|
||||||
|
self.reasoning_parser is not None:
|
||||||
|
try:
|
||||||
|
reasoning_parser = self.reasoning_parser(tokenizer)
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.exception("Error in reasoning parser creation.")
|
||||||
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
|
reasoning_content, content = (
|
||||||
|
reasoning_parser.extract_reasoning_content(
|
||||||
|
output.text, request=request))
|
||||||
|
|
||||||
|
if reasoning_content:
|
||||||
|
message = ChatMessage(role=role,
|
||||||
|
content=content,
|
||||||
|
reasoning_content=reasoning_content)
|
||||||
|
else:
|
||||||
|
message = ChatMessage(role=role, content=output.text)
|
||||||
|
|
||||||
# if auto tools are not enabled, and a named tool choice using
|
# if auto tools are not enabled, and a named tool choice using
|
||||||
# outlines is not being used
|
# outlines is not being used
|
||||||
if (not self.enable_auto_tools
|
elif (not self.enable_auto_tools
|
||||||
or not self.tool_parser) and not isinstance(
|
or not self.tool_parser) and not isinstance(
|
||||||
request.tool_choice,
|
request.tool_choice, ChatCompletionNamedToolChoiceParam):
|
||||||
ChatCompletionNamedToolChoiceParam):
|
|
||||||
message = ChatMessage(role=role, content=output.text)
|
message = ChatMessage(role=role, content=output.text)
|
||||||
|
|
||||||
# if the request uses tools and specified a tool choice
|
# if the request uses tools and specified a tool choice
|
||||||
@ -835,6 +919,17 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
return (request.tools and self.tool_parser and self.enable_auto_tools
|
return (request.tools and self.tool_parser and self.enable_auto_tools
|
||||||
and request.tool_choice in ['auto', None])
|
and request.tool_choice in ['auto', None])
|
||||||
|
|
||||||
|
def _should_stream_with_reasoning_parsing(self,
|
||||||
|
request: ChatCompletionRequest):
|
||||||
|
"""
|
||||||
|
Utility function to check if streamed tokens should go through the
|
||||||
|
reasoning parser that was configured.
|
||||||
|
|
||||||
|
We only want to do this IF reasoning is enabled and a reasoning
|
||||||
|
parser is configured.
|
||||||
|
"""
|
||||||
|
return self.enable_reasoning and self.reasoning_parser is not None
|
||||||
|
|
||||||
def _should_check_for_unstreamed_tool_arg_tokens(
|
def _should_check_for_unstreamed_tool_arg_tokens(
|
||||||
self,
|
self,
|
||||||
delta_message: Optional[DeltaMessage],
|
delta_message: Optional[DeltaMessage],
|
||||||
|
|||||||
@ -167,6 +167,7 @@ def main():
|
|||||||
"Must be a YAML with the following options:"
|
"Must be a YAML with the following options:"
|
||||||
"https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#cli-reference"
|
"https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#cli-reference"
|
||||||
)
|
)
|
||||||
|
|
||||||
serve_parser = make_arg_parser(serve_parser)
|
serve_parser = make_arg_parser(serve_parser)
|
||||||
serve_parser.set_defaults(dispatch_function=serve)
|
serve_parser.set_defaults(dispatch_function=serve)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user