diff --git a/tests/entrypoints/openai/reasoning_parsers/__init__.py b/tests/reasoning/__init__.py
similarity index 100%
rename from tests/entrypoints/openai/reasoning_parsers/__init__.py
rename to tests/reasoning/__init__.py
diff --git a/tests/entrypoints/openai/reasoning_parsers/test_deepseekr1_reasoning_parser.py b/tests/reasoning/test_deepseekr1_reasoning_parser.py
similarity index 75%
rename from tests/entrypoints/openai/reasoning_parsers/test_deepseekr1_reasoning_parser.py
rename to tests/reasoning/test_deepseekr1_reasoning_parser.py
index 5ce5d9280f3ef..7b6af183a86ad 100644
--- a/tests/entrypoints/openai/reasoning_parsers/test_deepseekr1_reasoning_parser.py
+++ b/tests/reasoning/test_deepseekr1_reasoning_parser.py
@@ -3,74 +3,92 @@
import pytest
from transformers import AutoTokenizer
-from tests.entrypoints.openai.reasoning_parsers.utils import (
- run_reasoning_extraction)
-from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser,
- ReasoningParserManager)
+from tests.reasoning.utils import run_reasoning_extraction
+from vllm.reasoning import ReasoningParser, ReasoningParserManager
parser_name = "deepseek_r1"
start_token = ""
end_token = ""
+REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
+
+
+@pytest.fixture(scope="module")
+def deepseek_r1_qwen_tokenizer():
+ return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)
+
+
SIMPLE_REASONING = {
"output": "This is a reasoning sectionThis is the rest",
"reasoning_content": "This is a reasoning section",
"content": "This is the rest",
+ "is_reasoning_end": True,
}
COMPLETE_REASONING = {
"output": "This is a reasoning section",
"reasoning_content": "This is a reasoning section",
"content": None,
+ "is_reasoning_end": True,
}
NO_CONTENT = {
"output": "This is content",
"reasoning_content": "This is content",
"content": None,
+ "is_reasoning_end": False,
}
NO_REASONING_STREAMING = {
"output": "This is a reasoning section",
"reasoning_content": "This is a reasoning section",
"content": None,
+ "is_reasoning_end": False,
}
MULTIPLE_LINES = {
"output": "This\nThatThis is the rest\nThat",
"reasoning_content": "This\nThat",
"content": "This is the rest\nThat",
+ "is_reasoning_end": True,
}
SHORTEST_REASONING_NO_STREAMING = {
"output": "This is the rest",
"reasoning_content": "",
"content": "This is the rest",
+ "is_reasoning_end": True,
}
SHORTEST_REASONING = {
"output": "This is the rest",
"reasoning_content": None,
"content": "This is the rest",
+ "is_reasoning_end": True,
}
REASONING_WITH_THINK = {
"output": "This is a reasoning sectionThis is the rest",
"reasoning_content": "This is a reasoning section",
"content": "This is the rest",
+ "is_reasoning_end": True,
}
COMPLETE_REASONING_WITH_THINK = {
"output": "This is a reasoning section",
"reasoning_content": "This is a reasoning section",
"content": None,
+ "is_reasoning_end": True,
}
MULTIPLE_LINES_WITH_THINK = {
"output": "This\nThatThis is the rest\nThat",
"reasoning_content": "This\nThat",
"content": "This is the rest\nThat",
+ "is_reasoning_end": True,
}
SHORTEST_REASONING_NO_STREAMING_WITH_THINK = {
"output": "This is the rest",
"reasoning_content": "",
"content": "This is the rest",
+ "is_reasoning_end": True,
}
SHORTEST_REASONING_WITH_THINK = {
"output": "This is the rest",
"reasoning_content": None,
"content": "This is the rest",
+ "is_reasoning_end": True,
}
TEST_CASES = [
@@ -166,23 +184,21 @@ TEST_CASES = [
),
]
-# Global tokenizer initialization to avoid repeated loading
-tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
-tokenizer.add_tokens([start_token, end_token])
-
@pytest.mark.parametrize("streaming, param_dict", TEST_CASES)
def test_reasoning(
streaming: bool,
param_dict: dict,
+ deepseek_r1_qwen_tokenizer,
):
- output = tokenizer.tokenize(param_dict["output"])
+ output = deepseek_r1_qwen_tokenizer.tokenize(param_dict["output"])
# decode everything to tokens
output_tokens: list[str] = [
- tokenizer.convert_tokens_to_string([token]) for token in output
+ deepseek_r1_qwen_tokenizer.convert_tokens_to_string([token])
+ for token in output
]
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(
- parser_name)(tokenizer)
+ parser_name)(deepseek_r1_qwen_tokenizer)
reasoning, content = run_reasoning_extraction(parser,
output_tokens,
@@ -190,3 +206,17 @@ def test_reasoning(
assert reasoning == param_dict["reasoning_content"]
assert content == param_dict["content"]
+
+ # Test is_reasoning_end
+ output_ids = deepseek_r1_qwen_tokenizer.convert_tokens_to_ids(output)
+ is_reasoning_end = parser.is_reasoning_end(output_ids)
+ assert is_reasoning_end == param_dict["is_reasoning_end"]
+
+ # Test extract_content
+ if param_dict["content"] is not None:
+ content = parser.extract_content_ids(output_ids)
+ assert content == deepseek_r1_qwen_tokenizer.convert_tokens_to_ids(
+ deepseek_r1_qwen_tokenizer.tokenize(param_dict["content"]))
+ else:
+ content = parser.extract_content_ids(output)
+ assert content == []
diff --git a/tests/entrypoints/openai/reasoning_parsers/test_granite_reasoning_parser.py b/tests/reasoning/test_granite_reasoning_parser.py
similarity index 97%
rename from tests/entrypoints/openai/reasoning_parsers/test_granite_reasoning_parser.py
rename to tests/reasoning/test_granite_reasoning_parser.py
index 84ac6600498b2..48fb8c2f8d1b9 100644
--- a/tests/entrypoints/openai/reasoning_parsers/test_granite_reasoning_parser.py
+++ b/tests/reasoning/test_granite_reasoning_parser.py
@@ -2,10 +2,8 @@
import pytest
from transformers import AutoTokenizer
-from tests.entrypoints.openai.reasoning_parsers.utils import (
- DeltaMessage, run_reasoning_extraction)
-from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser,
- ReasoningParserManager)
+from tests.reasoning.utils import DeltaMessage, run_reasoning_extraction
+from vllm.reasoning import ReasoningParser, ReasoningParserManager
parser_name = "granite"
START_REASONING = "Here is my thought process:"
diff --git a/tests/entrypoints/openai/reasoning_parsers/utils.py b/tests/reasoning/utils.py
similarity index 97%
rename from tests/entrypoints/openai/reasoning_parsers/utils.py
rename to tests/reasoning/utils.py
index 01e43130bc6e7..0f894ed800c6c 100644
--- a/tests/entrypoints/openai/reasoning_parsers/utils.py
+++ b/tests/reasoning/utils.py
@@ -4,7 +4,7 @@ from typing import Optional, Union
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage)
-from vllm.entrypoints.openai.reasoning_parsers import ReasoningParser
+from vllm.reasoning import ReasoningParser
class StreamingReasoningReconstructor:
diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py
index d049f773caccd..a416fa8aa08e3 100644
--- a/vllm/engine/arg_utils.py
+++ b/vllm/engine/arg_utils.py
@@ -23,6 +23,7 @@ from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.plugins import load_general_plugins
+from vllm.reasoning import ReasoningParserManager
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
from vllm.transformers_utils.utils import check_gguf_file
from vllm.usage.usage_lib import UsageContext
@@ -1119,7 +1120,7 @@ class EngineArgs:
parser.add_argument(
"--reasoning-parser",
type=str,
- choices=["deepseek_r1", "granite"],
+ choices=list(ReasoningParserManager.reasoning_parsers),
default=None,
help=
"Select the reasoning parser depending on the model that you're "
diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py
index 4856c3568319b..5682b3dabe2e8 100644
--- a/vllm/engine/llm_engine.py
+++ b/vllm/engine/llm_engine.py
@@ -2080,8 +2080,9 @@ class LLMEngine:
guided_decoding.backend = guided_decoding.backend or \
self.decoding_config.guided_decoding_backend
- logger.debug("Reasoning backend: %s",
- self.decoding_config.reasoning_backend)
+ if self.decoding_config.reasoning_backend is not None:
+ logger.debug("Building with reasoning backend %s",
+ self.decoding_config.reasoning_backend)
processor = get_local_guided_decoding_logits_processor(
guided_params=guided_decoding,
diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py
index 1e735da641df9..6c1f60fa6a3b4 100644
--- a/vllm/entrypoints/openai/api_server.py
+++ b/vllm/entrypoints/openai/api_server.py
@@ -68,7 +68,6 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
TranscriptionRequest,
TranscriptionResponse,
UnloadLoRAAdapterRequest)
-from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager
# yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
@@ -85,6 +84,7 @@ from vllm.entrypoints.openai.serving_transcription import (
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.entrypoints.utils import load_aware_call, with_cancellation
from vllm.logger import init_logger
+from vllm.reasoning import ReasoningParserManager
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.transformers_utils.tokenizer import MistralTokenizer
diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py
index 3102db4050f5b..eda4722836bdb 100644
--- a/vllm/entrypoints/openai/serving_chat.py
+++ b/vllm/entrypoints/openai/serving_chat.py
@@ -23,8 +23,6 @@ from vllm.entrypoints.openai.protocol import (
ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
DeltaToolCall, ErrorResponse, FunctionCall, PromptTokenUsageInfo,
RequestResponseMetadata, ToolCall, UsageInfo)
-from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser,
- ReasoningParserManager)
from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
clamp_prompt_logprobs)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
@@ -33,6 +31,7 @@ from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
MistralToolCall)
from vllm.logger import init_logger
from vllm.outputs import CompletionOutput, RequestOutput
+from vllm.reasoning import ReasoningParser, ReasoningParserManager
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py
index 0c26a60588c88..cecb3a8a1d4a8 100644
--- a/vllm/model_executor/guided_decoding/__init__.py
+++ b/vllm/model_executor/guided_decoding/__init__.py
@@ -5,10 +5,10 @@ from __future__ import annotations
from typing import TYPE_CHECKING
from vllm.logger import init_logger
-from vllm.model_executor.guided_decoding.reasoner import get_reasoner
from vllm.model_executor.guided_decoding.utils import (
convert_lark_to_gbnf, grammar_is_likely_lark,
has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features)
+from vllm.reasoning import ReasoningParserManager
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
@@ -107,7 +107,11 @@ async def get_guided_decoding_logits_processor(
model_config: ModelConfig,
reasoning_backend: str | None = None) -> LogitsProcessor | None:
- reasoner = get_reasoner(tokenizer, reasoning_backend)
+ reasoner = None
+ if reasoning_backend is not None:
+ reasoner_class = ReasoningParserManager.get_reasoning_parser(
+ reasoning_backend)
+ reasoner = reasoner_class(tokenizer)
guided_params = maybe_backend_fallback(guided_params)
@@ -146,8 +150,11 @@ def get_local_guided_decoding_logits_processor(
reasoning_backend: str | None = None) -> LogitsProcessor | None:
guided_params = maybe_backend_fallback(guided_params)
- # Get the reasoner if needed, it will be None if reasoning_
- reasoner = get_reasoner(tokenizer, reasoning_backend)
+ reasoner = None
+ if reasoning_backend is not None:
+ reasoner_class = ReasoningParserManager.get_reasoning_parser(
+ reasoning_backend)
+ reasoner = reasoner_class(tokenizer)
# CFG grammar not supported by LMFE, so we use outlines instead
if guided_params.backend_name == 'outlines':
diff --git a/vllm/model_executor/guided_decoding/outlines_decoding.py b/vllm/model_executor/guided_decoding/outlines_decoding.py
index 97f63ae11f457..564f9277a83c6 100644
--- a/vllm/model_executor/guided_decoding/outlines_decoding.py
+++ b/vllm/model_executor/guided_decoding/outlines_decoding.py
@@ -12,7 +12,7 @@ from transformers import PreTrainedTokenizerBase
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor)
-from vllm.model_executor.guided_decoding.reasoner import Reasoner
+from vllm.reasoning import ReasoningParser
from vllm.sampling_params import GuidedDecodingParams
@@ -61,7 +61,7 @@ _MAX_THREADPOOL_WORKERS = 16
async def get_outlines_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizerBase,
- reasoner: Optional[Reasoner],
+ reasoner: Optional[ReasoningParser],
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
None]:
"""
@@ -92,7 +92,7 @@ async def get_outlines_guided_decoding_logits_processor(
def get_local_outlines_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizerBase,
- reasoner: Optional[Reasoner],
+ reasoner: Optional[ReasoningParser],
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
None]:
"""
@@ -141,7 +141,7 @@ def _get_logits_processor(
tokenizer: PreTrainedTokenizerBase,
mode: GuidedDecodingMode,
whitespace_pattern: Union[str, None],
- reasoner: Optional[Reasoner],
+ reasoner: Optional[ReasoningParser],
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor]:
if mode == GuidedDecodingMode.JSON:
return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern,
diff --git a/vllm/model_executor/guided_decoding/outlines_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py
index 8b2a0f4cfe64b..31af4593f1123 100644
--- a/vllm/model_executor/guided_decoding/outlines_logits_processors.py
+++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py
@@ -34,8 +34,8 @@ from transformers import PreTrainedTokenizerBase
import vllm.envs as envs
from vllm.logger import init_logger
-from vllm.model_executor.guided_decoding.reasoner import Reasoner
from vllm.platforms import current_platform
+from vllm.reasoning import ReasoningParser
logger = init_logger(__name__)
@@ -49,9 +49,9 @@ else:
class BaseLogitsProcessor:
- def __init__(self, guide: Guide, reasoner: Optional[Reasoner]):
+ def __init__(self, guide: Guide, reasoner: Optional[ReasoningParser]):
self._guide: Guide = guide
- self._reasoner: Optional[Reasoner] = reasoner
+ self._reasoner: Optional[ReasoningParser] = reasoner
# CFGState is used for the FSM state for CFGGuide
self._fsm_state: DefaultDict[int, Union[int,
CFGState]] = defaultdict(int)
@@ -69,7 +69,7 @@ class BaseLogitsProcessor:
# Remove the reasoning tokens from the input_ids
# We need this because our implementation relies on the
# hash of the input_ids to store the FSM state.
- input_ids = self._reasoner.extract_content(input_ids)
+ input_ids = self._reasoner.extract_content_ids(input_ids)
seq_id = hash(tuple(input_ids))
@@ -142,7 +142,7 @@ class RegexLogitsProcessor(BaseLogitsProcessor):
self,
regex_string: str,
tokenizer: PreTrainedTokenizerBase,
- reasoner: Optional[Reasoner],
+ reasoner: Optional[ReasoningParser],
):
"""Compile the FSM that drives the regex-structured generation.
@@ -163,7 +163,7 @@ class JSONLogitsProcessor(RegexLogitsProcessor):
def __init__(self, schema: Union[str, Dict, BaseModel],
tokenizer: PreTrainedTokenizerBase,
whitespace_pattern: Union[str, None],
- reasoner: Optional[Reasoner]):
+ reasoner: Optional[ReasoningParser]):
"""Compile the FSM that drives the JSON-guided generation.
Parameters
@@ -203,7 +203,7 @@ class CFGLogitsProcessor(BaseLogitsProcessor):
return CFGGuide(cfg, tokenizer)
def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase,
- reasoner: Optional[Reasoner]):
+ reasoner: Optional[ReasoningParser]):
"""Compile the FSM that drives the context free grammar generation.
Parameters
diff --git a/vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py b/vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py
deleted file mode 100644
index 7e61e6a9620c7..0000000000000
--- a/vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py
+++ /dev/null
@@ -1,38 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-from dataclasses import dataclass
-
-from transformers import PreTrainedTokenizer
-
-from vllm.model_executor.guided_decoding.reasoner.reasoner import Reasoner
-
-
-@dataclass
-class DeepSeekReasoner(Reasoner):
- """
- Reasoner for DeepSeek R series models.
- """
- start_token_id: int
- end_token_id: int
-
- start_token: str = ""
- end_token: str = ""
-
- @classmethod
- def from_tokenizer(cls, tokenizer: PreTrainedTokenizer) -> Reasoner:
- return cls(start_token_id=tokenizer.encode(
- "", add_special_tokens=False)[0],
- end_token_id=tokenizer.encode("",
- add_special_tokens=False)[0])
-
- def is_reasoning_end(self, input_ids: list[int]) -> bool:
- return self.end_token_id in input_ids
-
- def extract_content(self, input_ids: list[int]) -> list[int]:
- """
- Extract the content after the end tokens
- """
- if self.end_token_id not in input_ids or \
- input_ids.index(self.end_token_id) + 1 == len(input_ids):
- return []
- else:
- return input_ids[input_ids.index(self.end_token_id) + 1:]
diff --git a/vllm/model_executor/guided_decoding/reasoner/reasoner.py b/vllm/model_executor/guided_decoding/reasoner/reasoner.py
deleted file mode 100644
index df21b1db62218..0000000000000
--- a/vllm/model_executor/guided_decoding/reasoner/reasoner.py
+++ /dev/null
@@ -1,23 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-from __future__ import annotations
-
-from abc import ABC, abstractmethod
-from dataclasses import dataclass
-
-from transformers import PreTrainedTokenizer
-
-
-@dataclass
-class Reasoner(ABC):
-
- @abstractmethod
- def from_tokenizer(cls, tokenizer: PreTrainedTokenizer) -> Reasoner:
- pass
-
- @abstractmethod
- def is_reasoning_end(self, input_ids: list[int]) -> bool:
- pass
-
- @abstractmethod
- def extract_content(self, input_ids: list[int]) -> list[int]:
- pass
diff --git a/vllm/model_executor/guided_decoding/xgrammar_decoding.py b/vllm/model_executor/guided_decoding/xgrammar_decoding.py
index bc156223953e0..47b1e7e3f9811 100644
--- a/vllm/model_executor/guided_decoding/xgrammar_decoding.py
+++ b/vllm/model_executor/guided_decoding/xgrammar_decoding.py
@@ -27,7 +27,7 @@ if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
from vllm.config import ModelConfig
- from vllm.model_executor.guided_decoding.reasoner import Reasoner
+ from vllm.reasoning import ReasoningParser
from vllm.sampling_params import GuidedDecodingParams
logger = init_logger(__name__)
@@ -37,7 +37,7 @@ def get_local_xgrammar_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizer,
model_config: ModelConfig,
- reasoner: Reasoner | None,
+ reasoner: ReasoningParser | None,
max_threads: int = 8):
config = GrammarConfig.from_guided_params(guided_params=guided_params,
model_config=model_config,
@@ -280,7 +280,7 @@ class GrammarConfig:
class XGrammarLogitsProcessor:
"""Wrapper class to support pickle protocol"""
config: GrammarConfig
- reasoner: Reasoner | None = None
+ reasoner: ReasoningParser | None = None
ctx: xgr.CompiledGrammar | None = None
tokenizer_info: xgr.TokenizerInfo = None # type: ignore[assignment]
diff --git a/vllm/entrypoints/openai/reasoning_parsers/__init__.py b/vllm/reasoning/__init__.py
similarity index 100%
rename from vllm/entrypoints/openai/reasoning_parsers/__init__.py
rename to vllm/reasoning/__init__.py
diff --git a/vllm/entrypoints/openai/reasoning_parsers/abs_reasoning_parsers.py b/vllm/reasoning/abs_reasoning_parsers.py
similarity index 82%
rename from vllm/entrypoints/openai/reasoning_parsers/abs_reasoning_parsers.py
rename to vllm/reasoning/abs_reasoning_parsers.py
index c95ff191e4d2e..454167a0dc950 100644
--- a/vllm/entrypoints/openai/reasoning_parsers/abs_reasoning_parsers.py
+++ b/vllm/reasoning/abs_reasoning_parsers.py
@@ -17,7 +17,7 @@ logger = init_logger(__name__)
class ReasoningParser:
"""
- Abstract reasoning parser class that should not be used directly.
+ Abstract reasoning parser class that should not be used directly.
Provided and methods should be used in derived classes.
It is used to extract reasoning content from the model output.
@@ -32,6 +32,36 @@ class ReasoningParser:
# whereas all tokenizers have .get_vocab()
return self.model_tokenizer.get_vocab()
+ @abstractmethod
+ def is_reasoning_end(self, input_ids: list[int]) -> bool:
+ """
+ Check if the reasoning content ends in the input_ids.
+
+ It is used in structured engines like `xgrammar` to check if the
+ reasoning content ends in the model output.
+
+ Parameters:
+ input_ids: list[int]
+ The input_ids of the model output.
+
+ Returns:
+ bool
+ True if the reasoning content ends in the input_ids.
+ """
+
+ @abstractmethod
+ def extract_content_ids(self, input_ids: list[int]) -> list[int]:
+ """
+ Extract content token ids from the input_ids.
+ Parameters:
+ input_ids: list[int]
+ The input_ids of the model output.
+ Returns:
+ list[int]
+ The extracted content from the input_ids.
+ """
+
+ @abstractmethod
def extract_reasoning_content(
self, model_output: str, request: ChatCompletionRequest
) -> tuple[Optional[str], Optional[str]]:
@@ -53,10 +83,7 @@ class ReasoningParser:
A tuple containing the reasoning content and the content.
"""
- raise NotImplementedError(
- "AbstractReasoningParser.extract_reasoning_calls "
- "has not been implemented!")
-
+ @abstractmethod
def extract_reasoning_content_streaming(
self,
previous_text: str,
@@ -73,43 +100,6 @@ class ReasoningParser:
the current tokens/diffs, but also the information about what has
previously been parsed and extracted (see constructor)
"""
- raise NotImplementedError(
- "AbstractReasoningParser.extract_reasoning_content_streaming "
- "has not been implemented!")
-
- # TODO: need to rebase by PR #14428
- @abstractmethod
- def is_reasoning_end(self, input_ids: list[int]) -> bool:
- """
- Check if the reasoning content ends in the input_ids.
- Parameters:
- input_ids: list[int]
- The input_ids of the model output.
- Returns:
- bool
- True if the reasoning content ends in the input_ids.
- """
-
- raise NotImplementedError(
- "AbstractReasoningParser.is_reasoning_end has"
- "not been implemented!")
-
- # TODO: need to rebase by PR #14428
- @abstractmethod
- def extract_content_ids(self, input_ids: list[int]) -> list[int]:
- """
- Extract content token ids from the input_ids.
- Parameters:
- input_ids: list[int]
- The input_ids of the model output.
- Returns:
- list[int]
- The extracted content from the input_ids.
- """
-
- raise NotImplementedError(
- "AbstractReasoningParser.extract_content_ids has"
- " not been implemented!")
class ReasoningParserManager:
@@ -125,14 +115,16 @@ class ReasoningParserManager:
if name in cls.reasoning_parsers:
return cls.reasoning_parsers[name]
- raise KeyError(f"reasoning helper: '{name}' not found in "
- "reasoning_parsers")
+ raise KeyError(
+ f"reasoning helper: '{name}' not found in reasoning_parsers")
@classmethod
- def _register_module(cls,
- module: type,
- module_name: Optional[Union[str, list[str]]] = None,
- force: bool = True) -> None:
+ def _register_module(
+ cls,
+ module: type,
+ module_name: Optional[Union[str, list[str]]] = None,
+ force: bool = True,
+ ) -> None:
if not issubclass(module, ReasoningParser):
raise TypeError("module must be subclass of ReasoningParser, "
f"but got {type(module)}")
@@ -149,13 +141,14 @@ class ReasoningParserManager:
@classmethod
def register_module(
- cls,
- name: Optional[Union[str, list[str]]] = None,
- force: bool = True,
- module: Union[type, None] = None) -> Union[type, Callable]:
+ cls,
+ name: Optional[Union[str, list[str]]] = None,
+ force: bool = True,
+ module: Union[type, None] = None,
+ ) -> Union[type, Callable]:
"""
Register module with the given name or name list. it can be used as a
- decoder(with module as None) or normal function(with module as not
+ decoder(with module as None) or normal function(with module as not
None).
"""
if not isinstance(force, bool):
@@ -183,7 +176,7 @@ class ReasoningParserManager:
@classmethod
def import_reasoning_parser(cls, plugin_path: str) -> None:
"""
- Import a user-defined reasoning parser by the path
+ Import a user-defined reasoning parser by the path
of the reasoning parser define file.
"""
module_name = os.path.splitext(os.path.basename(plugin_path))[0]
diff --git a/vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py b/vllm/reasoning/deepseek_r1_reasoning_parser.py
similarity index 64%
rename from vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py
rename to vllm/reasoning/deepseek_r1_reasoning_parser.py
index 54e960168cf46..73be6d4d1ab13 100644
--- a/vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py
+++ b/vllm/reasoning/deepseek_r1_reasoning_parser.py
@@ -8,9 +8,8 @@ from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage)
-from vllm.entrypoints.openai.reasoning_parsers.abs_reasoning_parsers import (
- ReasoningParser, ReasoningParserManager)
from vllm.logger import init_logger
+from vllm.reasoning import ReasoningParser, ReasoningParserManager
logger = init_logger(__name__)
@@ -20,43 +19,45 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
"""
Reasoning parser for DeepSeek R1 model.
- The DeepSeek R1 model uses ... tokens to denote reasoning
+ The DeepSeek R1 model uses ... tokens to denote reasoning
text. This parser extracts the reasoning content from the model output.
"""
+ start_token_id: int
+ end_token_id: int
+
+ start_token: str = ""
+ end_token: str = ""
+
def __init__(self, tokenizer: PreTrainedTokenizerBase):
super().__init__(tokenizer)
- self.think_start_token = ""
- self.think_end_token = ""
self.reasoning_regex = re.compile(
- rf"{self.think_start_token}(.*?){self.think_end_token}", re.DOTALL)
+ rf"{self.start_token}(.*?){self.end_token}", re.DOTALL)
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ReasoningParser "
"constructor during construction.")
- self.think_start_token_id = self.vocab.get(self.think_start_token)
- self.think_end_token_id = self.vocab.get(self.think_end_token)
- if (self.think_start_token_id is None
- or self.think_end_token_id is None):
+ self.start_token_id = self.vocab.get(self.start_token)
+ self.end_token_id = self.vocab.get(self.end_token)
+ if self.start_token_id is None or self.end_token_id is None:
raise RuntimeError(
"DeepSeek R1 reasoning parser could not locate think start/end "
"tokens in the tokenizer!")
- # TODO: need to rebase by PR #14428
def is_reasoning_end(self, input_ids: list[int]) -> bool:
- return self.think_end_token_id in input_ids
+ return self.end_token_id in input_ids
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
"""
Extract the content after the end tokens
"""
- if self.think_end_token_id not in input_ids[:-1]:
+ if self.end_token_id not in input_ids[:-1]:
return []
else:
- return input_ids[input_ids.index(self.think_end_token_id) + 1:]
+ return input_ids[input_ids.index(self.end_token_id) + 1:]
def extract_reasoning_content_streaming(
self,
@@ -77,22 +78,24 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
"""
# Skip single special tokens
if len(delta_token_ids) == 1 and (delta_token_ids[0] in [
- self.think_start_token_id, self.think_end_token_id
+ self.start_token_id, self.end_token_id
]):
return None
# Check if is present in previous or delta.
# Keep compatibility with models that don't generate tokens.
- if self.think_start_token_id in previous_token_ids:
- if self.think_end_token_id in delta_token_ids:
+ if self.start_token_id in previous_token_ids:
+ if self.end_token_id in delta_token_ids:
# in previous, in delta,
# extract reasoning content
- end_index = delta_text.find(self.think_end_token)
+ end_index = delta_text.find(self.end_token)
reasoning_content = delta_text[:end_index]
- content = delta_text[end_index + len(self.think_end_token):]
- return DeltaMessage(reasoning_content=reasoning_content,
- content=content if content else None)
- elif self.think_end_token_id in previous_token_ids:
+ content = delta_text[end_index + len(self.end_token):]
+ return DeltaMessage(
+ reasoning_content=reasoning_content,
+ content=content if content else None,
+ )
+ elif self.end_token_id in previous_token_ids:
# in previous, in previous,
# reasoning content continues
return DeltaMessage(content=delta_text)
@@ -100,17 +103,18 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
# in previous, no in previous or delta,
# reasoning content continues
return DeltaMessage(reasoning_content=delta_text)
- elif self.think_start_token_id in delta_token_ids:
- if self.think_end_token_id in delta_token_ids:
+ elif self.start_token_id in delta_token_ids:
+ if self.end_token_id in delta_token_ids:
# in delta, in delta, extract reasoning content
- start_index = delta_text.find(self.think_start_token)
- end_index = delta_text.find(self.think_end_token)
+ start_index = delta_text.find(self.start_token)
+ end_index = delta_text.find(self.end_token)
reasoning_content = delta_text[start_index +
- len(self.think_start_token
- ):end_index]
- content = delta_text[end_index + len(self.think_end_token):]
- return DeltaMessage(reasoning_content=reasoning_content,
- content=content if content else None)
+ len(self.start_token):end_index]
+ content = delta_text[end_index + len(self.end_token):]
+ return DeltaMessage(
+ reasoning_content=reasoning_content,
+ content=content if content else None,
+ )
else:
# in delta, no in delta,
# reasoning content continues
@@ -119,15 +123,17 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
# No in previous or delta, also need to check for .
# Because the model may have generated without
# Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f
- if self.think_end_token_id in delta_token_ids:
+ if self.end_token_id in delta_token_ids:
# in delta with more tokens,
# extract reasoning content and content
- end_index = delta_text.find(self.think_end_token)
+ end_index = delta_text.find(self.end_token)
reasoning_content = delta_text[:end_index]
- content = delta_text[end_index + len(self.think_end_token):]
- return DeltaMessage(reasoning_content=reasoning_content,
- content=content if content else None)
- elif self.think_end_token_id in previous_token_ids:
+ content = delta_text[end_index + len(self.end_token):]
+ return DeltaMessage(
+ reasoning_content=reasoning_content,
+ content=content if content else None,
+ )
+ elif self.end_token_id in previous_token_ids:
# in previous, thinking content ends
return DeltaMessage(content=delta_text)
else:
@@ -137,22 +143,20 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
def extract_reasoning_content(
self, model_output: str, request: ChatCompletionRequest
) -> tuple[Optional[str], Optional[str]]:
-
# DeepSeek R1 doesn't generate now.
# Thus we assume the reasoning content is always at the start.
# Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f
- if self.think_end_token not in model_output:
+ if self.end_token not in model_output:
return model_output, None
else:
# Add a start token if it's missing to keep compatibility.
- if self.think_start_token not in model_output:
- model_output = f"{self.think_start_token}{model_output}"
+ if self.start_token not in model_output:
+ model_output = f"{self.start_token}{model_output}"
# Use a regex to find the reasoning content
reasoning_content = self.reasoning_regex.findall(model_output)[0]
end_index = len(
- f"{self.think_start_token}{reasoning_content}{self.think_end_token}"
- )
+ f"{self.start_token}{reasoning_content}{self.end_token}")
final_output = model_output[end_index:]
if len(final_output) == 0:
diff --git a/vllm/entrypoints/openai/reasoning_parsers/granite_reasoning_parser.py b/vllm/reasoning/granite_reasoning_parser.py
similarity index 99%
rename from vllm/entrypoints/openai/reasoning_parsers/granite_reasoning_parser.py
rename to vllm/reasoning/granite_reasoning_parser.py
index 117d051a73782..249ace1f167fa 100644
--- a/vllm/entrypoints/openai/reasoning_parsers/granite_reasoning_parser.py
+++ b/vllm/reasoning/granite_reasoning_parser.py
@@ -8,9 +8,8 @@ from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage)
-from vllm.entrypoints.openai.reasoning_parsers.abs_reasoning_parsers import (
- ReasoningParser, ReasoningParserManager)
from vllm.logger import init_logger
+from vllm.reasoning import ReasoningParser, ReasoningParserManager
logger = init_logger(__name__)