From 32b14baf8a1f7195ca09484de3008063569b43c5 Mon Sep 17 00:00:00 2001 From: Ce Gao Date: Fri, 28 Mar 2025 15:23:30 +0800 Subject: [PATCH] [Refactor][Frontend] Keep all logic about reasoning into one class (#14428) Signed-off-by: Ce Gao --- .../__init__.py | 0 .../test_deepseekr1_reasoning_parser.py | 52 +++++++-- .../test_granite_reasoning_parser.py | 6 +- .../reasoning_parsers => reasoning}/utils.py | 2 +- vllm/engine/arg_utils.py | 3 +- vllm/engine/llm_engine.py | 5 +- vllm/entrypoints/openai/api_server.py | 2 +- vllm/entrypoints/openai/serving_chat.py | 3 +- .../guided_decoding/__init__.py | 15 ++- .../guided_decoding/outlines_decoding.py | 8 +- .../outlines_logits_processors.py | 14 +-- .../reasoner/deepseek_reasoner.py | 38 ------- .../guided_decoding/reasoner/reasoner.py | 23 ---- .../guided_decoding/xgrammar_decoding.py | 6 +- .../__init__.py | 0 .../abs_reasoning_parsers.py | 101 ++++++++---------- .../deepseek_r1_reasoning_parser.py | 90 ++++++++-------- .../granite_reasoning_parser.py | 3 +- 18 files changed, 171 insertions(+), 200 deletions(-) rename tests/{entrypoints/openai/reasoning_parsers => reasoning}/__init__.py (100%) rename tests/{entrypoints/openai/reasoning_parsers => reasoning}/test_deepseekr1_reasoning_parser.py (75%) rename tests/{entrypoints/openai/reasoning_parsers => reasoning}/test_granite_reasoning_parser.py (97%) rename tests/{entrypoints/openai/reasoning_parsers => reasoning}/utils.py (97%) delete mode 100644 vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py delete mode 100644 vllm/model_executor/guided_decoding/reasoner/reasoner.py rename vllm/{entrypoints/openai/reasoning_parsers => reasoning}/__init__.py (100%) rename vllm/{entrypoints/openai/reasoning_parsers => reasoning}/abs_reasoning_parsers.py (82%) rename vllm/{entrypoints/openai/reasoning_parsers => reasoning}/deepseek_r1_reasoning_parser.py (64%) rename vllm/{entrypoints/openai/reasoning_parsers => reasoning}/granite_reasoning_parser.py (99%) diff --git a/tests/entrypoints/openai/reasoning_parsers/__init__.py b/tests/reasoning/__init__.py similarity index 100% rename from tests/entrypoints/openai/reasoning_parsers/__init__.py rename to tests/reasoning/__init__.py diff --git a/tests/entrypoints/openai/reasoning_parsers/test_deepseekr1_reasoning_parser.py b/tests/reasoning/test_deepseekr1_reasoning_parser.py similarity index 75% rename from tests/entrypoints/openai/reasoning_parsers/test_deepseekr1_reasoning_parser.py rename to tests/reasoning/test_deepseekr1_reasoning_parser.py index 5ce5d9280f3ef..7b6af183a86ad 100644 --- a/tests/entrypoints/openai/reasoning_parsers/test_deepseekr1_reasoning_parser.py +++ b/tests/reasoning/test_deepseekr1_reasoning_parser.py @@ -3,74 +3,92 @@ import pytest from transformers import AutoTokenizer -from tests.entrypoints.openai.reasoning_parsers.utils import ( - run_reasoning_extraction) -from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser, - ReasoningParserManager) +from tests.reasoning.utils import run_reasoning_extraction +from vllm.reasoning import ReasoningParser, ReasoningParserManager parser_name = "deepseek_r1" start_token = "" end_token = "" +REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" + + +@pytest.fixture(scope="module") +def deepseek_r1_qwen_tokenizer(): + return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME) + + SIMPLE_REASONING = { "output": "This is a reasoning sectionThis is the rest", "reasoning_content": "This is a reasoning section", "content": "This is the rest", + "is_reasoning_end": True, } COMPLETE_REASONING = { "output": "This is a reasoning section", "reasoning_content": "This is a reasoning section", "content": None, + "is_reasoning_end": True, } NO_CONTENT = { "output": "This is content", "reasoning_content": "This is content", "content": None, + "is_reasoning_end": False, } NO_REASONING_STREAMING = { "output": "This is a reasoning section", "reasoning_content": "This is a reasoning section", "content": None, + "is_reasoning_end": False, } MULTIPLE_LINES = { "output": "This\nThatThis is the rest\nThat", "reasoning_content": "This\nThat", "content": "This is the rest\nThat", + "is_reasoning_end": True, } SHORTEST_REASONING_NO_STREAMING = { "output": "This is the rest", "reasoning_content": "", "content": "This is the rest", + "is_reasoning_end": True, } SHORTEST_REASONING = { "output": "This is the rest", "reasoning_content": None, "content": "This is the rest", + "is_reasoning_end": True, } REASONING_WITH_THINK = { "output": "This is a reasoning sectionThis is the rest", "reasoning_content": "This is a reasoning section", "content": "This is the rest", + "is_reasoning_end": True, } COMPLETE_REASONING_WITH_THINK = { "output": "This is a reasoning section", "reasoning_content": "This is a reasoning section", "content": None, + "is_reasoning_end": True, } MULTIPLE_LINES_WITH_THINK = { "output": "This\nThatThis is the rest\nThat", "reasoning_content": "This\nThat", "content": "This is the rest\nThat", + "is_reasoning_end": True, } SHORTEST_REASONING_NO_STREAMING_WITH_THINK = { "output": "This is the rest", "reasoning_content": "", "content": "This is the rest", + "is_reasoning_end": True, } SHORTEST_REASONING_WITH_THINK = { "output": "This is the rest", "reasoning_content": None, "content": "This is the rest", + "is_reasoning_end": True, } TEST_CASES = [ @@ -166,23 +184,21 @@ TEST_CASES = [ ), ] -# Global tokenizer initialization to avoid repeated loading -tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m") -tokenizer.add_tokens([start_token, end_token]) - @pytest.mark.parametrize("streaming, param_dict", TEST_CASES) def test_reasoning( streaming: bool, param_dict: dict, + deepseek_r1_qwen_tokenizer, ): - output = tokenizer.tokenize(param_dict["output"]) + output = deepseek_r1_qwen_tokenizer.tokenize(param_dict["output"]) # decode everything to tokens output_tokens: list[str] = [ - tokenizer.convert_tokens_to_string([token]) for token in output + deepseek_r1_qwen_tokenizer.convert_tokens_to_string([token]) + for token in output ] parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser( - parser_name)(tokenizer) + parser_name)(deepseek_r1_qwen_tokenizer) reasoning, content = run_reasoning_extraction(parser, output_tokens, @@ -190,3 +206,17 @@ def test_reasoning( assert reasoning == param_dict["reasoning_content"] assert content == param_dict["content"] + + # Test is_reasoning_end + output_ids = deepseek_r1_qwen_tokenizer.convert_tokens_to_ids(output) + is_reasoning_end = parser.is_reasoning_end(output_ids) + assert is_reasoning_end == param_dict["is_reasoning_end"] + + # Test extract_content + if param_dict["content"] is not None: + content = parser.extract_content_ids(output_ids) + assert content == deepseek_r1_qwen_tokenizer.convert_tokens_to_ids( + deepseek_r1_qwen_tokenizer.tokenize(param_dict["content"])) + else: + content = parser.extract_content_ids(output) + assert content == [] diff --git a/tests/entrypoints/openai/reasoning_parsers/test_granite_reasoning_parser.py b/tests/reasoning/test_granite_reasoning_parser.py similarity index 97% rename from tests/entrypoints/openai/reasoning_parsers/test_granite_reasoning_parser.py rename to tests/reasoning/test_granite_reasoning_parser.py index 84ac6600498b2..48fb8c2f8d1b9 100644 --- a/tests/entrypoints/openai/reasoning_parsers/test_granite_reasoning_parser.py +++ b/tests/reasoning/test_granite_reasoning_parser.py @@ -2,10 +2,8 @@ import pytest from transformers import AutoTokenizer -from tests.entrypoints.openai.reasoning_parsers.utils import ( - DeltaMessage, run_reasoning_extraction) -from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser, - ReasoningParserManager) +from tests.reasoning.utils import DeltaMessage, run_reasoning_extraction +from vllm.reasoning import ReasoningParser, ReasoningParserManager parser_name = "granite" START_REASONING = "Here is my thought process:" diff --git a/tests/entrypoints/openai/reasoning_parsers/utils.py b/tests/reasoning/utils.py similarity index 97% rename from tests/entrypoints/openai/reasoning_parsers/utils.py rename to tests/reasoning/utils.py index 01e43130bc6e7..0f894ed800c6c 100644 --- a/tests/entrypoints/openai/reasoning_parsers/utils.py +++ b/tests/reasoning/utils.py @@ -4,7 +4,7 @@ from typing import Optional, Union from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaMessage) -from vllm.entrypoints.openai.reasoning_parsers import ReasoningParser +from vllm.reasoning import ReasoningParser class StreamingReasoningReconstructor: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index d049f773caccd..a416fa8aa08e3 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -23,6 +23,7 @@ from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.plugins import load_general_plugins +from vllm.reasoning import ReasoningParserManager from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3 from vllm.transformers_utils.utils import check_gguf_file from vllm.usage.usage_lib import UsageContext @@ -1119,7 +1120,7 @@ class EngineArgs: parser.add_argument( "--reasoning-parser", type=str, - choices=["deepseek_r1", "granite"], + choices=list(ReasoningParserManager.reasoning_parsers), default=None, help= "Select the reasoning parser depending on the model that you're " diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 4856c3568319b..5682b3dabe2e8 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -2080,8 +2080,9 @@ class LLMEngine: guided_decoding.backend = guided_decoding.backend or \ self.decoding_config.guided_decoding_backend - logger.debug("Reasoning backend: %s", - self.decoding_config.reasoning_backend) + if self.decoding_config.reasoning_backend is not None: + logger.debug("Building with reasoning backend %s", + self.decoding_config.reasoning_backend) processor = get_local_guided_decoding_logits_processor( guided_params=guided_decoding, diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 1e735da641df9..6c1f60fa6a3b4 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -68,7 +68,6 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, TranscriptionRequest, TranscriptionResponse, UnloadLoRAAdapterRequest) -from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager # yapf: enable from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion @@ -85,6 +84,7 @@ from vllm.entrypoints.openai.serving_transcription import ( from vllm.entrypoints.openai.tool_parsers import ToolParserManager from vllm.entrypoints.utils import load_aware_call, with_cancellation from vllm.logger import init_logger +from vllm.reasoning import ReasoningParserManager from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) from vllm.transformers_utils.tokenizer import MistralTokenizer diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 3102db4050f5b..eda4722836bdb 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -23,8 +23,6 @@ from vllm.entrypoints.openai.protocol import ( ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage, DeltaToolCall, ErrorResponse, FunctionCall, PromptTokenUsageInfo, RequestResponseMetadata, ToolCall, UsageInfo) -from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser, - ReasoningParserManager) from vllm.entrypoints.openai.serving_engine import (OpenAIServing, clamp_prompt_logprobs) from vllm.entrypoints.openai.serving_models import OpenAIServingModels @@ -33,6 +31,7 @@ from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( MistralToolCall) from vllm.logger import init_logger from vllm.outputs import CompletionOutput, RequestOutput +from vllm.reasoning import ReasoningParser, ReasoningParserManager from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sequence import Logprob from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index 0c26a60588c88..cecb3a8a1d4a8 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -5,10 +5,10 @@ from __future__ import annotations from typing import TYPE_CHECKING from vllm.logger import init_logger -from vllm.model_executor.guided_decoding.reasoner import get_reasoner from vllm.model_executor.guided_decoding.utils import ( convert_lark_to_gbnf, grammar_is_likely_lark, has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features) +from vllm.reasoning import ReasoningParserManager if TYPE_CHECKING: from transformers import PreTrainedTokenizer @@ -107,7 +107,11 @@ async def get_guided_decoding_logits_processor( model_config: ModelConfig, reasoning_backend: str | None = None) -> LogitsProcessor | None: - reasoner = get_reasoner(tokenizer, reasoning_backend) + reasoner = None + if reasoning_backend is not None: + reasoner_class = ReasoningParserManager.get_reasoning_parser( + reasoning_backend) + reasoner = reasoner_class(tokenizer) guided_params = maybe_backend_fallback(guided_params) @@ -146,8 +150,11 @@ def get_local_guided_decoding_logits_processor( reasoning_backend: str | None = None) -> LogitsProcessor | None: guided_params = maybe_backend_fallback(guided_params) - # Get the reasoner if needed, it will be None if reasoning_ - reasoner = get_reasoner(tokenizer, reasoning_backend) + reasoner = None + if reasoning_backend is not None: + reasoner_class = ReasoningParserManager.get_reasoning_parser( + reasoning_backend) + reasoner = reasoner_class(tokenizer) # CFG grammar not supported by LMFE, so we use outlines instead if guided_params.backend_name == 'outlines': diff --git a/vllm/model_executor/guided_decoding/outlines_decoding.py b/vllm/model_executor/guided_decoding/outlines_decoding.py index 97f63ae11f457..564f9277a83c6 100644 --- a/vllm/model_executor/guided_decoding/outlines_decoding.py +++ b/vllm/model_executor/guided_decoding/outlines_decoding.py @@ -12,7 +12,7 @@ from transformers import PreTrainedTokenizerBase from vllm.model_executor.guided_decoding.outlines_logits_processors import ( CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor) -from vllm.model_executor.guided_decoding.reasoner import Reasoner +from vllm.reasoning import ReasoningParser from vllm.sampling_params import GuidedDecodingParams @@ -61,7 +61,7 @@ _MAX_THREADPOOL_WORKERS = 16 async def get_outlines_guided_decoding_logits_processor( guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase, - reasoner: Optional[Reasoner], + reasoner: Optional[ReasoningParser], ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor, None]: """ @@ -92,7 +92,7 @@ async def get_outlines_guided_decoding_logits_processor( def get_local_outlines_guided_decoding_logits_processor( guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase, - reasoner: Optional[Reasoner], + reasoner: Optional[ReasoningParser], ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor, None]: """ @@ -141,7 +141,7 @@ def _get_logits_processor( tokenizer: PreTrainedTokenizerBase, mode: GuidedDecodingMode, whitespace_pattern: Union[str, None], - reasoner: Optional[Reasoner], + reasoner: Optional[ReasoningParser], ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor]: if mode == GuidedDecodingMode.JSON: return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern, diff --git a/vllm/model_executor/guided_decoding/outlines_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py index 8b2a0f4cfe64b..31af4593f1123 100644 --- a/vllm/model_executor/guided_decoding/outlines_logits_processors.py +++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py @@ -34,8 +34,8 @@ from transformers import PreTrainedTokenizerBase import vllm.envs as envs from vllm.logger import init_logger -from vllm.model_executor.guided_decoding.reasoner import Reasoner from vllm.platforms import current_platform +from vllm.reasoning import ReasoningParser logger = init_logger(__name__) @@ -49,9 +49,9 @@ else: class BaseLogitsProcessor: - def __init__(self, guide: Guide, reasoner: Optional[Reasoner]): + def __init__(self, guide: Guide, reasoner: Optional[ReasoningParser]): self._guide: Guide = guide - self._reasoner: Optional[Reasoner] = reasoner + self._reasoner: Optional[ReasoningParser] = reasoner # CFGState is used for the FSM state for CFGGuide self._fsm_state: DefaultDict[int, Union[int, CFGState]] = defaultdict(int) @@ -69,7 +69,7 @@ class BaseLogitsProcessor: # Remove the reasoning tokens from the input_ids # We need this because our implementation relies on the # hash of the input_ids to store the FSM state. - input_ids = self._reasoner.extract_content(input_ids) + input_ids = self._reasoner.extract_content_ids(input_ids) seq_id = hash(tuple(input_ids)) @@ -142,7 +142,7 @@ class RegexLogitsProcessor(BaseLogitsProcessor): self, regex_string: str, tokenizer: PreTrainedTokenizerBase, - reasoner: Optional[Reasoner], + reasoner: Optional[ReasoningParser], ): """Compile the FSM that drives the regex-structured generation. @@ -163,7 +163,7 @@ class JSONLogitsProcessor(RegexLogitsProcessor): def __init__(self, schema: Union[str, Dict, BaseModel], tokenizer: PreTrainedTokenizerBase, whitespace_pattern: Union[str, None], - reasoner: Optional[Reasoner]): + reasoner: Optional[ReasoningParser]): """Compile the FSM that drives the JSON-guided generation. Parameters @@ -203,7 +203,7 @@ class CFGLogitsProcessor(BaseLogitsProcessor): return CFGGuide(cfg, tokenizer) def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase, - reasoner: Optional[Reasoner]): + reasoner: Optional[ReasoningParser]): """Compile the FSM that drives the context free grammar generation. Parameters diff --git a/vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py b/vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py deleted file mode 100644 index 7e61e6a9620c7..0000000000000 --- a/vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py +++ /dev/null @@ -1,38 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -from dataclasses import dataclass - -from transformers import PreTrainedTokenizer - -from vllm.model_executor.guided_decoding.reasoner.reasoner import Reasoner - - -@dataclass -class DeepSeekReasoner(Reasoner): - """ - Reasoner for DeepSeek R series models. - """ - start_token_id: int - end_token_id: int - - start_token: str = "" - end_token: str = "" - - @classmethod - def from_tokenizer(cls, tokenizer: PreTrainedTokenizer) -> Reasoner: - return cls(start_token_id=tokenizer.encode( - "", add_special_tokens=False)[0], - end_token_id=tokenizer.encode("", - add_special_tokens=False)[0]) - - def is_reasoning_end(self, input_ids: list[int]) -> bool: - return self.end_token_id in input_ids - - def extract_content(self, input_ids: list[int]) -> list[int]: - """ - Extract the content after the end tokens - """ - if self.end_token_id not in input_ids or \ - input_ids.index(self.end_token_id) + 1 == len(input_ids): - return [] - else: - return input_ids[input_ids.index(self.end_token_id) + 1:] diff --git a/vllm/model_executor/guided_decoding/reasoner/reasoner.py b/vllm/model_executor/guided_decoding/reasoner/reasoner.py deleted file mode 100644 index df21b1db62218..0000000000000 --- a/vllm/model_executor/guided_decoding/reasoner/reasoner.py +++ /dev/null @@ -1,23 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -from __future__ import annotations - -from abc import ABC, abstractmethod -from dataclasses import dataclass - -from transformers import PreTrainedTokenizer - - -@dataclass -class Reasoner(ABC): - - @abstractmethod - def from_tokenizer(cls, tokenizer: PreTrainedTokenizer) -> Reasoner: - pass - - @abstractmethod - def is_reasoning_end(self, input_ids: list[int]) -> bool: - pass - - @abstractmethod - def extract_content(self, input_ids: list[int]) -> list[int]: - pass diff --git a/vllm/model_executor/guided_decoding/xgrammar_decoding.py b/vllm/model_executor/guided_decoding/xgrammar_decoding.py index bc156223953e0..47b1e7e3f9811 100644 --- a/vllm/model_executor/guided_decoding/xgrammar_decoding.py +++ b/vllm/model_executor/guided_decoding/xgrammar_decoding.py @@ -27,7 +27,7 @@ if TYPE_CHECKING: from transformers import PreTrainedTokenizer from vllm.config import ModelConfig - from vllm.model_executor.guided_decoding.reasoner import Reasoner + from vllm.reasoning import ReasoningParser from vllm.sampling_params import GuidedDecodingParams logger = init_logger(__name__) @@ -37,7 +37,7 @@ def get_local_xgrammar_guided_decoding_logits_processor( guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer, model_config: ModelConfig, - reasoner: Reasoner | None, + reasoner: ReasoningParser | None, max_threads: int = 8): config = GrammarConfig.from_guided_params(guided_params=guided_params, model_config=model_config, @@ -280,7 +280,7 @@ class GrammarConfig: class XGrammarLogitsProcessor: """Wrapper class to support pickle protocol""" config: GrammarConfig - reasoner: Reasoner | None = None + reasoner: ReasoningParser | None = None ctx: xgr.CompiledGrammar | None = None tokenizer_info: xgr.TokenizerInfo = None # type: ignore[assignment] diff --git a/vllm/entrypoints/openai/reasoning_parsers/__init__.py b/vllm/reasoning/__init__.py similarity index 100% rename from vllm/entrypoints/openai/reasoning_parsers/__init__.py rename to vllm/reasoning/__init__.py diff --git a/vllm/entrypoints/openai/reasoning_parsers/abs_reasoning_parsers.py b/vllm/reasoning/abs_reasoning_parsers.py similarity index 82% rename from vllm/entrypoints/openai/reasoning_parsers/abs_reasoning_parsers.py rename to vllm/reasoning/abs_reasoning_parsers.py index c95ff191e4d2e..454167a0dc950 100644 --- a/vllm/entrypoints/openai/reasoning_parsers/abs_reasoning_parsers.py +++ b/vllm/reasoning/abs_reasoning_parsers.py @@ -17,7 +17,7 @@ logger = init_logger(__name__) class ReasoningParser: """ - Abstract reasoning parser class that should not be used directly. + Abstract reasoning parser class that should not be used directly. Provided and methods should be used in derived classes. It is used to extract reasoning content from the model output. @@ -32,6 +32,36 @@ class ReasoningParser: # whereas all tokenizers have .get_vocab() return self.model_tokenizer.get_vocab() + @abstractmethod + def is_reasoning_end(self, input_ids: list[int]) -> bool: + """ + Check if the reasoning content ends in the input_ids. + + It is used in structured engines like `xgrammar` to check if the + reasoning content ends in the model output. + + Parameters: + input_ids: list[int] + The input_ids of the model output. + + Returns: + bool + True if the reasoning content ends in the input_ids. + """ + + @abstractmethod + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + """ + Extract content token ids from the input_ids. + Parameters: + input_ids: list[int] + The input_ids of the model output. + Returns: + list[int] + The extracted content from the input_ids. + """ + + @abstractmethod def extract_reasoning_content( self, model_output: str, request: ChatCompletionRequest ) -> tuple[Optional[str], Optional[str]]: @@ -53,10 +83,7 @@ class ReasoningParser: A tuple containing the reasoning content and the content. """ - raise NotImplementedError( - "AbstractReasoningParser.extract_reasoning_calls " - "has not been implemented!") - + @abstractmethod def extract_reasoning_content_streaming( self, previous_text: str, @@ -73,43 +100,6 @@ class ReasoningParser: the current tokens/diffs, but also the information about what has previously been parsed and extracted (see constructor) """ - raise NotImplementedError( - "AbstractReasoningParser.extract_reasoning_content_streaming " - "has not been implemented!") - - # TODO: need to rebase by PR #14428 - @abstractmethod - def is_reasoning_end(self, input_ids: list[int]) -> bool: - """ - Check if the reasoning content ends in the input_ids. - Parameters: - input_ids: list[int] - The input_ids of the model output. - Returns: - bool - True if the reasoning content ends in the input_ids. - """ - - raise NotImplementedError( - "AbstractReasoningParser.is_reasoning_end has" - "not been implemented!") - - # TODO: need to rebase by PR #14428 - @abstractmethod - def extract_content_ids(self, input_ids: list[int]) -> list[int]: - """ - Extract content token ids from the input_ids. - Parameters: - input_ids: list[int] - The input_ids of the model output. - Returns: - list[int] - The extracted content from the input_ids. - """ - - raise NotImplementedError( - "AbstractReasoningParser.extract_content_ids has" - " not been implemented!") class ReasoningParserManager: @@ -125,14 +115,16 @@ class ReasoningParserManager: if name in cls.reasoning_parsers: return cls.reasoning_parsers[name] - raise KeyError(f"reasoning helper: '{name}' not found in " - "reasoning_parsers") + raise KeyError( + f"reasoning helper: '{name}' not found in reasoning_parsers") @classmethod - def _register_module(cls, - module: type, - module_name: Optional[Union[str, list[str]]] = None, - force: bool = True) -> None: + def _register_module( + cls, + module: type, + module_name: Optional[Union[str, list[str]]] = None, + force: bool = True, + ) -> None: if not issubclass(module, ReasoningParser): raise TypeError("module must be subclass of ReasoningParser, " f"but got {type(module)}") @@ -149,13 +141,14 @@ class ReasoningParserManager: @classmethod def register_module( - cls, - name: Optional[Union[str, list[str]]] = None, - force: bool = True, - module: Union[type, None] = None) -> Union[type, Callable]: + cls, + name: Optional[Union[str, list[str]]] = None, + force: bool = True, + module: Union[type, None] = None, + ) -> Union[type, Callable]: """ Register module with the given name or name list. it can be used as a - decoder(with module as None) or normal function(with module as not + decoder(with module as None) or normal function(with module as not None). """ if not isinstance(force, bool): @@ -183,7 +176,7 @@ class ReasoningParserManager: @classmethod def import_reasoning_parser(cls, plugin_path: str) -> None: """ - Import a user-defined reasoning parser by the path + Import a user-defined reasoning parser by the path of the reasoning parser define file. """ module_name = os.path.splitext(os.path.basename(plugin_path))[0] diff --git a/vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py b/vllm/reasoning/deepseek_r1_reasoning_parser.py similarity index 64% rename from vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py rename to vllm/reasoning/deepseek_r1_reasoning_parser.py index 54e960168cf46..73be6d4d1ab13 100644 --- a/vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py +++ b/vllm/reasoning/deepseek_r1_reasoning_parser.py @@ -8,9 +8,8 @@ from transformers import PreTrainedTokenizerBase from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaMessage) -from vllm.entrypoints.openai.reasoning_parsers.abs_reasoning_parsers import ( - ReasoningParser, ReasoningParserManager) from vllm.logger import init_logger +from vllm.reasoning import ReasoningParser, ReasoningParserManager logger = init_logger(__name__) @@ -20,43 +19,45 @@ class DeepSeekR1ReasoningParser(ReasoningParser): """ Reasoning parser for DeepSeek R1 model. - The DeepSeek R1 model uses ... tokens to denote reasoning + The DeepSeek R1 model uses ... tokens to denote reasoning text. This parser extracts the reasoning content from the model output. """ + start_token_id: int + end_token_id: int + + start_token: str = "" + end_token: str = "" + def __init__(self, tokenizer: PreTrainedTokenizerBase): super().__init__(tokenizer) - self.think_start_token = "" - self.think_end_token = "" self.reasoning_regex = re.compile( - rf"{self.think_start_token}(.*?){self.think_end_token}", re.DOTALL) + rf"{self.start_token}(.*?){self.end_token}", re.DOTALL) if not self.model_tokenizer: raise ValueError( "The model tokenizer must be passed to the ReasoningParser " "constructor during construction.") - self.think_start_token_id = self.vocab.get(self.think_start_token) - self.think_end_token_id = self.vocab.get(self.think_end_token) - if (self.think_start_token_id is None - or self.think_end_token_id is None): + self.start_token_id = self.vocab.get(self.start_token) + self.end_token_id = self.vocab.get(self.end_token) + if self.start_token_id is None or self.end_token_id is None: raise RuntimeError( "DeepSeek R1 reasoning parser could not locate think start/end " "tokens in the tokenizer!") - # TODO: need to rebase by PR #14428 def is_reasoning_end(self, input_ids: list[int]) -> bool: - return self.think_end_token_id in input_ids + return self.end_token_id in input_ids def extract_content_ids(self, input_ids: list[int]) -> list[int]: """ Extract the content after the end tokens """ - if self.think_end_token_id not in input_ids[:-1]: + if self.end_token_id not in input_ids[:-1]: return [] else: - return input_ids[input_ids.index(self.think_end_token_id) + 1:] + return input_ids[input_ids.index(self.end_token_id) + 1:] def extract_reasoning_content_streaming( self, @@ -77,22 +78,24 @@ class DeepSeekR1ReasoningParser(ReasoningParser): """ # Skip single special tokens if len(delta_token_ids) == 1 and (delta_token_ids[0] in [ - self.think_start_token_id, self.think_end_token_id + self.start_token_id, self.end_token_id ]): return None # Check if is present in previous or delta. # Keep compatibility with models that don't generate tokens. - if self.think_start_token_id in previous_token_ids: - if self.think_end_token_id in delta_token_ids: + if self.start_token_id in previous_token_ids: + if self.end_token_id in delta_token_ids: # in previous, in delta, # extract reasoning content - end_index = delta_text.find(self.think_end_token) + end_index = delta_text.find(self.end_token) reasoning_content = delta_text[:end_index] - content = delta_text[end_index + len(self.think_end_token):] - return DeltaMessage(reasoning_content=reasoning_content, - content=content if content else None) - elif self.think_end_token_id in previous_token_ids: + content = delta_text[end_index + len(self.end_token):] + return DeltaMessage( + reasoning_content=reasoning_content, + content=content if content else None, + ) + elif self.end_token_id in previous_token_ids: # in previous, in previous, # reasoning content continues return DeltaMessage(content=delta_text) @@ -100,17 +103,18 @@ class DeepSeekR1ReasoningParser(ReasoningParser): # in previous, no in previous or delta, # reasoning content continues return DeltaMessage(reasoning_content=delta_text) - elif self.think_start_token_id in delta_token_ids: - if self.think_end_token_id in delta_token_ids: + elif self.start_token_id in delta_token_ids: + if self.end_token_id in delta_token_ids: # in delta, in delta, extract reasoning content - start_index = delta_text.find(self.think_start_token) - end_index = delta_text.find(self.think_end_token) + start_index = delta_text.find(self.start_token) + end_index = delta_text.find(self.end_token) reasoning_content = delta_text[start_index + - len(self.think_start_token - ):end_index] - content = delta_text[end_index + len(self.think_end_token):] - return DeltaMessage(reasoning_content=reasoning_content, - content=content if content else None) + len(self.start_token):end_index] + content = delta_text[end_index + len(self.end_token):] + return DeltaMessage( + reasoning_content=reasoning_content, + content=content if content else None, + ) else: # in delta, no in delta, # reasoning content continues @@ -119,15 +123,17 @@ class DeepSeekR1ReasoningParser(ReasoningParser): # No in previous or delta, also need to check for . # Because the model may have generated without # Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f - if self.think_end_token_id in delta_token_ids: + if self.end_token_id in delta_token_ids: # in delta with more tokens, # extract reasoning content and content - end_index = delta_text.find(self.think_end_token) + end_index = delta_text.find(self.end_token) reasoning_content = delta_text[:end_index] - content = delta_text[end_index + len(self.think_end_token):] - return DeltaMessage(reasoning_content=reasoning_content, - content=content if content else None) - elif self.think_end_token_id in previous_token_ids: + content = delta_text[end_index + len(self.end_token):] + return DeltaMessage( + reasoning_content=reasoning_content, + content=content if content else None, + ) + elif self.end_token_id in previous_token_ids: # in previous, thinking content ends return DeltaMessage(content=delta_text) else: @@ -137,22 +143,20 @@ class DeepSeekR1ReasoningParser(ReasoningParser): def extract_reasoning_content( self, model_output: str, request: ChatCompletionRequest ) -> tuple[Optional[str], Optional[str]]: - # DeepSeek R1 doesn't generate now. # Thus we assume the reasoning content is always at the start. # Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f - if self.think_end_token not in model_output: + if self.end_token not in model_output: return model_output, None else: # Add a start token if it's missing to keep compatibility. - if self.think_start_token not in model_output: - model_output = f"{self.think_start_token}{model_output}" + if self.start_token not in model_output: + model_output = f"{self.start_token}{model_output}" # Use a regex to find the reasoning content reasoning_content = self.reasoning_regex.findall(model_output)[0] end_index = len( - f"{self.think_start_token}{reasoning_content}{self.think_end_token}" - ) + f"{self.start_token}{reasoning_content}{self.end_token}") final_output = model_output[end_index:] if len(final_output) == 0: diff --git a/vllm/entrypoints/openai/reasoning_parsers/granite_reasoning_parser.py b/vllm/reasoning/granite_reasoning_parser.py similarity index 99% rename from vllm/entrypoints/openai/reasoning_parsers/granite_reasoning_parser.py rename to vllm/reasoning/granite_reasoning_parser.py index 117d051a73782..249ace1f167fa 100644 --- a/vllm/entrypoints/openai/reasoning_parsers/granite_reasoning_parser.py +++ b/vllm/reasoning/granite_reasoning_parser.py @@ -8,9 +8,8 @@ from transformers import PreTrainedTokenizerBase from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaMessage) -from vllm.entrypoints.openai.reasoning_parsers.abs_reasoning_parsers import ( - ReasoningParser, ReasoningParserManager) from vllm.logger import init_logger +from vllm.reasoning import ReasoningParser, ReasoningParserManager logger = init_logger(__name__)