diff --git a/tests/v1/e2e/test_context_length.py b/tests/v1/e2e/test_context_length.py new file mode 100644 index 000000000000..0ac40bec35fe --- /dev/null +++ b/tests/v1/e2e/test_context_length.py @@ -0,0 +1,63 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Tests for vLLM `vllm/v1/engine/processor.Processor._validate_model_input()` +handling of maximum context length for decoder models. + +This test ensures: +- A prompt that is one token shorter than the model's maximum context length + can be processed successfully when requesting one additional token. +- A prompt that reaches the model's maximum context length throws a + `ValueError` when requesting at least one additional token. +""" + +import pytest + +from tests.conftest import VllmRunner +from tests.utils import create_new_process_for_each_test + + +@create_new_process_for_each_test() +@pytest.mark.parametrize("model, max_model_len", [("JackFram/llama-160m", 2048)]) +@pytest.mark.parametrize( + "prompt_len, max_tokens", + [ + (2047, 1), # prompt_len = max_model_len - 1 -> allowed + (2048, 1), # prompt_len = max_model_len -> not allowed + ], +) +def test_decoder_max_context_length_validation( + model: str, + max_model_len: int, + vllm_runner: type[VllmRunner], + prompt_len: int, + max_tokens: int, +) -> None: + """Check vLLM decoder model input validation for edge cases where + the prompt length is (almost) equal to the max model length.""" + + prompt_ids = [[43] * prompt_len] + + with vllm_runner( + model_name=model, + tokenizer_name=model, + max_model_len=max_model_len, + max_num_seqs=1, + tensor_parallel_size=1, + ) as vllm_model: + if prompt_len + max_tokens <= max_model_len: + # Should succeed as constraints are met + vllm_model.generate_greedy(prompt_ids, max_tokens) + else: + # Should raise the ValueError defined in + # vllm/v1/engine/processor.Processor_validate_model_input() + expected_msg = ( + f"The decoder prompt (length {prompt_len}) plus the number of " + f"requested output tokens (at least 1) is longer than " + f"the maximum model length of {max_model_len}. " + "Make sure that `max_model_len` is no smaller than the number of " + "text tokens (prompt + requested output tokens)." + ) + with pytest.raises(ValueError) as excinfo: + vllm_model.generate_greedy(prompt_ids, max_tokens) + assert expected_msg in str(excinfo.value) diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index f2d992403e1a..69509d5d4712 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -575,6 +575,21 @@ class Processor: # check that chunked prefill does not truncate them # max_batch_len = self.scheduler_config.max_num_batched_tokens + if ( + prompt_len == max_prompt_len + and prompt_type == "decoder" + and not model_config.is_multimodal_model + ): + suggestion = ( + "Make sure that `max_model_len` is no smaller than the " + "number of text tokens (prompt + requested output tokens)." + ) + raise ValueError( + f"The {prompt_type} prompt (length {prompt_len}) plus the number of " + f"requested output tokens (at least 1) is longer than the maximum " + f"model length of {max_prompt_len}. {suggestion}" + ) + def stat_mm_cache(self) -> MultiModalCacheStats | None: return self.input_preprocessor.stat_mm_cache()