[Model] Support SeedOss Reason Parser (#24263)

Signed-off-by: Yan Lu <luyan@nvidia.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
0xNullPath 2025-09-24 08:15:51 +08:00 committed by GitHub
parent c8bde93367
commit be0bb568c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 887 additions and 246 deletions

View File

@ -0,0 +1,392 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from transformers import AutoTokenizer
from tests.reasoning.utils import run_reasoning_extraction
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser
# Create a concrete test implementation of BaseThinkingReasoningParser
class TestThinkingReasoningParser(BaseThinkingReasoningParser):
"""Test implementation of BaseThinkingReasoningParser."""
@property
def start_token(self) -> str:
return "<test:think>"
@property
def end_token(self) -> str:
return "</test:think>"
class TestThinkingReasoningParserAlt(BaseThinkingReasoningParser):
"""Alternative test implementation with different tokens."""
@property
def start_token(self) -> str:
return "<alt:start>"
@property
def end_token(self) -> str:
return "<alt:end>"
# Use a test model
REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
@pytest.fixture(scope="module")
def test_tokenizer():
tokenizer = AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)
# Add custom test tokens
test_tokens = ["<test:think>", "</test:think>", "<alt:start>", "<alt:end>"]
existing_tokens = set(tokenizer.get_vocab().keys())
new_tokens = [
token for token in test_tokens if token not in existing_tokens
]
if new_tokens:
tokenizer.add_tokens(new_tokens)
return tokenizer
class TestBaseThinkingReasoningParserInit:
"""
Test initialization and basic properties of
BaseThinkingReasoningParser.
"""
def test_successful_initialization(self, test_tokenizer):
"""Test successful initialization with valid tokens."""
parser = TestThinkingReasoningParser(test_tokenizer)
assert parser.start_token == "<test:think>"
assert parser.end_token == "</test:think>"
assert parser.start_token_id is not None
assert parser.end_token_id is not None
def test_initialization_with_missing_tokenizer(self):
"""Test that initialization fails without tokenizer."""
with pytest.raises(ValueError, match="model tokenizer must be passed"):
TestThinkingReasoningParser(None)
def test_initialization_with_missing_tokens(self, test_tokenizer):
"""Test that initialization fails when tokens are not in vocabulary."""
# Create a parser with tokens not in vocabulary
class MissingTokenParser(BaseThinkingReasoningParser):
@property
def start_token(self) -> str:
return "<missing:start>"
@property
def end_token(self) -> str:
return "<missing:end>"
with pytest.raises(RuntimeError,
match="could not locate think start/end tokens"):
MissingTokenParser(test_tokenizer)
def test_initialization_with_empty_tokens(self, test_tokenizer):
"""Test that initialization fails with empty token strings."""
class EmptyTokenParser(BaseThinkingReasoningParser):
@property
def start_token(self) -> str:
return ""
@property
def end_token(self) -> str:
return ""
with pytest.raises(ValueError,
match="start_token and end_token must be defined"):
EmptyTokenParser(test_tokenizer)
class TestBaseThinkingReasoningParserMethods:
"""Test the methods of BaseThinkingReasoningParser."""
def test_is_reasoning_end(self, test_tokenizer):
"""Test the is_reasoning_end method."""
parser = TestThinkingReasoningParser(test_tokenizer)
end_token_id = parser.end_token_id
# Test with end token present
assert parser.is_reasoning_end([1, 2, end_token_id, 4]) is True
# Test without end token
assert parser.is_reasoning_end([1, 2, 3, 4]) is False
# Test with empty list
assert parser.is_reasoning_end([]) is False
def test_extract_content_ids(self, test_tokenizer):
"""Test the extract_content_ids method."""
parser = TestThinkingReasoningParser(test_tokenizer)
end_token_id = parser.end_token_id
# Test with end token in the middle
input_ids = [1, 2, end_token_id, 4, 5]
content_ids = parser.extract_content_ids(input_ids)
assert content_ids == [4, 5]
# Test with end token at the end
input_ids = [1, 2, 3, end_token_id]
content_ids = parser.extract_content_ids(input_ids)
assert content_ids == []
# Test without end token
input_ids = [1, 2, 3, 4]
content_ids = parser.extract_content_ids(input_ids)
assert content_ids == []
# Test with end token as last element (should not extract)
input_ids = [1, 2, 3, end_token_id]
content_ids = parser.extract_content_ids(input_ids)
assert content_ids == []
class TestBaseThinkingReasoningParserExtraction:
"""Test reasoning content extraction methods."""
def test_extract_reasoning_content_with_both_tokens(self, test_tokenizer):
"""Test extraction when both start and end tokens are present."""
parser = TestThinkingReasoningParser(test_tokenizer)
request = ChatCompletionRequest(messages=[], model="test-model")
model_output = ("<test:think>This is reasoning"
"</test:think>This is content")
reasoning, content = parser.extract_reasoning_content(
model_output, request)
assert reasoning == "This is reasoning"
assert content == "This is content"
def test_extract_reasoning_content_only_end_token(self, test_tokenizer):
"""Test extraction when only end token is present."""
parser = TestThinkingReasoningParser(test_tokenizer)
request = ChatCompletionRequest(messages=[], model="test-model")
model_output = ("This is reasoning</test:think>This is content")
reasoning, content = parser.extract_reasoning_content(
model_output, request)
assert reasoning == "This is reasoning"
assert content == "This is content"
def test_extract_reasoning_content_no_end_token(self, test_tokenizer):
"""Test extraction when no end token is present."""
parser = TestThinkingReasoningParser(test_tokenizer)
request = ChatCompletionRequest(messages=[], model="test-model")
model_output = "This is just content"
reasoning, content = parser.extract_reasoning_content(
model_output, request)
assert reasoning == "This is just content"
assert content is None
def test_extract_reasoning_content_empty_output(self, test_tokenizer):
"""Test extraction with empty output."""
parser = TestThinkingReasoningParser(test_tokenizer)
request = ChatCompletionRequest(messages=[], model="test-model")
model_output = ""
reasoning, content = parser.extract_reasoning_content(
model_output, request)
assert reasoning == ""
assert content is None
def test_extract_reasoning_content_only_tokens(self, test_tokenizer):
"""Test extraction with only tokens and no content."""
parser = TestThinkingReasoningParser(test_tokenizer)
request = ChatCompletionRequest(messages=[], model="test-model")
model_output = ("<test:think></test:think>")
reasoning, content = parser.extract_reasoning_content(
model_output, request)
assert reasoning == ""
assert content is None
class TestBaseThinkingReasoningParserStreaming:
"""Test streaming functionality of BaseThinkingReasoningParser."""
@pytest.mark.parametrize("streaming", [True, False])
def test_simple_reasoning_extraction(self, test_tokenizer, streaming):
"""
Test basic reasoning extraction in both
streaming and non-streaming modes.
"""
parser = TestThinkingReasoningParser(test_tokenizer)
model_output = [
"<test:think>", "Some ", "reasoning ", "content", "</test:think>",
"Final ", "answer"
]
reasoning, content = run_reasoning_extraction(parser,
model_output,
streaming=streaming)
assert reasoning == "Some reasoning content"
assert content == "Final answer"
def test_streaming_with_incremental_deltas(self, test_tokenizer):
"""Test streaming processing with small incremental deltas."""
parser = TestThinkingReasoningParser(test_tokenizer)
deltas = [
"<test:think>",
"Some ",
"reasoning ",
"content",
"</test:think>",
"Final ",
"answer",
]
reasoning, content = run_reasoning_extraction(parser,
deltas,
streaming=True)
assert reasoning == "Some reasoning content"
assert content == "Final answer"
def test_streaming_with_start_token(self, test_tokenizer):
"""Test streaming with start token included."""
parser = TestThinkingReasoningParser(test_tokenizer)
deltas = [
"<test:think>",
"Some ",
"reasoning",
"</test:think>",
"Answer",
]
reasoning, content = run_reasoning_extraction(parser,
deltas,
streaming=True)
assert reasoning == "Some reasoning"
assert content == "Answer"
def test_streaming_no_end_token(self, test_tokenizer):
"""Test streaming when no end token is encountered."""
parser = TestThinkingReasoningParser(test_tokenizer)
deltas = [
"<test:think>",
"Some ",
"reasoning ",
"without ",
"end",
]
reasoning, content = run_reasoning_extraction(parser,
deltas,
streaming=True)
assert reasoning == "Some reasoning without end"
assert content is None
def test_streaming_only_end_token(self, test_tokenizer):
"""Test streaming when only end token appears."""
parser = TestThinkingReasoningParser(test_tokenizer)
deltas = [
"<test:think>",
"Reasoning ",
"content",
"</test:think>",
"Final",
]
reasoning, content = run_reasoning_extraction(parser,
deltas,
streaming=True)
assert reasoning == "Reasoning content"
assert content == "Final"
class TestBaseThinkingReasoningParserMultipleImplementations:
"""
Test that multiple implementations of
BaseThinkingReasoningParser work correctly.
"""
def test_different_token_implementations(self, test_tokenizer):
"""
Test that different implementations
with different tokens work independently.
"""
parser1 = TestThinkingReasoningParser(test_tokenizer)
parser2 = TestThinkingReasoningParserAlt(test_tokenizer)
# Test parser1
model_output1 = ("Reasoning1</test:think>Content1")
reasoning1, content1 = run_reasoning_extraction(
parser1, [model_output1])
assert reasoning1 == "Reasoning1"
assert content1 == "Content1"
# Test parser2
model_output2 = "Reasoning2<alt:end>Content2"
reasoning2, content2 = run_reasoning_extraction(
parser2, [model_output2])
assert reasoning2 == "Reasoning2"
assert content2 == "Content2"
# Verify tokens are different
assert parser1.start_token != parser2.start_token
assert parser1.end_token != parser2.end_token
assert parser1.start_token_id != parser2.start_token_id
assert parser1.end_token_id != parser2.end_token_id
class TestBaseThinkingReasoningParserEdgeCases:
"""Test edge cases and error conditions."""
def test_multiple_end_tokens(self, test_tokenizer):
"""Test behavior with multiple end tokens."""
parser = TestThinkingReasoningParser(test_tokenizer)
model_output = ("First</test:think>Middle</test:think>Last")
reasoning, content = run_reasoning_extraction(parser, [model_output])
# Should stop at first end token
assert reasoning == "First"
assert content == "Middle</test:think>Last"
def test_nested_tokens(self, test_tokenizer):
"""Test behavior with nested-like token patterns."""
parser = TestThinkingReasoningParser(test_tokenizer)
model_output = ("<test:think>Outer"
"<test:think>Inner</test:think>Content")
reasoning, content = run_reasoning_extraction(parser, [model_output])
# Should process normally, start from first start token
assert reasoning == "Outer<test:think>Inner"
assert content == "Content"
def test_malformed_tokens(self, test_tokenizer):
"""Test behavior with malformed token-like strings."""
parser = TestThinkingReasoningParser(test_tokenizer)
model_output = ("<test:thinking>Not a real token"
"</test:thinking>Content")
reasoning, content = run_reasoning_extraction(parser, [model_output])
# Should treat as regular content since tokens don't match exactly
assert reasoning == ("<test:thinking>Not a real token"
"</test:thinking>Content")
assert content is None

