mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-27 14:14:28 +08:00
[Bugfix] Respect min_tokens in scheduler stop check (#26317)
Signed-off-by: Elaine Zhao <elaineyz@amazon.com>
This commit is contained in:
parent
93f2c0aa08
commit
f08919b7d1
@ -497,6 +497,96 @@ def test_stop_via_update_from_output():
|
|||||||
assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID, 10, 11]
|
assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID, 10, 11]
|
||||||
|
|
||||||
|
|
||||||
|
def test_check_stop_min_tokens():
|
||||||
|
"""Test that requests don't stop when min_tokens requirement isn't met."""
|
||||||
|
from vllm.v1.core.sched.utils import check_stop
|
||||||
|
|
||||||
|
# Test case 1: num_output_tokens < min_tokens
|
||||||
|
# Should return False (don't stop)
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
ignore_eos=False,
|
||||||
|
max_tokens=20,
|
||||||
|
min_tokens=5,
|
||||||
|
)
|
||||||
|
request = Request(
|
||||||
|
request_id="0",
|
||||||
|
prompt_token_ids=[0, 1, 2],
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
pooling_params=None,
|
||||||
|
eos_token_id=EOS_TOKEN_ID,
|
||||||
|
)
|
||||||
|
# Simulate having generated 3 output tokens (less than min_tokens=5)
|
||||||
|
request.append_output_token_ids([10, 11, EOS_TOKEN_ID]) # EOS token present
|
||||||
|
|
||||||
|
result = check_stop(request, max_model_len=100)
|
||||||
|
assert result is False, "Should not stop when num_output_tokens<min_tokens"
|
||||||
|
|
||||||
|
# Test case 2: num_output_tokens >= min_tokens
|
||||||
|
# Should follow normal stopping logic (stop on EOS)
|
||||||
|
request.append_output_token_ids(
|
||||||
|
[
|
||||||
|
10,
|
||||||
|
11,
|
||||||
|
12,
|
||||||
|
13,
|
||||||
|
14,
|
||||||
|
EOS_TOKEN_ID,
|
||||||
|
]
|
||||||
|
) # 6 tokens > min_tokens
|
||||||
|
|
||||||
|
result = check_stop(request, max_model_len=100)
|
||||||
|
assert result is True, "Should stop on EOS when min_tokens met"
|
||||||
|
assert request.status == RequestStatus.FINISHED_STOPPED
|
||||||
|
|
||||||
|
# Test case 3: min_tokens = 0, should follow normal stopping logic
|
||||||
|
sampling_params_no_min = SamplingParams(
|
||||||
|
ignore_eos=False,
|
||||||
|
max_tokens=20,
|
||||||
|
min_tokens=0,
|
||||||
|
)
|
||||||
|
request_no_min = Request(
|
||||||
|
request_id="1",
|
||||||
|
prompt_token_ids=[0, 1, 2],
|
||||||
|
sampling_params=sampling_params_no_min,
|
||||||
|
pooling_params=None,
|
||||||
|
eos_token_id=EOS_TOKEN_ID,
|
||||||
|
)
|
||||||
|
request_no_min.append_output_token_ids([10, EOS_TOKEN_ID])
|
||||||
|
|
||||||
|
result = check_stop(request_no_min, max_model_len=100)
|
||||||
|
assert result is True, "Should stop on EOS when min_tokens=0"
|
||||||
|
assert request_no_min.status == RequestStatus.FINISHED_STOPPED
|
||||||
|
|
||||||
|
# Test case 4: min_tokens > 0 with stop token (not EOS)
|
||||||
|
sampling_params_stop = SamplingParams(
|
||||||
|
ignore_eos=False,
|
||||||
|
max_tokens=20,
|
||||||
|
min_tokens=5,
|
||||||
|
stop_token_ids=[42],
|
||||||
|
)
|
||||||
|
request_stop = Request(
|
||||||
|
request_id="2",
|
||||||
|
prompt_token_ids=[0, 1, 2],
|
||||||
|
sampling_params=sampling_params_stop,
|
||||||
|
pooling_params=None,
|
||||||
|
eos_token_id=EOS_TOKEN_ID,
|
||||||
|
)
|
||||||
|
# Only 3 output tokens, less than min_tokens=5, but has stop token
|
||||||
|
request_stop.append_output_token_ids([10, 11, 42])
|
||||||
|
result = check_stop(request_stop, max_model_len=100)
|
||||||
|
assert result is False, "Should not stop when num_output_tokens<min_tokens"
|
||||||
|
|
||||||
|
# Test case 5: min_tokens met, should stop on stop token
|
||||||
|
request_stop.append_output_token_ids(
|
||||||
|
[10, 11, 12, 13, 14, 42]
|
||||||
|
) # 6 tokens >= min_tokens=5
|
||||||
|
|
||||||
|
result = check_stop(request_stop, max_model_len=100)
|
||||||
|
assert result is True, "Should stop on stop token when min_tokens met"
|
||||||
|
assert request_stop.status == RequestStatus.FINISHED_STOPPED
|
||||||
|
assert request_stop.stop_reason == 42
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"enable_prefix_caching, prompt_logprobs",
|
"enable_prefix_caching, prompt_logprobs",
|
||||||
[
|
[
|
||||||
|
|||||||
@ -58,6 +58,11 @@ def check_stop(
|
|||||||
|
|
||||||
sampling_params = request.sampling_params
|
sampling_params = request.sampling_params
|
||||||
assert sampling_params is not None
|
assert sampling_params is not None
|
||||||
|
|
||||||
|
min_tokens = sampling_params.min_tokens
|
||||||
|
if request.num_output_tokens < min_tokens:
|
||||||
|
return False
|
||||||
|
|
||||||
last_token_id = request.output_token_ids[-1]
|
last_token_id = request.output_token_ids[-1]
|
||||||
if not sampling_params.ignore_eos and last_token_id == request.eos_token_id:
|
if not sampling_params.ignore_eos and last_token_id == request.eos_token_id:
|
||||||
request.status = RequestStatus.FINISHED_STOPPED
|
request.status = RequestStatus.FINISHED_STOPPED
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user