diff --git a/tests/reasoning/test_olmo3_reasoning_parser.py b/tests/reasoning/test_olmo3_reasoning_parser.py new file mode 100644 index 0000000000000..067be1ea29966 --- /dev/null +++ b/tests/reasoning/test_olmo3_reasoning_parser.py @@ -0,0 +1,157 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +from transformers import AutoTokenizer + +from tests.reasoning.utils import run_reasoning_extraction +from vllm.reasoning import ReasoningParser, ReasoningParserManager + +parser_name = "olmo3" +START_REASONING = "" +END_REASONING = "" + +NO_REASONING = { + "output": f"{START_REASONING}{END_REASONING}No thoughts, head empty!", + "reasoning_content": None, + "content": "No thoughts, head empty!", +} + +NO_REASONING_WITH_NEWLINE = { + "output": + f"{START_REASONING}\n{END_REASONING}\n\nNo thoughts, head empty!", + "reasoning_content": "\n", + "content": "\n\nNo thoughts, head empty!", +} + +SIMPLE_REASONING = { + "output": + f"{START_REASONING}This is a reasoning section{END_REASONING}This is the rest", # noqa: E501 + "reasoning_content": "This is a reasoning section", + "content": "This is the rest", +} + +SIMPLE_REASONING_WITH_NEWLINE = { + "output": + f"{START_REASONING} Look!\n\nI'm thinking...{END_REASONING}\nThis is the rest", # noqa: E501 + "reasoning_content": " Look!\n\nI'm thinking...", + "content": "\nThis is the rest", +} + +SIMPLE_REASONING_WITH_MULTIPLE_NEWLINES = { + "output": + f"{START_REASONING}\nLook!\nI'm thinking...\n\n{END_REASONING}\n\n\nThis is the rest", # noqa: E501 + "reasoning_content": "\nLook!\nI'm thinking...\n\n", + "content": "\n\n\nThis is the rest", +} + +NO_REASONING_ONLY_END_THINK = { + "output": f"{END_REASONING}\n\nNo thoughts, head empty!", + "reasoning_content": None, + "content": "\n\nNo thoughts, head empty!", +} + +REASONING_ONLY_END_THINK = { + "output": + f"The user is asking me not to think.{END_REASONING}No thoughts!", + "reasoning_content": "The user is asking me not to think.", + "content": "No thoughts!", +} + +TEST_CASES = [ + pytest.param( + False, # not streaming + NO_REASONING, + id="no_reasoning", + ), + pytest.param( + False, # not streaming + NO_REASONING_WITH_NEWLINE, + id="no_reasoning_with_newline", + ), + pytest.param( + False, # not streaming + SIMPLE_REASONING, + id="simple_reasoning", + ), + pytest.param( + False, # not streaming + SIMPLE_REASONING_WITH_NEWLINE, + id="simple_reasoning_with_newline", + ), + pytest.param( + True, # enable streaming + SIMPLE_REASONING_WITH_MULTIPLE_NEWLINES, + id="simple_reasoning_with_multiple_newlines", + ), + pytest.param( + False, # not streaming + NO_REASONING_ONLY_END_THINK, + id="no_reasoning_only_end_think", + ), + pytest.param( + False, # not streaming + REASONING_ONLY_END_THINK, + id="yes_reasoning_only_end_think", + ), + pytest.param( + True, # enable streaming + NO_REASONING, + id="no_reasoning_streaming", + ), + pytest.param( + True, # enable streaming + NO_REASONING_WITH_NEWLINE, + id="no_reasoning_with_newline_streaming", + ), + pytest.param( + True, # enable streaming + SIMPLE_REASONING, + id="simple_reasoning_streaming", + ), + pytest.param( + True, # enable streaming + SIMPLE_REASONING_WITH_NEWLINE, + id="simple_reasoning_with_newline_streaming", + ), + pytest.param( + True, # enable streaming + SIMPLE_REASONING_WITH_MULTIPLE_NEWLINES, + id="simple_reasoning_with_multiple_newlines_streaming", + ), + pytest.param( + True, # enable streaming + NO_REASONING_ONLY_END_THINK, + id="no_reasoning_only_end_think_streaming", + ), + pytest.param( + True, # enable streaming + REASONING_ONLY_END_THINK, + id="yes_reasoning_only_end_think_streaming", + ), +] + +# Global tokenizer initialization to avoid repeated loading +tokenizer = AutoTokenizer.from_pretrained("allenai/dolma2-tokenizer") + + +@pytest.mark.parametrize("streaming, param_dict", TEST_CASES) +def test_reasoning( + streaming: bool, + param_dict: dict[str, str], +): + output = tokenizer.tokenize(param_dict["output"]) + + # decode everything to tokens + model_output: list[str] = [ + tokenizer.convert_tokens_to_string([token]) for token in output + ] + parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name) + parser: ReasoningParser = parser_cls(tokenizer) + + reasoning, content = run_reasoning_extraction(reasoning_parser=parser, + model_output=model_output, + streaming=streaming) + + assert reasoning == param_dict["reasoning_content"] + assert content == param_dict["content"] diff --git a/vllm/reasoning/__init__.py b/vllm/reasoning/__init__.py index 3c8a9c6ae0d33..78d3bf35f2a32 100644 --- a/vllm/reasoning/__init__.py +++ b/vllm/reasoning/__init__.py @@ -9,6 +9,7 @@ from .gptoss_reasoning_parser import GptOssReasoningParser from .granite_reasoning_parser import GraniteReasoningParser from .hunyuan_a13b_reasoning_parser import HunyuanA13BReasoningParser from .mistral_reasoning_parser import MistralReasoningParser +from .olmo3_reasoning_parser import Olmo3ReasoningParser from .qwen3_reasoning_parser import Qwen3ReasoningParser from .seedoss_reasoning_parser import SeedOSSReasoningParser from .step3_reasoning_parser import Step3ReasoningParser @@ -23,6 +24,7 @@ __all__ = [ "Qwen3ReasoningParser", "Glm4MoeModelReasoningParser", "MistralReasoningParser", + "Olmo3ReasoningParser", "Step3ReasoningParser", "GptOssReasoningParser", "SeedOSSReasoningParser", diff --git a/vllm/reasoning/olmo3_reasoning_parser.py b/vllm/reasoning/olmo3_reasoning_parser.py new file mode 100644 index 0000000000000..e37b6d887515f --- /dev/null +++ b/vllm/reasoning/olmo3_reasoning_parser.py @@ -0,0 +1,294 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import dataclasses as dt +import enum +from collections.abc import Sequence +from typing import TYPE_CHECKING, Optional, Union + +import regex as re + +if TYPE_CHECKING: + from vllm.transformers_utils.tokenizer import AnyTokenizer + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaMessage, ResponsesRequest) +from vllm.logger import init_logger +from vllm.reasoning import ReasoningParser, ReasoningParserManager + +logger = init_logger(__name__) + + +class Olmo3ReasoningState(enum.Enum): + REASONING = 1 + CONTENT = 2 + + +@dt.dataclass(frozen=True) +class Indices: + start: int + end: int + + def __len__(self): + return self.end - self.start + + +def string_overlap(a: str, + b: str) -> tuple[Optional[Indices], Optional[Indices]]: + """ + Find the longest overlap where the end of string a matches the start + of string b. + + Args: + a: First string + b: Second string + + Returns: + Tuple of IndicesTuples representing the overlapping portions in each + string, or a tuple of None if no overlap exists + """ + + # swap so a is always the shorter string + a, b, swap = (a, b, False) if len(a) < len(b) else (b, a, True) + + # first check: is a fully contained in b? + if a in b: + ind_a = Indices(0, len(a)) + ind_b = Indices(b.index(a), b.index(a) + len(a)) + return (ind_b, ind_a) if swap else (ind_a, ind_b) + + # second check: does the end of a overlap with the + # beginning of b? + for i in range(len(a) - 1, 0, -1): + if a[-i:] == b[:i]: + ind_a = Indices(len(a) - i, len(a)) + ind_b = Indices(0, i) + return (ind_b, ind_a) if swap else (ind_a, ind_b) + + # third check: does the beginning of a overlap with + # the end of b? + for i in range(len(a) - 1, 0, -1): + if b[-i:] == a[:i]: + ind_a = Indices(0, i) + ind_b = Indices(len(b) - i, len(b)) + return (ind_b, ind_a) if swap else (ind_a, ind_b) + + return None, None + + +@dt.dataclass +class Olmo3ReasoningBuffer: + think_start: str = "" + think_end: str = "" + buffer: str = "" + + # we start in reasoning state to support cases where we hardcode + # as the start of the reasoning block. + # In those cases, the only token we will see is , which + # is when we switch to content state. + state: Olmo3ReasoningState = Olmo3ReasoningState.REASONING + + def process_buffer(self) -> Optional[DeltaMessage]: + start_think_idx = self.buffer.find(self.think_start) + + if start_think_idx >= 0: + self.state = Olmo3ReasoningState.REASONING + pretext, self.buffer = ( + self.buffer[:start_think_idx], + self.buffer[start_think_idx + len(self.think_start):], + ) + if start_think_idx > 0: + # this covers the case there's content before + # the start of the reasoning block + return DeltaMessage(content=pretext) + + end_think_idx = self.buffer.rfind(self.think_end) + + if end_think_idx >= 0: + self.state = Olmo3ReasoningState.CONTENT + pretext, self.buffer = ( + self.buffer[:end_think_idx], + self.buffer[end_think_idx + len(self.think_end):], + ) + if end_think_idx > 0: + # this covers the case there's content before + # the end of the reasoning block + return DeltaMessage(reasoning_content=pretext) + + if self.state == Olmo3ReasoningState.REASONING: + # we are inside reasoning block, return and empty + # the text buffer + ( + text_buffer, + self.buffer, + ) = self.buffer, "" + return DeltaMessage(reasoning_content=text_buffer) + + if self.state == Olmo3ReasoningState.CONTENT: + # we are outside reasoning block, return and empty + # the text buffer + ( + text_buffer, + self.buffer, + ) = self.buffer, "" + return DeltaMessage(content=text_buffer) + + # nothing to return unless we are in reasoning or content state + return None + + def __len__(self): + # is the length of the text buffer + return len(self.buffer) + + def add_text(self, delta_text: str) -> Optional[DeltaMessage]: + # we start by adding the delta text to the buffer + self.buffer += delta_text + + # setting this to empty before starting + delta_message: Optional[DeltaMessage] = None + + # we start by computing the overlap between the delta_text + # and start/end of think tokens. + _, overlap_think_start = string_overlap(delta_text, self.think_start) + _, overlap_think_end = string_overlap(delta_text, self.think_end) + + partial_overlap_start = overlap_think_start is not None and len( + overlap_think_start) < len(self.think_start) + partial_overlap_end = overlap_think_end is not None and len( + overlap_think_end) < len(self.think_end) + + if (partial_overlap_start and self.think_start in self.buffer + and not partial_overlap_end): + # we can only process the buffer if partial overlap + # is the last part of think token (thus causing + # text_buffer to contain the start of think token) + # and there are no partial overlaps with end think + delta_message = self.process_buffer() + + elif partial_overlap_end and self.think_end in self.buffer: + # same as before (partial overlap only allowed) + # if the buffer contains the end think token, + # but we don't have to check for partial overlap + # with start think token because they are handled + # by the previous condition + delta_message = self.process_buffer() + + elif partial_overlap_start or partial_overlap_end: + # in general, if there are overlaps, we don't + # process the buffer because we want to wait until + # the think token is fully completed. + return None + else: + # we process the buffer as normal + delta_message = self.process_buffer() + + return delta_message + + +@ReasoningParserManager.register_module("olmo3") +class Olmo3ReasoningParser(ReasoningParser): + """ + Reasoning parser for Olmo 3 model + + Olmo3ReasoningParser + + This class implements a reasoning parser specifically designed for the + Olmo 3 family of models. Olmo 3 models do not use special tokens to + indicate reasoning; rather, reasoning trace is wrapped in `` and + ``, which are tokenized using standard vocabulary entries. + Because of this, the parser operates in string space, accumulating the + characters in a buffer until it sees `` or ``. tokens + to switch modes. + + Key Features: + - For non-stream output, Recognizes and extracts reasoning (text + bracketed by `` and ``) and content (everything + after the first ``). + - For stream process, it uses a buffer to accumulate delta text, + and output progressive delta messages as soon as thinking starts + or ends. + - For reliability, some Olmo 3 models may hardcode the first + `` token is the input text (similar to Deepseek R1, + or reasoning-only Qwen models). To support such variants, the + parser can optionally work in cases where the first `` + token is missing from generation. + """ + + def __init__(self, tokenizer: "AnyTokenizer", *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + + self.think_start = r"" + self.think_end = r"" + + # notice that the first think is optional; this allows template to + # work in cases when we hardcode a at the beginning of the + # reasoning template. + reasoning_expr = (rf"^(?:{self.think_start})?(?P.*?)" + + rf"{self.think_end}(?P.*)$") + self.reasoning_regex = re.compile(reasoning_expr, re.DOTALL) + + self.buffer = Olmo3ReasoningBuffer(think_start=self.think_start, + think_end=self.think_end) + + def is_reasoning_end(self, input_ids: list[int]) -> bool: + text = self.model_tokenizer.decode(input_ids) + return self.think_end in text + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + # for Olmo 3 streaming reason parsing, the stream parse + # will call first, and the same token will be called in + # is_reasoning_end and extract_content_ids + # this id is not part of content, so just return [] here. + return [] + + def extract_reasoning_content( + self, + model_output: str, + request: Union[ChatCompletionRequest, ResponsesRequest], + ) -> tuple[Optional[str], Optional[str]]: + """Extract the reasoning content & content sections, respectively. + If the sequence doesn't match what we expect, i.e., the model generates + something else, all content is considered non-reasoning content. + + Args: + model_output (str): Output of the model to be parsed. + request (ChatCompletionRequest | ResponsesRequest): Request being + processed. + + Returns: + tuple[Optional[str], Optional[str]]: Tuple pair containing the + reasoning content and non-reasoning content. + """ + + re_match = self.reasoning_regex.match(model_output) + if re_match: + reasoning_content = re_match.group("reasoning") or None + content = re_match.group("content") or None + return reasoning_content, content + + # no reasoning content + return None, model_output + + def extract_reasoning_content_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> Union[DeltaMessage, None]: + """Extract content using token ID sequence state machine""" + + delta_message = self.buffer.add_text(delta_text) + if (delta_message is None + and self.buffer.think_end in self.buffer.buffer): + # this is a bit hacky, but, because of how the buffer is + # constructed, if the last delta_text contains characters that + # marks the end of thinking tokens, then messages in the buffer + # would never be processed because we get no other turn. To get + # around that, we check if the text buffer contains the end of + # thinking tokens, and, if so, we reprocess the buffer again. + delta_message = self.buffer.process_buffer() + + return delta_message