From 782505ed8eb4f1b27cccd009a8dc9b69f6ad6ebc Mon Sep 17 00:00:00 2001 From: CSWYF3634076 Date: Mon, 13 Oct 2025 15:55:20 +0800 Subject: [PATCH] [Model] Add reasoning_parser and tool_parser for Ernie45 thinking (#25027) Signed-off-by: wangyafeng --- docs/features/reasoning_outputs.md | 2 + .../test_ernie45_reasoning_parser.py | 124 ++++++ .../tool_use/test_ernie45_moe_tool_parser.py | 359 ++++++++++++++++++ .../openai/tool_parsers/__init__.py | 2 + .../tool_parsers/ernie45_tool_parser.py | 212 +++++++++++ vllm/reasoning/__init__.py | 2 + vllm/reasoning/ernie45_reasoning_parser.py | 169 +++++++++ 7 files changed, 870 insertions(+) create mode 100644 tests/reasoning/test_ernie45_reasoning_parser.py create mode 100644 tests/tool_use/test_ernie45_moe_tool_parser.py create mode 100644 vllm/entrypoints/openai/tool_parsers/ernie45_tool_parser.py create mode 100644 vllm/reasoning/ernie45_reasoning_parser.py diff --git a/docs/features/reasoning_outputs.md b/docs/features/reasoning_outputs.md index 85681669dfb22..389b3cb21ef5d 100644 --- a/docs/features/reasoning_outputs.md +++ b/docs/features/reasoning_outputs.md @@ -11,6 +11,8 @@ vLLM currently supports the following reasoning models: | Model Series | Parser Name | Structured Output Support | Tool Calling | |--------------|-------------|------------------|-------------| | [DeepSeek R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d) | `deepseek_r1` | `json`, `regex` | ❌ | +| [ERNIE-4.5-VL series](https://huggingface.co/baidu/ERNIE-4.5-VL-28B-A3B-PT) | `ernie45` | `json`, `regex` | ❌ | +| [ERNIE-4.5-21B-A3B-Thinking](https://huggingface.co/baidu/ERNIE-4.5-21B-A3B-Thinking) | `ernie45` | `json`, `regex` | ✅ | | [QwQ-32B](https://huggingface.co/Qwen/QwQ-32B) | `deepseek_r1` | `json`, `regex` | ✅ | | [IBM Granite 3.2 language models](https://huggingface.co/collections/ibm-granite/granite-32-language-models-67b3bc8c13508f6d064cff9a) | `granite` | ❌ | ❌ | | [Qwen3 series](https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f) | `qwen3` | `json`, `regex` | ✅ | diff --git a/tests/reasoning/test_ernie45_reasoning_parser.py b/tests/reasoning/test_ernie45_reasoning_parser.py new file mode 100644 index 0000000000000..344478013e6b4 --- /dev/null +++ b/tests/reasoning/test_ernie45_reasoning_parser.py @@ -0,0 +1,124 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +from transformers import AutoTokenizer + +from tests.reasoning.utils import run_reasoning_extraction +from vllm.reasoning import ReasoningParser, ReasoningParserManager + +parser_name = "ernie45" + +REASONING_MODEL_NAME = "baidu/ERNIE-4.5-21B-A3B-Thinking" + + +@pytest.fixture(scope="module") +def ernie45_tokenizer(): + return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME) + + +# 带 ,非stream +WITH_THINK = { + "output": "abcdef", + "reasoning_content": "abc", + "content": "def", +} +# 带 ,stream +WITH_THINK_STREAM = { + "output": "abcdef", + "reasoning_content": "abc", + "content": "def", +} +# without , all is reasoning_content +WITHOUT_THINK = { + "output": "abc", + "reasoning_content": "abc", + "content": None, +} +# without , all is reasoning_content +WITHOUT_THINK_STREAM = { + "output": "abc", + "reasoning_content": "abc", + "content": None, +} + +COMPLETE_REASONING = { + "output": "abc", + "reasoning_content": "abc", + "content": None, +} +MULTILINE_REASONING = { + "output": "abc\nABCdef\nDEF", + "reasoning_content": "abc\nABC", + "content": "def\nDEF", +} + +TEST_CASES = [ + pytest.param( + False, + WITH_THINK, + id="with_think", + ), + pytest.param( + True, + WITH_THINK_STREAM, + id="with_think_stream", + ), + pytest.param( + False, + WITHOUT_THINK, + id="without_think", + ), + pytest.param( + True, + WITHOUT_THINK_STREAM, + id="without_think_stream", + ), + pytest.param( + False, + COMPLETE_REASONING, + id="complete_reasoning", + ), + pytest.param( + True, + COMPLETE_REASONING, + id="complete_reasoning_stream", + ), + pytest.param( + False, + MULTILINE_REASONING, + id="multiline_reasoning", + ), + pytest.param( + True, + MULTILINE_REASONING, + id="multiline_reasoning_stream", + ), +] + + +@pytest.mark.parametrize("streaming, param_dict", TEST_CASES) +def test_reasoning( + streaming: bool, + param_dict: dict, + ernie45_tokenizer, +): + output = ernie45_tokenizer.tokenize(param_dict["output"]) + output_tokens: list[str] = [] + for token in output: + one_token = ernie45_tokenizer.convert_tokens_to_string([token]) + if one_token: + output_tokens.append(one_token) + + parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)( + ernie45_tokenizer + ) + + reasoning, content = run_reasoning_extraction( + parser, output_tokens, streaming=streaming + ) + + print() + + assert reasoning == param_dict["reasoning_content"] + assert content == param_dict["content"] diff --git a/tests/tool_use/test_ernie45_moe_tool_parser.py b/tests/tool_use/test_ernie45_moe_tool_parser.py new file mode 100644 index 0000000000000..0862d14812d72 --- /dev/null +++ b/tests/tool_use/test_ernie45_moe_tool_parser.py @@ -0,0 +1,359 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# ruff: noqa: E501 + +import json +from collections.abc import Generator + +import pytest + +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaMessage, + FunctionCall, + ToolCall, +) +from vllm.entrypoints.openai.tool_parsers import Ernie45ToolParser +from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally +from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer + +# Use a common model that is likely to be available +MODEL = "baidu/ERNIE-4.5-21B-A3B-Thinking" + + +@pytest.fixture(scope="module") +def ernie45_tokenizer(): + return get_tokenizer(tokenizer_name=MODEL, trust_remote_code=True) + + +@pytest.fixture +def ernie45_tool_parser(ernie45_tokenizer): + return Ernie45ToolParser(ernie45_tokenizer) + + +def assert_tool_calls( + actual_tool_calls: list[ToolCall], expected_tool_calls: list[ToolCall] +): + assert len(actual_tool_calls) == len(expected_tool_calls) + + for actual_tool_call, expected_tool_call in zip( + actual_tool_calls, expected_tool_calls + ): + assert isinstance(actual_tool_call.id, str) + assert len(actual_tool_call.id) > 0 + + assert actual_tool_call.type == "function" + assert actual_tool_call.function.name == expected_tool_call.function.name + # Compare arguments as JSON objects to handle formatting differences + actual_args = json.loads(actual_tool_call.function.arguments) + expected_args = json.loads(expected_tool_call.function.arguments) + assert actual_args == expected_args + + +def test_extract_tool_calls_no_tools(ernie45_tool_parser): + model_output = "This is a test" + extracted_tool_calls = ernie45_tool_parser.extract_tool_calls( + model_output, request=None + ) # type: ignore[arg-type] + assert not extracted_tool_calls.tools_called + assert extracted_tool_calls.tool_calls == [] + assert extracted_tool_calls.content == model_output + + +@pytest.mark.parametrize( + ids=[ + "single_tool_call", + "multiple_tool_calls", + "tool_call_with_content_before", + ], + argnames=["model_output", "expected_tool_calls", "expected_content"], + argvalues=[ + ( + """ +{"name": "get_current_temperature", "arguments": {"location": "Beijing"}} + +""", + [ + ToolCall( + function=FunctionCall( + name="get_current_temperature", + arguments=json.dumps( + { + "location": "Beijing", + } + ), + ) + ) + ], + None, + ), + ( + """ +{"name": "get_current_temperature", "arguments": {"location": "Beijing"}} + + +{"name": "get_temperature_unit", "arguments": {"location": "Guangzhou", "unit": "c"}} + +""", + [ + ToolCall( + function=FunctionCall( + name="get_current_temperature", + arguments=json.dumps( + { + "location": "Beijing", + } + ), + ) + ), + ToolCall( + function=FunctionCall( + name="get_temperature_unit", + arguments=json.dumps( + { + "location": "Guangzhou", + "unit": "c", + } + ), + ) + ), + ], + None, + ), + ( + """I need to call two tools to handle these two issues separately. + + + +{"name": "get_current_temperature", "arguments": {"location": "Beijing"}} + + +{"name": "get_temperature_unit", "arguments": {"location": "Guangzhou", "unit": "c"}} + +""", + [ + ToolCall( + function=FunctionCall( + name="get_current_temperature", + arguments=json.dumps( + { + "location": "Beijing", + } + ), + ) + ), + ToolCall( + function=FunctionCall( + name="get_temperature_unit", + arguments=json.dumps( + { + "location": "Guangzhou", + "unit": "c", + } + ), + ) + ), + ], + "I need to call two tools to handle these two issues separately.\n", + ), + ], +) +def test_extract_tool_calls( + ernie45_tool_parser, model_output, expected_tool_calls, expected_content +): + extracted_tool_calls = ernie45_tool_parser.extract_tool_calls( + model_output, request=None + ) # type: ignore[arg-type] + assert extracted_tool_calls.tools_called + + assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls) + + assert extracted_tool_calls.content == expected_content + + +def stream_delta_message_generator( + ernie45_tool_parser: Ernie45ToolParser, + ernie45_tokenizer: AnyTokenizer, + model_output: str, + request: ChatCompletionRequest | None = None, +) -> Generator[DeltaMessage, None, None]: + all_token_ids = ernie45_tokenizer.encode(model_output, add_special_tokens=False) + + previous_text = "" + previous_tokens = None + prefix_offset = 0 + read_offset = 0 + for i, delta_token in enumerate(all_token_ids): + delta_token_ids = [delta_token] + previous_token_ids = all_token_ids[:i] + current_token_ids = all_token_ids[: i + 1] + + (new_tokens, delta_text, new_prefix_offset, new_read_offset) = ( + detokenize_incrementally( + tokenizer=ernie45_tokenizer, + all_input_ids=current_token_ids, + prev_tokens=previous_tokens, + prefix_offset=prefix_offset, + read_offset=read_offset, + skip_special_tokens=False, + spaces_between_special_tokens=True, + ) + ) + + current_text = previous_text + delta_text + + delta_message = ernie45_tool_parser.extract_tool_calls_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + delta_token_ids, + request=request, + ) + if delta_message: + yield delta_message + + previous_text = current_text + previous_tokens = ( + previous_tokens + new_tokens if previous_tokens else new_tokens + ) + prefix_offset = new_prefix_offset + read_offset = new_read_offset + + +@pytest.mark.parametrize( + ids=[ + "single_tool_call", + "multiple_tool_calls", + "tool_call_with_content_before", + ], + argnames=["model_output", "expected_tool_calls", "expected_content"], + argvalues=[ + ( + """ +{"name": "get_current_temperature", "arguments": {"location": "Beijing"}} + +""", + [ + ToolCall( + function=FunctionCall( + name="get_current_temperature", + arguments=json.dumps( + { + "location": "Beijing", + } + ), + ) + ) + ], + None, + ), + ( + """ +{"name": "get_current_temperature", "arguments": {"location": "Beijing"}} + + +{"name": "get_temperature_unit", "arguments": {"location": "Guangzhou", "unit": "c"}} + +""", + [ + ToolCall( + function=FunctionCall( + name="get_current_temperature", + arguments=json.dumps( + { + "location": "Beijing", + } + ), + ) + ), + ToolCall( + function=FunctionCall( + name="get_temperature_unit", + arguments=json.dumps( + { + "location": "Guangzhou", + "unit": "c", + } + ), + ) + ), + ], + None, + ), + ( + """I need to call two tools to handle these two issues separately. + + + +{"name": "get_current_temperature", "arguments": {"location": "Beijing"}} + + +{"name": "get_temperature_unit", "arguments": {"location": "Guangzhou", "unit": "c"}} + +""", + [ + ToolCall( + function=FunctionCall( + name="get_current_temperature", + arguments=json.dumps( + { + "location": "Beijing", + } + ), + ) + ), + ToolCall( + function=FunctionCall( + name="get_temperature_unit", + arguments=json.dumps( + { + "location": "Guangzhou", + "unit": "c", + } + ), + ) + ), + ], + "I need to call two tools to handle these two issues separately.\n", + ), + ], +) +def test_extract_tool_calls_streaming_incremental( + ernie45_tool_parser, + ernie45_tokenizer, + model_output, + expected_tool_calls, + expected_content, +): + """Verify the Ernie45 Parser streaming behavior by verifying each chunk is as expected.""" # noqa: E501 + request = ChatCompletionRequest(model=MODEL, messages=[], tools=[]) + + tool_calls_dict = {} + for delta_message in stream_delta_message_generator( + ernie45_tool_parser, ernie45_tokenizer, model_output, request + ): + if ( + delta_message.role is None + and delta_message.content is None + and delta_message.reasoning_content is None + and len(delta_message.tool_calls) == 0 + ): + continue + tool_calls = delta_message.tool_calls + for tool_call_chunk in tool_calls: + index = tool_call_chunk.index + if index not in tool_calls_dict: + if tool_call_chunk.function.arguments is None: + tool_call_chunk.function.arguments = "" + tool_calls_dict[index] = tool_call_chunk + else: + tool_calls_dict[ + index + ].function.arguments += tool_call_chunk.function.arguments + actual_tool_calls = list(tool_calls_dict.values()) + + assert len(actual_tool_calls) > 0 + # check tool call format + assert_tool_calls(actual_tool_calls, expected_tool_calls) diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py index 2c5a0a6af23f0..859da8392fc07 100644 --- a/vllm/entrypoints/openai/tool_parsers/__init__.py +++ b/vllm/entrypoints/openai/tool_parsers/__init__.py @@ -4,6 +4,7 @@ from .abstract_tool_parser import ToolParser, ToolParserManager from .deepseekv3_tool_parser import DeepSeekV3ToolParser from .deepseekv31_tool_parser import DeepSeekV31ToolParser +from .ernie45_tool_parser import Ernie45ToolParser from .glm4_moe_tool_parser import Glm4MoeModelToolParser from .granite_20b_fc_tool_parser import Granite20bFCToolParser from .granite_tool_parser import GraniteToolParser @@ -42,6 +43,7 @@ __all__ = [ "Phi4MiniJsonToolParser", "DeepSeekV3ToolParser", "DeepSeekV31ToolParser", + "Ernie45ToolParser", "xLAMToolParser", "MinimaxToolParser", "KimiK2ToolParser", diff --git a/vllm/entrypoints/openai/tool_parsers/ernie45_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/ernie45_tool_parser.py new file mode 100644 index 0000000000000..e4696334eb135 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/ernie45_tool_parser.py @@ -0,0 +1,212 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json +from collections.abc import Sequence + +import regex as re + +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, + ToolParserManager, +) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer + +logger = init_logger(__name__) + + +@ToolParserManager.register_module("ernie45") +class Ernie45ToolParser(ToolParser): + def __init__(self, tokenizer: AnyTokenizer): + """ + Ernie thinking model format: + abc\n\n\n\n\ndef\n\n + """ + super().__init__(tokenizer) + self.current_tool_name_sent = False + self.prev_tool_call_arr: list[dict] = [] + self.current_tool_id = -1 + self.streamed_args_for_tool: list[str] = [] + self.think_end_token = "" + self.response_start_token: str = "" + self.response_end_token: str = "" + self.tool_call_start_token = "" + self.tool_call_end_token = "" + self.tool_calls_start_token = self.tool_call_start_token + self.newline_token: str = "<0x0A>" + + self.tool_call_regex = re.compile( + r"\s*(?P\{.*?\})\s*", re.DOTALL + ) + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ToolParser " + "constructor during construction." + ) + + self.think_end_token_id = self.vocab.get(self.think_end_token) + self.response_start_token_id = self.vocab.get(self.response_start_token) + self.response_end_token_id = self.vocab.get(self.response_end_token) + self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token) + self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) + self.newline_token_id = self.vocab.get(self.newline_token) + self.parser_token_ids = [ + self.think_end_token_id, + self.response_start_token_id, + self.response_end_token_id, + ] + + self._buffer = "" + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + # sanity check; avoid unnecessary processing + if self.tool_calls_start_token not in model_output: + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) + + else: + try: + tool_call_json_list = self.tool_call_regex.findall(model_output) + + tool_calls = [] + for tool_call_json in tool_call_json_list: + tool_call_dict = json.loads(tool_call_json) + args_str = json.dumps( + tool_call_dict.get("arguments", {}), ensure_ascii=False + ) + tool_calls.append( + ToolCall( + type="function", + function=FunctionCall( + name=tool_call_dict.get("name", ""), + arguments=args_str, + ), + ) + ) + + content = model_output[ + : model_output.find(self.tool_calls_start_token) + ].rstrip("\n") + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=content if content else None, + ) + + except Exception: + logger.exception("Error in extracting tool call from response.") + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> DeltaMessage | None: + self._buffer += delta_text + cur_text = self._buffer + start_idx = cur_text.find(self.tool_call_start_token) + if start_idx == -1: + self._buffer = "" + # At least one toolcall has been completed + if self.current_tool_id > 0: + cur_text = "" + if self.current_tool_id == -1 and all( + token_id == self.newline_token_id for token_id in previous_token_ids + ): + cur_text = cur_text.strip("\n") + + # handle when tool_call is not triggered + # cur_text === delta_text + content = cur_text + if self.response_start_token_id in delta_token_ids: + content = content.lstrip("\n") + response_start_idx = content.find(self.response_start_token) + content = content[response_start_idx + len(self.response_start_token) :] + # if have , remove it + response_end_idx = content.rfind(self.response_end_token) + if response_end_idx != -1: + content = content[:response_end_idx] + elif self.response_end_token_id in delta_token_ids: + response_end_idx = content.rfind(self.response_end_token) + content = content[:response_end_idx] + # remove \n after or or + if ( + len(previous_token_ids) > 0 + and previous_token_ids[-1] in self.parser_token_ids + ) and ( + len(delta_token_ids) > 0 and delta_token_ids[0] == self.newline_token_id + ): + content = content.lstrip("\n") + + return DeltaMessage(content=content if content else None) + logger.debug("cur_text = %s", cur_text) + end_idx = cur_text.find(self.tool_call_end_token) + if end_idx != -1: + if self.current_tool_id == -1: + self.current_tool_id = 0 + self.prev_tool_call_arr = [] + self.streamed_args_for_tool = [] + while len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + while len(self.streamed_args_for_tool) <= self.current_tool_id: + self.streamed_args_for_tool.append("") + + extracted_tool_calls = self.extract_tool_calls( + cur_text[: end_idx + len(self.tool_call_end_token)], request + ) + + if len(extracted_tool_calls.tool_calls) == 0: + logger.warning("Failed to extract any tool calls.") + return None + tool_call = extracted_tool_calls.tool_calls[0] + self.prev_tool_call_arr[self.current_tool_id] = { + "name": tool_call.function.name, + "arguments": json.loads(tool_call.function.arguments), + } + self.streamed_args_for_tool[self.current_tool_id] = ( + tool_call.function.arguments + ) + delta = DeltaMessage( + content=extracted_tool_calls.content, + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + id=tool_call.id, + type=tool_call.type, + function=DeltaFunctionCall( + name=tool_call.function.name, + arguments=tool_call.function.arguments, + ), + ) + ], + ) + self.current_tool_id += 1 + self._buffer = cur_text[end_idx + len(self.tool_call_end_token) :] + return delta + + self._buffer = cur_text[start_idx:] + content = cur_text[:start_idx].rstrip("\n") + return DeltaMessage(content=content if content else None) diff --git a/vllm/reasoning/__init__.py b/vllm/reasoning/__init__.py index 78d3bf35f2a32..10c990f361324 100644 --- a/vllm/reasoning/__init__.py +++ b/vllm/reasoning/__init__.py @@ -4,6 +4,7 @@ from .abs_reasoning_parsers import ReasoningParser, ReasoningParserManager from .basic_parsers import BaseThinkingReasoningParser from .deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser +from .ernie45_reasoning_parser import Ernie45ReasoningParser from .glm4_moe_reasoning_parser import Glm4MoeModelReasoningParser from .gptoss_reasoning_parser import GptOssReasoningParser from .granite_reasoning_parser import GraniteReasoningParser @@ -19,6 +20,7 @@ __all__ = [ "BaseThinkingReasoningParser", "ReasoningParserManager", "DeepSeekR1ReasoningParser", + "Ernie45ReasoningParser", "GraniteReasoningParser", "HunyuanA13BReasoningParser", "Qwen3ReasoningParser", diff --git a/vllm/reasoning/ernie45_reasoning_parser.py b/vllm/reasoning/ernie45_reasoning_parser.py new file mode 100644 index 0000000000000..f9d4a30398cfd --- /dev/null +++ b/vllm/reasoning/ernie45_reasoning_parser.py @@ -0,0 +1,169 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence + +from transformers import PreTrainedTokenizerBase + +from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage +from vllm.logger import init_logger +from vllm.reasoning import ReasoningParserManager +from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser + +logger = init_logger(__name__) + + +@ReasoningParserManager.register_module("ernie45") +class Ernie45ReasoningParser(BaseThinkingReasoningParser): + """ + Reasoning parser for Ernie45 thinking model. + The Ernie45 thinking model ouput format is + abc\n\n\n\ndef\n\n + or abc\n\ndef + """ + + response_start_token: str = "" + response_end_token: str = "" + newline_token: str = "<0x0A>" + + @property + def start_token(self) -> str: + """The token that starts reasoning content.""" + return "" + + @property + def end_token(self) -> str: + """The token that ends reasoning content.""" + return "" + + def __init__(self, tokenizer: PreTrainedTokenizerBase): + super().__init__(tokenizer) + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ReasoningParser " + "constructor during construction." + ) + + self.start_token_id = self.vocab.get(self.start_token) + self.end_token_id = self.vocab.get(self.end_token) + self.response_start_token_id = self.vocab.get(self.response_start_token) + self.response_end_token_id = self.vocab.get(self.response_end_token) + self.newline_token_id = self.vocab.get(self.newline_token) + + self.parser_token_ids = [self.end_token_id, self.response_end_token_id] + + if self.start_token_id is None or self.end_token_id is None: + raise RuntimeError( + "Ernie45 reasoning parser could not locate think start/end " + "tokens in the tokenizer!" + ) + + def extract_reasoning_content_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + """ + Extract reasoning content from a delta message. + Handles streaming output where previous + delta = current. + Uses token IDs for faster processing. + The Ernie45 thinking model ouput format is + abc\n\n\n\ndef\n\n + or abc\n\ndef + - 'abc' goes to reasoning_content + - 'def' goes to content + """ + # Skip single special tokens + if len(delta_token_ids) == 1 and ( + delta_token_ids[0] + in [ + self.start_token_id, + self.end_token_id, + self.response_start_token_id, + self.response_end_token_id, + ] + ): + return None + + # No in previous or delta, also need to check for . + # Because the model may have generated without + if self.end_token_id in delta_token_ids: + # in delta with more tokens, + # extract reasoning content and content + think_end_index = delta_text.find(self.end_token) + reasoning_content = delta_text[:think_end_index] + content = delta_text[think_end_index + len(self.end_token) :] + content = content.lstrip("\n") + response_start_idx = content.find(self.response_start_token) + response_end_idx = content.rfind(self.response_end_token) + if response_start_idx != -1: + content = content[response_start_idx + len(self.response_start_token) :] + if response_end_idx != -1: + content = content[:response_end_idx] + return DeltaMessage( + reasoning_content=reasoning_content, + content=content if content else None, + ) + elif self.end_token_id in previous_token_ids: + # in previous, thinking content ends + content = delta_text + if self.response_start_token_id in delta_token_ids: + content = content.lstrip("\n") + response_start_idx = content.find(self.response_start_token) + content = content[response_start_idx + len(self.response_start_token) :] + # if have , remove it + response_end_idx = content.rfind(self.response_end_token) + if response_end_idx != -1: + content = content[:response_end_idx] + elif self.response_end_token_id in delta_token_ids: + response_end_idx = content.rfind(self.response_end_token) + content = content[:response_end_idx] + # remove \n after or + if previous_token_ids[-1] in self.parser_token_ids and ( + len(delta_token_ids) > 0 and delta_token_ids[0] == self.newline_token_id + ): + content = content.lstrip("\n") + # remove \n after \n + if ( + len(previous_token_ids) > 1 + and previous_token_ids[-2] == self.end_token_id + ) and ( + len(delta_token_ids) > 0 and delta_token_ids[0] == self.newline_token_id + ): + content = content.lstrip("\n") + + return DeltaMessage(content=content if content else None) + else: + # no in previous or delta, reasoning content continues + return DeltaMessage(reasoning_content=delta_text) + + def extract_reasoning_content( + self, model_output: str, request: ChatCompletionRequest + ) -> tuple[str | None, str | None]: + """ + Extract reasoning content from the model output. + The Ernie45 thinking model ouput format is + abc\n\n\n\n\ndef\n\n + or abc\n\ndef + - 'abc' goes to reasoning_content + - 'def' goes to content + Returns: + tuple[Optional[str], Optional[str]]: reasoning content and content + """ + reasoning_content, content = super().extract_reasoning_content( + model_output, request + ) + if content: + start_idx = content.find(self.response_start_token) + end_idx = content.rfind(self.response_end_token) + # Simultaneously existing and in the correct order + if start_idx != -1 and end_idx != -1 and start_idx < end_idx: + content = content[start_idx + len(self.response_start_token) : end_idx] + final_content = content or None + + return reasoning_content, final_content