diff --git a/tests/engine/test_stop_checker.py b/tests/engine/test_stop_checker.py new file mode 100644 index 0000000000000..3d1e1c8032a48 --- /dev/null +++ b/tests/engine/test_stop_checker.py @@ -0,0 +1,228 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +from transformers import AutoTokenizer + +from vllm.engine.output_processor.stop_checker import StopChecker +from vllm.reasoning import ReasoningParser +from vllm.sampling_params import SamplingParams +from vllm.sequence import Sequence, SequenceStatus + +REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" + + +class MockReasoningParser(ReasoningParser): + """Mock reasoning parser for testing purposes.""" + + def __init__(self, + tokenizer: AutoTokenizer, + reasoning_active: bool = False): + super().__init__(tokenizer) + self.reasoning_active = reasoning_active + + def is_reasoning_end(self, input_ids: list[int]) -> bool: + return not self.reasoning_active + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + return input_ids + + +class MockSequence(Sequence): + """Mock sequence for testing purposes.""" + + def __init__(self, token_ids, output_text="test_output", eos_token_id=0): + self.token_ids = token_ids + self.output_text = output_text + self.eos_token_id = eos_token_id + self.status = SequenceStatus.RUNNING + self.stop_reason = None + + def get_token_ids(self): + return self.token_ids + + def get_last_token_id(self): + return self.token_ids[-1] if self.token_ids else None + + def get_len(self): + return len(self.token_ids) + + def get_output_len(self): + return len(self.token_ids) - 1 # Simulating prompt + outputs + + +@pytest.fixture +def deepseek_r1_qwen_tokenizer(): + return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME) + + +@pytest.fixture +def stop_checker(): + return StopChecker(max_model_len=10, + get_tokenizer_for_seq=deepseek_r1_qwen_tokenizer) + + +@pytest.fixture +def stop_checker_with_reasoner(): + reasoner = MockReasoningParser(deepseek_r1_qwen_tokenizer) + return StopChecker(max_model_len=10, + get_tokenizer_for_seq=deepseek_r1_qwen_tokenizer, + reasoner=reasoner) + + +def test_eos_token_stopping(stop_checker): + """Test sequence stopping when EOS token is encountered.""" + seq = MockSequence(token_ids=[1, 2, 0], eos_token_id=0) + sampling_params = SamplingParams() + + stop_checker.maybe_stop_sequence(seq, + new_char_count=1, + sampling_params=sampling_params) + + assert seq.status == SequenceStatus.FINISHED_STOPPED + + +def test_ignore_eos(stop_checker): + """Test sequence continuing when EOS token is ignored.""" + seq = MockSequence(token_ids=[1, 2, 0], eos_token_id=0) + sampling_params = SamplingParams(ignore_eos=True) + + stop_checker.maybe_stop_sequence(seq, + new_char_count=1, + sampling_params=sampling_params) + + assert seq.status == SequenceStatus.RUNNING + + +def test_min_tokens(stop_checker): + """Test min_tokens prevents early stopping.""" + seq = MockSequence(token_ids=[1, 2, 0], eos_token_id=0) + sampling_params = SamplingParams(min_tokens=3) + + stop_checker.maybe_stop_sequence(seq, + new_char_count=1, + sampling_params=sampling_params) + + assert seq.status == SequenceStatus.RUNNING + + +def test_stop_token_ids(stop_checker): + """Test sequence stopping with custom stop token IDs.""" + seq = MockSequence(token_ids=[1, 2, 3], eos_token_id=0) + sampling_params = SamplingParams(stop_token_ids=[3]) + + stop_checker.maybe_stop_sequence(seq, + new_char_count=1, + sampling_params=sampling_params) + + assert seq.status == SequenceStatus.FINISHED_STOPPED + assert seq.stop_reason == 3 + + +def test_stop_strings(stop_checker): + """Test sequence stopping with stop strings.""" + seq = MockSequence(token_ids=[1, 2, 3], + output_text="test output with STOP", + eos_token_id=0) + sampling_params = SamplingParams(stop=["STOP"]) + + stop_checker.maybe_stop_sequence(seq, + new_char_count=1, + sampling_params=sampling_params) + + assert seq.status == SequenceStatus.FINISHED_STOPPED + assert seq.stop_reason == "STOP" + assert "STOP" not in seq.output_text # Default behavior removes stop string + + +def test_include_stop_str_in_output(stop_checker): + """Test keeping stop strings in output.""" + seq = MockSequence(token_ids=[1, 2, 3], + output_text="test output with STOP", + eos_token_id=0) + sampling_params = SamplingParams(stop=["STOP"], + include_stop_str_in_output=True) + + stop_checker.maybe_stop_sequence(seq, + new_char_count=5, + sampling_params=sampling_params) + + assert seq.status == SequenceStatus.FINISHED_STOPPED + assert "STOP" in seq.output_text + + +def test_max_tokens(stop_checker): + """Test sequence stopping at max_tokens.""" + seq = MockSequence(token_ids=[1, 2, 3], eos_token_id=0) + sampling_params = SamplingParams(max_tokens=2) + + stop_checker.maybe_stop_sequence(seq, + new_char_count=1, + sampling_params=sampling_params) + + assert seq.status == SequenceStatus.FINISHED_LENGTH_CAPPED + + +def test_max_model_len(stop_checker): + """Test sequence stopping at max_model_len.""" + seq = MockSequence(token_ids=list(range(11)), + eos_token_id=0) # 11 tokens, max is 10 + sampling_params = SamplingParams() + + stop_checker.maybe_stop_sequence(seq, + new_char_count=1, + sampling_params=sampling_params) + + assert seq.status == SequenceStatus.FINISHED_LENGTH_CAPPED + + +def test_reasoning_skip_stops(stop_checker_with_reasoner): + """Test that stop tokens and strings are ignored during reasoning.""" + # Set reasoning_active to True to simulate being in reasoning mode + stop_checker_with_reasoner.reasoner.reasoning_active = True + + # Test with stop token + seq = MockSequence(token_ids=[1, 2, 3], eos_token_id=0) + sampling_params = SamplingParams(stop_token_ids=[3]) + + stop_checker_with_reasoner.maybe_stop_sequence( + seq, new_char_count=1, sampling_params=sampling_params) + assert seq.status == SequenceStatus.RUNNING + + # Test with stop string + seq = MockSequence(token_ids=[1, 2, 3], output_text="test STOP") + sampling_params = SamplingParams(stop=["STOP"]) + + stop_checker_with_reasoner.maybe_stop_sequence( + seq, new_char_count=4, sampling_params=sampling_params) + assert seq.status == SequenceStatus.RUNNING + + # But EOS token still stops the sequence + seq = MockSequence(token_ids=[1, 2, 0], eos_token_id=0) + sampling_params = SamplingParams() + + stop_checker_with_reasoner.maybe_stop_sequence( + seq, new_char_count=1, sampling_params=sampling_params) + assert seq.status == SequenceStatus.FINISHED_STOPPED + + +def test_reasoning_end_enables_stops(stop_checker_with_reasoner): + """Test that stop tokens work after reasoning ends.""" + # Set reasoning_active to False to simulate being out of reasoning mode + stop_checker_with_reasoner.reasoner.reasoning_active = False + + # Test with stop token + seq = MockSequence(token_ids=[1, 2, 3], eos_token_id=0) + sampling_params = SamplingParams(stop_token_ids=[3]) + + stop_checker_with_reasoner.maybe_stop_sequence( + seq, new_char_count=1, sampling_params=sampling_params) + assert seq.status == SequenceStatus.FINISHED_STOPPED + + # Test with stop string + seq = MockSequence(token_ids=[1, 2, 3], output_text="test STOP") + sampling_params = SamplingParams(stop=["STOP"]) + + stop_checker_with_reasoner.maybe_stop_sequence( + seq, new_char_count=4, sampling_params=sampling_params) + assert seq.status == SequenceStatus.FINISHED_STOPPED diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index c303d093f6324..f25530fc9dac8 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -40,6 +40,7 @@ from vllm.multimodal.cache import processor_only_cache_from_config from vllm.multimodal.processing import EncDecMultiModalProcessor from vllm.outputs import (PoolingRequestOutput, RequestOutput, RequestOutputFactory) +from vllm.reasoning import ReasoningParser, ReasoningParserManager from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup, Sequence, SequenceGroup, SequenceGroupBase, @@ -372,6 +373,14 @@ class LLMEngine: "vllm.llm_engine", self.observability_config.otlp_traces_endpoint) + # Initialize reasoning parser if reasoning backend is set. + if self.decoding_config.reasoning_backend and \ + self.tokenizer: + reasoner_class = ReasoningParserManager.get_reasoning_parser( + self.decoding_config.reasoning_backend) + self.reasoner: ReasoningParser = reasoner_class( + self.tokenizer.get_lora_tokenizer()) + # Create sequence output processor, e.g. for beam search or # speculative decoding. self.output_processor = ( @@ -381,8 +390,12 @@ class LLMEngine: self.scheduler, self.seq_counter, get_tokenizer_for_seq, - stop_checker=StopChecker(self.scheduler_config.max_model_len, - get_tokenizer_for_seq), + stop_checker=StopChecker( + self.scheduler_config.max_model_len, + get_tokenizer_for_seq, + self.reasoner if self.decoding_config.reasoning_backend + and self.tokenizer else None, + ), )) self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {} diff --git a/vllm/engine/output_processor/stop_checker.py b/vllm/engine/output_processor/stop_checker.py index 3fb2f71b5e999..68a63044df05e 100644 --- a/vllm/engine/output_processor/stop_checker.py +++ b/vllm/engine/output_processor/stop_checker.py @@ -4,6 +4,7 @@ from typing import Callable, List, Optional, Tuple from vllm.lora.request import LoRARequest +from vllm.reasoning import ReasoningParser from vllm.sampling_params import SamplingParams from vllm.sequence import Sequence, SequenceStatus from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -16,11 +17,16 @@ class StopChecker: emitted, or if we have exceeded the max model len. """ - def __init__(self, max_model_len: int, - get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer]): + def __init__( + self, + max_model_len: int, + get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer], + reasoner: Optional[ReasoningParser] = None, + ): # Do not use it directly, but use `self._get_max_model_len`. self._max_model_len = max_model_len self.get_tokenizer_for_seq = get_tokenizer_for_seq + self.reasoner = reasoner def _get_max_model_len(self, lora_req: Optional[LoRARequest]): if lora_req and lora_req.long_lora_max_len: @@ -57,6 +63,11 @@ class StopChecker: seq.status = SequenceStatus.FINISHED_STOPPED return + # Skip stop string/token checks if in reasoning content generation + if self.reasoner is not None and \ + not self.reasoner.is_reasoning_end(seq.get_token_ids()): + return + # Check if a stop token was encountered. # This assumes a single token produced per step. last_token_id = seq.get_last_token_id()