mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 14:17:16 +08:00
[Model] Add reasoning_parser and tool_parser for Ernie45 thinking (#25027)
Signed-off-by: wangyafeng <wangyafeng@baidu.com>
This commit is contained in:
parent
98f30b8cba
commit
782505ed8e
@ -11,6 +11,8 @@ vLLM currently supports the following reasoning models:
|
|||||||
| Model Series | Parser Name | Structured Output Support | Tool Calling |
|
| Model Series | Parser Name | Structured Output Support | Tool Calling |
|
||||||
|--------------|-------------|------------------|-------------|
|
|--------------|-------------|------------------|-------------|
|
||||||
| [DeepSeek R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d) | `deepseek_r1` | `json`, `regex` | ❌ |
|
| [DeepSeek R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d) | `deepseek_r1` | `json`, `regex` | ❌ |
|
||||||
|
| [ERNIE-4.5-VL series](https://huggingface.co/baidu/ERNIE-4.5-VL-28B-A3B-PT) | `ernie45` | `json`, `regex` | ❌ |
|
||||||
|
| [ERNIE-4.5-21B-A3B-Thinking](https://huggingface.co/baidu/ERNIE-4.5-21B-A3B-Thinking) | `ernie45` | `json`, `regex` | ✅ |
|
||||||
| [QwQ-32B](https://huggingface.co/Qwen/QwQ-32B) | `deepseek_r1` | `json`, `regex` | ✅ |
|
| [QwQ-32B](https://huggingface.co/Qwen/QwQ-32B) | `deepseek_r1` | `json`, `regex` | ✅ |
|
||||||
| [IBM Granite 3.2 language models](https://huggingface.co/collections/ibm-granite/granite-32-language-models-67b3bc8c13508f6d064cff9a) | `granite` | ❌ | ❌ |
|
| [IBM Granite 3.2 language models](https://huggingface.co/collections/ibm-granite/granite-32-language-models-67b3bc8c13508f6d064cff9a) | `granite` | ❌ | ❌ |
|
||||||
| [Qwen3 series](https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f) | `qwen3` | `json`, `regex` | ✅ |
|
| [Qwen3 series](https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f) | `qwen3` | `json`, `regex` | ✅ |
|
||||||
|
|||||||
124
tests/reasoning/test_ernie45_reasoning_parser.py
Normal file
124
tests/reasoning/test_ernie45_reasoning_parser.py
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from tests.reasoning.utils import run_reasoning_extraction
|
||||||
|
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||||
|
|
||||||
|
parser_name = "ernie45"
|
||||||
|
|
||||||
|
REASONING_MODEL_NAME = "baidu/ERNIE-4.5-21B-A3B-Thinking"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def ernie45_tokenizer():
|
||||||
|
return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)
|
||||||
|
|
||||||
|
|
||||||
|
# 带 </think>,非stream
|
||||||
|
WITH_THINK = {
|
||||||
|
"output": "abc</think>def",
|
||||||
|
"reasoning_content": "abc",
|
||||||
|
"content": "def",
|
||||||
|
}
|
||||||
|
# 带 </think>,stream
|
||||||
|
WITH_THINK_STREAM = {
|
||||||
|
"output": "abc</think>def",
|
||||||
|
"reasoning_content": "abc",
|
||||||
|
"content": "def",
|
||||||
|
}
|
||||||
|
# without </think>, all is reasoning_content
|
||||||
|
WITHOUT_THINK = {
|
||||||
|
"output": "abc",
|
||||||
|
"reasoning_content": "abc",
|
||||||
|
"content": None,
|
||||||
|
}
|
||||||
|
# without </think>, all is reasoning_content
|
||||||
|
WITHOUT_THINK_STREAM = {
|
||||||
|
"output": "abc",
|
||||||
|
"reasoning_content": "abc",
|
||||||
|
"content": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
COMPLETE_REASONING = {
|
||||||
|
"output": "abc</think>",
|
||||||
|
"reasoning_content": "abc",
|
||||||
|
"content": None,
|
||||||
|
}
|
||||||
|
MULTILINE_REASONING = {
|
||||||
|
"output": "abc\nABC</think>def\nDEF",
|
||||||
|
"reasoning_content": "abc\nABC",
|
||||||
|
"content": "def\nDEF",
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASES = [
|
||||||
|
pytest.param(
|
||||||
|
False,
|
||||||
|
WITH_THINK,
|
||||||
|
id="with_think",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
True,
|
||||||
|
WITH_THINK_STREAM,
|
||||||
|
id="with_think_stream",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
False,
|
||||||
|
WITHOUT_THINK,
|
||||||
|
id="without_think",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
True,
|
||||||
|
WITHOUT_THINK_STREAM,
|
||||||
|
id="without_think_stream",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
False,
|
||||||
|
COMPLETE_REASONING,
|
||||||
|
id="complete_reasoning",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
True,
|
||||||
|
COMPLETE_REASONING,
|
||||||
|
id="complete_reasoning_stream",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
False,
|
||||||
|
MULTILINE_REASONING,
|
||||||
|
id="multiline_reasoning",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
True,
|
||||||
|
MULTILINE_REASONING,
|
||||||
|
id="multiline_reasoning_stream",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("streaming, param_dict", TEST_CASES)
|
||||||
|
def test_reasoning(
|
||||||
|
streaming: bool,
|
||||||
|
param_dict: dict,
|
||||||
|
ernie45_tokenizer,
|
||||||
|
):
|
||||||
|
output = ernie45_tokenizer.tokenize(param_dict["output"])
|
||||||
|
output_tokens: list[str] = []
|
||||||
|
for token in output:
|
||||||
|
one_token = ernie45_tokenizer.convert_tokens_to_string([token])
|
||||||
|
if one_token:
|
||||||
|
output_tokens.append(one_token)
|
||||||
|
|
||||||
|
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)(
|
||||||
|
ernie45_tokenizer
|
||||||
|
)
|
||||||
|
|
||||||
|
reasoning, content = run_reasoning_extraction(
|
||||||
|
parser, output_tokens, streaming=streaming
|
||||||
|
)
|
||||||
|
|
||||||
|
print()
|
||||||
|
|
||||||
|
assert reasoning == param_dict["reasoning_content"]
|
||||||
|
assert content == param_dict["content"]
|
||||||
359
tests/tool_use/test_ernie45_moe_tool_parser.py
Normal file
359
tests/tool_use/test_ernie45_moe_tool_parser.py
Normal file
@ -0,0 +1,359 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
# ruff: noqa: E501
|
||||||
|
|
||||||
|
import json
|
||||||
|
from collections.abc import Generator
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.entrypoints.openai.protocol import (
|
||||||
|
ChatCompletionRequest,
|
||||||
|
DeltaMessage,
|
||||||
|
FunctionCall,
|
||||||
|
ToolCall,
|
||||||
|
)
|
||||||
|
from vllm.entrypoints.openai.tool_parsers import Ernie45ToolParser
|
||||||
|
from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally
|
||||||
|
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
|
||||||
|
|
||||||
|
# Use a common model that is likely to be available
|
||||||
|
MODEL = "baidu/ERNIE-4.5-21B-A3B-Thinking"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def ernie45_tokenizer():
|
||||||
|
return get_tokenizer(tokenizer_name=MODEL, trust_remote_code=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def ernie45_tool_parser(ernie45_tokenizer):
|
||||||
|
return Ernie45ToolParser(ernie45_tokenizer)
|
||||||
|
|
||||||
|
|
||||||
|
def assert_tool_calls(
|
||||||
|
actual_tool_calls: list[ToolCall], expected_tool_calls: list[ToolCall]
|
||||||
|
):
|
||||||
|
assert len(actual_tool_calls) == len(expected_tool_calls)
|
||||||
|
|
||||||
|
for actual_tool_call, expected_tool_call in zip(
|
||||||
|
actual_tool_calls, expected_tool_calls
|
||||||
|
):
|
||||||
|
assert isinstance(actual_tool_call.id, str)
|
||||||
|
assert len(actual_tool_call.id) > 0
|
||||||
|
|
||||||
|
assert actual_tool_call.type == "function"
|
||||||
|
assert actual_tool_call.function.name == expected_tool_call.function.name
|
||||||
|
# Compare arguments as JSON objects to handle formatting differences
|
||||||
|
actual_args = json.loads(actual_tool_call.function.arguments)
|
||||||
|
expected_args = json.loads(expected_tool_call.function.arguments)
|
||||||
|
assert actual_args == expected_args
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_tool_calls_no_tools(ernie45_tool_parser):
|
||||||
|
model_output = "This is a test"
|
||||||
|
extracted_tool_calls = ernie45_tool_parser.extract_tool_calls(
|
||||||
|
model_output, request=None
|
||||||
|
) # type: ignore[arg-type]
|
||||||
|
assert not extracted_tool_calls.tools_called
|
||||||
|
assert extracted_tool_calls.tool_calls == []
|
||||||
|
assert extracted_tool_calls.content == model_output
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
ids=[
|
||||||
|
"single_tool_call",
|
||||||
|
"multiple_tool_calls",
|
||||||
|
"tool_call_with_content_before",
|
||||||
|
],
|
||||||
|
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||||
|
argvalues=[
|
||||||
|
(
|
||||||
|
"""<tool_call>
|
||||||
|
{"name": "get_current_temperature", "arguments": {"location": "Beijing"}}
|
||||||
|
</tool_call>
|
||||||
|
""",
|
||||||
|
[
|
||||||
|
ToolCall(
|
||||||
|
function=FunctionCall(
|
||||||
|
name="get_current_temperature",
|
||||||
|
arguments=json.dumps(
|
||||||
|
{
|
||||||
|
"location": "Beijing",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
],
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"""<tool_call>
|
||||||
|
{"name": "get_current_temperature", "arguments": {"location": "Beijing"}}
|
||||||
|
</tool_call>
|
||||||
|
<tool_call>
|
||||||
|
{"name": "get_temperature_unit", "arguments": {"location": "Guangzhou", "unit": "c"}}
|
||||||
|
</tool_call>
|
||||||
|
""",
|
||||||
|
[
|
||||||
|
ToolCall(
|
||||||
|
function=FunctionCall(
|
||||||
|
name="get_current_temperature",
|
||||||
|
arguments=json.dumps(
|
||||||
|
{
|
||||||
|
"location": "Beijing",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
ToolCall(
|
||||||
|
function=FunctionCall(
|
||||||
|
name="get_temperature_unit",
|
||||||
|
arguments=json.dumps(
|
||||||
|
{
|
||||||
|
"location": "Guangzhou",
|
||||||
|
"unit": "c",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
],
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"""I need to call two tools to handle these two issues separately.
|
||||||
|
</think>
|
||||||
|
|
||||||
|
<tool_call>
|
||||||
|
{"name": "get_current_temperature", "arguments": {"location": "Beijing"}}
|
||||||
|
</tool_call>
|
||||||
|
<tool_call>
|
||||||
|
{"name": "get_temperature_unit", "arguments": {"location": "Guangzhou", "unit": "c"}}
|
||||||
|
</tool_call>
|
||||||
|
""",
|
||||||
|
[
|
||||||
|
ToolCall(
|
||||||
|
function=FunctionCall(
|
||||||
|
name="get_current_temperature",
|
||||||
|
arguments=json.dumps(
|
||||||
|
{
|
||||||
|
"location": "Beijing",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
ToolCall(
|
||||||
|
function=FunctionCall(
|
||||||
|
name="get_temperature_unit",
|
||||||
|
arguments=json.dumps(
|
||||||
|
{
|
||||||
|
"location": "Guangzhou",
|
||||||
|
"unit": "c",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
],
|
||||||
|
"I need to call two tools to handle these two issues separately.\n</think>",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_extract_tool_calls(
|
||||||
|
ernie45_tool_parser, model_output, expected_tool_calls, expected_content
|
||||||
|
):
|
||||||
|
extracted_tool_calls = ernie45_tool_parser.extract_tool_calls(
|
||||||
|
model_output, request=None
|
||||||
|
) # type: ignore[arg-type]
|
||||||
|
assert extracted_tool_calls.tools_called
|
||||||
|
|
||||||
|
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
|
||||||
|
|
||||||
|
assert extracted_tool_calls.content == expected_content
|
||||||
|
|
||||||
|
|
||||||
|
def stream_delta_message_generator(
|
||||||
|
ernie45_tool_parser: Ernie45ToolParser,
|
||||||
|
ernie45_tokenizer: AnyTokenizer,
|
||||||
|
model_output: str,
|
||||||
|
request: ChatCompletionRequest | None = None,
|
||||||
|
) -> Generator[DeltaMessage, None, None]:
|
||||||
|
all_token_ids = ernie45_tokenizer.encode(model_output, add_special_tokens=False)
|
||||||
|
|
||||||
|
previous_text = ""
|
||||||
|
previous_tokens = None
|
||||||
|
prefix_offset = 0
|
||||||
|
read_offset = 0
|
||||||
|
for i, delta_token in enumerate(all_token_ids):
|
||||||
|
delta_token_ids = [delta_token]
|
||||||
|
previous_token_ids = all_token_ids[:i]
|
||||||
|
current_token_ids = all_token_ids[: i + 1]
|
||||||
|
|
||||||
|
(new_tokens, delta_text, new_prefix_offset, new_read_offset) = (
|
||||||
|
detokenize_incrementally(
|
||||||
|
tokenizer=ernie45_tokenizer,
|
||||||
|
all_input_ids=current_token_ids,
|
||||||
|
prev_tokens=previous_tokens,
|
||||||
|
prefix_offset=prefix_offset,
|
||||||
|
read_offset=read_offset,
|
||||||
|
skip_special_tokens=False,
|
||||||
|
spaces_between_special_tokens=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
current_text = previous_text + delta_text
|
||||||
|
|
||||||
|
delta_message = ernie45_tool_parser.extract_tool_calls_streaming(
|
||||||
|
previous_text,
|
||||||
|
current_text,
|
||||||
|
delta_text,
|
||||||
|
previous_token_ids,
|
||||||
|
current_token_ids,
|
||||||
|
delta_token_ids,
|
||||||
|
request=request,
|
||||||
|
)
|
||||||
|
if delta_message:
|
||||||
|
yield delta_message
|
||||||
|
|
||||||
|
previous_text = current_text
|
||||||
|
previous_tokens = (
|
||||||
|
previous_tokens + new_tokens if previous_tokens else new_tokens
|
||||||
|
)
|
||||||
|
prefix_offset = new_prefix_offset
|
||||||
|
read_offset = new_read_offset
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
ids=[
|
||||||
|
"single_tool_call",
|
||||||
|
"multiple_tool_calls",
|
||||||
|
"tool_call_with_content_before",
|
||||||
|
],
|
||||||
|
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||||
|
argvalues=[
|
||||||
|
(
|
||||||
|
"""<tool_call>
|
||||||
|
{"name": "get_current_temperature", "arguments": {"location": "Beijing"}}
|
||||||
|
</tool_call>
|
||||||
|
""",
|
||||||
|
[
|
||||||
|
ToolCall(
|
||||||
|
function=FunctionCall(
|
||||||
|
name="get_current_temperature",
|
||||||
|
arguments=json.dumps(
|
||||||
|
{
|
||||||
|
"location": "Beijing",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
],
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"""<tool_call>
|
||||||
|
{"name": "get_current_temperature", "arguments": {"location": "Beijing"}}
|
||||||
|
</tool_call>
|
||||||
|
<tool_call>
|
||||||
|
{"name": "get_temperature_unit", "arguments": {"location": "Guangzhou", "unit": "c"}}
|
||||||
|
</tool_call>
|
||||||
|
""",
|
||||||
|
[
|
||||||
|
ToolCall(
|
||||||
|
function=FunctionCall(
|
||||||
|
name="get_current_temperature",
|
||||||
|
arguments=json.dumps(
|
||||||
|
{
|
||||||
|
"location": "Beijing",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
ToolCall(
|
||||||
|
function=FunctionCall(
|
||||||
|
name="get_temperature_unit",
|
||||||
|
arguments=json.dumps(
|
||||||
|
{
|
||||||
|
"location": "Guangzhou",
|
||||||
|
"unit": "c",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
],
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"""I need to call two tools to handle these two issues separately.
|
||||||
|
</think>
|
||||||
|
|
||||||
|
<tool_call>
|
||||||
|
{"name": "get_current_temperature", "arguments": {"location": "Beijing"}}
|
||||||
|
</tool_call>
|
||||||
|
<tool_call>
|
||||||
|
{"name": "get_temperature_unit", "arguments": {"location": "Guangzhou", "unit": "c"}}
|
||||||
|
</tool_call>
|
||||||
|
""",
|
||||||
|
[
|
||||||
|
ToolCall(
|
||||||
|
function=FunctionCall(
|
||||||
|
name="get_current_temperature",
|
||||||
|
arguments=json.dumps(
|
||||||
|
{
|
||||||
|
"location": "Beijing",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
ToolCall(
|
||||||
|
function=FunctionCall(
|
||||||
|
name="get_temperature_unit",
|
||||||
|
arguments=json.dumps(
|
||||||
|
{
|
||||||
|
"location": "Guangzhou",
|
||||||
|
"unit": "c",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
],
|
||||||
|
"I need to call two tools to handle these two issues separately.\n</think>",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_extract_tool_calls_streaming_incremental(
|
||||||
|
ernie45_tool_parser,
|
||||||
|
ernie45_tokenizer,
|
||||||
|
model_output,
|
||||||
|
expected_tool_calls,
|
||||||
|
expected_content,
|
||||||
|
):
|
||||||
|
"""Verify the Ernie45 Parser streaming behavior by verifying each chunk is as expected.""" # noqa: E501
|
||||||
|
request = ChatCompletionRequest(model=MODEL, messages=[], tools=[])
|
||||||
|
|
||||||
|
tool_calls_dict = {}
|
||||||
|
for delta_message in stream_delta_message_generator(
|
||||||
|
ernie45_tool_parser, ernie45_tokenizer, model_output, request
|
||||||
|
):
|
||||||
|
if (
|
||||||
|
delta_message.role is None
|
||||||
|
and delta_message.content is None
|
||||||
|
and delta_message.reasoning_content is None
|
||||||
|
and len(delta_message.tool_calls) == 0
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
tool_calls = delta_message.tool_calls
|
||||||
|
for tool_call_chunk in tool_calls:
|
||||||
|
index = tool_call_chunk.index
|
||||||
|
if index not in tool_calls_dict:
|
||||||
|
if tool_call_chunk.function.arguments is None:
|
||||||
|
tool_call_chunk.function.arguments = ""
|
||||||
|
tool_calls_dict[index] = tool_call_chunk
|
||||||
|
else:
|
||||||
|
tool_calls_dict[
|
||||||
|
index
|
||||||
|
].function.arguments += tool_call_chunk.function.arguments
|
||||||
|
actual_tool_calls = list(tool_calls_dict.values())
|
||||||
|
|
||||||
|
assert len(actual_tool_calls) > 0
|
||||||
|
# check tool call format
|
||||||
|
assert_tool_calls(actual_tool_calls, expected_tool_calls)
|
||||||
@ -4,6 +4,7 @@
|
|||||||
from .abstract_tool_parser import ToolParser, ToolParserManager
|
from .abstract_tool_parser import ToolParser, ToolParserManager
|
||||||
from .deepseekv3_tool_parser import DeepSeekV3ToolParser
|
from .deepseekv3_tool_parser import DeepSeekV3ToolParser
|
||||||
from .deepseekv31_tool_parser import DeepSeekV31ToolParser
|
from .deepseekv31_tool_parser import DeepSeekV31ToolParser
|
||||||
|
from .ernie45_tool_parser import Ernie45ToolParser
|
||||||
from .glm4_moe_tool_parser import Glm4MoeModelToolParser
|
from .glm4_moe_tool_parser import Glm4MoeModelToolParser
|
||||||
from .granite_20b_fc_tool_parser import Granite20bFCToolParser
|
from .granite_20b_fc_tool_parser import Granite20bFCToolParser
|
||||||
from .granite_tool_parser import GraniteToolParser
|
from .granite_tool_parser import GraniteToolParser
|
||||||
@ -42,6 +43,7 @@ __all__ = [
|
|||||||
"Phi4MiniJsonToolParser",
|
"Phi4MiniJsonToolParser",
|
||||||
"DeepSeekV3ToolParser",
|
"DeepSeekV3ToolParser",
|
||||||
"DeepSeekV31ToolParser",
|
"DeepSeekV31ToolParser",
|
||||||
|
"Ernie45ToolParser",
|
||||||
"xLAMToolParser",
|
"xLAMToolParser",
|
||||||
"MinimaxToolParser",
|
"MinimaxToolParser",
|
||||||
"KimiK2ToolParser",
|
"KimiK2ToolParser",
|
||||||
|
|||||||
212
vllm/entrypoints/openai/tool_parsers/ernie45_tool_parser.py
Normal file
212
vllm/entrypoints/openai/tool_parsers/ernie45_tool_parser.py
Normal file
@ -0,0 +1,212 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import json
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
import regex as re
|
||||||
|
|
||||||
|
from vllm.entrypoints.openai.protocol import (
|
||||||
|
ChatCompletionRequest,
|
||||||
|
DeltaFunctionCall,
|
||||||
|
DeltaMessage,
|
||||||
|
DeltaToolCall,
|
||||||
|
ExtractedToolCallInformation,
|
||||||
|
FunctionCall,
|
||||||
|
ToolCall,
|
||||||
|
)
|
||||||
|
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||||
|
ToolParser,
|
||||||
|
ToolParserManager,
|
||||||
|
)
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ToolParserManager.register_module("ernie45")
|
||||||
|
class Ernie45ToolParser(ToolParser):
|
||||||
|
def __init__(self, tokenizer: AnyTokenizer):
|
||||||
|
"""
|
||||||
|
Ernie thinking model format:
|
||||||
|
abc\n</think>\n\n\n<tool_call>\ndef\n</tool_call>\n
|
||||||
|
"""
|
||||||
|
super().__init__(tokenizer)
|
||||||
|
self.current_tool_name_sent = False
|
||||||
|
self.prev_tool_call_arr: list[dict] = []
|
||||||
|
self.current_tool_id = -1
|
||||||
|
self.streamed_args_for_tool: list[str] = []
|
||||||
|
self.think_end_token = "</think>"
|
||||||
|
self.response_start_token: str = "<response>"
|
||||||
|
self.response_end_token: str = "</response>"
|
||||||
|
self.tool_call_start_token = "<tool_call>"
|
||||||
|
self.tool_call_end_token = "</tool_call>"
|
||||||
|
self.tool_calls_start_token = self.tool_call_start_token
|
||||||
|
self.newline_token: str = "<0x0A>"
|
||||||
|
|
||||||
|
self.tool_call_regex = re.compile(
|
||||||
|
r"<tool_call>\s*(?P<json>\{.*?\})\s*</tool_call>", re.DOTALL
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.model_tokenizer:
|
||||||
|
raise ValueError(
|
||||||
|
"The model tokenizer must be passed to the ToolParser "
|
||||||
|
"constructor during construction."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.think_end_token_id = self.vocab.get(self.think_end_token)
|
||||||
|
self.response_start_token_id = self.vocab.get(self.response_start_token)
|
||||||
|
self.response_end_token_id = self.vocab.get(self.response_end_token)
|
||||||
|
self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
|
||||||
|
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
|
||||||
|
self.newline_token_id = self.vocab.get(self.newline_token)
|
||||||
|
self.parser_token_ids = [
|
||||||
|
self.think_end_token_id,
|
||||||
|
self.response_start_token_id,
|
||||||
|
self.response_end_token_id,
|
||||||
|
]
|
||||||
|
|
||||||
|
self._buffer = ""
|
||||||
|
|
||||||
|
def extract_tool_calls(
|
||||||
|
self,
|
||||||
|
model_output: str,
|
||||||
|
request: ChatCompletionRequest,
|
||||||
|
) -> ExtractedToolCallInformation:
|
||||||
|
# sanity check; avoid unnecessary processing
|
||||||
|
if self.tool_calls_start_token not in model_output:
|
||||||
|
return ExtractedToolCallInformation(
|
||||||
|
tools_called=False, tool_calls=[], content=model_output
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
tool_call_json_list = self.tool_call_regex.findall(model_output)
|
||||||
|
|
||||||
|
tool_calls = []
|
||||||
|
for tool_call_json in tool_call_json_list:
|
||||||
|
tool_call_dict = json.loads(tool_call_json)
|
||||||
|
args_str = json.dumps(
|
||||||
|
tool_call_dict.get("arguments", {}), ensure_ascii=False
|
||||||
|
)
|
||||||
|
tool_calls.append(
|
||||||
|
ToolCall(
|
||||||
|
type="function",
|
||||||
|
function=FunctionCall(
|
||||||
|
name=tool_call_dict.get("name", ""),
|
||||||
|
arguments=args_str,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
content = model_output[
|
||||||
|
: model_output.find(self.tool_calls_start_token)
|
||||||
|
].rstrip("\n")
|
||||||
|
return ExtractedToolCallInformation(
|
||||||
|
tools_called=True,
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
content=content if content else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error in extracting tool call from response.")
|
||||||
|
return ExtractedToolCallInformation(
|
||||||
|
tools_called=False, tool_calls=[], content=model_output
|
||||||
|
)
|
||||||
|
|
||||||
|
def extract_tool_calls_streaming(
|
||||||
|
self,
|
||||||
|
previous_text: str,
|
||||||
|
current_text: str,
|
||||||
|
delta_text: str,
|
||||||
|
previous_token_ids: Sequence[int],
|
||||||
|
current_token_ids: Sequence[int],
|
||||||
|
delta_token_ids: Sequence[int],
|
||||||
|
request: ChatCompletionRequest,
|
||||||
|
) -> DeltaMessage | None:
|
||||||
|
self._buffer += delta_text
|
||||||
|
cur_text = self._buffer
|
||||||
|
start_idx = cur_text.find(self.tool_call_start_token)
|
||||||
|
if start_idx == -1:
|
||||||
|
self._buffer = ""
|
||||||
|
# At least one toolcall has been completed
|
||||||
|
if self.current_tool_id > 0:
|
||||||
|
cur_text = ""
|
||||||
|
if self.current_tool_id == -1 and all(
|
||||||
|
token_id == self.newline_token_id for token_id in previous_token_ids
|
||||||
|
):
|
||||||
|
cur_text = cur_text.strip("\n")
|
||||||
|
|
||||||
|
# handle <response> </response> when tool_call is not triggered
|
||||||
|
# cur_text === delta_text
|
||||||
|
content = cur_text
|
||||||
|
if self.response_start_token_id in delta_token_ids:
|
||||||
|
content = content.lstrip("\n")
|
||||||
|
response_start_idx = content.find(self.response_start_token)
|
||||||
|
content = content[response_start_idx + len(self.response_start_token) :]
|
||||||
|
# if have </response>, remove it
|
||||||
|
response_end_idx = content.rfind(self.response_end_token)
|
||||||
|
if response_end_idx != -1:
|
||||||
|
content = content[:response_end_idx]
|
||||||
|
elif self.response_end_token_id in delta_token_ids:
|
||||||
|
response_end_idx = content.rfind(self.response_end_token)
|
||||||
|
content = content[:response_end_idx]
|
||||||
|
# remove \n after </think> or <response> or </response>
|
||||||
|
if (
|
||||||
|
len(previous_token_ids) > 0
|
||||||
|
and previous_token_ids[-1] in self.parser_token_ids
|
||||||
|
) and (
|
||||||
|
len(delta_token_ids) > 0 and delta_token_ids[0] == self.newline_token_id
|
||||||
|
):
|
||||||
|
content = content.lstrip("\n")
|
||||||
|
|
||||||
|
return DeltaMessage(content=content if content else None)
|
||||||
|
logger.debug("cur_text = %s", cur_text)
|
||||||
|
end_idx = cur_text.find(self.tool_call_end_token)
|
||||||
|
if end_idx != -1:
|
||||||
|
if self.current_tool_id == -1:
|
||||||
|
self.current_tool_id = 0
|
||||||
|
self.prev_tool_call_arr = []
|
||||||
|
self.streamed_args_for_tool = []
|
||||||
|
while len(self.prev_tool_call_arr) <= self.current_tool_id:
|
||||||
|
self.prev_tool_call_arr.append({})
|
||||||
|
while len(self.streamed_args_for_tool) <= self.current_tool_id:
|
||||||
|
self.streamed_args_for_tool.append("")
|
||||||
|
|
||||||
|
extracted_tool_calls = self.extract_tool_calls(
|
||||||
|
cur_text[: end_idx + len(self.tool_call_end_token)], request
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(extracted_tool_calls.tool_calls) == 0:
|
||||||
|
logger.warning("Failed to extract any tool calls.")
|
||||||
|
return None
|
||||||
|
tool_call = extracted_tool_calls.tool_calls[0]
|
||||||
|
self.prev_tool_call_arr[self.current_tool_id] = {
|
||||||
|
"name": tool_call.function.name,
|
||||||
|
"arguments": json.loads(tool_call.function.arguments),
|
||||||
|
}
|
||||||
|
self.streamed_args_for_tool[self.current_tool_id] = (
|
||||||
|
tool_call.function.arguments
|
||||||
|
)
|
||||||
|
delta = DeltaMessage(
|
||||||
|
content=extracted_tool_calls.content,
|
||||||
|
tool_calls=[
|
||||||
|
DeltaToolCall(
|
||||||
|
index=self.current_tool_id,
|
||||||
|
id=tool_call.id,
|
||||||
|
type=tool_call.type,
|
||||||
|
function=DeltaFunctionCall(
|
||||||
|
name=tool_call.function.name,
|
||||||
|
arguments=tool_call.function.arguments,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
self.current_tool_id += 1
|
||||||
|
self._buffer = cur_text[end_idx + len(self.tool_call_end_token) :]
|
||||||
|
return delta
|
||||||
|
|
||||||
|
self._buffer = cur_text[start_idx:]
|
||||||
|
content = cur_text[:start_idx].rstrip("\n")
|
||||||
|
return DeltaMessage(content=content if content else None)
|
||||||
@ -4,6 +4,7 @@
|
|||||||
from .abs_reasoning_parsers import ReasoningParser, ReasoningParserManager
|
from .abs_reasoning_parsers import ReasoningParser, ReasoningParserManager
|
||||||
from .basic_parsers import BaseThinkingReasoningParser
|
from .basic_parsers import BaseThinkingReasoningParser
|
||||||
from .deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
|
from .deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
|
||||||
|
from .ernie45_reasoning_parser import Ernie45ReasoningParser
|
||||||
from .glm4_moe_reasoning_parser import Glm4MoeModelReasoningParser
|
from .glm4_moe_reasoning_parser import Glm4MoeModelReasoningParser
|
||||||
from .gptoss_reasoning_parser import GptOssReasoningParser
|
from .gptoss_reasoning_parser import GptOssReasoningParser
|
||||||
from .granite_reasoning_parser import GraniteReasoningParser
|
from .granite_reasoning_parser import GraniteReasoningParser
|
||||||
@ -19,6 +20,7 @@ __all__ = [
|
|||||||
"BaseThinkingReasoningParser",
|
"BaseThinkingReasoningParser",
|
||||||
"ReasoningParserManager",
|
"ReasoningParserManager",
|
||||||
"DeepSeekR1ReasoningParser",
|
"DeepSeekR1ReasoningParser",
|
||||||
|
"Ernie45ReasoningParser",
|
||||||
"GraniteReasoningParser",
|
"GraniteReasoningParser",
|
||||||
"HunyuanA13BReasoningParser",
|
"HunyuanA13BReasoningParser",
|
||||||
"Qwen3ReasoningParser",
|
"Qwen3ReasoningParser",
|
||||||
|
|||||||
169
vllm/reasoning/ernie45_reasoning_parser.py
Normal file
169
vllm/reasoning/ernie45_reasoning_parser.py
Normal file
@ -0,0 +1,169 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
|
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.reasoning import ReasoningParserManager
|
||||||
|
from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ReasoningParserManager.register_module("ernie45")
|
||||||
|
class Ernie45ReasoningParser(BaseThinkingReasoningParser):
|
||||||
|
"""
|
||||||
|
Reasoning parser for Ernie45 thinking model.
|
||||||
|
The Ernie45 thinking model ouput format is
|
||||||
|
abc\n</think>\n\n<response>\ndef\n</response>\n
|
||||||
|
or abc\n</think>\ndef
|
||||||
|
"""
|
||||||
|
|
||||||
|
response_start_token: str = "<response>"
|
||||||
|
response_end_token: str = "</response>"
|
||||||
|
newline_token: str = "<0x0A>"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def start_token(self) -> str:
|
||||||
|
"""The token that starts reasoning content."""
|
||||||
|
return "<think>"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def end_token(self) -> str:
|
||||||
|
"""The token that ends reasoning content."""
|
||||||
|
return "</think>"
|
||||||
|
|
||||||
|
def __init__(self, tokenizer: PreTrainedTokenizerBase):
|
||||||
|
super().__init__(tokenizer)
|
||||||
|
|
||||||
|
if not self.model_tokenizer:
|
||||||
|
raise ValueError(
|
||||||
|
"The model tokenizer must be passed to the ReasoningParser "
|
||||||
|
"constructor during construction."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.start_token_id = self.vocab.get(self.start_token)
|
||||||
|
self.end_token_id = self.vocab.get(self.end_token)
|
||||||
|
self.response_start_token_id = self.vocab.get(self.response_start_token)
|
||||||
|
self.response_end_token_id = self.vocab.get(self.response_end_token)
|
||||||
|
self.newline_token_id = self.vocab.get(self.newline_token)
|
||||||
|
|
||||||
|
self.parser_token_ids = [self.end_token_id, self.response_end_token_id]
|
||||||
|
|
||||||
|
if self.start_token_id is None or self.end_token_id is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Ernie45 reasoning parser could not locate think start/end "
|
||||||
|
"tokens in the tokenizer!"
|
||||||
|
)
|
||||||
|
|
||||||
|
def extract_reasoning_content_streaming(
|
||||||
|
self,
|
||||||
|
previous_text: str,
|
||||||
|
current_text: str,
|
||||||
|
delta_text: str,
|
||||||
|
previous_token_ids: Sequence[int],
|
||||||
|
current_token_ids: Sequence[int],
|
||||||
|
delta_token_ids: Sequence[int],
|
||||||
|
) -> DeltaMessage | None:
|
||||||
|
"""
|
||||||
|
Extract reasoning content from a delta message.
|
||||||
|
Handles streaming output where previous + delta = current.
|
||||||
|
Uses token IDs for faster processing.
|
||||||
|
The Ernie45 thinking model ouput format is
|
||||||
|
abc\n</think>\n\n<response>\ndef\n</response>\n
|
||||||
|
or abc\n</think>\ndef
|
||||||
|
- 'abc' goes to reasoning_content
|
||||||
|
- 'def' goes to content
|
||||||
|
"""
|
||||||
|
# Skip single special tokens
|
||||||
|
if len(delta_token_ids) == 1 and (
|
||||||
|
delta_token_ids[0]
|
||||||
|
in [
|
||||||
|
self.start_token_id,
|
||||||
|
self.end_token_id,
|
||||||
|
self.response_start_token_id,
|
||||||
|
self.response_end_token_id,
|
||||||
|
]
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# No <think> in previous or delta, also need to check for </think>.
|
||||||
|
# Because the model may have generated </think> without <think>
|
||||||
|
if self.end_token_id in delta_token_ids:
|
||||||
|
# </think> in delta with more tokens,
|
||||||
|
# extract reasoning content and content
|
||||||
|
think_end_index = delta_text.find(self.end_token)
|
||||||
|
reasoning_content = delta_text[:think_end_index]
|
||||||
|
content = delta_text[think_end_index + len(self.end_token) :]
|
||||||
|
content = content.lstrip("\n")
|
||||||
|
response_start_idx = content.find(self.response_start_token)
|
||||||
|
response_end_idx = content.rfind(self.response_end_token)
|
||||||
|
if response_start_idx != -1:
|
||||||
|
content = content[response_start_idx + len(self.response_start_token) :]
|
||||||
|
if response_end_idx != -1:
|
||||||
|
content = content[:response_end_idx]
|
||||||
|
return DeltaMessage(
|
||||||
|
reasoning_content=reasoning_content,
|
||||||
|
content=content if content else None,
|
||||||
|
)
|
||||||
|
elif self.end_token_id in previous_token_ids:
|
||||||
|
# </think> in previous, thinking content ends
|
||||||
|
content = delta_text
|
||||||
|
if self.response_start_token_id in delta_token_ids:
|
||||||
|
content = content.lstrip("\n")
|
||||||
|
response_start_idx = content.find(self.response_start_token)
|
||||||
|
content = content[response_start_idx + len(self.response_start_token) :]
|
||||||
|
# if have </response>, remove it
|
||||||
|
response_end_idx = content.rfind(self.response_end_token)
|
||||||
|
if response_end_idx != -1:
|
||||||
|
content = content[:response_end_idx]
|
||||||
|
elif self.response_end_token_id in delta_token_ids:
|
||||||
|
response_end_idx = content.rfind(self.response_end_token)
|
||||||
|
content = content[:response_end_idx]
|
||||||
|
# remove \n after </think> or </response>
|
||||||
|
if previous_token_ids[-1] in self.parser_token_ids and (
|
||||||
|
len(delta_token_ids) > 0 and delta_token_ids[0] == self.newline_token_id
|
||||||
|
):
|
||||||
|
content = content.lstrip("\n")
|
||||||
|
# remove \n after </think>\n
|
||||||
|
if (
|
||||||
|
len(previous_token_ids) > 1
|
||||||
|
and previous_token_ids[-2] == self.end_token_id
|
||||||
|
) and (
|
||||||
|
len(delta_token_ids) > 0 and delta_token_ids[0] == self.newline_token_id
|
||||||
|
):
|
||||||
|
content = content.lstrip("\n")
|
||||||
|
|
||||||
|
return DeltaMessage(content=content if content else None)
|
||||||
|
else:
|
||||||
|
# no </think> in previous or delta, reasoning content continues
|
||||||
|
return DeltaMessage(reasoning_content=delta_text)
|
||||||
|
|
||||||
|
def extract_reasoning_content(
|
||||||
|
self, model_output: str, request: ChatCompletionRequest
|
||||||
|
) -> tuple[str | None, str | None]:
|
||||||
|
"""
|
||||||
|
Extract reasoning content from the model output.
|
||||||
|
The Ernie45 thinking model ouput format is
|
||||||
|
abc\n</think>\n\n\n<response>\ndef\n</response>\n
|
||||||
|
or abc\n</think>\ndef
|
||||||
|
- 'abc' goes to reasoning_content
|
||||||
|
- 'def' goes to content
|
||||||
|
Returns:
|
||||||
|
tuple[Optional[str], Optional[str]]: reasoning content and content
|
||||||
|
"""
|
||||||
|
reasoning_content, content = super().extract_reasoning_content(
|
||||||
|
model_output, request
|
||||||
|
)
|
||||||
|
if content:
|
||||||
|
start_idx = content.find(self.response_start_token)
|
||||||
|
end_idx = content.rfind(self.response_end_token)
|
||||||
|
# Simultaneously existing and in the correct order
|
||||||
|
if start_idx != -1 and end_idx != -1 and start_idx < end_idx:
|
||||||
|
content = content[start_idx + len(self.response_start_token) : end_idx]
|
||||||
|
final_content = content or None
|
||||||
|
|
||||||
|
return reasoning_content, final_content
|
||||||
Loading…
x
Reference in New Issue
Block a user