diff --git a/tests/reasoning/test_base_thinking_reasoning_parser.py b/tests/reasoning/test_base_thinking_reasoning_parser.py new file mode 100644 index 000000000000..6a939dcfc2c9 --- /dev/null +++ b/tests/reasoning/test_base_thinking_reasoning_parser.py @@ -0,0 +1,392 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +from transformers import AutoTokenizer + +from tests.reasoning.utils import run_reasoning_extraction +from vllm.entrypoints.openai.protocol import ChatCompletionRequest +from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser + + +# Create a concrete test implementation of BaseThinkingReasoningParser +class TestThinkingReasoningParser(BaseThinkingReasoningParser): + """Test implementation of BaseThinkingReasoningParser.""" + + @property + def start_token(self) -> str: + return "" + + @property + def end_token(self) -> str: + return "" + + +class TestThinkingReasoningParserAlt(BaseThinkingReasoningParser): + """Alternative test implementation with different tokens.""" + + @property + def start_token(self) -> str: + return "" + + @property + def end_token(self) -> str: + return "" + + +# Use a test model +REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" + + +@pytest.fixture(scope="module") +def test_tokenizer(): + tokenizer = AutoTokenizer.from_pretrained(REASONING_MODEL_NAME) + # Add custom test tokens + test_tokens = ["", "", "", ""] + existing_tokens = set(tokenizer.get_vocab().keys()) + new_tokens = [ + token for token in test_tokens if token not in existing_tokens + ] + if new_tokens: + tokenizer.add_tokens(new_tokens) + return tokenizer + + +class TestBaseThinkingReasoningParserInit: + """ + Test initialization and basic properties of + BaseThinkingReasoningParser. + """ + + def test_successful_initialization(self, test_tokenizer): + """Test successful initialization with valid tokens.""" + parser = TestThinkingReasoningParser(test_tokenizer) + assert parser.start_token == "" + assert parser.end_token == "" + assert parser.start_token_id is not None + assert parser.end_token_id is not None + + def test_initialization_with_missing_tokenizer(self): + """Test that initialization fails without tokenizer.""" + with pytest.raises(ValueError, match="model tokenizer must be passed"): + TestThinkingReasoningParser(None) + + def test_initialization_with_missing_tokens(self, test_tokenizer): + """Test that initialization fails when tokens are not in vocabulary.""" + + # Create a parser with tokens not in vocabulary + class MissingTokenParser(BaseThinkingReasoningParser): + + @property + def start_token(self) -> str: + return "" + + @property + def end_token(self) -> str: + return "" + + with pytest.raises(RuntimeError, + match="could not locate think start/end tokens"): + MissingTokenParser(test_tokenizer) + + def test_initialization_with_empty_tokens(self, test_tokenizer): + """Test that initialization fails with empty token strings.""" + + class EmptyTokenParser(BaseThinkingReasoningParser): + + @property + def start_token(self) -> str: + return "" + + @property + def end_token(self) -> str: + return "" + + with pytest.raises(ValueError, + match="start_token and end_token must be defined"): + EmptyTokenParser(test_tokenizer) + + +class TestBaseThinkingReasoningParserMethods: + """Test the methods of BaseThinkingReasoningParser.""" + + def test_is_reasoning_end(self, test_tokenizer): + """Test the is_reasoning_end method.""" + parser = TestThinkingReasoningParser(test_tokenizer) + end_token_id = parser.end_token_id + + # Test with end token present + assert parser.is_reasoning_end([1, 2, end_token_id, 4]) is True + + # Test without end token + assert parser.is_reasoning_end([1, 2, 3, 4]) is False + + # Test with empty list + assert parser.is_reasoning_end([]) is False + + def test_extract_content_ids(self, test_tokenizer): + """Test the extract_content_ids method.""" + parser = TestThinkingReasoningParser(test_tokenizer) + end_token_id = parser.end_token_id + + # Test with end token in the middle + input_ids = [1, 2, end_token_id, 4, 5] + content_ids = parser.extract_content_ids(input_ids) + assert content_ids == [4, 5] + + # Test with end token at the end + input_ids = [1, 2, 3, end_token_id] + content_ids = parser.extract_content_ids(input_ids) + assert content_ids == [] + + # Test without end token + input_ids = [1, 2, 3, 4] + content_ids = parser.extract_content_ids(input_ids) + assert content_ids == [] + + # Test with end token as last element (should not extract) + input_ids = [1, 2, 3, end_token_id] + content_ids = parser.extract_content_ids(input_ids) + assert content_ids == [] + + +class TestBaseThinkingReasoningParserExtraction: + """Test reasoning content extraction methods.""" + + def test_extract_reasoning_content_with_both_tokens(self, test_tokenizer): + """Test extraction when both start and end tokens are present.""" + parser = TestThinkingReasoningParser(test_tokenizer) + request = ChatCompletionRequest(messages=[], model="test-model") + + model_output = ("This is reasoning" + "This is content") + reasoning, content = parser.extract_reasoning_content( + model_output, request) + + assert reasoning == "This is reasoning" + assert content == "This is content" + + def test_extract_reasoning_content_only_end_token(self, test_tokenizer): + """Test extraction when only end token is present.""" + parser = TestThinkingReasoningParser(test_tokenizer) + request = ChatCompletionRequest(messages=[], model="test-model") + + model_output = ("This is reasoningThis is content") + reasoning, content = parser.extract_reasoning_content( + model_output, request) + + assert reasoning == "This is reasoning" + assert content == "This is content" + + def test_extract_reasoning_content_no_end_token(self, test_tokenizer): + """Test extraction when no end token is present.""" + parser = TestThinkingReasoningParser(test_tokenizer) + request = ChatCompletionRequest(messages=[], model="test-model") + + model_output = "This is just content" + reasoning, content = parser.extract_reasoning_content( + model_output, request) + + assert reasoning == "This is just content" + assert content is None + + def test_extract_reasoning_content_empty_output(self, test_tokenizer): + """Test extraction with empty output.""" + parser = TestThinkingReasoningParser(test_tokenizer) + request = ChatCompletionRequest(messages=[], model="test-model") + + model_output = "" + reasoning, content = parser.extract_reasoning_content( + model_output, request) + + assert reasoning == "" + assert content is None + + def test_extract_reasoning_content_only_tokens(self, test_tokenizer): + """Test extraction with only tokens and no content.""" + parser = TestThinkingReasoningParser(test_tokenizer) + request = ChatCompletionRequest(messages=[], model="test-model") + + model_output = ("") + reasoning, content = parser.extract_reasoning_content( + model_output, request) + + assert reasoning == "" + assert content is None + + +class TestBaseThinkingReasoningParserStreaming: + """Test streaming functionality of BaseThinkingReasoningParser.""" + + @pytest.mark.parametrize("streaming", [True, False]) + def test_simple_reasoning_extraction(self, test_tokenizer, streaming): + """ + Test basic reasoning extraction in both + streaming and non-streaming modes. + """ + parser = TestThinkingReasoningParser(test_tokenizer) + + model_output = [ + "", "Some ", "reasoning ", "content", "", + "Final ", "answer" + ] + + reasoning, content = run_reasoning_extraction(parser, + model_output, + streaming=streaming) + + assert reasoning == "Some reasoning content" + assert content == "Final answer" + + def test_streaming_with_incremental_deltas(self, test_tokenizer): + """Test streaming processing with small incremental deltas.""" + parser = TestThinkingReasoningParser(test_tokenizer) + + deltas = [ + "", + "Some ", + "reasoning ", + "content", + "", + "Final ", + "answer", + ] + + reasoning, content = run_reasoning_extraction(parser, + deltas, + streaming=True) + + assert reasoning == "Some reasoning content" + assert content == "Final answer" + + def test_streaming_with_start_token(self, test_tokenizer): + """Test streaming with start token included.""" + parser = TestThinkingReasoningParser(test_tokenizer) + + deltas = [ + "", + "Some ", + "reasoning", + "", + "Answer", + ] + + reasoning, content = run_reasoning_extraction(parser, + deltas, + streaming=True) + + assert reasoning == "Some reasoning" + assert content == "Answer" + + def test_streaming_no_end_token(self, test_tokenizer): + """Test streaming when no end token is encountered.""" + parser = TestThinkingReasoningParser(test_tokenizer) + + deltas = [ + "", + "Some ", + "reasoning ", + "without ", + "end", + ] + + reasoning, content = run_reasoning_extraction(parser, + deltas, + streaming=True) + + assert reasoning == "Some reasoning without end" + assert content is None + + def test_streaming_only_end_token(self, test_tokenizer): + """Test streaming when only end token appears.""" + parser = TestThinkingReasoningParser(test_tokenizer) + + deltas = [ + "", + "Reasoning ", + "content", + "", + "Final", + ] + + reasoning, content = run_reasoning_extraction(parser, + deltas, + streaming=True) + + assert reasoning == "Reasoning content" + assert content == "Final" + + +class TestBaseThinkingReasoningParserMultipleImplementations: + """ + Test that multiple implementations of + BaseThinkingReasoningParser work correctly. + """ + + def test_different_token_implementations(self, test_tokenizer): + """ + Test that different implementations + with different tokens work independently. + """ + parser1 = TestThinkingReasoningParser(test_tokenizer) + parser2 = TestThinkingReasoningParserAlt(test_tokenizer) + + # Test parser1 + model_output1 = ("Reasoning1Content1") + reasoning1, content1 = run_reasoning_extraction( + parser1, [model_output1]) + assert reasoning1 == "Reasoning1" + assert content1 == "Content1" + + # Test parser2 + model_output2 = "Reasoning2Content2" + reasoning2, content2 = run_reasoning_extraction( + parser2, [model_output2]) + assert reasoning2 == "Reasoning2" + assert content2 == "Content2" + + # Verify tokens are different + assert parser1.start_token != parser2.start_token + assert parser1.end_token != parser2.end_token + assert parser1.start_token_id != parser2.start_token_id + assert parser1.end_token_id != parser2.end_token_id + + +class TestBaseThinkingReasoningParserEdgeCases: + """Test edge cases and error conditions.""" + + def test_multiple_end_tokens(self, test_tokenizer): + """Test behavior with multiple end tokens.""" + parser = TestThinkingReasoningParser(test_tokenizer) + + model_output = ("FirstMiddleLast") + reasoning, content = run_reasoning_extraction(parser, [model_output]) + + # Should stop at first end token + assert reasoning == "First" + assert content == "MiddleLast" + + def test_nested_tokens(self, test_tokenizer): + """Test behavior with nested-like token patterns.""" + parser = TestThinkingReasoningParser(test_tokenizer) + + model_output = ("Outer" + "InnerContent") + reasoning, content = run_reasoning_extraction(parser, [model_output]) + + # Should process normally, start from first start token + assert reasoning == "OuterInner" + assert content == "Content" + + def test_malformed_tokens(self, test_tokenizer): + """Test behavior with malformed token-like strings.""" + parser = TestThinkingReasoningParser(test_tokenizer) + + model_output = ("Not a real token" + "Content") + reasoning, content = run_reasoning_extraction(parser, [model_output]) + + # Should treat as regular content since tokens don't match exactly + assert reasoning == ("Not a real token" + "Content") + assert content is None diff --git a/tests/reasoning/test_seedoss_reasoning_parser.py b/tests/reasoning/test_seedoss_reasoning_parser.py new file mode 100644 index 000000000000..bb5dc0f4ffe4 --- /dev/null +++ b/tests/reasoning/test_seedoss_reasoning_parser.py @@ -0,0 +1,237 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any, cast + +import pytest +from transformers import AutoTokenizer + +from tests.reasoning.utils import run_reasoning_extraction +from vllm.reasoning import ReasoningParser, ReasoningParserManager + +parser_name = "seed_oss" +start_token = "" +end_token = "" + +# Use a test model that contains our custom tokens +REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" + + +@pytest.fixture(scope="module") +def seedoss_tokenizer(): + tokenizer = AutoTokenizer.from_pretrained(REASONING_MODEL_NAME) + # Add custom SeedOSS tokens if they don't exist + if start_token not in tokenizer.get_vocab(): + tokenizer.add_tokens([start_token, end_token]) + return tokenizer + + +SIMPLE_REASONING: dict[str, Any] = { + "output": "This is a reasoning sectionThis is the rest", + "reasoning_content": "This is a reasoning section", + "content": "This is the rest", + "is_reasoning_end": True, +} +COMPLETE_REASONING: dict[str, Any] = { + "output": "This is a reasoning section", + "reasoning_content": "This is a reasoning section", + "content": None, + "is_reasoning_end": True, +} +NO_CONTENT: dict[str, Any] = { + "output": "This is content", + "reasoning_content": "This is content", + "content": None, + "is_reasoning_end": False, +} +NO_REASONING_STREAMING: dict[str, Any] = { + "output": "This is a reasoning section", + "reasoning_content": "This is a reasoning section", + "content": None, + "is_reasoning_end": False, +} +MULTIPLE_LINES: dict[str, Any] = { + "output": "This\nThatThis is the rest\nThat", + "reasoning_content": "This\nThat", + "content": "This is the rest\nThat", + "is_reasoning_end": True, +} +WITH_START_TOKEN: dict[str, Any] = { + "output": ("This is a reasoning section" + "This is the rest"), + "reasoning_content": + "This is a reasoning section", + "content": + "This is the rest", + "is_reasoning_end": + True, +} +ONLY_END_TOKEN: dict[str, Any] = { + "output": "Some reasoningThis is the rest", + "reasoning_content": "Some reasoning", + "content": "This is the rest", + "is_reasoning_end": True, +} +NO_TOKENS: dict[str, Any] = { + "output": "This is just content without any reasoning tokens", + "reasoning_content": "This is just content without any reasoning tokens", + "content": None, + "is_reasoning_end": False, +} + + +def test_seedoss_reasoning_parser_creation(seedoss_tokenizer): + """Test that the SeedOSS reasoning parser can be created and registered.""" + parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name) + parser = parser_cls(seedoss_tokenizer) + assert isinstance(parser, ReasoningParser) + assert parser.start_token == start_token + assert parser.end_token == end_token + + +@pytest.mark.parametrize("streaming", [True, False]) +def test_simple_reasoning(seedoss_tokenizer, streaming): + """Test basic reasoning extraction with both tokens.""" + parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name) + parser = parser_cls(seedoss_tokenizer) + + reasoning, content = run_reasoning_extraction( + parser, [cast(str, SIMPLE_REASONING["output"])], streaming=streaming) + + assert reasoning == SIMPLE_REASONING["reasoning_content"] + assert content == SIMPLE_REASONING["content"] + + +@pytest.mark.parametrize("streaming", [True, False]) +def test_complete_reasoning(seedoss_tokenizer, streaming): + """Test reasoning extraction when there's no content after reasoning.""" + parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name) + parser = parser_cls(seedoss_tokenizer) + + reasoning, content = run_reasoning_extraction( + parser, [cast(str, COMPLETE_REASONING["output"])], streaming=streaming) + + assert reasoning == COMPLETE_REASONING["reasoning_content"] + assert content == COMPLETE_REASONING["content"] + + +@pytest.mark.parametrize("streaming", [True, False]) +def test_no_content(seedoss_tokenizer, streaming): + """Test when there's no end token - everything is reasoning content.""" + parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name) + parser = parser_cls(seedoss_tokenizer) + + reasoning, content = run_reasoning_extraction( + parser, [cast(str, NO_CONTENT["output"])], streaming=streaming) + + assert reasoning == NO_CONTENT["reasoning_content"] + assert content == NO_CONTENT["content"] + + +@pytest.mark.parametrize("streaming", [True, False]) +def test_multiple_lines(seedoss_tokenizer, streaming): + """Test reasoning extraction with multiline content.""" + parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name) + parser = parser_cls(seedoss_tokenizer) + + reasoning, content = run_reasoning_extraction( + parser, [cast(str, MULTIPLE_LINES["output"])], streaming=streaming) + + assert reasoning == MULTIPLE_LINES["reasoning_content"] + assert content == MULTIPLE_LINES["content"] + + +@pytest.mark.parametrize("streaming", [True, False]) +def test_with_start_token(seedoss_tokenizer, streaming): + """Test reasoning extraction with both start and end tokens.""" + parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name) + parser = parser_cls(seedoss_tokenizer) + + reasoning, content = run_reasoning_extraction( + parser, [cast(str, WITH_START_TOKEN["output"])], streaming=streaming) + + assert reasoning == WITH_START_TOKEN["reasoning_content"] + assert content == WITH_START_TOKEN["content"] + + +@pytest.mark.parametrize("streaming", [True, False]) +def test_only_end_token(seedoss_tokenizer, streaming): + """ + Test reasoning extraction with only end token + (SeedOSS typical behavior). + """ + parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name) + parser = parser_cls(seedoss_tokenizer) + + reasoning, content = run_reasoning_extraction( + parser, [cast(str, ONLY_END_TOKEN["output"])], streaming=streaming) + + assert reasoning == ONLY_END_TOKEN["reasoning_content"] + assert content == ONLY_END_TOKEN["content"] + + +@pytest.mark.parametrize("streaming", [True, False]) +def test_no_tokens(seedoss_tokenizer, streaming): + """Test when there are no reasoning tokens at all.""" + parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name) + parser = parser_cls(seedoss_tokenizer) + + reasoning, content = run_reasoning_extraction( + parser, [cast(str, NO_TOKENS["output"])], streaming=streaming) + + assert reasoning == NO_TOKENS["reasoning_content"] + assert content == NO_TOKENS["content"] + + +def test_is_reasoning_end(seedoss_tokenizer): + """Test the is_reasoning_end method.""" + parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name) + parser = parser_cls(seedoss_tokenizer) + + # Test with end token present + end_token_id = parser.end_token_id + assert parser.is_reasoning_end([1, 2, end_token_id, 4]) is True + + # Test without end token + assert parser.is_reasoning_end([1, 2, 3, 4]) is False + + +def test_extract_content_ids(seedoss_tokenizer): + """Test the extract_content_ids method.""" + parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name) + parser = parser_cls(seedoss_tokenizer) + + end_token_id = parser.end_token_id + + # Test with end token in the middle + input_ids = [1, 2, end_token_id, 4, 5] + content_ids = parser.extract_content_ids(input_ids) + assert content_ids == [4, 5] + + # Test with end token at the end + input_ids = [1, 2, 3, end_token_id] + content_ids = parser.extract_content_ids(input_ids) + assert content_ids == [] + + # Test without end token + input_ids = [1, 2, 3, 4] + content_ids = parser.extract_content_ids(input_ids) + assert content_ids == [] + + +def test_streaming_delta_processing(seedoss_tokenizer): + """Test streaming processing with small deltas.""" + parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name) + parser = parser_cls(seedoss_tokenizer) + + # Test streaming with incremental tokens + deltas = [ + "Some ", "reasoning ", "content", "", "Final ", "answer" + ] + + reasoning, content = run_reasoning_extraction(parser, + deltas, + streaming=True) + + assert reasoning == "Some reasoning content" + assert content == "Final answer" diff --git a/vllm/reasoning/__init__.py b/vllm/reasoning/__init__.py index b987adeb6428..3c8a9c6ae0d3 100644 --- a/vllm/reasoning/__init__.py +++ b/vllm/reasoning/__init__.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from .abs_reasoning_parsers import ReasoningParser, ReasoningParserManager +from .basic_parsers import BaseThinkingReasoningParser from .deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser from .glm4_moe_reasoning_parser import Glm4MoeModelReasoningParser from .gptoss_reasoning_parser import GptOssReasoningParser @@ -9,10 +10,12 @@ from .granite_reasoning_parser import GraniteReasoningParser from .hunyuan_a13b_reasoning_parser import HunyuanA13BReasoningParser from .mistral_reasoning_parser import MistralReasoningParser from .qwen3_reasoning_parser import Qwen3ReasoningParser +from .seedoss_reasoning_parser import SeedOSSReasoningParser from .step3_reasoning_parser import Step3ReasoningParser __all__ = [ "ReasoningParser", + "BaseThinkingReasoningParser", "ReasoningParserManager", "DeepSeekR1ReasoningParser", "GraniteReasoningParser", @@ -22,4 +25,5 @@ __all__ = [ "MistralReasoningParser", "Step3ReasoningParser", "GptOssReasoningParser", + "SeedOSSReasoningParser", ] diff --git a/vllm/reasoning/abs_reasoning_parsers.py b/vllm/reasoning/abs_reasoning_parsers.py index df9e84163f16..39b08ec11107 100644 --- a/vllm/reasoning/abs_reasoning_parsers.py +++ b/vllm/reasoning/abs_reasoning_parsers.py @@ -7,7 +7,7 @@ import os from abc import abstractmethod from collections.abc import Sequence from functools import cached_property -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Union from vllm.logger import init_logger from vllm.utils import import_from_path, is_list_of @@ -77,7 +77,7 @@ class ReasoningParser: self, model_output: str, request: Union[ChatCompletionRequest, ResponsesRequest], - ) -> tuple[Optional[str], Optional[str]]: + ) -> tuple[str | None, str | None]: """ Extract reasoning content from a complete model-generated string. @@ -135,7 +135,7 @@ class ReasoningParserManager: def _register_module( cls, module: type, - module_name: Optional[Union[str, list[str]]] = None, + module_name: Union[str, list[str]] | None = None, force: bool = True, ) -> None: if not issubclass(module, ReasoningParser): @@ -155,7 +155,7 @@ class ReasoningParserManager: @classmethod def register_module( cls, - name: Optional[Union[str, list[str]]] = None, + name: Union[str, list[str]] | None = None, force: bool = True, module: Union[type, None] = None, ) -> Union[type, Callable]: diff --git a/vllm/reasoning/basic_parsers.py b/vllm/reasoning/basic_parsers.py new file mode 100644 index 000000000000..03cb882c2693 --- /dev/null +++ b/vllm/reasoning/basic_parsers.py @@ -0,0 +1,156 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from abc import abstractmethod +from collections.abc import Sequence +from typing import Optional, Union + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaMessage, ResponsesRequest) +from vllm.reasoning.abs_reasoning_parsers import ReasoningParser +from vllm.transformers_utils.tokenizer import AnyTokenizer + + +class BaseThinkingReasoningParser(ReasoningParser): + """ + Base class for reasoning parsers that use thinking tokens. + + This class provides common functionality for parsers that use start and end + tokens to delimit reasoning content ( + e.g., ..., ...). + + Subclasses must implement the start and end tokens via abstract + properties. + """ + + @property + @abstractmethod + def start_token(self) -> str: + """The token that starts reasoning content.""" + raise NotImplementedError + + @property + @abstractmethod + def end_token(self) -> str: + """The token that ends reasoning content.""" + raise NotImplementedError + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ReasoningParser " + "constructor during construction.") + + if not self.start_token or not self.end_token: + raise ValueError( + "start_token and end_token must be defined in subclasses") + + self.start_token_id = self.vocab.get(self.start_token) + self.end_token_id = self.vocab.get(self.end_token) + if self.start_token_id is None or self.end_token_id is None: + raise RuntimeError( + f"{self.__class__.__name__} reasoning parser could not locate " + "think start/end tokens in the tokenizer!") + + def is_reasoning_end(self, input_ids: list[int]) -> bool: + return self.end_token_id in input_ids + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + """ + Extract the content after the end tokens + """ + if self.end_token_id not in input_ids[:-1]: + return [] + else: + return input_ids[input_ids.index(self.end_token_id) + 1:] + + def extract_reasoning_content_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> Union[DeltaMessage, None]: + """ + Extract reasoning content from a delta message. + Handles streaming output where previous + delta = current. + Uses token IDs for faster processing. + """ + # Skip single special tokens + if len(delta_token_ids) == 1 and (delta_token_ids[0] in [ + self.start_token_id, self.end_token_id + ]): + return None + + # Check if start token is present in previous or delta. + # Keep compatibility with models that don't generate start tokens. + if self.start_token_id in previous_token_ids: + if self.end_token_id in delta_token_ids: + # start token in previous, end token in delta, + # extract reasoning content + end_index = delta_text.find(self.end_token) + reasoning_content = delta_text[:end_index] + content = delta_text[end_index + len(self.end_token):] + return DeltaMessage( + reasoning_content=reasoning_content, + content=content if content else None, + ) + elif self.end_token_id in previous_token_ids: + # start token in previous, end token in previous, + # reasoning content continues + return DeltaMessage(content=delta_text) + else: + # start token in previous, no end token in previous or delta, + # reasoning content continues + return DeltaMessage(reasoning_content=delta_text) + elif self.start_token_id in delta_token_ids: + if self.end_token_id in delta_token_ids: + # start token in delta, end token in delta, + # extract reasoning content + start_index = delta_text.find(self.start_token) + end_index = delta_text.find(self.end_token) + reasoning_content = delta_text[start_index + + len(self.start_token):end_index] + content = delta_text[end_index + len(self.end_token):] + return DeltaMessage( + reasoning_content=reasoning_content, + content=content if content else None, + ) + else: + # start token in delta, no end token in delta, + # reasoning content continues + return DeltaMessage(reasoning_content=delta_text) + else: + # not find thinking start token + return DeltaMessage(content=delta_text) + + def extract_reasoning_content( + self, model_output: str, request: Union[ChatCompletionRequest, + ResponsesRequest] + ) -> tuple[Optional[str], Optional[str]]: + """ + Extract reasoning content from the model output. + + This is the base implementation that works for most models. + Subclasses can override this method for specific behavior. + """ + # Check if the start token is present in the model output, remove it + # if it is present. + model_output_parts = model_output.partition(self.start_token) + model_output = model_output_parts[2] if model_output_parts[ + 1] else model_output_parts[0] + + # For models that may not generate start token, + # assume the reasoning content is always at the start. + if self.end_token not in model_output: + return model_output, None + else: + reasoning_content, _, content = model_output.partition( + self.end_token) + # If generation stops right after end-of-think, return null content + final_content = content or None + return reasoning_content, final_content diff --git a/vllm/reasoning/deepseek_r1_reasoning_parser.py b/vllm/reasoning/deepseek_r1_reasoning_parser.py index 1a5ca46a60f1..76d2959e1c9a 100644 --- a/vllm/reasoning/deepseek_r1_reasoning_parser.py +++ b/vllm/reasoning/deepseek_r1_reasoning_parser.py @@ -2,20 +2,15 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Sequence -from typing import Optional, Union +from typing import Union -from transformers import PreTrainedTokenizerBase - -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaMessage) -from vllm.logger import init_logger -from vllm.reasoning import ReasoningParser, ReasoningParserManager - -logger = init_logger(__name__) +from vllm.entrypoints.openai.protocol import DeltaMessage +from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager +from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser @ReasoningParserManager.register_module("deepseek_r1") -class DeepSeekR1ReasoningParser(ReasoningParser): +class DeepSeekR1ReasoningParser(BaseThinkingReasoningParser): """ Reasoning parser for DeepSeek R1 model. @@ -23,38 +18,15 @@ class DeepSeekR1ReasoningParser(ReasoningParser): text. This parser extracts the reasoning content from the model output. """ - start_token_id: int - end_token_id: int + @property + def start_token(self) -> str: + """The token that starts reasoning content.""" + return "" - start_token: str = "" - end_token: str = "" - - def __init__(self, tokenizer: PreTrainedTokenizerBase): - super().__init__(tokenizer) - - if not self.model_tokenizer: - raise ValueError( - "The model tokenizer must be passed to the ReasoningParser " - "constructor during construction.") - - self.start_token_id = self.vocab.get(self.start_token) - self.end_token_id = self.vocab.get(self.end_token) - if self.start_token_id is None or self.end_token_id is None: - raise RuntimeError( - "DeepSeek R1 reasoning parser could not locate think start/end " - "tokens in the tokenizer!") - - def is_reasoning_end(self, input_ids: list[int]) -> bool: - return self.end_token_id in input_ids - - def extract_content_ids(self, input_ids: list[int]) -> list[int]: - """ - Extract the content after the end tokens - """ - if self.end_token_id not in input_ids[:-1]: - return [] - else: - return input_ids[input_ids.index(self.end_token_id) + 1:] + @property + def end_token(self) -> str: + """The token that ends reasoning content.""" + return "" def extract_reasoning_content_streaming( self, @@ -65,63 +37,18 @@ class DeepSeekR1ReasoningParser(ReasoningParser): current_token_ids: Sequence[int], delta_token_ids: Sequence[int], ) -> Union[DeltaMessage, None]: - """ - Extract reasoning content from a delta message. - Handles streaming output where previous + delta = current. - Uses token IDs for faster processing. - For text abcxyz: - - 'abc' goes to reasoning_content - - 'xyz' goes to content - """ - # Skip single special tokens - if len(delta_token_ids) == 1 and (delta_token_ids[0] in [ - self.start_token_id, self.end_token_id - ]): - return None - - # Check if is present in previous or delta. - # Keep compatibility with models that don't generate tokens. - if self.start_token_id in previous_token_ids: + ret = super().extract_reasoning_content_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + delta_token_ids, + ) + if (ret is not None and self.start_token_id not in previous_token_ids + and self.start_token_id not in delta_token_ids): if self.end_token_id in delta_token_ids: - # in previous, in delta, - # extract reasoning content - end_index = delta_text.find(self.end_token) - reasoning_content = delta_text[:end_index] - content = delta_text[end_index + len(self.end_token):] - return DeltaMessage( - reasoning_content=reasoning_content, - content=content if content else None, - ) - elif self.end_token_id in previous_token_ids: - # in previous, in previous, - # reasoning content continues - return DeltaMessage(content=delta_text) - else: - # in previous, no in previous or delta, - # reasoning content continues - return DeltaMessage(reasoning_content=delta_text) - elif self.start_token_id in delta_token_ids: - if self.end_token_id in delta_token_ids: - # in delta, in delta, extract reasoning content - start_index = delta_text.find(self.start_token) - end_index = delta_text.find(self.end_token) - reasoning_content = delta_text[start_index + - len(self.start_token):end_index] - content = delta_text[end_index + len(self.end_token):] - return DeltaMessage( - reasoning_content=reasoning_content, - content=content if content else None, - ) - else: - # in delta, no in delta, - # reasoning content continues - return DeltaMessage(reasoning_content=delta_text) - else: - # No in previous or delta, also need to check for . - # Because the model may have generated without - # Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f - if self.end_token_id in delta_token_ids: - # in delta with more tokens, + # end token in delta with more tokens, # extract reasoning content and content end_index = delta_text.find(self.end_token) reasoning_content = delta_text[:end_index] @@ -131,43 +58,10 @@ class DeepSeekR1ReasoningParser(ReasoningParser): content=content if content else None, ) elif self.end_token_id in previous_token_ids: - # in previous, thinking content ends + # end token in previous, thinking content ends return DeltaMessage(content=delta_text) else: - # no in previous or delta, reasoning content continues + # no end token in previous or delta, reasoning content continues return DeltaMessage(reasoning_content=delta_text) - def extract_reasoning_content( - self, model_output: str, request: ChatCompletionRequest - ) -> tuple[Optional[str], Optional[str]]: - """ - Extract reasoning content from the model output. - - For text abcxyz: - - 'abc' goes to reasoning_content - - 'xyz' goes to content - - Returns: - tuple[Optional[str], Optional[str]]: reasoning content and content - """ - - # Check if the start token is present in the model output, remove it - # if it is present. - model_output_parts = model_output.partition(self.start_token) - model_output = model_output_parts[2] if model_output_parts[ - 1] else model_output_parts[0] - - # DeepSeek R1 doesn't generate now. - # Thus we assume the reasoning content is always at the start. - # Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f - if self.end_token not in model_output: - return model_output, None - else: - reasoning_content, _, content = model_output.partition( - self.end_token) - # If the end token is not found, return the model output as is. - # It should not happen since we already checked for the presence - # of the end token. - # If generation stops right after end-of-think, return null content - final_content = content or None - return reasoning_content, final_content + return ret diff --git a/vllm/reasoning/mistral_reasoning_parser.py b/vllm/reasoning/mistral_reasoning_parser.py index 6c707a4079fa..5cb54e6acbb3 100644 --- a/vllm/reasoning/mistral_reasoning_parser.py +++ b/vllm/reasoning/mistral_reasoning_parser.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from functools import cached_property + from vllm.logger import init_logger from vllm.reasoning import ReasoningParser, ReasoningParserManager from vllm.reasoning.deepseek_r1_reasoning_parser import ( @@ -31,11 +33,6 @@ class MistralReasoningParser(DeepSeekR1ReasoningParser): "The model tokenizer must be passed to the ReasoningParser " "constructor during construction.") - from mistral_common.tokens.tokenizers.base import SpecialTokens - - self.start_token = SpecialTokens.begin_think - self.end_token = SpecialTokens.end_think - self.start_token_id = tokenizer.tokenizer.get_control_token( self.start_token) self.end_token_id = tokenizer.tokenizer.get_control_token( @@ -45,3 +42,15 @@ class MistralReasoningParser(DeepSeekR1ReasoningParser): raise RuntimeError( "Mistral reasoning parser could not locate think start/end " "tokens in the tokenizer!") + + @cached_property + def start_token(self) -> str: + """The token that starts reasoning content.""" + from mistral_common.tokens.tokenizers.base import SpecialTokens + return SpecialTokens.begin_think + + @cached_property + def end_token(self) -> str: + """The token that ends reasoning content.""" + from mistral_common.tokens.tokenizers.base import SpecialTokens + return SpecialTokens.end_think diff --git a/vllm/reasoning/qwen3_reasoning_parser.py b/vllm/reasoning/qwen3_reasoning_parser.py index 61bafc724c17..3e3c7f32796b 100644 --- a/vllm/reasoning/qwen3_reasoning_parser.py +++ b/vllm/reasoning/qwen3_reasoning_parser.py @@ -1,21 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Sequence from typing import Optional, Union -from transformers import PreTrainedTokenizerBase - from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaMessage) -from vllm.logger import init_logger -from vllm.reasoning import ReasoningParser, ReasoningParserManager - -logger = init_logger(__name__) + ResponsesRequest) +from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager +from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser @ReasoningParserManager.register_module("qwen3") -class Qwen3ReasoningParser(ReasoningParser): +class Qwen3ReasoningParser(BaseThinkingReasoningParser): """ Reasoning parser for the Qwen3 model. @@ -26,100 +21,25 @@ class Qwen3ReasoningParser(ReasoningParser): output. """ - def __init__(self, tokenizer: PreTrainedTokenizerBase): - super().__init__(tokenizer) - self.think_start_token = "" - self.think_end_token = "" + @property + def start_token(self) -> str: + """The token that starts reasoning content.""" + return "" - if not self.model_tokenizer: - raise ValueError( - "The model tokenizer must be passed to the ReasoningParser " - "constructor during construction.") - - self.think_start_token_id = self.vocab.get(self.think_start_token) - self.think_end_token_id = self.vocab.get(self.think_end_token) - if (self.think_start_token_id is None - or self.think_end_token_id is None): - raise RuntimeError( - "Qwen3 reasoning parser could not locate think start/end " - "tokens in the tokenizer!") - - def is_reasoning_end(self, input_ids: list[int]) -> bool: - return self.think_end_token_id in input_ids - - def extract_content_ids(self, input_ids: list[int]) -> list[int]: - """ - Extract the content after the end tokens - """ - if self.think_end_token_id not in input_ids[:-1]: - return [] - else: - return input_ids[input_ids.index(self.think_end_token_id) + 1:] - - def extract_reasoning_content_streaming( - self, - previous_text: str, - current_text: str, - delta_text: str, - previous_token_ids: Sequence[int], - current_token_ids: Sequence[int], - delta_token_ids: Sequence[int], - ) -> Union[DeltaMessage, None]: - """ - Extract reasoning content from a delta message. - Handles streaming output where previous + delta = current. - Uses token IDs for faster processing. - For text abcxyz: - - 'abc' goes to reasoning_content - - 'xyz' goes to content - """ - # Skip single special tokens - if len(delta_token_ids) == 1 and (delta_token_ids[0] in [ - self.think_start_token_id, self.think_end_token_id - ]): - return None - - if self.think_start_token_id in previous_token_ids: - if self.think_end_token_id in delta_token_ids: - # in previous, in delta, - # extract reasoning content - end_index = delta_text.find(self.think_end_token) - reasoning_content = delta_text[:end_index] - content = delta_text[end_index + len(self.think_end_token):] - return DeltaMessage(reasoning_content=reasoning_content, - content=content if content else None) - elif self.think_end_token_id in previous_token_ids: - # in previous, in previous, - # reasoning content continues - return DeltaMessage(content=delta_text) - else: - # in previous, no in previous or delta, - # reasoning content continues - return DeltaMessage(reasoning_content=delta_text) - elif self.think_start_token_id in delta_token_ids: - if self.think_end_token_id in delta_token_ids: - # in delta, in delta, extract reasoning content - start_index = delta_text.find(self.think_start_token) - end_index = delta_text.find(self.think_end_token) - reasoning_content = delta_text[start_index + - len(self.think_start_token - ):end_index] - content = delta_text[end_index + len(self.think_end_token):] - return DeltaMessage(reasoning_content=reasoning_content, - content=content if content else None) - else: - # in delta, no in delta, - # reasoning content continues - return DeltaMessage(reasoning_content=delta_text) - else: - # thinking is disabled, just content - return DeltaMessage(content=delta_text) + @property + def end_token(self) -> str: + """The token that ends reasoning content.""" + return "" def extract_reasoning_content( - self, model_output: str, request: ChatCompletionRequest + self, model_output: str, request: Union[ChatCompletionRequest, + ResponsesRequest] ) -> tuple[Optional[str], Optional[str]]: """ Extract reasoning content from the model output. + + Qwen3 has stricter requirements - it needs both start and end tokens + to be present, unlike other models that work with just the end token. For text abcxyz: - 'abc' goes to reasoning_content @@ -129,23 +49,24 @@ class Qwen3ReasoningParser(ReasoningParser): tuple[Optional[str], Optional[str]]: reasoning content and content """ - # Check if the model output contains the and tokens. - if (self.think_start_token not in model_output - or self.think_end_token not in model_output): + # Check if the model output contains both and tokens. + if (self.start_token not in model_output + or self.end_token not in model_output): return None, model_output + # Check if the is present in the model output, remove it # if it is present. - model_output_parts = model_output.partition(self.think_start_token) + model_output_parts = model_output.partition(self.start_token) model_output = model_output_parts[2] if model_output_parts[ 1] else model_output_parts[0] + # Check if the model output contains the tokens. # If the end token is not found, return the model output as is. - if self.think_end_token not in model_output: + if self.end_token not in model_output: return None, model_output # Extract reasoning content from the model output. - reasoning_content, _, content = model_output.partition( - self.think_end_token) + reasoning_content, _, content = model_output.partition(self.end_token) final_content = content or None return reasoning_content, final_content diff --git a/vllm/reasoning/seedoss_reasoning_parser.py b/vllm/reasoning/seedoss_reasoning_parser.py new file mode 100644 index 000000000000..5f4bbbf1557e --- /dev/null +++ b/vllm/reasoning/seedoss_reasoning_parser.py @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager +from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser + + +@ReasoningParserManager.register_module("seed_oss") +class SeedOSSReasoningParser(BaseThinkingReasoningParser): + """ + Reasoning parser for SeedOSS model. + + The SeedOSS model uses ... tokens to + denote reasoning content text. This parser extracts + the reasoning content from the model output. + Similar to DeepSeek R1, it supports cases + where the model doesn't generate the start token. + """ + + @property + def start_token(self) -> str: + """The token that starts reasoning content.""" + return "" + + @property + def end_token(self) -> str: + """The token that ends reasoning content.""" + return ""