From 6796ce8bdbf29f5624fcdc03792626574c919b41 Mon Sep 17 00:00:00 2001 From: Chauncey Date: Thu, 4 Dec 2025 19:11:59 +0800 Subject: [PATCH] [Bugfix] Fix the issue with interleaved thinking when using streaming (#30033) Signed-off-by: chaunceyjiang Signed-off-by: Chauncey Co-authored-by: Cyrus Leung --- .../reasoning/test_base_thinking_reasoning_parser.py | 12 +++++++++++- vllm/reasoning/basic_parsers.py | 9 ++++++++- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/tests/reasoning/test_base_thinking_reasoning_parser.py b/tests/reasoning/test_base_thinking_reasoning_parser.py index d31b1c7d169b7..34e9483de54b3 100644 --- a/tests/reasoning/test_base_thinking_reasoning_parser.py +++ b/tests/reasoning/test_base_thinking_reasoning_parser.py @@ -112,7 +112,7 @@ class TestBaseThinkingReasoningParserMethods: """Test the is_reasoning_end method.""" parser = TestThinkingReasoningParser(test_tokenizer) end_token_id = parser.end_token_id - + start_token_id = parser.start_token_id # Test with end token present assert parser.is_reasoning_end([1, 2, end_token_id, 4]) is True @@ -122,6 +122,16 @@ class TestBaseThinkingReasoningParserMethods: # Test with empty list assert parser.is_reasoning_end([]) is False + # Test with interleaved thinking + assert parser.is_reasoning_end([1, start_token_id, 2, end_token_id]) is True + assert parser.is_reasoning_end([1, start_token_id, 2, 3]) is False + assert ( + parser.is_reasoning_end( + [1, start_token_id, 2, end_token_id, 2, 2, start_token_id] + ) + is False + ) + def test_extract_content_ids(self, test_tokenizer): """Test the extract_content_ids method.""" parser = TestThinkingReasoningParser(test_tokenizer) diff --git a/vllm/reasoning/basic_parsers.py b/vllm/reasoning/basic_parsers.py index 35084c0e7cc86..e78ac4a5ebb37 100644 --- a/vllm/reasoning/basic_parsers.py +++ b/vllm/reasoning/basic_parsers.py @@ -64,8 +64,15 @@ class BaseThinkingReasoningParser(ReasoningParser): ) def is_reasoning_end(self, input_ids: list[int]) -> bool: + start_token_id = self.start_token_id end_token_id = self.end_token_id - return any(input_id == end_token_id for input_id in reversed(input_ids)) + + for i in range(len(input_ids) - 1, -1, -1): + if input_ids[i] == start_token_id: + return False + if input_ids[i] == end_token_id: + return True + return False def extract_content_ids(self, input_ids: list[int]) -> list[int]: """