From f08919b7d17f6adf4eea26ca410c4a3de41a71da Mon Sep 17 00:00:00 2001 From: Elaine Zhao Date: Wed, 8 Oct 2025 14:08:24 -0700 Subject: [PATCH] [Bugfix] Respect min_tokens in scheduler stop check (#26317) Signed-off-by: Elaine Zhao --- tests/v1/core/test_scheduler.py | 90 +++++++++++++++++++++++++++++++++ vllm/v1/core/sched/utils.py | 5 ++ 2 files changed, 95 insertions(+) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index e78cced2d2db4..5baa2a1be4ab0 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -497,6 +497,96 @@ def test_stop_via_update_from_output(): assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID, 10, 11] +def test_check_stop_min_tokens(): + """Test that requests don't stop when min_tokens requirement isn't met.""" + from vllm.v1.core.sched.utils import check_stop + + # Test case 1: num_output_tokens < min_tokens + # Should return False (don't stop) + sampling_params = SamplingParams( + ignore_eos=False, + max_tokens=20, + min_tokens=5, + ) + request = Request( + request_id="0", + prompt_token_ids=[0, 1, 2], + sampling_params=sampling_params, + pooling_params=None, + eos_token_id=EOS_TOKEN_ID, + ) + # Simulate having generated 3 output tokens (less than min_tokens=5) + request.append_output_token_ids([10, 11, EOS_TOKEN_ID]) # EOS token present + + result = check_stop(request, max_model_len=100) + assert result is False, "Should not stop when num_output_tokens= min_tokens + # Should follow normal stopping logic (stop on EOS) + request.append_output_token_ids( + [ + 10, + 11, + 12, + 13, + 14, + EOS_TOKEN_ID, + ] + ) # 6 tokens > min_tokens + + result = check_stop(request, max_model_len=100) + assert result is True, "Should stop on EOS when min_tokens met" + assert request.status == RequestStatus.FINISHED_STOPPED + + # Test case 3: min_tokens = 0, should follow normal stopping logic + sampling_params_no_min = SamplingParams( + ignore_eos=False, + max_tokens=20, + min_tokens=0, + ) + request_no_min = Request( + request_id="1", + prompt_token_ids=[0, 1, 2], + sampling_params=sampling_params_no_min, + pooling_params=None, + eos_token_id=EOS_TOKEN_ID, + ) + request_no_min.append_output_token_ids([10, EOS_TOKEN_ID]) + + result = check_stop(request_no_min, max_model_len=100) + assert result is True, "Should stop on EOS when min_tokens=0" + assert request_no_min.status == RequestStatus.FINISHED_STOPPED + + # Test case 4: min_tokens > 0 with stop token (not EOS) + sampling_params_stop = SamplingParams( + ignore_eos=False, + max_tokens=20, + min_tokens=5, + stop_token_ids=[42], + ) + request_stop = Request( + request_id="2", + prompt_token_ids=[0, 1, 2], + sampling_params=sampling_params_stop, + pooling_params=None, + eos_token_id=EOS_TOKEN_ID, + ) + # Only 3 output tokens, less than min_tokens=5, but has stop token + request_stop.append_output_token_ids([10, 11, 42]) + result = check_stop(request_stop, max_model_len=100) + assert result is False, "Should not stop when num_output_tokens= min_tokens=5 + + result = check_stop(request_stop, max_model_len=100) + assert result is True, "Should stop on stop token when min_tokens met" + assert request_stop.status == RequestStatus.FINISHED_STOPPED + assert request_stop.stop_reason == 42 + + @pytest.mark.parametrize( "enable_prefix_caching, prompt_logprobs", [ diff --git a/vllm/v1/core/sched/utils.py b/vllm/v1/core/sched/utils.py index 4f17468d2d581..c5989b37d22d7 100644 --- a/vllm/v1/core/sched/utils.py +++ b/vllm/v1/core/sched/utils.py @@ -58,6 +58,11 @@ def check_stop( sampling_params = request.sampling_params assert sampling_params is not None + + min_tokens = sampling_params.min_tokens + if request.num_output_tokens < min_tokens: + return False + last_token_id = request.output_token_ids[-1] if not sampling_params.ignore_eos and last_token_id == request.eos_token_id: request.status = RequestStatus.FINISHED_STOPPED