[Bugfix] Fix validate model input for decoder models (#27099)

Signed-off-by: Yannick Schnider <yannick.schnider1@ibm.com>
Signed-off-by: Yannick Schnider <Yannick.Schnider1@ibm.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Yannick Schnider 2025-11-13 19:18:47 +01:00 committed by GitHub
parent fe1cd7704d
commit 119c4927b3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 78 additions and 0 deletions

View File

@ -0,0 +1,63 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests for vLLM `vllm/v1/engine/processor.Processor._validate_model_input()`
handling of maximum context length for decoder models.
This test ensures:
- A prompt that is one token shorter than the model's maximum context length
can be processed successfully when requesting one additional token.
- A prompt that reaches the model's maximum context length throws a
`ValueError` when requesting at least one additional token.
"""
import pytest
from tests.conftest import VllmRunner
from tests.utils import create_new_process_for_each_test
@create_new_process_for_each_test()
@pytest.mark.parametrize("model, max_model_len", [("JackFram/llama-160m", 2048)])
@pytest.mark.parametrize(
"prompt_len, max_tokens",
[
(2047, 1), # prompt_len = max_model_len - 1 -> allowed
(2048, 1), # prompt_len = max_model_len -> not allowed
],
)
def test_decoder_max_context_length_validation(
model: str,
max_model_len: int,
vllm_runner: type[VllmRunner],
prompt_len: int,
max_tokens: int,
) -> None:
"""Check vLLM decoder model input validation for edge cases where
the prompt length is (almost) equal to the max model length."""
prompt_ids = [[43] * prompt_len]
with vllm_runner(
model_name=model,
tokenizer_name=model,
max_model_len=max_model_len,
max_num_seqs=1,
tensor_parallel_size=1,
) as vllm_model:
if prompt_len + max_tokens <= max_model_len:
# Should succeed as constraints are met
vllm_model.generate_greedy(prompt_ids, max_tokens)
else:
# Should raise the ValueError defined in
# vllm/v1/engine/processor.Processor_validate_model_input()
expected_msg = (
f"The decoder prompt (length {prompt_len}) plus the number of "
f"requested output tokens (at least 1) is longer than "
f"the maximum model length of {max_model_len}. "
"Make sure that `max_model_len` is no smaller than the number of "
"text tokens (prompt + requested output tokens)."
)
with pytest.raises(ValueError) as excinfo:
vllm_model.generate_greedy(prompt_ids, max_tokens)
assert expected_msg in str(excinfo.value)

View File

@ -575,6 +575,21 @@ class Processor:
# check that chunked prefill does not truncate them
# max_batch_len = self.scheduler_config.max_num_batched_tokens
if (
prompt_len == max_prompt_len
and prompt_type == "decoder"
and not model_config.is_multimodal_model
):
suggestion = (
"Make sure that `max_model_len` is no smaller than the "
"number of text tokens (prompt + requested output tokens)."
)
raise ValueError(
f"The {prompt_type} prompt (length {prompt_len}) plus the number of "
f"requested output tokens (at least 1) is longer than the maximum "
f"model length of {max_prompt_len}. {suggestion}"
)
def stat_mm_cache(self) -> MultiModalCacheStats | None:
return self.input_preprocessor.stat_mm_cache()