From a39203f99ed426eef8b55927cb8f8668644d0a86 Mon Sep 17 00:00:00 2001 From: mofanke <54242816+mofanke@users.noreply.github.com> Date: Wed, 30 Apr 2025 00:32:40 +0800 Subject: [PATCH] =?UTF-8?q?[Bugfix]=20add=20qwen3=20reasoning-parser=20fix?= =?UTF-8?q?=20content=20is=20None=20when=20disable=20=E2=80=A6=20(#17369)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: mofanke --- docs/source/features/reasoning_outputs.md | 1 + .../reasoning/test_qwen3_reasoning_parser.py | 141 ++++++++++++++++++ vllm/reasoning/__init__.py | 2 + vllm/reasoning/qwen3_reasoning_parser.py | 138 +++++++++++++++++ 4 files changed, 282 insertions(+) create mode 100644 tests/reasoning/test_qwen3_reasoning_parser.py create mode 100644 vllm/reasoning/qwen3_reasoning_parser.py diff --git a/docs/source/features/reasoning_outputs.md b/docs/source/features/reasoning_outputs.md index 3a0be69f8e1c6..323bf849a082d 100644 --- a/docs/source/features/reasoning_outputs.md +++ b/docs/source/features/reasoning_outputs.md @@ -15,6 +15,7 @@ vLLM currently supports the following reasoning models: | [DeepSeek R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d) | `deepseek_r1` | `guided_json`, `guided_regex` | ❌ | | [QwQ-32B](https://huggingface.co/Qwen/QwQ-32B) | `deepseek_r1` | `guided_json`, `guided_regex` | ✅ | | [IBM Granite 3.2 language models](https://huggingface.co/collections/ibm-granite/granite-32-language-models-67b3bc8c13508f6d064cff9a) | `granite` | ❌ | ❌ | +| [Qwen3 series](https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f) | `qwen3` | `guided_json`, `guided_regex` | ✅ | - IBM Granite 3.2 reasoning is disabled by default; to enable it, you must also pass `thinking=True` in your `chat_template_kwargs`. diff --git a/tests/reasoning/test_qwen3_reasoning_parser.py b/tests/reasoning/test_qwen3_reasoning_parser.py new file mode 100644 index 0000000000000..95b7460d359e4 --- /dev/null +++ b/tests/reasoning/test_qwen3_reasoning_parser.py @@ -0,0 +1,141 @@ +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from transformers import AutoTokenizer + +from tests.reasoning.utils import run_reasoning_extraction +from vllm.reasoning import ReasoningParser, ReasoningParserManager + +parser_name = "qwen3" +start_token = "" +end_token = "" + +REASONING_MODEL_NAME = "Qwen/Qwen3-0.6B" + + +@pytest.fixture(scope="module") +def qwen3_tokenizer(): + return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME) + + +# 带 ,非stream +WITH_THINK = { + "output": "This is a reasoning sectionThis is the rest", + "reasoning_content": "This is a reasoning section", + "content": "This is the rest", +} +# 带 ,stream +WITH_THINK_STREAM = { + "output": "This is a reasoning sectionThis is the rest", + "reasoning_content": "This is a reasoning section", + "content": "This is the rest", +} +# 不带 ,非stream +WITHOUT_THINK = { + "output": "This is the rest", + "reasoning_content": None, + "content": "This is the rest", +} +# 不带 ,stream +WITHOUT_THINK_STREAM = { + "output": "This is the rest", + "reasoning_content": None, + "content": "This is the rest", +} + +COMPLETE_REASONING = { + "output": "This is a reasoning section", + "reasoning_content": "This is a reasoning section", + "content": None, +} +MULTILINE_REASONING = { + "output": + "This is a reasoning\nsectionThis is the rest\nThat", + "reasoning_content": "This is a reasoning\nsection", + "content": "This is the rest\nThat", +} +ONLY_OPEN_TAG = { + "output": "This is a reasoning section", + "reasoning_content": None, + "content": "This is a reasoning section", +} + +ONLY_OPEN_TAG_STREAM = { + "output": "This is a reasoning section", + "reasoning_content": "This is a reasoning section", + "content": None, +} + +TEST_CASES = [ + pytest.param( + False, + WITH_THINK, + id="with_think", + ), + pytest.param( + True, + WITH_THINK_STREAM, + id="with_think_stream", + ), + pytest.param( + False, + WITHOUT_THINK, + id="without_think", + ), + pytest.param( + True, + WITHOUT_THINK_STREAM, + id="without_think_stream", + ), + pytest.param( + False, + COMPLETE_REASONING, + id="complete_reasoning", + ), + pytest.param( + True, + COMPLETE_REASONING, + id="complete_reasoning_stream", + ), + pytest.param( + False, + MULTILINE_REASONING, + id="multiline_reasoning", + ), + pytest.param( + True, + MULTILINE_REASONING, + id="multiline_reasoning_stream", + ), + pytest.param( + False, + ONLY_OPEN_TAG, + id="only_open_tag", + ), + pytest.param( + True, + ONLY_OPEN_TAG_STREAM, + id="only_open_tag_stream", + ), +] + + +@pytest.mark.parametrize("streaming, param_dict", TEST_CASES) +def test_reasoning( + streaming: bool, + param_dict: dict, + qwen3_tokenizer, +): + output = qwen3_tokenizer.tokenize(param_dict["output"]) + output_tokens: list[str] = [ + qwen3_tokenizer.convert_tokens_to_string([token]) for token in output + ] + parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser( + parser_name)(qwen3_tokenizer) + + reasoning, content = run_reasoning_extraction(parser, + output_tokens, + streaming=streaming) + + assert reasoning == param_dict["reasoning_content"] + assert content == param_dict["content"] diff --git a/vllm/reasoning/__init__.py b/vllm/reasoning/__init__.py index 45132a780e5b2..65606ce55af72 100644 --- a/vllm/reasoning/__init__.py +++ b/vllm/reasoning/__init__.py @@ -3,10 +3,12 @@ from .abs_reasoning_parsers import ReasoningParser, ReasoningParserManager from .deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser from .granite_reasoning_parser import GraniteReasoningParser +from .qwen3_reasoning_parser import Qwen3ReasoningParser __all__ = [ "ReasoningParser", "ReasoningParserManager", "DeepSeekR1ReasoningParser", "GraniteReasoningParser", + "Qwen3ReasoningParser", ] diff --git a/vllm/reasoning/qwen3_reasoning_parser.py b/vllm/reasoning/qwen3_reasoning_parser.py new file mode 100644 index 0000000000000..78a73011ff88f --- /dev/null +++ b/vllm/reasoning/qwen3_reasoning_parser.py @@ -0,0 +1,138 @@ +# SPDX-License-Identifier: Apache-2.0 + +import re +from collections.abc import Sequence +from typing import Optional, Union + +from transformers import PreTrainedTokenizerBase + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaMessage) +from vllm.logger import init_logger +from vllm.reasoning import ReasoningParser, ReasoningParserManager + +logger = init_logger(__name__) + + +@ReasoningParserManager.register_module("qwen3") +class Qwen3ReasoningParser(ReasoningParser): + """ + Reasoning parser for the Qwen3 model. + + The Qwen3 model uses ... tokens to denote reasoning text + within its output. The model provides a strict switch to disable reasoning + output via the 'enable_thinking=False' parameter. This parser extracts the + reasoning content enclosed by and tokens from the model's + output. + """ + + def __init__(self, tokenizer: PreTrainedTokenizerBase): + super().__init__(tokenizer) + self.think_start_token = "" + self.think_end_token = "" + + self.reasoning_regex = re.compile( + rf"{self.think_start_token}(.*?){self.think_end_token}", re.DOTALL) + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ReasoningParser " + "constructor during construction.") + + self.think_start_token_id = self.vocab.get(self.think_start_token) + self.think_end_token_id = self.vocab.get(self.think_end_token) + if (self.think_start_token_id is None + or self.think_end_token_id is None): + raise RuntimeError( + "Qwen3 reasoning parser could not locate think start/end " + "tokens in the tokenizer!") + + def extract_reasoning_content_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> Union[DeltaMessage, None]: + """ + Extract reasoning content from a delta message. + Handles streaming output where previous + delta = current. + Uses token IDs for faster processing. + For text abcxyz: + - 'abc' goes to reasoning_content + - 'xyz' goes to content + """ + # Skip single special tokens + if len(delta_token_ids) == 1 and (delta_token_ids[0] in [ + self.think_start_token_id, self.think_end_token_id + ]): + return None + + if self.think_start_token_id in previous_token_ids: + if self.think_end_token_id in delta_token_ids: + # in previous, in delta, + # extract reasoning content + end_index = delta_text.find(self.think_end_token) + reasoning_content = delta_text[:end_index] + content = delta_text[end_index + len(self.think_end_token):] + return DeltaMessage(reasoning_content=reasoning_content, + content=content if content else None) + elif self.think_end_token_id in previous_token_ids: + # in previous, in previous, + # reasoning content continues + return DeltaMessage(content=delta_text) + else: + # in previous, no in previous or delta, + # reasoning content continues + return DeltaMessage(reasoning_content=delta_text) + elif self.think_start_token_id in delta_token_ids: + logger.info(delta_text) + if self.think_end_token_id in delta_token_ids: + # in delta, in delta, extract reasoning content + start_index = delta_text.find(self.think_start_token) + end_index = delta_text.find(self.think_end_token) + reasoning_content = delta_text[start_index + + len(self.think_start_token + ):end_index] + content = delta_text[end_index + len(self.think_end_token):] + return DeltaMessage(reasoning_content=reasoning_content, + content=content if content else None) + else: + # in delta, no in delta, + # reasoning content continues + return DeltaMessage(reasoning_content=delta_text) + else: + # thinking is disabled, just content + return DeltaMessage(content=delta_text) + + def extract_reasoning_content( + self, model_output: str, request: ChatCompletionRequest + ) -> tuple[Optional[str], Optional[str]]: + + # Check if the model output contains the tokens. + if (self.think_start_token not in model_output + or self.think_end_token not in model_output): + return None, model_output + else: + # Use a regex to find the reasoning content + reasoning_content = self.reasoning_regex.findall(model_output)[0] + + # Remove the reasoning content from the model output + # Although token is always at the + # beginning of the line, we cannot guarantee that the + # other models will follow this convention. + # Therefore, we need to add :start_index. + start_index = model_output.find(self.think_start_token) + if start_index != -1: + end_index = start_index + len( + f"{self.think_start_token}{reasoning_content}{self.think_end_token}" + ) + model_output = model_output[:start_index] + \ + model_output[end_index:] + + if len(model_output) == 0: + return reasoning_content, None + + return reasoning_content, model_output