[Bugfix] Re-enable prefill of max model length (#24446)

Signed-off-by: Yannick Schnider <yannick.schnider1@ibm.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Yannick Schnider 2025-10-03 14:13:34 +02:00 committed by yewentao256
parent ff1daf6c8a
commit 7faf51f1cc
2 changed files with 112 additions and 7 deletions

View File

@ -0,0 +1,91 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
end-to-end tests for context length corner cases of vLLM v1 model runner
versus HuggingFace's transformers.
This test verifies the following behavior: allow a prefill that fills the
model's maximum context length and then request a single new token.
Test strategy
- Build a textual prompt that tokenizes to exactly ``max_model_len`` tokens.
- Run vLLM generation requesting a single new token (max_tokens=1).
- Run HF generation on the same prompt requesting a single token too.
- Assert both return the same number of generated tokens and the same ids.
"""
import pytest
import torch
from transformers import AutoModelForCausalLM
from tests.models.utils import check_outputs_equal
from tests.utils import create_new_process_for_each_test
from vllm import LLM, SamplingParams
from vllm.inputs import TokensPrompt
@create_new_process_for_each_test()
@pytest.mark.parametrize("model", ["JackFram/llama-160m"])
@pytest.mark.parametrize("max_model_len", [2048])
@pytest.mark.parametrize("max_tokens", [1])
def test_prefill_max_context_length(
model: str,
max_model_len: int,
max_tokens: int,
) -> None:
"""Compare vLLM and HuggingFace when the prompt already fills the
model's maximum context length and we request a single new token.
The test ensures vLLM does not raise the "Sampled token IDs exceed the
max model length" assertion and that both vLLM and HF produce the same
single token when given the same inputs.
"""
# Construct a prompt of size max_model_len
prompt_ids = [[43] * max_model_len]
# Generate max_tokens new tokens deterministically.
sampling_params = [
SamplingParams(max_tokens=max_tokens, temperature=0.0, ignore_eos=True)
]
# --- vLLM generation ---
llm = LLM(
model=model,
tokenizer=model,
max_num_seqs=1,
tensor_parallel_size=1,
)
vllm_token_prompts = [TokensPrompt(prompt_token_ids=prompt_ids[0])]
vllm_results = llm.generate(vllm_token_prompts, sampling_params)
vllm_output_ids = vllm_results[0].outputs[0].token_ids
# --- HuggingFace generation ---
with torch.no_grad():
hf_model = AutoModelForCausalLM.from_pretrained(model)
# HF expects a tensor of input ids shaped (batch, seq_len).
hf_input_tokens = torch.tensor(prompt_ids[0]).unsqueeze(0)
# Generate max_tokens new tokens deterministically.
hf_generated = hf_model.generate(
hf_input_tokens,
do_sample=False,
min_new_tokens=max_tokens,
max_new_tokens=max_tokens,
)
# HF returns the prompt + generated tokens. Slice off the prompt.
hf_output_ids = hf_generated.cpu().tolist()[0][len(prompt_ids[0]):]
# check that vLLM outputs (token ids) match HF outputs
# Note: for simplicity don't pass detokenized string
check_outputs_equal(
outputs_0_lst=[(hf_output_ids, "")],
outputs_1_lst=[(vllm_output_ids, "")],
name_0="hf",
name_1="vllm",
)

View File

@ -2247,14 +2247,28 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
start_idx = self.input_batch.num_tokens_no_spec[req_idx]
end_idx = start_idx + len(sampled_ids)
assert end_idx <= self.max_model_len, (
"Sampled token IDs exceed the max model length. "
f"Total number of tokens: {end_idx} > max_model_len: "
f"{self.max_model_len}")
assert end_idx <= self.max_model_len + 1, (
"Sampled token IDs exceed the max model length + 1. "
f"Total number of tokens: {end_idx} > max_model_len + 1: "
f"{self.max_model_len + 1}")
n_tokens_cache = len(sampled_ids)
# Sampled token IDs exceed the max model length by 1. This is
# legitimate as we can still sample 1 last token when the context
# length equals the max model length. Note that we do not need to
# cache this token ID as the sequence finishes after this step.
# Additionally, the buffers token_ids_cpu and is_token_ids are of
# size max model length only.
if end_idx == self.max_model_len + 1:
n_tokens_cache -= 1
self.input_batch.token_ids_cpu[req_idx, start_idx:(
start_idx + n_tokens_cache)] = sampled_ids[:n_tokens_cache]
self.input_batch.is_token_ids[req_idx,
start_idx:(start_idx +
n_tokens_cache)] = True
self.input_batch.token_ids_cpu[req_idx,
start_idx:end_idx] = sampled_ids
self.input_batch.is_token_ids[req_idx, start_idx:end_idx] = True
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
self.input_batch.num_tokens[req_idx] = end_idx