View File

@ -0,0 +1,237 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, cast
import pytest
from transformers import AutoTokenizer
from tests.reasoning.utils import run_reasoning_extraction
from vllm.reasoning import ReasoningParser, ReasoningParserManager
parser_name = "seed_oss"
start_token = "<seed:think>"
end_token = "</seed:think>"
# Use a test model that contains our custom tokens
REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
@pytest.fixture(scope="module")
def seedoss_tokenizer():
tokenizer = AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)
# Add custom SeedOSS tokens if they don't exist
if start_token not in tokenizer.get_vocab():
tokenizer.add_tokens([start_token, end_token])
return tokenizer
SIMPLE_REASONING: dict[str, Any] = {
"output": "This is a reasoning section</seed:think>This is the rest",
"reasoning_content": "This is a reasoning section",
"content": "This is the rest",
"is_reasoning_end": True,
}
COMPLETE_REASONING: dict[str, Any] = {
"output": "This is a reasoning section</seed:think>",
"reasoning_content": "This is a reasoning section",
"content": None,
"is_reasoning_end": True,
}
NO_CONTENT: dict[str, Any] = {
"output": "This is content",
"reasoning_content": "This is content",
"content": None,
"is_reasoning_end": False,
}
NO_REASONING_STREAMING: dict[str, Any] = {
"output": "This is a reasoning section",
"reasoning_content": "This is a reasoning section",
"content": None,
"is_reasoning_end": False,
}
MULTIPLE_LINES: dict[str, Any] = {
"output": "This\nThat</seed:think>This is the rest\nThat",
"reasoning_content": "This\nThat",
"content": "This is the rest\nThat",
"is_reasoning_end": True,
}
WITH_START_TOKEN: dict[str, Any] = {
"output": ("<seed:think>This is a reasoning section"
"</seed:think>This is the rest"),
"reasoning_content":
"This is a reasoning section",
"content":
"This is the rest",
"is_reasoning_end":
True,
}
ONLY_END_TOKEN: dict[str, Any] = {
"output": "Some reasoning</seed:think>This is the rest",
"reasoning_content": "Some reasoning",
"content": "This is the rest",
"is_reasoning_end": True,
}
NO_TOKENS: dict[str, Any] = {
"output": "This is just content without any reasoning tokens",
"reasoning_content": "This is just content without any reasoning tokens",
"content": None,
"is_reasoning_end": False,
}
def test_seedoss_reasoning_parser_creation(seedoss_tokenizer):
"""Test that the SeedOSS reasoning parser can be created and registered."""
parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
parser = parser_cls(seedoss_tokenizer)
assert isinstance(parser, ReasoningParser)
assert parser.start_token == start_token
assert parser.end_token == end_token
@pytest.mark.parametrize("streaming", [True, False])
def test_simple_reasoning(seedoss_tokenizer, streaming):
"""Test basic reasoning extraction with both tokens."""
parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
parser = parser_cls(seedoss_tokenizer)
reasoning, content = run_reasoning_extraction(
parser, [cast(str, SIMPLE_REASONING["output"])], streaming=streaming)
assert reasoning == SIMPLE_REASONING["reasoning_content"]
assert content == SIMPLE_REASONING["content"]
@pytest.mark.parametrize("streaming", [True, False])
def test_complete_reasoning(seedoss_tokenizer, streaming):
"""Test reasoning extraction when there's no content after reasoning."""
parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
parser = parser_cls(seedoss_tokenizer)
reasoning, content = run_reasoning_extraction(
parser, [cast(str, COMPLETE_REASONING["output"])], streaming=streaming)
assert reasoning == COMPLETE_REASONING["reasoning_content"]
assert content == COMPLETE_REASONING["content"]
@pytest.mark.parametrize("streaming", [True, False])
def test_no_content(seedoss_tokenizer, streaming):
"""Test when there's no end token - everything is reasoning content."""
parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
parser = parser_cls(seedoss_tokenizer)
reasoning, content = run_reasoning_extraction(
parser, [cast(str, NO_CONTENT["output"])], streaming=streaming)
assert reasoning == NO_CONTENT["reasoning_content"]
assert content == NO_CONTENT["content"]
@pytest.mark.parametrize("streaming", [True, False])
def test_multiple_lines(seedoss_tokenizer, streaming):
"""Test reasoning extraction with multiline content."""
parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
parser = parser_cls(seedoss_tokenizer)
reasoning, content = run_reasoning_extraction(
parser, [cast(str, MULTIPLE_LINES["output"])], streaming=streaming)
assert reasoning == MULTIPLE_LINES["reasoning_content"]
assert content == MULTIPLE_LINES["content"]
@pytest.mark.parametrize("streaming", [True, False])
def test_with_start_token(seedoss_tokenizer, streaming):
"""Test reasoning extraction with both start and end tokens."""
parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
parser = parser_cls(seedoss_tokenizer)
reasoning, content = run_reasoning_extraction(
parser, [cast(str, WITH_START_TOKEN["output"])], streaming=streaming)
assert reasoning == WITH_START_TOKEN["reasoning_content"]
assert content == WITH_START_TOKEN["content"]
@pytest.mark.parametrize("streaming", [True, False])
def test_only_end_token(seedoss_tokenizer, streaming):
"""
Test reasoning extraction with only end token
(SeedOSS typical behavior).
"""
parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
parser = parser_cls(seedoss_tokenizer)
reasoning, content = run_reasoning_extraction(
parser, [cast(str, ONLY_END_TOKEN["output"])], streaming=streaming)
assert reasoning == ONLY_END_TOKEN["reasoning_content"]
assert content == ONLY_END_TOKEN["content"]
@pytest.mark.parametrize("streaming", [True, False])
def test_no_tokens(seedoss_tokenizer, streaming):
"""Test when there are no reasoning tokens at all."""
parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
parser = parser_cls(seedoss_tokenizer)
reasoning, content = run_reasoning_extraction(
parser, [cast(str, NO_TOKENS["output"])], streaming=streaming)
assert reasoning == NO_TOKENS["reasoning_content"]
assert content == NO_TOKENS["content"]
def test_is_reasoning_end(seedoss_tokenizer):
"""Test the is_reasoning_end method."""
parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
parser = parser_cls(seedoss_tokenizer)
# Test with end token present
end_token_id = parser.end_token_id
assert parser.is_reasoning_end([1, 2, end_token_id, 4]) is True
# Test without end token
assert parser.is_reasoning_end([1, 2, 3, 4]) is False
def test_extract_content_ids(seedoss_tokenizer):
"""Test the extract_content_ids method."""
parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
parser = parser_cls(seedoss_tokenizer)
end_token_id = parser.end_token_id
# Test with end token in the middle
input_ids = [1, 2, end_token_id, 4, 5]
content_ids = parser.extract_content_ids(input_ids)
assert content_ids == [4, 5]
# Test with end token at the end
input_ids = [1, 2, 3, end_token_id]
content_ids = parser.extract_content_ids(input_ids)
assert content_ids == []
# Test without end token
input_ids = [1, 2, 3, 4]
content_ids = parser.extract_content_ids(input_ids)
assert content_ids == []
def test_streaming_delta_processing(seedoss_tokenizer):
"""Test streaming processing with small deltas."""
parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
parser = parser_cls(seedoss_tokenizer)
# Test streaming with incremental tokens
deltas = [
"Some ", "reasoning ", "content", "</seed:think>", "Final ", "answer"
]
reasoning, content = run_reasoning_extraction(parser,
deltas,
streaming=True)
assert reasoning == "Some reasoning content"
assert content == "Final answer"

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .abs_reasoning_parsers import ReasoningParser, ReasoningParserManager
from .basic_parsers import BaseThinkingReasoningParser
from .deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
from .glm4_moe_reasoning_parser import Glm4MoeModelReasoningParser
from .gptoss_reasoning_parser import GptOssReasoningParser
@ -9,10 +10,12 @@ from .granite_reasoning_parser import GraniteReasoningParser
from .hunyuan_a13b_reasoning_parser import HunyuanA13BReasoningParser
from .mistral_reasoning_parser import MistralReasoningParser
from .qwen3_reasoning_parser import Qwen3ReasoningParser
from .seedoss_reasoning_parser import SeedOSSReasoningParser
from .step3_reasoning_parser import Step3ReasoningParser
__all__ = [
"ReasoningParser",
"BaseThinkingReasoningParser",
"ReasoningParserManager",
"DeepSeekR1ReasoningParser",
"GraniteReasoningParser",
@ -22,4 +25,5 @@ __all__ = [
"MistralReasoningParser",
"Step3ReasoningParser",
"GptOssReasoningParser",
"SeedOSSReasoningParser",
]

View File

@ -7,7 +7,7 @@ import os
from abc import abstractmethod
from collections.abc import Sequence
from functools import cached_property
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Union
from vllm.logger import init_logger
from vllm.utils import import_from_path, is_list_of
@ -77,7 +77,7 @@ class ReasoningParser:
self,
model_output: str,
request: Union[ChatCompletionRequest, ResponsesRequest],
) -> tuple[Optional[str], Optional[str]]:
) -> tuple[str | None, str | None]:
"""
Extract reasoning content from a complete model-generated string.
@ -135,7 +135,7 @@ class ReasoningParserManager:
def _register_module(
cls,
module: type,
module_name: Optional[Union[str, list[str]]] = None,
module_name: Union[str, list[str]] | None = None,
force: bool = True,
) -> None:
if not issubclass(module, ReasoningParser):
@ -155,7 +155,7 @@ class ReasoningParserManager:
@classmethod
def register_module(
cls,
name: Optional[Union[str, list[str]]] = None,
name: Union[str, list[str]] | None = None,
force: bool = True,
module: Union[type, None] = None,
) -> Union[type, Callable]:

View File

@ -0,0 +1,156 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import abstractmethod
from collections.abc import Sequence
from typing import Optional, Union
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage, ResponsesRequest)
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
from vllm.transformers_utils.tokenizer import AnyTokenizer
class BaseThinkingReasoningParser(ReasoningParser):
"""
Base class for reasoning parsers that use thinking tokens.
This class provides common functionality for parsers that use start and end
tokens to delimit reasoning content (
e.g., <think>...</think>, <seed:think>...</seed:think>).
Subclasses must implement the start and end tokens via abstract
properties.
"""
@property
@abstractmethod
def start_token(self) -> str:
"""The token that starts reasoning content."""
raise NotImplementedError
@property
@abstractmethod
def end_token(self) -> str:
"""The token that ends reasoning content."""
raise NotImplementedError
def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ReasoningParser "
"constructor during construction.")
if not self.start_token or not self.end_token:
raise ValueError(
"start_token and end_token must be defined in subclasses")
self.start_token_id = self.vocab.get(self.start_token)
self.end_token_id = self.vocab.get(self.end_token)
if self.start_token_id is None or self.end_token_id is None:
raise RuntimeError(
f"{self.__class__.__name__} reasoning parser could not locate "
"think start/end tokens in the tokenizer!")
def is_reasoning_end(self, input_ids: list[int]) -> bool:
return self.end_token_id in input_ids
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
"""
Extract the content after the end tokens
"""
if self.end_token_id not in input_ids[:-1]:
return []
else:
return input_ids[input_ids.index(self.end_token_id) + 1:]
def extract_reasoning_content_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
) -> Union[DeltaMessage, None]:
"""
Extract reasoning content from a delta message.
Handles streaming output where previous + delta = current.
Uses token IDs for faster processing.
"""
# Skip single special tokens
if len(delta_token_ids) == 1 and (delta_token_ids[0] in [
self.start_token_id, self.end_token_id
]):
return None
# Check if start token is present in previous or delta.
# Keep compatibility with models that don't generate start tokens.
if self.start_token_id in previous_token_ids:
if self.end_token_id in delta_token_ids:
# start token in previous, end token in delta,
# extract reasoning content
end_index = delta_text.find(self.end_token)
reasoning_content = delta_text[:end_index]
content = delta_text[end_index + len(self.end_token):]
return DeltaMessage(
reasoning_content=reasoning_content,
content=content if content else None,
)
elif self.end_token_id in previous_token_ids:
# start token in previous, end token in previous,
# reasoning content continues
return DeltaMessage(content=delta_text)
else:
# start token in previous, no end token in previous or delta,
# reasoning content continues
return DeltaMessage(reasoning_content=delta_text)
elif self.start_token_id in delta_token_ids:
if self.end_token_id in delta_token_ids:
# start token in delta, end token in delta,
# extract reasoning content
start_index = delta_text.find(self.start_token)
end_index = delta_text.find(self.end_token)
reasoning_content = delta_text[start_index +
len(self.start_token):end_index]
content = delta_text[end_index + len(self.end_token):]
return DeltaMessage(
reasoning_content=reasoning_content,
content=content if content else None,
)
else:
# start token in delta, no end token in delta,
# reasoning content continues
return DeltaMessage(reasoning_content=delta_text)
else:
# not find thinking start token
return DeltaMessage(content=delta_text)
def extract_reasoning_content(
self, model_output: str, request: Union[ChatCompletionRequest,
ResponsesRequest]
) -> tuple[Optional[str], Optional[str]]:
"""
Extract reasoning content from the model output.
This is the base implementation that works for most models.
Subclasses can override this method for specific behavior.
"""
# Check if the start token is present in the model output, remove it
# if it is present.
model_output_parts = model_output.partition(self.start_token)
model_output = model_output_parts[2] if model_output_parts[
1] else model_output_parts[0]
# For models that may not generate start token,
# assume the reasoning content is always at the start.
if self.end_token not in model_output:
return model_output, None
else:
reasoning_content, _, content = model_output.partition(
self.end_token)
# If generation stops right after end-of-think, return null content
final_content = content or None
return reasoning_content, final_content

View File

@ -2,20 +2,15 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from typing import Optional, Union
from typing import Union
from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage)
from vllm.logger import init_logger
from vllm.reasoning import ReasoningParser, ReasoningParserManager
logger = init_logger(__name__)
from vllm.entrypoints.openai.protocol import DeltaMessage
from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager
from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser
@ReasoningParserManager.register_module("deepseek_r1")
class DeepSeekR1ReasoningParser(ReasoningParser):
class DeepSeekR1ReasoningParser(BaseThinkingReasoningParser):
"""
Reasoning parser for DeepSeek R1 model.
@ -23,38 +18,15 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
text. This parser extracts the reasoning content from the model output.
"""
start_token_id: int
end_token_id: int
@property
def start_token(self) -> str:
"""The token that starts reasoning content."""
return "<think>"
start_token: str = "<think>"
end_token: str = "</think>"
def __init__(self, tokenizer: PreTrainedTokenizerBase):
super().__init__(tokenizer)
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ReasoningParser "
"constructor during construction.")
self.start_token_id = self.vocab.get(self.start_token)
self.end_token_id = self.vocab.get(self.end_token)
if self.start_token_id is None or self.end_token_id is None:
raise RuntimeError(
"DeepSeek R1 reasoning parser could not locate think start/end "
"tokens in the tokenizer!")
def is_reasoning_end(self, input_ids: list[int]) -> bool:
return self.end_token_id in input_ids
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
"""
Extract the content after the end tokens
"""
if self.end_token_id not in input_ids[:-1]:
return []
else:
return input_ids[input_ids.index(self.end_token_id) + 1:]
@property
def end_token(self) -> str:
"""The token that ends reasoning content."""
return "</think>"
def extract_reasoning_content_streaming(
self,
@ -65,63 +37,18 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
) -> Union[DeltaMessage, None]:
"""
Extract reasoning content from a delta message.
Handles streaming output where previous + delta = current.
Uses token IDs for faster processing.
For text <think>abc</think>xyz:
- 'abc' goes to reasoning_content
- 'xyz' goes to content
"""
# Skip single special tokens
if len(delta_token_ids) == 1 and (delta_token_ids[0] in [
self.start_token_id, self.end_token_id
]):
return None
# Check if <think> is present in previous or delta.
# Keep compatibility with models that don't generate <think> tokens.
if self.start_token_id in previous_token_ids:
ret = super().extract_reasoning_content_streaming(
previous_text,
current_text,
delta_text,
previous_token_ids,
current_token_ids,
delta_token_ids,
)
if (ret is not None and self.start_token_id not in previous_token_ids
and self.start_token_id not in delta_token_ids):
if self.end_token_id in delta_token_ids:
# <think> in previous, </think> in delta,
# extract reasoning content
end_index = delta_text.find(self.end_token)
reasoning_content = delta_text[:end_index]
content = delta_text[end_index + len(self.end_token):]
return DeltaMessage(
reasoning_content=reasoning_content,
content=content if content else None,
)
elif self.end_token_id in previous_token_ids:
# <think> in previous, </think> in previous,
# reasoning content continues
return DeltaMessage(content=delta_text)
else:
# <think> in previous, no </think> in previous or delta,
# reasoning content continues
return DeltaMessage(reasoning_content=delta_text)
elif self.start_token_id in delta_token_ids:
if self.end_token_id in delta_token_ids:
# <think> in delta, </think> in delta, extract reasoning content
start_index = delta_text.find(self.start_token)
end_index = delta_text.find(self.end_token)
reasoning_content = delta_text[start_index +
len(self.start_token):end_index]
content = delta_text[end_index + len(self.end_token):]
return DeltaMessage(
reasoning_content=reasoning_content,
content=content if content else None,
)
else:
# <think> in delta, no </think> in delta,
# reasoning content continues
return DeltaMessage(reasoning_content=delta_text)
else:
# No <think> in previous or delta, also need to check for </think>.
# Because the model may have generated </think> without <think>
# Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f
if self.end_token_id in delta_token_ids:
# </think> in delta with more tokens,
# end token in delta with more tokens,
# extract reasoning content and content
end_index = delta_text.find(self.end_token)
reasoning_content = delta_text[:end_index]
@ -131,43 +58,10 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
content=content if content else None,
)
elif self.end_token_id in previous_token_ids:
# </think> in previous, thinking content ends
# end token in previous, thinking content ends
return DeltaMessage(content=delta_text)
else:
# no </think> in previous or delta, reasoning content continues
# no end token in previous or delta, reasoning content continues
return DeltaMessage(reasoning_content=delta_text)
def extract_reasoning_content(
self, model_output: str, request: ChatCompletionRequest
) -> tuple[Optional[str], Optional[str]]:
"""
Extract reasoning content from the model output.
For text <think>abc</think>xyz:
- 'abc' goes to reasoning_content
- 'xyz' goes to content
Returns:
tuple[Optional[str], Optional[str]]: reasoning content and content
"""
# Check if the start token is present in the model output, remove it
# if it is present.
model_output_parts = model_output.partition(self.start_token)
model_output = model_output_parts[2] if model_output_parts[
1] else model_output_parts[0]
# DeepSeek R1 doesn't generate <think> now.
# Thus we assume the reasoning content is always at the start.
# Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f
if self.end_token not in model_output:
return model_output, None
else:
reasoning_content, _, content = model_output.partition(
self.end_token)
# If the end token is not found, return the model output as is.
# It should not happen since we already checked for the presence
# of the end token.
# If generation stops right after end-of-think, return null content
final_content = content or None
return reasoning_content, final_content
return ret

View File

@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from functools import cached_property
from vllm.logger import init_logger
from vllm.reasoning import ReasoningParser, ReasoningParserManager
from vllm.reasoning.deepseek_r1_reasoning_parser import (
@ -31,11 +33,6 @@ class MistralReasoningParser(DeepSeekR1ReasoningParser):
"The model tokenizer must be passed to the ReasoningParser "
"constructor during construction.")
from mistral_common.tokens.tokenizers.base import SpecialTokens
self.start_token = SpecialTokens.begin_think
self.end_token = SpecialTokens.end_think
self.start_token_id = tokenizer.tokenizer.get_control_token(
self.start_token)
self.end_token_id = tokenizer.tokenizer.get_control_token(
@ -45,3 +42,15 @@ class MistralReasoningParser(DeepSeekR1ReasoningParser):
raise RuntimeError(
"Mistral reasoning parser could not locate think start/end "
"tokens in the tokenizer!")
@cached_property
def start_token(self) -> str:
"""The token that starts reasoning content."""
from mistral_common.tokens.tokenizers.base import SpecialTokens
return SpecialTokens.begin_think
@cached_property
def end_token(self) -> str:
"""The token that ends reasoning content."""
from mistral_common.tokens.tokenizers.base import SpecialTokens
return SpecialTokens.end_think

View File

@ -1,21 +1,16 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from typing import Optional, Union
from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage)
from vllm.logger import init_logger
from vllm.reasoning import ReasoningParser, ReasoningParserManager
logger = init_logger(__name__)
ResponsesRequest)
from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager
from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser
@ReasoningParserManager.register_module("qwen3")
class Qwen3ReasoningParser(ReasoningParser):
class Qwen3ReasoningParser(BaseThinkingReasoningParser):
"""
Reasoning parser for the Qwen3 model.
@ -26,100 +21,25 @@ class Qwen3ReasoningParser(ReasoningParser):
output.
"""
def __init__(self, tokenizer: PreTrainedTokenizerBase):
super().__init__(tokenizer)
self.think_start_token = "<think>"
self.think_end_token = "</think>"
@property
def start_token(self) -> str:
"""The token that starts reasoning content."""
return "<think>"
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ReasoningParser "
"constructor during construction.")
self.think_start_token_id = self.vocab.get(self.think_start_token)
self.think_end_token_id = self.vocab.get(self.think_end_token)
if (self.think_start_token_id is None
or self.think_end_token_id is None):
raise RuntimeError(
"Qwen3 reasoning parser could not locate think start/end "
"tokens in the tokenizer!")
def is_reasoning_end(self, input_ids: list[int]) -> bool:
return self.think_end_token_id in input_ids
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
"""
Extract the content after the end tokens
"""
if self.think_end_token_id not in input_ids[:-1]:
return []
else:
return input_ids[input_ids.index(self.think_end_token_id) + 1:]
def extract_reasoning_content_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
) -> Union[DeltaMessage, None]:
"""
Extract reasoning content from a delta message.
Handles streaming output where previous + delta = current.
Uses token IDs for faster processing.
For text <think>abc</think>xyz:
- 'abc' goes to reasoning_content
- 'xyz' goes to content
"""
# Skip single special tokens
if len(delta_token_ids) == 1 and (delta_token_ids[0] in [
self.think_start_token_id, self.think_end_token_id
]):
return None
if self.think_start_token_id in previous_token_ids:
if self.think_end_token_id in delta_token_ids:
# <think> in previous, </think> in delta,
# extract reasoning content
end_index = delta_text.find(self.think_end_token)
reasoning_content = delta_text[:end_index]
content = delta_text[end_index + len(self.think_end_token):]
return DeltaMessage(reasoning_content=reasoning_content,
content=content if content else None)
elif self.think_end_token_id in previous_token_ids:
# <think> in previous, </think> in previous,
# reasoning content continues
return DeltaMessage(content=delta_text)
else:
# <think> in previous, no </think> in previous or delta,
# reasoning content continues
return DeltaMessage(reasoning_content=delta_text)
elif self.think_start_token_id in delta_token_ids:
if self.think_end_token_id in delta_token_ids:
# <think> in delta, </think> in delta, extract reasoning content
start_index = delta_text.find(self.think_start_token)
end_index = delta_text.find(self.think_end_token)
reasoning_content = delta_text[start_index +
len(self.think_start_token
):end_index]
content = delta_text[end_index + len(self.think_end_token):]
return DeltaMessage(reasoning_content=reasoning_content,
content=content if content else None)
else:
# <think> in delta, no </think> in delta,
# reasoning content continues
return DeltaMessage(reasoning_content=delta_text)
else:
# thinking is disabled, just content
return DeltaMessage(content=delta_text)
@property
def end_token(self) -> str:
"""The token that ends reasoning content."""
return "</think>"
def extract_reasoning_content(
self, model_output: str, request: ChatCompletionRequest
self, model_output: str, request: Union[ChatCompletionRequest,
ResponsesRequest]
) -> tuple[Optional[str], Optional[str]]:
"""
Extract reasoning content from the model output.
Qwen3 has stricter requirements - it needs both start and end tokens
to be present, unlike other models that work with just the end token.
For text <think>abc</think>xyz:
- 'abc' goes to reasoning_content
@ -129,23 +49,24 @@ class Qwen3ReasoningParser(ReasoningParser):
tuple[Optional[str], Optional[str]]: reasoning content and content
"""
# Check if the model output contains the <think> and </think> tokens.
if (self.think_start_token not in model_output
or self.think_end_token not in model_output):
# Check if the model output contains both <think> and </think> tokens.
if (self.start_token not in model_output
or self.end_token not in model_output):
return None, model_output
# Check if the <think> is present in the model output, remove it
# if it is present.
model_output_parts = model_output.partition(self.think_start_token)
model_output_parts = model_output.partition(self.start_token)
model_output = model_output_parts[2] if model_output_parts[
1] else model_output_parts[0]
# Check if the model output contains the </think> tokens.
# If the end token is not found, return the model output as is.
if self.think_end_token not in model_output:
if self.end_token not in model_output:
return None, model_output
# Extract reasoning content from the model output.
reasoning_content, _, content = model_output.partition(
self.think_end_token)
reasoning_content, _, content = model_output.partition(self.end_token)
final_content = content or None
return reasoning_content, final_content

View File

@ -0,0 +1,28 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager
from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser
@ReasoningParserManager.register_module("seed_oss")
class SeedOSSReasoningParser(BaseThinkingReasoningParser):
"""
Reasoning parser for SeedOSS model.
The SeedOSS model uses <seed:think>...</seed:think> tokens to
denote reasoning content text. This parser extracts
the reasoning content from the model output.
Similar to DeepSeek R1, it supports cases
where the model doesn't generate the start token.
"""
@property
def start_token(self) -> str:
"""The token that starts reasoning content."""
return "<seed:think>"
@property
def end_token(self) -> str:
"""The token that ends reasoning content."""
return "</seed:think>"