mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 06:35:00 +08:00
Signed-off-by: Yannick Schnider <yannick.schnider1@ibm.com> Signed-off-by: Yannick Schnider <Yannick.Schnider1@ibm.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Roger Wang <hey@rogerw.io>
64 lines
2.4 KiB
Python
64 lines
2.4 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""
|
|
Tests for vLLM `vllm/v1/engine/processor.Processor._validate_model_input()`
|
|
handling of maximum context length for decoder models.
|
|
|
|
This test ensures:
|
|
- A prompt that is one token shorter than the model's maximum context length
|
|
can be processed successfully when requesting one additional token.
|
|
- A prompt that reaches the model's maximum context length throws a
|
|
`ValueError` when requesting at least one additional token.
|
|
"""
|
|
|
|
import pytest
|
|
|
|
from tests.conftest import VllmRunner
|
|
from tests.utils import create_new_process_for_each_test
|
|
|
|
|
|
@create_new_process_for_each_test()
|
|
@pytest.mark.parametrize("model, max_model_len", [("JackFram/llama-160m", 2048)])
|
|
@pytest.mark.parametrize(
|
|
"prompt_len, max_tokens",
|
|
[
|
|
(2047, 1), # prompt_len = max_model_len - 1 -> allowed
|
|
(2048, 1), # prompt_len = max_model_len -> not allowed
|
|
],
|
|
)
|
|
def test_decoder_max_context_length_validation(
|
|
model: str,
|
|
max_model_len: int,
|
|
vllm_runner: type[VllmRunner],
|
|
prompt_len: int,
|
|
max_tokens: int,
|
|
) -> None:
|
|
"""Check vLLM decoder model input validation for edge cases where
|
|
the prompt length is (almost) equal to the max model length."""
|
|
|
|
prompt_ids = [[43] * prompt_len]
|
|
|
|
with vllm_runner(
|
|
model_name=model,
|
|
tokenizer_name=model,
|
|
max_model_len=max_model_len,
|
|
max_num_seqs=1,
|
|
tensor_parallel_size=1,
|
|
) as vllm_model:
|
|
if prompt_len + max_tokens <= max_model_len:
|
|
# Should succeed as constraints are met
|
|
vllm_model.generate_greedy(prompt_ids, max_tokens)
|
|
else:
|
|
# Should raise the ValueError defined in
|
|
# vllm/v1/engine/processor.Processor_validate_model_input()
|
|
expected_msg = (
|
|
f"The decoder prompt (length {prompt_len}) plus the number of "
|
|
f"requested output tokens (at least 1) is longer than "
|
|
f"the maximum model length of {max_model_len}. "
|
|
"Make sure that `max_model_len` is no smaller than the number of "
|
|
"text tokens (prompt + requested output tokens)."
|
|
)
|
|
with pytest.raises(ValueError) as excinfo:
|
|
vllm_model.generate_greedy(prompt_ids, max_tokens)
|
|
assert expected_msg in str(excinfo.value)
|