From 7faf51f1ccc1c74ac4064e44494113117d2a36c9 Mon Sep 17 00:00:00 2001 From: Yannick Schnider Date: Fri, 3 Oct 2025 14:13:34 +0200 Subject: [PATCH] [Bugfix] Re-enable prefill of max model length (#24446) Signed-off-by: Yannick Schnider Signed-off-by: yewentao256 --- tests/v1/e2e/test_context_length.py | 91 +++++++++++++++++++++++++++++ vllm/v1/worker/gpu_model_runner.py | 28 ++++++--- 2 files changed, 112 insertions(+), 7 deletions(-) create mode 100644 tests/v1/e2e/test_context_length.py diff --git a/tests/v1/e2e/test_context_length.py b/tests/v1/e2e/test_context_length.py new file mode 100644 index 0000000000000..67a6c7be4432e --- /dev/null +++ b/tests/v1/e2e/test_context_length.py @@ -0,0 +1,91 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +end-to-end tests for context length corner cases of vLLM v1 model runner +versus HuggingFace's transformers. + +This test verifies the following behavior: allow a prefill that fills the +model's maximum context length and then request a single new token. + +Test strategy +- Build a textual prompt that tokenizes to exactly ``max_model_len`` tokens. +- Run vLLM generation requesting a single new token (max_tokens=1). +- Run HF generation on the same prompt requesting a single token too. +- Assert both return the same number of generated tokens and the same ids. + +""" + +import pytest +import torch +from transformers import AutoModelForCausalLM + +from tests.models.utils import check_outputs_equal +from tests.utils import create_new_process_for_each_test +from vllm import LLM, SamplingParams +from vllm.inputs import TokensPrompt + + +@create_new_process_for_each_test() +@pytest.mark.parametrize("model", ["JackFram/llama-160m"]) +@pytest.mark.parametrize("max_model_len", [2048]) +@pytest.mark.parametrize("max_tokens", [1]) +def test_prefill_max_context_length( + model: str, + max_model_len: int, + max_tokens: int, +) -> None: + """Compare vLLM and HuggingFace when the prompt already fills the + model's maximum context length and we request a single new token. + + The test ensures vLLM does not raise the "Sampled token IDs exceed the + max model length" assertion and that both vLLM and HF produce the same + single token when given the same inputs. + """ + + # Construct a prompt of size max_model_len + prompt_ids = [[43] * max_model_len] + + # Generate max_tokens new tokens deterministically. + sampling_params = [ + SamplingParams(max_tokens=max_tokens, temperature=0.0, ignore_eos=True) + ] + + # --- vLLM generation --- + llm = LLM( + model=model, + tokenizer=model, + max_num_seqs=1, + tensor_parallel_size=1, + ) + + vllm_token_prompts = [TokensPrompt(prompt_token_ids=prompt_ids[0])] + vllm_results = llm.generate(vllm_token_prompts, sampling_params) + + vllm_output_ids = vllm_results[0].outputs[0].token_ids + + # --- HuggingFace generation --- + with torch.no_grad(): + hf_model = AutoModelForCausalLM.from_pretrained(model) + + # HF expects a tensor of input ids shaped (batch, seq_len). + hf_input_tokens = torch.tensor(prompt_ids[0]).unsqueeze(0) + + # Generate max_tokens new tokens deterministically. + hf_generated = hf_model.generate( + hf_input_tokens, + do_sample=False, + min_new_tokens=max_tokens, + max_new_tokens=max_tokens, + ) + + # HF returns the prompt + generated tokens. Slice off the prompt. + hf_output_ids = hf_generated.cpu().tolist()[0][len(prompt_ids[0]):] + + # check that vLLM outputs (token ids) match HF outputs + # Note: for simplicity don't pass detokenized string + check_outputs_equal( + outputs_0_lst=[(hf_output_ids, "")], + outputs_1_lst=[(vllm_output_ids, "")], + name_0="hf", + name_1="vllm", + ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 8b92cb052efd6..ff95acf0c0169 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2247,14 +2247,28 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): start_idx = self.input_batch.num_tokens_no_spec[req_idx] end_idx = start_idx + len(sampled_ids) - assert end_idx <= self.max_model_len, ( - "Sampled token IDs exceed the max model length. " - f"Total number of tokens: {end_idx} > max_model_len: " - f"{self.max_model_len}") + assert end_idx <= self.max_model_len + 1, ( + "Sampled token IDs exceed the max model length + 1. " + f"Total number of tokens: {end_idx} > max_model_len + 1: " + f"{self.max_model_len + 1}") + + n_tokens_cache = len(sampled_ids) + + # Sampled token IDs exceed the max model length by 1. This is + # legitimate as we can still sample 1 last token when the context + # length equals the max model length. Note that we do not need to + # cache this token ID as the sequence finishes after this step. + # Additionally, the buffers token_ids_cpu and is_token_ids are of + # size max model length only. + if end_idx == self.max_model_len + 1: + n_tokens_cache -= 1 + + self.input_batch.token_ids_cpu[req_idx, start_idx:( + start_idx + n_tokens_cache)] = sampled_ids[:n_tokens_cache] + self.input_batch.is_token_ids[req_idx, + start_idx:(start_idx + + n_tokens_cache)] = True - self.input_batch.token_ids_cpu[req_idx, - start_idx:end_idx] = sampled_ids - self.input_batch.is_token_ids[req_idx, start_idx:end_idx] = True self.input_batch.num_tokens_no_spec[req_idx] = end_idx self.input_batch.num_tokens[req_idx] = end_idx