diff --git a/tests/reasoning/test_base_thinking_reasoning_parser.py b/tests/reasoning/test_base_thinking_reasoning_parser.py
new file mode 100644
index 000000000000..6a939dcfc2c9
--- /dev/null
+++ b/tests/reasoning/test_base_thinking_reasoning_parser.py
@@ -0,0 +1,392 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import pytest
+from transformers import AutoTokenizer
+
+from tests.reasoning.utils import run_reasoning_extraction
+from vllm.entrypoints.openai.protocol import ChatCompletionRequest
+from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser
+
+
+# Create a concrete test implementation of BaseThinkingReasoningParser
+class TestThinkingReasoningParser(BaseThinkingReasoningParser):
+ """Test implementation of BaseThinkingReasoningParser."""
+
+ @property
+ def start_token(self) -> str:
+ return ""
+
+ @property
+ def end_token(self) -> str:
+ return ""
+
+
+class TestThinkingReasoningParserAlt(BaseThinkingReasoningParser):
+ """Alternative test implementation with different tokens."""
+
+ @property
+ def start_token(self) -> str:
+ return ""
+
+ @property
+ def end_token(self) -> str:
+ return ""
+
+
+# Use a test model
+REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
+
+
+@pytest.fixture(scope="module")
+def test_tokenizer():
+ tokenizer = AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)
+ # Add custom test tokens
+ test_tokens = ["", "", "", ""]
+ existing_tokens = set(tokenizer.get_vocab().keys())
+ new_tokens = [
+ token for token in test_tokens if token not in existing_tokens
+ ]
+ if new_tokens:
+ tokenizer.add_tokens(new_tokens)
+ return tokenizer
+
+
+class TestBaseThinkingReasoningParserInit:
+ """
+ Test initialization and basic properties of
+ BaseThinkingReasoningParser.
+ """
+
+ def test_successful_initialization(self, test_tokenizer):
+ """Test successful initialization with valid tokens."""
+ parser = TestThinkingReasoningParser(test_tokenizer)
+ assert parser.start_token == ""
+ assert parser.end_token == ""
+ assert parser.start_token_id is not None
+ assert parser.end_token_id is not None
+
+ def test_initialization_with_missing_tokenizer(self):
+ """Test that initialization fails without tokenizer."""
+ with pytest.raises(ValueError, match="model tokenizer must be passed"):
+ TestThinkingReasoningParser(None)
+
+ def test_initialization_with_missing_tokens(self, test_tokenizer):
+ """Test that initialization fails when tokens are not in vocabulary."""
+
+ # Create a parser with tokens not in vocabulary
+ class MissingTokenParser(BaseThinkingReasoningParser):
+
+ @property
+ def start_token(self) -> str:
+ return ""
+
+ @property
+ def end_token(self) -> str:
+ return ""
+
+ with pytest.raises(RuntimeError,
+ match="could not locate think start/end tokens"):
+ MissingTokenParser(test_tokenizer)
+
+ def test_initialization_with_empty_tokens(self, test_tokenizer):
+ """Test that initialization fails with empty token strings."""
+
+ class EmptyTokenParser(BaseThinkingReasoningParser):
+
+ @property
+ def start_token(self) -> str:
+ return ""
+
+ @property
+ def end_token(self) -> str:
+ return ""
+
+ with pytest.raises(ValueError,
+ match="start_token and end_token must be defined"):
+ EmptyTokenParser(test_tokenizer)
+
+
+class TestBaseThinkingReasoningParserMethods:
+ """Test the methods of BaseThinkingReasoningParser."""
+
+ def test_is_reasoning_end(self, test_tokenizer):
+ """Test the is_reasoning_end method."""
+ parser = TestThinkingReasoningParser(test_tokenizer)
+ end_token_id = parser.end_token_id
+
+ # Test with end token present
+ assert parser.is_reasoning_end([1, 2, end_token_id, 4]) is True
+
+ # Test without end token
+ assert parser.is_reasoning_end([1, 2, 3, 4]) is False
+
+ # Test with empty list
+ assert parser.is_reasoning_end([]) is False
+
+ def test_extract_content_ids(self, test_tokenizer):
+ """Test the extract_content_ids method."""
+ parser = TestThinkingReasoningParser(test_tokenizer)
+ end_token_id = parser.end_token_id
+
+ # Test with end token in the middle
+ input_ids = [1, 2, end_token_id, 4, 5]
+ content_ids = parser.extract_content_ids(input_ids)
+ assert content_ids == [4, 5]
+
+ # Test with end token at the end
+ input_ids = [1, 2, 3, end_token_id]
+ content_ids = parser.extract_content_ids(input_ids)
+ assert content_ids == []
+
+ # Test without end token
+ input_ids = [1, 2, 3, 4]
+ content_ids = parser.extract_content_ids(input_ids)
+ assert content_ids == []
+
+ # Test with end token as last element (should not extract)
+ input_ids = [1, 2, 3, end_token_id]
+ content_ids = parser.extract_content_ids(input_ids)
+ assert content_ids == []
+
+
+class TestBaseThinkingReasoningParserExtraction:
+ """Test reasoning content extraction methods."""
+
+ def test_extract_reasoning_content_with_both_tokens(self, test_tokenizer):
+ """Test extraction when both start and end tokens are present."""
+ parser = TestThinkingReasoningParser(test_tokenizer)
+ request = ChatCompletionRequest(messages=[], model="test-model")
+
+ model_output = ("This is reasoning"
+ "This is content")
+ reasoning, content = parser.extract_reasoning_content(
+ model_output, request)
+
+ assert reasoning == "This is reasoning"
+ assert content == "This is content"
+
+ def test_extract_reasoning_content_only_end_token(self, test_tokenizer):
+ """Test extraction when only end token is present."""
+ parser = TestThinkingReasoningParser(test_tokenizer)
+ request = ChatCompletionRequest(messages=[], model="test-model")
+
+ model_output = ("This is reasoningThis is content")
+ reasoning, content = parser.extract_reasoning_content(
+ model_output, request)
+
+ assert reasoning == "This is reasoning"
+ assert content == "This is content"
+
+ def test_extract_reasoning_content_no_end_token(self, test_tokenizer):
+ """Test extraction when no end token is present."""
+ parser = TestThinkingReasoningParser(test_tokenizer)
+ request = ChatCompletionRequest(messages=[], model="test-model")
+
+ model_output = "This is just content"
+ reasoning, content = parser.extract_reasoning_content(
+ model_output, request)
+
+ assert reasoning == "This is just content"
+ assert content is None
+
+ def test_extract_reasoning_content_empty_output(self, test_tokenizer):
+ """Test extraction with empty output."""
+ parser = TestThinkingReasoningParser(test_tokenizer)
+ request = ChatCompletionRequest(messages=[], model="test-model")
+
+ model_output = ""
+ reasoning, content = parser.extract_reasoning_content(
+ model_output, request)
+
+ assert reasoning == ""
+ assert content is None
+
+ def test_extract_reasoning_content_only_tokens(self, test_tokenizer):
+ """Test extraction with only tokens and no content."""
+ parser = TestThinkingReasoningParser(test_tokenizer)
+ request = ChatCompletionRequest(messages=[], model="test-model")
+
+ model_output = ("")
+ reasoning, content = parser.extract_reasoning_content(
+ model_output, request)
+
+ assert reasoning == ""
+ assert content is None
+
+
+class TestBaseThinkingReasoningParserStreaming:
+ """Test streaming functionality of BaseThinkingReasoningParser."""
+
+ @pytest.mark.parametrize("streaming", [True, False])
+ def test_simple_reasoning_extraction(self, test_tokenizer, streaming):
+ """
+ Test basic reasoning extraction in both
+ streaming and non-streaming modes.
+ """
+ parser = TestThinkingReasoningParser(test_tokenizer)
+
+ model_output = [
+ "", "Some ", "reasoning ", "content", "",
+ "Final ", "answer"
+ ]
+
+ reasoning, content = run_reasoning_extraction(parser,
+ model_output,
+ streaming=streaming)
+
+ assert reasoning == "Some reasoning content"
+ assert content == "Final answer"
+
+ def test_streaming_with_incremental_deltas(self, test_tokenizer):
+ """Test streaming processing with small incremental deltas."""
+ parser = TestThinkingReasoningParser(test_tokenizer)
+
+ deltas = [
+ "",
+ "Some ",
+ "reasoning ",
+ "content",
+ "",
+ "Final ",
+ "answer",
+ ]
+
+ reasoning, content = run_reasoning_extraction(parser,
+ deltas,
+ streaming=True)
+
+ assert reasoning == "Some reasoning content"
+ assert content == "Final answer"
+
+ def test_streaming_with_start_token(self, test_tokenizer):
+ """Test streaming with start token included."""
+ parser = TestThinkingReasoningParser(test_tokenizer)
+
+ deltas = [
+ "",
+ "Some ",
+ "reasoning",
+ "",
+ "Answer",
+ ]
+
+ reasoning, content = run_reasoning_extraction(parser,
+ deltas,
+ streaming=True)
+
+ assert reasoning == "Some reasoning"
+ assert content == "Answer"
+
+ def test_streaming_no_end_token(self, test_tokenizer):
+ """Test streaming when no end token is encountered."""
+ parser = TestThinkingReasoningParser(test_tokenizer)
+
+ deltas = [
+ "",
+ "Some ",
+ "reasoning ",
+ "without ",
+ "end",
+ ]
+
+ reasoning, content = run_reasoning_extraction(parser,
+ deltas,
+ streaming=True)
+
+ assert reasoning == "Some reasoning without end"
+ assert content is None
+
+ def test_streaming_only_end_token(self, test_tokenizer):
+ """Test streaming when only end token appears."""
+ parser = TestThinkingReasoningParser(test_tokenizer)
+
+ deltas = [
+ "",
+ "Reasoning ",
+ "content",
+ "",
+ "Final",
+ ]
+
+ reasoning, content = run_reasoning_extraction(parser,
+ deltas,
+ streaming=True)
+
+ assert reasoning == "Reasoning content"
+ assert content == "Final"
+
+
+class TestBaseThinkingReasoningParserMultipleImplementations:
+ """
+ Test that multiple implementations of
+ BaseThinkingReasoningParser work correctly.
+ """
+
+ def test_different_token_implementations(self, test_tokenizer):
+ """
+ Test that different implementations
+ with different tokens work independently.
+ """
+ parser1 = TestThinkingReasoningParser(test_tokenizer)
+ parser2 = TestThinkingReasoningParserAlt(test_tokenizer)
+
+ # Test parser1
+ model_output1 = ("Reasoning1Content1")
+ reasoning1, content1 = run_reasoning_extraction(
+ parser1, [model_output1])
+ assert reasoning1 == "Reasoning1"
+ assert content1 == "Content1"
+
+ # Test parser2
+ model_output2 = "Reasoning2Content2"
+ reasoning2, content2 = run_reasoning_extraction(
+ parser2, [model_output2])
+ assert reasoning2 == "Reasoning2"
+ assert content2 == "Content2"
+
+ # Verify tokens are different
+ assert parser1.start_token != parser2.start_token
+ assert parser1.end_token != parser2.end_token
+ assert parser1.start_token_id != parser2.start_token_id
+ assert parser1.end_token_id != parser2.end_token_id
+
+
+class TestBaseThinkingReasoningParserEdgeCases:
+ """Test edge cases and error conditions."""
+
+ def test_multiple_end_tokens(self, test_tokenizer):
+ """Test behavior with multiple end tokens."""
+ parser = TestThinkingReasoningParser(test_tokenizer)
+
+ model_output = ("FirstMiddleLast")
+ reasoning, content = run_reasoning_extraction(parser, [model_output])
+
+ # Should stop at first end token
+ assert reasoning == "First"
+ assert content == "MiddleLast"
+
+ def test_nested_tokens(self, test_tokenizer):
+ """Test behavior with nested-like token patterns."""
+ parser = TestThinkingReasoningParser(test_tokenizer)
+
+ model_output = ("Outer"
+ "InnerContent")
+ reasoning, content = run_reasoning_extraction(parser, [model_output])
+
+ # Should process normally, start from first start token
+ assert reasoning == "OuterInner"
+ assert content == "Content"
+
+ def test_malformed_tokens(self, test_tokenizer):
+ """Test behavior with malformed token-like strings."""
+ parser = TestThinkingReasoningParser(test_tokenizer)
+
+ model_output = ("Not a real token"
+ "Content")
+ reasoning, content = run_reasoning_extraction(parser, [model_output])
+
+ # Should treat as regular content since tokens don't match exactly
+ assert reasoning == ("Not a real token"
+ "Content")
+ assert content is None
diff --git a/tests/reasoning/test_seedoss_reasoning_parser.py b/tests/reasoning/test_seedoss_reasoning_parser.py
new file mode 100644
index 000000000000..bb5dc0f4ffe4
--- /dev/null
+++ b/tests/reasoning/test_seedoss_reasoning_parser.py
@@ -0,0 +1,237 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from typing import Any, cast
+
+import pytest
+from transformers import AutoTokenizer
+
+from tests.reasoning.utils import run_reasoning_extraction
+from vllm.reasoning import ReasoningParser, ReasoningParserManager
+
+parser_name = "seed_oss"
+start_token = ""
+end_token = ""
+
+# Use a test model that contains our custom tokens
+REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
+
+
+@pytest.fixture(scope="module")
+def seedoss_tokenizer():
+ tokenizer = AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)
+ # Add custom SeedOSS tokens if they don't exist
+ if start_token not in tokenizer.get_vocab():
+ tokenizer.add_tokens([start_token, end_token])
+ return tokenizer
+
+
+SIMPLE_REASONING: dict[str, Any] = {
+ "output": "This is a reasoning sectionThis is the rest",
+ "reasoning_content": "This is a reasoning section",
+ "content": "This is the rest",
+ "is_reasoning_end": True,
+}
+COMPLETE_REASONING: dict[str, Any] = {
+ "output": "This is a reasoning section",
+ "reasoning_content": "This is a reasoning section",
+ "content": None,
+ "is_reasoning_end": True,
+}
+NO_CONTENT: dict[str, Any] = {
+ "output": "This is content",
+ "reasoning_content": "This is content",
+ "content": None,
+ "is_reasoning_end": False,
+}
+NO_REASONING_STREAMING: dict[str, Any] = {
+ "output": "This is a reasoning section",
+ "reasoning_content": "This is a reasoning section",
+ "content": None,
+ "is_reasoning_end": False,
+}
+MULTIPLE_LINES: dict[str, Any] = {
+ "output": "This\nThatThis is the rest\nThat",
+ "reasoning_content": "This\nThat",
+ "content": "This is the rest\nThat",
+ "is_reasoning_end": True,
+}
+WITH_START_TOKEN: dict[str, Any] = {
+ "output": ("This is a reasoning section"
+ "This is the rest"),
+ "reasoning_content":
+ "This is a reasoning section",
+ "content":
+ "This is the rest",
+ "is_reasoning_end":
+ True,
+}
+ONLY_END_TOKEN: dict[str, Any] = {
+ "output": "Some reasoningThis is the rest",
+ "reasoning_content": "Some reasoning",
+ "content": "This is the rest",
+ "is_reasoning_end": True,
+}
+NO_TOKENS: dict[str, Any] = {
+ "output": "This is just content without any reasoning tokens",
+ "reasoning_content": "This is just content without any reasoning tokens",
+ "content": None,
+ "is_reasoning_end": False,
+}
+
+
+def test_seedoss_reasoning_parser_creation(seedoss_tokenizer):
+ """Test that the SeedOSS reasoning parser can be created and registered."""
+ parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
+ parser = parser_cls(seedoss_tokenizer)
+ assert isinstance(parser, ReasoningParser)
+ assert parser.start_token == start_token
+ assert parser.end_token == end_token
+
+
+@pytest.mark.parametrize("streaming", [True, False])
+def test_simple_reasoning(seedoss_tokenizer, streaming):
+ """Test basic reasoning extraction with both tokens."""
+ parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
+ parser = parser_cls(seedoss_tokenizer)
+
+ reasoning, content = run_reasoning_extraction(
+ parser, [cast(str, SIMPLE_REASONING["output"])], streaming=streaming)
+
+ assert reasoning == SIMPLE_REASONING["reasoning_content"]
+ assert content == SIMPLE_REASONING["content"]
+
+
+@pytest.mark.parametrize("streaming", [True, False])
+def test_complete_reasoning(seedoss_tokenizer, streaming):
+ """Test reasoning extraction when there's no content after reasoning."""
+ parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
+ parser = parser_cls(seedoss_tokenizer)
+
+ reasoning, content = run_reasoning_extraction(
+ parser, [cast(str, COMPLETE_REASONING["output"])], streaming=streaming)
+
+ assert reasoning == COMPLETE_REASONING["reasoning_content"]
+ assert content == COMPLETE_REASONING["content"]
+
+
+@pytest.mark.parametrize("streaming", [True, False])
+def test_no_content(seedoss_tokenizer, streaming):
+ """Test when there's no end token - everything is reasoning content."""
+ parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
+ parser = parser_cls(seedoss_tokenizer)
+
+ reasoning, content = run_reasoning_extraction(
+ parser, [cast(str, NO_CONTENT["output"])], streaming=streaming)
+
+ assert reasoning == NO_CONTENT["reasoning_content"]
+ assert content == NO_CONTENT["content"]
+
+
+@pytest.mark.parametrize("streaming", [True, False])
+def test_multiple_lines(seedoss_tokenizer, streaming):
+ """Test reasoning extraction with multiline content."""
+ parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
+ parser = parser_cls(seedoss_tokenizer)
+
+ reasoning, content = run_reasoning_extraction(
+ parser, [cast(str, MULTIPLE_LINES["output"])], streaming=streaming)
+
+ assert reasoning == MULTIPLE_LINES["reasoning_content"]
+ assert content == MULTIPLE_LINES["content"]
+
+
+@pytest.mark.parametrize("streaming", [True, False])
+def test_with_start_token(seedoss_tokenizer, streaming):
+ """Test reasoning extraction with both start and end tokens."""
+ parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
+ parser = parser_cls(seedoss_tokenizer)
+
+ reasoning, content = run_reasoning_extraction(
+ parser, [cast(str, WITH_START_TOKEN["output"])], streaming=streaming)
+
+ assert reasoning == WITH_START_TOKEN["reasoning_content"]
+ assert content == WITH_START_TOKEN["content"]
+
+
+@pytest.mark.parametrize("streaming", [True, False])
+def test_only_end_token(seedoss_tokenizer, streaming):
+ """
+ Test reasoning extraction with only end token
+ (SeedOSS typical behavior).
+ """
+ parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
+ parser = parser_cls(seedoss_tokenizer)
+
+ reasoning, content = run_reasoning_extraction(
+ parser, [cast(str, ONLY_END_TOKEN["output"])], streaming=streaming)
+
+ assert reasoning == ONLY_END_TOKEN["reasoning_content"]
+ assert content == ONLY_END_TOKEN["content"]
+
+
+@pytest.mark.parametrize("streaming", [True, False])
+def test_no_tokens(seedoss_tokenizer, streaming):
+ """Test when there are no reasoning tokens at all."""
+ parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
+ parser = parser_cls(seedoss_tokenizer)
+
+ reasoning, content = run_reasoning_extraction(
+ parser, [cast(str, NO_TOKENS["output"])], streaming=streaming)
+
+ assert reasoning == NO_TOKENS["reasoning_content"]
+ assert content == NO_TOKENS["content"]
+
+
+def test_is_reasoning_end(seedoss_tokenizer):
+ """Test the is_reasoning_end method."""
+ parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
+ parser = parser_cls(seedoss_tokenizer)
+
+ # Test with end token present
+ end_token_id = parser.end_token_id
+ assert parser.is_reasoning_end([1, 2, end_token_id, 4]) is True
+
+ # Test without end token
+ assert parser.is_reasoning_end([1, 2, 3, 4]) is False
+
+
+def test_extract_content_ids(seedoss_tokenizer):
+ """Test the extract_content_ids method."""
+ parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
+ parser = parser_cls(seedoss_tokenizer)
+
+ end_token_id = parser.end_token_id
+
+ # Test with end token in the middle
+ input_ids = [1, 2, end_token_id, 4, 5]
+ content_ids = parser.extract_content_ids(input_ids)
+ assert content_ids == [4, 5]
+
+ # Test with end token at the end
+ input_ids = [1, 2, 3, end_token_id]
+ content_ids = parser.extract_content_ids(input_ids)
+ assert content_ids == []
+
+ # Test without end token
+ input_ids = [1, 2, 3, 4]
+ content_ids = parser.extract_content_ids(input_ids)
+ assert content_ids == []
+
+
+def test_streaming_delta_processing(seedoss_tokenizer):
+ """Test streaming processing with small deltas."""
+ parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
+ parser = parser_cls(seedoss_tokenizer)
+
+ # Test streaming with incremental tokens
+ deltas = [
+ "Some ", "reasoning ", "content", "", "Final ", "answer"
+ ]
+
+ reasoning, content = run_reasoning_extraction(parser,
+ deltas,
+ streaming=True)
+
+ assert reasoning == "Some reasoning content"
+ assert content == "Final answer"
diff --git a/vllm/reasoning/__init__.py b/vllm/reasoning/__init__.py
index b987adeb6428..3c8a9c6ae0d3 100644
--- a/vllm/reasoning/__init__.py
+++ b/vllm/reasoning/__init__.py
@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .abs_reasoning_parsers import ReasoningParser, ReasoningParserManager
+from .basic_parsers import BaseThinkingReasoningParser
from .deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
from .glm4_moe_reasoning_parser import Glm4MoeModelReasoningParser
from .gptoss_reasoning_parser import GptOssReasoningParser
@@ -9,10 +10,12 @@ from .granite_reasoning_parser import GraniteReasoningParser
from .hunyuan_a13b_reasoning_parser import HunyuanA13BReasoningParser
from .mistral_reasoning_parser import MistralReasoningParser
from .qwen3_reasoning_parser import Qwen3ReasoningParser
+from .seedoss_reasoning_parser import SeedOSSReasoningParser
from .step3_reasoning_parser import Step3ReasoningParser
__all__ = [
"ReasoningParser",
+ "BaseThinkingReasoningParser",
"ReasoningParserManager",
"DeepSeekR1ReasoningParser",
"GraniteReasoningParser",
@@ -22,4 +25,5 @@ __all__ = [
"MistralReasoningParser",
"Step3ReasoningParser",
"GptOssReasoningParser",
+ "SeedOSSReasoningParser",
]
diff --git a/vllm/reasoning/abs_reasoning_parsers.py b/vllm/reasoning/abs_reasoning_parsers.py
index df9e84163f16..39b08ec11107 100644
--- a/vllm/reasoning/abs_reasoning_parsers.py
+++ b/vllm/reasoning/abs_reasoning_parsers.py
@@ -7,7 +7,7 @@ import os
from abc import abstractmethod
from collections.abc import Sequence
from functools import cached_property
-from typing import TYPE_CHECKING, Any, Callable, Optional, Union
+from typing import TYPE_CHECKING, Any, Callable, Union
from vllm.logger import init_logger
from vllm.utils import import_from_path, is_list_of
@@ -77,7 +77,7 @@ class ReasoningParser:
self,
model_output: str,
request: Union[ChatCompletionRequest, ResponsesRequest],
- ) -> tuple[Optional[str], Optional[str]]:
+ ) -> tuple[str | None, str | None]:
"""
Extract reasoning content from a complete model-generated string.
@@ -135,7 +135,7 @@ class ReasoningParserManager:
def _register_module(
cls,
module: type,
- module_name: Optional[Union[str, list[str]]] = None,
+ module_name: Union[str, list[str]] | None = None,
force: bool = True,
) -> None:
if not issubclass(module, ReasoningParser):
@@ -155,7 +155,7 @@ class ReasoningParserManager:
@classmethod
def register_module(
cls,
- name: Optional[Union[str, list[str]]] = None,
+ name: Union[str, list[str]] | None = None,
force: bool = True,
module: Union[type, None] = None,
) -> Union[type, Callable]:
diff --git a/vllm/reasoning/basic_parsers.py b/vllm/reasoning/basic_parsers.py
new file mode 100644
index 000000000000..03cb882c2693
--- /dev/null
+++ b/vllm/reasoning/basic_parsers.py
@@ -0,0 +1,156 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from abc import abstractmethod
+from collections.abc import Sequence
+from typing import Optional, Union
+
+from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
+ DeltaMessage, ResponsesRequest)
+from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
+from vllm.transformers_utils.tokenizer import AnyTokenizer
+
+
+class BaseThinkingReasoningParser(ReasoningParser):
+ """
+ Base class for reasoning parsers that use thinking tokens.
+
+ This class provides common functionality for parsers that use start and end
+ tokens to delimit reasoning content (
+ e.g., ..., ...).
+
+ Subclasses must implement the start and end tokens via abstract
+ properties.
+ """
+
+ @property
+ @abstractmethod
+ def start_token(self) -> str:
+ """The token that starts reasoning content."""
+ raise NotImplementedError
+
+ @property
+ @abstractmethod
+ def end_token(self) -> str:
+ """The token that ends reasoning content."""
+ raise NotImplementedError
+
+ def __init__(self, tokenizer: AnyTokenizer):
+ super().__init__(tokenizer)
+
+ if not self.model_tokenizer:
+ raise ValueError(
+ "The model tokenizer must be passed to the ReasoningParser "
+ "constructor during construction.")
+
+ if not self.start_token or not self.end_token:
+ raise ValueError(
+ "start_token and end_token must be defined in subclasses")
+
+ self.start_token_id = self.vocab.get(self.start_token)
+ self.end_token_id = self.vocab.get(self.end_token)
+ if self.start_token_id is None or self.end_token_id is None:
+ raise RuntimeError(
+ f"{self.__class__.__name__} reasoning parser could not locate "
+ "think start/end tokens in the tokenizer!")
+
+ def is_reasoning_end(self, input_ids: list[int]) -> bool:
+ return self.end_token_id in input_ids
+
+ def extract_content_ids(self, input_ids: list[int]) -> list[int]:
+ """
+ Extract the content after the end tokens
+ """
+ if self.end_token_id not in input_ids[:-1]:
+ return []
+ else:
+ return input_ids[input_ids.index(self.end_token_id) + 1:]
+
+ def extract_reasoning_content_streaming(
+ self,
+ previous_text: str,
+ current_text: str,
+ delta_text: str,
+ previous_token_ids: Sequence[int],
+ current_token_ids: Sequence[int],
+ delta_token_ids: Sequence[int],
+ ) -> Union[DeltaMessage, None]:
+ """
+ Extract reasoning content from a delta message.
+ Handles streaming output where previous + delta = current.
+ Uses token IDs for faster processing.
+ """
+ # Skip single special tokens
+ if len(delta_token_ids) == 1 and (delta_token_ids[0] in [
+ self.start_token_id, self.end_token_id
+ ]):
+ return None
+
+ # Check if start token is present in previous or delta.
+ # Keep compatibility with models that don't generate start tokens.
+ if self.start_token_id in previous_token_ids:
+ if self.end_token_id in delta_token_ids:
+ # start token in previous, end token in delta,
+ # extract reasoning content
+ end_index = delta_text.find(self.end_token)
+ reasoning_content = delta_text[:end_index]
+ content = delta_text[end_index + len(self.end_token):]
+ return DeltaMessage(
+ reasoning_content=reasoning_content,
+ content=content if content else None,
+ )
+ elif self.end_token_id in previous_token_ids:
+ # start token in previous, end token in previous,
+ # reasoning content continues
+ return DeltaMessage(content=delta_text)
+ else:
+ # start token in previous, no end token in previous or delta,
+ # reasoning content continues
+ return DeltaMessage(reasoning_content=delta_text)
+ elif self.start_token_id in delta_token_ids:
+ if self.end_token_id in delta_token_ids:
+ # start token in delta, end token in delta,
+ # extract reasoning content
+ start_index = delta_text.find(self.start_token)
+ end_index = delta_text.find(self.end_token)
+ reasoning_content = delta_text[start_index +
+ len(self.start_token):end_index]
+ content = delta_text[end_index + len(self.end_token):]
+ return DeltaMessage(
+ reasoning_content=reasoning_content,
+ content=content if content else None,
+ )
+ else:
+ # start token in delta, no end token in delta,
+ # reasoning content continues
+ return DeltaMessage(reasoning_content=delta_text)
+ else:
+ # not find thinking start token
+ return DeltaMessage(content=delta_text)
+
+ def extract_reasoning_content(
+ self, model_output: str, request: Union[ChatCompletionRequest,
+ ResponsesRequest]
+ ) -> tuple[Optional[str], Optional[str]]:
+ """
+ Extract reasoning content from the model output.
+
+ This is the base implementation that works for most models.
+ Subclasses can override this method for specific behavior.
+ """
+ # Check if the start token is present in the model output, remove it
+ # if it is present.
+ model_output_parts = model_output.partition(self.start_token)
+ model_output = model_output_parts[2] if model_output_parts[
+ 1] else model_output_parts[0]
+
+ # For models that may not generate start token,
+ # assume the reasoning content is always at the start.
+ if self.end_token not in model_output:
+ return model_output, None
+ else:
+ reasoning_content, _, content = model_output.partition(
+ self.end_token)
+ # If generation stops right after end-of-think, return null content
+ final_content = content or None
+ return reasoning_content, final_content
diff --git a/vllm/reasoning/deepseek_r1_reasoning_parser.py b/vllm/reasoning/deepseek_r1_reasoning_parser.py
index 1a5ca46a60f1..76d2959e1c9a 100644
--- a/vllm/reasoning/deepseek_r1_reasoning_parser.py
+++ b/vllm/reasoning/deepseek_r1_reasoning_parser.py
@@ -2,20 +2,15 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
-from typing import Optional, Union
+from typing import Union
-from transformers import PreTrainedTokenizerBase
-
-from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
- DeltaMessage)
-from vllm.logger import init_logger
-from vllm.reasoning import ReasoningParser, ReasoningParserManager
-
-logger = init_logger(__name__)
+from vllm.entrypoints.openai.protocol import DeltaMessage
+from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager
+from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser
@ReasoningParserManager.register_module("deepseek_r1")
-class DeepSeekR1ReasoningParser(ReasoningParser):
+class DeepSeekR1ReasoningParser(BaseThinkingReasoningParser):
"""
Reasoning parser for DeepSeek R1 model.
@@ -23,38 +18,15 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
text. This parser extracts the reasoning content from the model output.
"""
- start_token_id: int
- end_token_id: int
+ @property
+ def start_token(self) -> str:
+ """The token that starts reasoning content."""
+ return ""
- start_token: str = ""
- end_token: str = ""
-
- def __init__(self, tokenizer: PreTrainedTokenizerBase):
- super().__init__(tokenizer)
-
- if not self.model_tokenizer:
- raise ValueError(
- "The model tokenizer must be passed to the ReasoningParser "
- "constructor during construction.")
-
- self.start_token_id = self.vocab.get(self.start_token)
- self.end_token_id = self.vocab.get(self.end_token)
- if self.start_token_id is None or self.end_token_id is None:
- raise RuntimeError(
- "DeepSeek R1 reasoning parser could not locate think start/end "
- "tokens in the tokenizer!")
-
- def is_reasoning_end(self, input_ids: list[int]) -> bool:
- return self.end_token_id in input_ids
-
- def extract_content_ids(self, input_ids: list[int]) -> list[int]:
- """
- Extract the content after the end tokens
- """
- if self.end_token_id not in input_ids[:-1]:
- return []
- else:
- return input_ids[input_ids.index(self.end_token_id) + 1:]
+ @property
+ def end_token(self) -> str:
+ """The token that ends reasoning content."""
+ return ""
def extract_reasoning_content_streaming(
self,
@@ -65,63 +37,18 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
) -> Union[DeltaMessage, None]:
- """
- Extract reasoning content from a delta message.
- Handles streaming output where previous + delta = current.
- Uses token IDs for faster processing.
- For text abcxyz:
- - 'abc' goes to reasoning_content
- - 'xyz' goes to content
- """
- # Skip single special tokens
- if len(delta_token_ids) == 1 and (delta_token_ids[0] in [
- self.start_token_id, self.end_token_id
- ]):
- return None
-
- # Check if is present in previous or delta.
- # Keep compatibility with models that don't generate tokens.
- if self.start_token_id in previous_token_ids:
+ ret = super().extract_reasoning_content_streaming(
+ previous_text,
+ current_text,
+ delta_text,
+ previous_token_ids,
+ current_token_ids,
+ delta_token_ids,
+ )
+ if (ret is not None and self.start_token_id not in previous_token_ids
+ and self.start_token_id not in delta_token_ids):
if self.end_token_id in delta_token_ids:
- # in previous, in delta,
- # extract reasoning content
- end_index = delta_text.find(self.end_token)
- reasoning_content = delta_text[:end_index]
- content = delta_text[end_index + len(self.end_token):]
- return DeltaMessage(
- reasoning_content=reasoning_content,
- content=content if content else None,
- )
- elif self.end_token_id in previous_token_ids:
- # in previous, in previous,
- # reasoning content continues
- return DeltaMessage(content=delta_text)
- else:
- # in previous, no in previous or delta,
- # reasoning content continues
- return DeltaMessage(reasoning_content=delta_text)
- elif self.start_token_id in delta_token_ids:
- if self.end_token_id in delta_token_ids:
- # in delta, in delta, extract reasoning content
- start_index = delta_text.find(self.start_token)
- end_index = delta_text.find(self.end_token)
- reasoning_content = delta_text[start_index +
- len(self.start_token):end_index]
- content = delta_text[end_index + len(self.end_token):]
- return DeltaMessage(
- reasoning_content=reasoning_content,
- content=content if content else None,
- )
- else:
- # in delta, no in delta,
- # reasoning content continues
- return DeltaMessage(reasoning_content=delta_text)
- else:
- # No in previous or delta, also need to check for .
- # Because the model may have generated without
- # Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f
- if self.end_token_id in delta_token_ids:
- # in delta with more tokens,
+ # end token in delta with more tokens,
# extract reasoning content and content
end_index = delta_text.find(self.end_token)
reasoning_content = delta_text[:end_index]
@@ -131,43 +58,10 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
content=content if content else None,
)
elif self.end_token_id in previous_token_ids:
- # in previous, thinking content ends
+ # end token in previous, thinking content ends
return DeltaMessage(content=delta_text)
else:
- # no in previous or delta, reasoning content continues
+ # no end token in previous or delta, reasoning content continues
return DeltaMessage(reasoning_content=delta_text)
- def extract_reasoning_content(
- self, model_output: str, request: ChatCompletionRequest
- ) -> tuple[Optional[str], Optional[str]]:
- """
- Extract reasoning content from the model output.
-
- For text abcxyz:
- - 'abc' goes to reasoning_content
- - 'xyz' goes to content
-
- Returns:
- tuple[Optional[str], Optional[str]]: reasoning content and content
- """
-
- # Check if the start token is present in the model output, remove it
- # if it is present.
- model_output_parts = model_output.partition(self.start_token)
- model_output = model_output_parts[2] if model_output_parts[
- 1] else model_output_parts[0]
-
- # DeepSeek R1 doesn't generate now.
- # Thus we assume the reasoning content is always at the start.
- # Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f
- if self.end_token not in model_output:
- return model_output, None
- else:
- reasoning_content, _, content = model_output.partition(
- self.end_token)
- # If the end token is not found, return the model output as is.
- # It should not happen since we already checked for the presence
- # of the end token.
- # If generation stops right after end-of-think, return null content
- final_content = content or None
- return reasoning_content, final_content
+ return ret
diff --git a/vllm/reasoning/mistral_reasoning_parser.py b/vllm/reasoning/mistral_reasoning_parser.py
index 6c707a4079fa..5cb54e6acbb3 100644
--- a/vllm/reasoning/mistral_reasoning_parser.py
+++ b/vllm/reasoning/mistral_reasoning_parser.py
@@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from functools import cached_property
+
from vllm.logger import init_logger
from vllm.reasoning import ReasoningParser, ReasoningParserManager
from vllm.reasoning.deepseek_r1_reasoning_parser import (
@@ -31,11 +33,6 @@ class MistralReasoningParser(DeepSeekR1ReasoningParser):
"The model tokenizer must be passed to the ReasoningParser "
"constructor during construction.")
- from mistral_common.tokens.tokenizers.base import SpecialTokens
-
- self.start_token = SpecialTokens.begin_think
- self.end_token = SpecialTokens.end_think
-
self.start_token_id = tokenizer.tokenizer.get_control_token(
self.start_token)
self.end_token_id = tokenizer.tokenizer.get_control_token(
@@ -45,3 +42,15 @@ class MistralReasoningParser(DeepSeekR1ReasoningParser):
raise RuntimeError(
"Mistral reasoning parser could not locate think start/end "
"tokens in the tokenizer!")
+
+ @cached_property
+ def start_token(self) -> str:
+ """The token that starts reasoning content."""
+ from mistral_common.tokens.tokenizers.base import SpecialTokens
+ return SpecialTokens.begin_think
+
+ @cached_property
+ def end_token(self) -> str:
+ """The token that ends reasoning content."""
+ from mistral_common.tokens.tokenizers.base import SpecialTokens
+ return SpecialTokens.end_think
diff --git a/vllm/reasoning/qwen3_reasoning_parser.py b/vllm/reasoning/qwen3_reasoning_parser.py
index 61bafc724c17..3e3c7f32796b 100644
--- a/vllm/reasoning/qwen3_reasoning_parser.py
+++ b/vllm/reasoning/qwen3_reasoning_parser.py
@@ -1,21 +1,16 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-from collections.abc import Sequence
from typing import Optional, Union
-from transformers import PreTrainedTokenizerBase
-
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
- DeltaMessage)
-from vllm.logger import init_logger
-from vllm.reasoning import ReasoningParser, ReasoningParserManager
-
-logger = init_logger(__name__)
+ ResponsesRequest)
+from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager
+from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser
@ReasoningParserManager.register_module("qwen3")
-class Qwen3ReasoningParser(ReasoningParser):
+class Qwen3ReasoningParser(BaseThinkingReasoningParser):
"""
Reasoning parser for the Qwen3 model.
@@ -26,100 +21,25 @@ class Qwen3ReasoningParser(ReasoningParser):
output.
"""
- def __init__(self, tokenizer: PreTrainedTokenizerBase):
- super().__init__(tokenizer)
- self.think_start_token = ""
- self.think_end_token = ""
+ @property
+ def start_token(self) -> str:
+ """The token that starts reasoning content."""
+ return ""
- if not self.model_tokenizer:
- raise ValueError(
- "The model tokenizer must be passed to the ReasoningParser "
- "constructor during construction.")
-
- self.think_start_token_id = self.vocab.get(self.think_start_token)
- self.think_end_token_id = self.vocab.get(self.think_end_token)
- if (self.think_start_token_id is None
- or self.think_end_token_id is None):
- raise RuntimeError(
- "Qwen3 reasoning parser could not locate think start/end "
- "tokens in the tokenizer!")
-
- def is_reasoning_end(self, input_ids: list[int]) -> bool:
- return self.think_end_token_id in input_ids
-
- def extract_content_ids(self, input_ids: list[int]) -> list[int]:
- """
- Extract the content after the end tokens
- """
- if self.think_end_token_id not in input_ids[:-1]:
- return []
- else:
- return input_ids[input_ids.index(self.think_end_token_id) + 1:]
-
- def extract_reasoning_content_streaming(
- self,
- previous_text: str,
- current_text: str,
- delta_text: str,
- previous_token_ids: Sequence[int],
- current_token_ids: Sequence[int],
- delta_token_ids: Sequence[int],
- ) -> Union[DeltaMessage, None]:
- """
- Extract reasoning content from a delta message.
- Handles streaming output where previous + delta = current.
- Uses token IDs for faster processing.
- For text abcxyz:
- - 'abc' goes to reasoning_content
- - 'xyz' goes to content
- """
- # Skip single special tokens
- if len(delta_token_ids) == 1 and (delta_token_ids[0] in [
- self.think_start_token_id, self.think_end_token_id
- ]):
- return None
-
- if self.think_start_token_id in previous_token_ids:
- if self.think_end_token_id in delta_token_ids:
- # in previous, in delta,
- # extract reasoning content
- end_index = delta_text.find(self.think_end_token)
- reasoning_content = delta_text[:end_index]
- content = delta_text[end_index + len(self.think_end_token):]
- return DeltaMessage(reasoning_content=reasoning_content,
- content=content if content else None)
- elif self.think_end_token_id in previous_token_ids:
- # in previous, in previous,
- # reasoning content continues
- return DeltaMessage(content=delta_text)
- else:
- # in previous, no in previous or delta,
- # reasoning content continues
- return DeltaMessage(reasoning_content=delta_text)
- elif self.think_start_token_id in delta_token_ids:
- if self.think_end_token_id in delta_token_ids:
- # in delta, in delta, extract reasoning content
- start_index = delta_text.find(self.think_start_token)
- end_index = delta_text.find(self.think_end_token)
- reasoning_content = delta_text[start_index +
- len(self.think_start_token
- ):end_index]
- content = delta_text[end_index + len(self.think_end_token):]
- return DeltaMessage(reasoning_content=reasoning_content,
- content=content if content else None)
- else:
- # in delta, no in delta,
- # reasoning content continues
- return DeltaMessage(reasoning_content=delta_text)
- else:
- # thinking is disabled, just content
- return DeltaMessage(content=delta_text)
+ @property
+ def end_token(self) -> str:
+ """The token that ends reasoning content."""
+ return ""
def extract_reasoning_content(
- self, model_output: str, request: ChatCompletionRequest
+ self, model_output: str, request: Union[ChatCompletionRequest,
+ ResponsesRequest]
) -> tuple[Optional[str], Optional[str]]:
"""
Extract reasoning content from the model output.
+
+ Qwen3 has stricter requirements - it needs both start and end tokens
+ to be present, unlike other models that work with just the end token.
For text abcxyz:
- 'abc' goes to reasoning_content
@@ -129,23 +49,24 @@ class Qwen3ReasoningParser(ReasoningParser):
tuple[Optional[str], Optional[str]]: reasoning content and content
"""
- # Check if the model output contains the and tokens.
- if (self.think_start_token not in model_output
- or self.think_end_token not in model_output):
+ # Check if the model output contains both and tokens.
+ if (self.start_token not in model_output
+ or self.end_token not in model_output):
return None, model_output
+
# Check if the is present in the model output, remove it
# if it is present.
- model_output_parts = model_output.partition(self.think_start_token)
+ model_output_parts = model_output.partition(self.start_token)
model_output = model_output_parts[2] if model_output_parts[
1] else model_output_parts[0]
+
# Check if the model output contains the tokens.
# If the end token is not found, return the model output as is.
- if self.think_end_token not in model_output:
+ if self.end_token not in model_output:
return None, model_output
# Extract reasoning content from the model output.
- reasoning_content, _, content = model_output.partition(
- self.think_end_token)
+ reasoning_content, _, content = model_output.partition(self.end_token)
final_content = content or None
return reasoning_content, final_content
diff --git a/vllm/reasoning/seedoss_reasoning_parser.py b/vllm/reasoning/seedoss_reasoning_parser.py
new file mode 100644
index 000000000000..5f4bbbf1557e
--- /dev/null
+++ b/vllm/reasoning/seedoss_reasoning_parser.py
@@ -0,0 +1,28 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager
+from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser
+
+
+@ReasoningParserManager.register_module("seed_oss")
+class SeedOSSReasoningParser(BaseThinkingReasoningParser):
+ """
+ Reasoning parser for SeedOSS model.
+
+ The SeedOSS model uses ... tokens to
+ denote reasoning content text. This parser extracts
+ the reasoning content from the model output.
+ Similar to DeepSeek R1, it supports cases
+ where the model doesn't generate the start token.
+ """
+
+ @property
+ def start_token(self) -> str:
+ """The token that starts reasoning content."""
+ return ""
+
+ @property
+ def end_token(self) -> str:
+ """The token that ends reasoning content."""
+ return ""