mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-26 13:34:32 +08:00
Add Olmo 3 reasoning parser (#26054)
Signed-off-by: Luca Soldaini <luca@soldaini.net>
This commit is contained in:
parent
1838cd4860
commit
d0df145c2a
157
tests/reasoning/test_olmo3_reasoning_parser.py
Normal file
157
tests/reasoning/test_olmo3_reasoning_parser.py
Normal file
@ -0,0 +1,157 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from tests.reasoning.utils import run_reasoning_extraction
|
||||
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||
|
||||
parser_name = "olmo3"
|
||||
START_REASONING = "<think>"
|
||||
END_REASONING = "</think>"
|
||||
|
||||
NO_REASONING = {
|
||||
"output": f"{START_REASONING}{END_REASONING}No thoughts, head empty!",
|
||||
"reasoning_content": None,
|
||||
"content": "No thoughts, head empty!",
|
||||
}
|
||||
|
||||
NO_REASONING_WITH_NEWLINE = {
|
||||
"output":
|
||||
f"{START_REASONING}\n{END_REASONING}\n\nNo thoughts, head empty!",
|
||||
"reasoning_content": "\n",
|
||||
"content": "\n\nNo thoughts, head empty!",
|
||||
}
|
||||
|
||||
SIMPLE_REASONING = {
|
||||
"output":
|
||||
f"{START_REASONING}This is a reasoning section{END_REASONING}This is the rest", # noqa: E501
|
||||
"reasoning_content": "This is a reasoning section",
|
||||
"content": "This is the rest",
|
||||
}
|
||||
|
||||
SIMPLE_REASONING_WITH_NEWLINE = {
|
||||
"output":
|
||||
f"{START_REASONING} Look!\n\nI'm thinking...{END_REASONING}\nThis is the rest", # noqa: E501
|
||||
"reasoning_content": " Look!\n\nI'm thinking...",
|
||||
"content": "\nThis is the rest",
|
||||
}
|
||||
|
||||
SIMPLE_REASONING_WITH_MULTIPLE_NEWLINES = {
|
||||
"output":
|
||||
f"{START_REASONING}\nLook!\nI'm thinking...\n\n{END_REASONING}\n\n\nThis is the rest", # noqa: E501
|
||||
"reasoning_content": "\nLook!\nI'm thinking...\n\n",
|
||||
"content": "\n\n\nThis is the rest",
|
||||
}
|
||||
|
||||
NO_REASONING_ONLY_END_THINK = {
|
||||
"output": f"{END_REASONING}\n\nNo thoughts, head empty!",
|
||||
"reasoning_content": None,
|
||||
"content": "\n\nNo thoughts, head empty!",
|
||||
}
|
||||
|
||||
REASONING_ONLY_END_THINK = {
|
||||
"output":
|
||||
f"The user is asking me not to think.{END_REASONING}No thoughts!",
|
||||
"reasoning_content": "The user is asking me not to think.",
|
||||
"content": "No thoughts!",
|
||||
}
|
||||
|
||||
TEST_CASES = [
|
||||
pytest.param(
|
||||
False, # not streaming
|
||||
NO_REASONING,
|
||||
id="no_reasoning",
|
||||
),
|
||||
pytest.param(
|
||||
False, # not streaming
|
||||
NO_REASONING_WITH_NEWLINE,
|
||||
id="no_reasoning_with_newline",
|
||||
),
|
||||
pytest.param(
|
||||
False, # not streaming
|
||||
SIMPLE_REASONING,
|
||||
id="simple_reasoning",
|
||||
),
|
||||
pytest.param(
|
||||
False, # not streaming
|
||||
SIMPLE_REASONING_WITH_NEWLINE,
|
||||
id="simple_reasoning_with_newline",
|
||||
),
|
||||
pytest.param(
|
||||
True, # enable streaming
|
||||
SIMPLE_REASONING_WITH_MULTIPLE_NEWLINES,
|
||||
id="simple_reasoning_with_multiple_newlines",
|
||||
),
|
||||
pytest.param(
|
||||
False, # not streaming
|
||||
NO_REASONING_ONLY_END_THINK,
|
||||
id="no_reasoning_only_end_think",
|
||||
),
|
||||
pytest.param(
|
||||
False, # not streaming
|
||||
REASONING_ONLY_END_THINK,
|
||||
id="yes_reasoning_only_end_think",
|
||||
),
|
||||
pytest.param(
|
||||
True, # enable streaming
|
||||
NO_REASONING,
|
||||
id="no_reasoning_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
True, # enable streaming
|
||||
NO_REASONING_WITH_NEWLINE,
|
||||
id="no_reasoning_with_newline_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
True, # enable streaming
|
||||
SIMPLE_REASONING,
|
||||
id="simple_reasoning_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
True, # enable streaming
|
||||
SIMPLE_REASONING_WITH_NEWLINE,
|
||||
id="simple_reasoning_with_newline_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
True, # enable streaming
|
||||
SIMPLE_REASONING_WITH_MULTIPLE_NEWLINES,
|
||||
id="simple_reasoning_with_multiple_newlines_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
True, # enable streaming
|
||||
NO_REASONING_ONLY_END_THINK,
|
||||
id="no_reasoning_only_end_think_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
True, # enable streaming
|
||||
REASONING_ONLY_END_THINK,
|
||||
id="yes_reasoning_only_end_think_streaming",
|
||||
),
|
||||
]
|
||||
|
||||
# Global tokenizer initialization to avoid repeated loading
|
||||
tokenizer = AutoTokenizer.from_pretrained("allenai/dolma2-tokenizer")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming, param_dict", TEST_CASES)
|
||||
def test_reasoning(
|
||||
streaming: bool,
|
||||
param_dict: dict[str, str],
|
||||
):
|
||||
output = tokenizer.tokenize(param_dict["output"])
|
||||
|
||||
# decode everything to tokens
|
||||
model_output: list[str] = [
|
||||
tokenizer.convert_tokens_to_string([token]) for token in output
|
||||
]
|
||||
parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
|
||||
parser: ReasoningParser = parser_cls(tokenizer)
|
||||
|
||||
reasoning, content = run_reasoning_extraction(reasoning_parser=parser,
|
||||
model_output=model_output,
|
||||
streaming=streaming)
|
||||
|
||||
assert reasoning == param_dict["reasoning_content"]
|
||||
assert content == param_dict["content"]
|
||||
@ -9,6 +9,7 @@ from .gptoss_reasoning_parser import GptOssReasoningParser
|
||||
from .granite_reasoning_parser import GraniteReasoningParser
|
||||
from .hunyuan_a13b_reasoning_parser import HunyuanA13BReasoningParser
|
||||
from .mistral_reasoning_parser import MistralReasoningParser
|
||||
from .olmo3_reasoning_parser import Olmo3ReasoningParser
|
||||
from .qwen3_reasoning_parser import Qwen3ReasoningParser
|
||||
from .seedoss_reasoning_parser import SeedOSSReasoningParser
|
||||
from .step3_reasoning_parser import Step3ReasoningParser
|
||||
@ -23,6 +24,7 @@ __all__ = [
|
||||
"Qwen3ReasoningParser",
|
||||
"Glm4MoeModelReasoningParser",
|
||||
"MistralReasoningParser",
|
||||
"Olmo3ReasoningParser",
|
||||
"Step3ReasoningParser",
|
||||
"GptOssReasoningParser",
|
||||
"SeedOSSReasoningParser",
|
||||
|
||||
294
vllm/reasoning/olmo3_reasoning_parser.py
Normal file
294
vllm/reasoning/olmo3_reasoning_parser.py
Normal file
@ -0,0 +1,294 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import dataclasses as dt
|
||||
import enum
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
import regex as re
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaMessage, ResponsesRequest)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Olmo3ReasoningState(enum.Enum):
|
||||
REASONING = 1
|
||||
CONTENT = 2
|
||||
|
||||
|
||||
@dt.dataclass(frozen=True)
|
||||
class Indices:
|
||||
start: int
|
||||
end: int
|
||||
|
||||
def __len__(self):
|
||||
return self.end - self.start
|
||||
|
||||
|
||||
def string_overlap(a: str,
|
||||
b: str) -> tuple[Optional[Indices], Optional[Indices]]:
|
||||
"""
|
||||
Find the longest overlap where the end of string a matches the start
|
||||
of string b.
|
||||
|
||||
Args:
|
||||
a: First string
|
||||
b: Second string
|
||||
|
||||
Returns:
|
||||
Tuple of IndicesTuples representing the overlapping portions in each
|
||||
string, or a tuple of None if no overlap exists
|
||||
"""
|
||||
|
||||
# swap so a is always the shorter string
|
||||
a, b, swap = (a, b, False) if len(a) < len(b) else (b, a, True)
|
||||
|
||||
# first check: is a fully contained in b?
|
||||
if a in b:
|
||||
ind_a = Indices(0, len(a))
|
||||
ind_b = Indices(b.index(a), b.index(a) + len(a))
|
||||
return (ind_b, ind_a) if swap else (ind_a, ind_b)
|
||||
|
||||
# second check: does the end of a overlap with the
|
||||
# beginning of b?
|
||||
for i in range(len(a) - 1, 0, -1):
|
||||
if a[-i:] == b[:i]:
|
||||
ind_a = Indices(len(a) - i, len(a))
|
||||
ind_b = Indices(0, i)
|
||||
return (ind_b, ind_a) if swap else (ind_a, ind_b)
|
||||
|
||||
# third check: does the beginning of a overlap with
|
||||
# the end of b?
|
||||
for i in range(len(a) - 1, 0, -1):
|
||||
if b[-i:] == a[:i]:
|
||||
ind_a = Indices(0, i)
|
||||
ind_b = Indices(len(b) - i, len(b))
|
||||
return (ind_b, ind_a) if swap else (ind_a, ind_b)
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
@dt.dataclass
|
||||
class Olmo3ReasoningBuffer:
|
||||
think_start: str = "<think>"
|
||||
think_end: str = "</think>"
|
||||
buffer: str = ""
|
||||
|
||||
# we start in reasoning state to support cases where we hardcode
|
||||
# <think> as the start of the reasoning block.
|
||||
# In those cases, the only token we will see is </think>, which
|
||||
# is when we switch to content state.
|
||||
state: Olmo3ReasoningState = Olmo3ReasoningState.REASONING
|
||||
|
||||
def process_buffer(self) -> Optional[DeltaMessage]:
|
||||
start_think_idx = self.buffer.find(self.think_start)
|
||||
|
||||
if start_think_idx >= 0:
|
||||
self.state = Olmo3ReasoningState.REASONING
|
||||
pretext, self.buffer = (
|
||||
self.buffer[:start_think_idx],
|
||||
self.buffer[start_think_idx + len(self.think_start):],
|
||||
)
|
||||
if start_think_idx > 0:
|
||||
# this covers the case there's content before
|
||||
# the start of the reasoning block
|
||||
return DeltaMessage(content=pretext)
|
||||
|
||||
end_think_idx = self.buffer.rfind(self.think_end)
|
||||
|
||||
if end_think_idx >= 0:
|
||||
self.state = Olmo3ReasoningState.CONTENT
|
||||
pretext, self.buffer = (
|
||||
self.buffer[:end_think_idx],
|
||||
self.buffer[end_think_idx + len(self.think_end):],
|
||||
)
|
||||
if end_think_idx > 0:
|
||||
# this covers the case there's content before
|
||||
# the end of the reasoning block
|
||||
return DeltaMessage(reasoning_content=pretext)
|
||||
|
||||
if self.state == Olmo3ReasoningState.REASONING:
|
||||
# we are inside reasoning block, return and empty
|
||||
# the text buffer
|
||||
(
|
||||
text_buffer,
|
||||
self.buffer,
|
||||
) = self.buffer, ""
|
||||
return DeltaMessage(reasoning_content=text_buffer)
|
||||
|
||||
if self.state == Olmo3ReasoningState.CONTENT:
|
||||
# we are outside reasoning block, return and empty
|
||||
# the text buffer
|
||||
(
|
||||
text_buffer,
|
||||
self.buffer,
|
||||
) = self.buffer, ""
|
||||
return DeltaMessage(content=text_buffer)
|
||||
|
||||
# nothing to return unless we are in reasoning or content state
|
||||
return None
|
||||
|
||||
def __len__(self):
|
||||
# is the length of the text buffer
|
||||
return len(self.buffer)
|
||||
|
||||
def add_text(self, delta_text: str) -> Optional[DeltaMessage]:
|
||||
# we start by adding the delta text to the buffer
|
||||
self.buffer += delta_text
|
||||
|
||||
# setting this to empty before starting
|
||||
delta_message: Optional[DeltaMessage] = None
|
||||
|
||||
# we start by computing the overlap between the delta_text
|
||||
# and start/end of think tokens.
|
||||
_, overlap_think_start = string_overlap(delta_text, self.think_start)
|
||||
_, overlap_think_end = string_overlap(delta_text, self.think_end)
|
||||
|
||||
partial_overlap_start = overlap_think_start is not None and len(
|
||||
overlap_think_start) < len(self.think_start)
|
||||
partial_overlap_end = overlap_think_end is not None and len(
|
||||
overlap_think_end) < len(self.think_end)
|
||||
|
||||
if (partial_overlap_start and self.think_start in self.buffer
|
||||
and not partial_overlap_end):
|
||||
# we can only process the buffer if partial overlap
|
||||
# is the last part of think token (thus causing
|
||||
# text_buffer to contain the start of think token)
|
||||
# and there are no partial overlaps with end think
|
||||
delta_message = self.process_buffer()
|
||||
|
||||
elif partial_overlap_end and self.think_end in self.buffer:
|
||||
# same as before (partial overlap only allowed)
|
||||
# if the buffer contains the end think token,
|
||||
# but we don't have to check for partial overlap
|
||||
# with start think token because they are handled
|
||||
# by the previous condition
|
||||
delta_message = self.process_buffer()
|
||||
|
||||
elif partial_overlap_start or partial_overlap_end:
|
||||
# in general, if there are overlaps, we don't
|
||||
# process the buffer because we want to wait until
|
||||
# the think token is fully completed.
|
||||
return None
|
||||
else:
|
||||
# we process the buffer as normal
|
||||
delta_message = self.process_buffer()
|
||||
|
||||
return delta_message
|
||||
|
||||
|
||||
@ReasoningParserManager.register_module("olmo3")
|
||||
class Olmo3ReasoningParser(ReasoningParser):
|
||||
"""
|
||||
Reasoning parser for Olmo 3 model
|
||||
|
||||
Olmo3ReasoningParser
|
||||
|
||||
This class implements a reasoning parser specifically designed for the
|
||||
Olmo 3 family of models. Olmo 3 models do not use special tokens to
|
||||
indicate reasoning; rather, reasoning trace is wrapped in `<think>` and
|
||||
`</think>`, which are tokenized using standard vocabulary entries.
|
||||
Because of this, the parser operates in string space, accumulating the
|
||||
characters in a buffer until it sees `<think>` or `</think>`. tokens
|
||||
to switch modes.
|
||||
|
||||
Key Features:
|
||||
- For non-stream output, Recognizes and extracts reasoning (text
|
||||
bracketed by `<think>` and `</think>`) and content (everything
|
||||
after the first `</think>`).
|
||||
- For stream process, it uses a buffer to accumulate delta text,
|
||||
and output progressive delta messages as soon as thinking starts
|
||||
or ends.
|
||||
- For reliability, some Olmo 3 models may hardcode the first
|
||||
`<think>` token is the input text (similar to Deepseek R1,
|
||||
or reasoning-only Qwen models). To support such variants, the
|
||||
parser can optionally work in cases where the first `<think>`
|
||||
token is missing from generation.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: "AnyTokenizer", *args, **kwargs):
|
||||
super().__init__(tokenizer, *args, **kwargs)
|
||||
|
||||
self.think_start = r"<think>"
|
||||
self.think_end = r"</think>"
|
||||
|
||||
# notice that the first think is optional; this allows template to
|
||||
# work in cases when we hardcode a <think> at the beginning of the
|
||||
# reasoning template.
|
||||
reasoning_expr = (rf"^(?:{self.think_start})?(?P<reasoning>.*?)" +
|
||||
rf"{self.think_end}(?P<content>.*)$")
|
||||
self.reasoning_regex = re.compile(reasoning_expr, re.DOTALL)
|
||||
|
||||
self.buffer = Olmo3ReasoningBuffer(think_start=self.think_start,
|
||||
think_end=self.think_end)
|
||||
|
||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||
text = self.model_tokenizer.decode(input_ids)
|
||||
return self.think_end in text
|
||||
|
||||
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
|
||||
# for Olmo 3 streaming reason parsing, the stream parse
|
||||
# will call first, and the same token will be called in
|
||||
# is_reasoning_end and extract_content_ids
|
||||
# this id is not part of content, so just return [] here.
|
||||
return []
|
||||
|
||||
def extract_reasoning_content(
|
||||
self,
|
||||
model_output: str,
|
||||
request: Union[ChatCompletionRequest, ResponsesRequest],
|
||||
) -> tuple[Optional[str], Optional[str]]:
|
||||
"""Extract the reasoning content & content sections, respectively.
|
||||
If the sequence doesn't match what we expect, i.e., the model generates
|
||||
something else, all content is considered non-reasoning content.
|
||||
|
||||
Args:
|
||||
model_output (str): Output of the model to be parsed.
|
||||
request (ChatCompletionRequest | ResponsesRequest): Request being
|
||||
processed.
|
||||
|
||||
Returns:
|
||||
tuple[Optional[str], Optional[str]]: Tuple pair containing the
|
||||
reasoning content and non-reasoning content.
|
||||
"""
|
||||
|
||||
re_match = self.reasoning_regex.match(model_output)
|
||||
if re_match:
|
||||
reasoning_content = re_match.group("reasoning") or None
|
||||
content = re_match.group("content") or None
|
||||
return reasoning_content, content
|
||||
|
||||
# no reasoning content
|
||||
return None, model_output
|
||||
|
||||
def extract_reasoning_content_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
) -> Union[DeltaMessage, None]:
|
||||
"""Extract content using token ID sequence state machine"""
|
||||
|
||||
delta_message = self.buffer.add_text(delta_text)
|
||||
if (delta_message is None
|
||||
and self.buffer.think_end in self.buffer.buffer):
|
||||
# this is a bit hacky, but, because of how the buffer is
|
||||
# constructed, if the last delta_text contains characters that
|
||||
# marks the end of thinking tokens, then messages in the buffer
|
||||
# would never be processed because we get no other turn. To get
|
||||
# around that, we check if the text buffer contains the end of
|
||||
# thinking tokens, and, if so, we reprocess the buffer again.
|
||||
delta_message = self.buffer.process_buffer()
|
||||
|
||||
return delta_message
|
||||
Loading…
x
Reference in New Issue
Block a user