mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 09:06:19 +08:00
108 lines
3.6 KiB
Python
108 lines
3.6 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""
|
|
end-to-end tests for context length corner cases of vLLM v1 model runner
|
|
versus HuggingFace's transformers.
|
|
|
|
This test verifies the following behavior: allow prefill and decodes on the
|
|
model's maximum context length ``max_model_len`` and get one more token.
|
|
|
|
Test strategy
|
|
- Build a prompt consisting of exactly ``prompt_len`` tokens.
|
|
- Run vLLM generation requesting ``max_tokens`` new tokens.
|
|
- Run HF generation on the same prompt requesting the same number of tokens.
|
|
- Assert both return the same number of generated tokens and the same ids.
|
|
|
|
Test cases
|
|
- Prefill a prompt of ``max_model_len`` (2048) and request a single token which
|
|
will be sampled after the prefill (context length ``max_model_len``).
|
|
- Prefill a prompt of ``max_model_len`` - 1 (2047) and request two tokens where
|
|
the 1st will be sampled after the prefill and the 2nd after the first decode
|
|
(context length ``max_model_len``).
|
|
|
|
"""
|
|
|
|
import pytest
|
|
import torch
|
|
from transformers import AutoModelForCausalLM
|
|
|
|
from tests.models.utils import check_outputs_equal
|
|
from tests.utils import create_new_process_for_each_test
|
|
from vllm import LLM, SamplingParams
|
|
from vllm.inputs import TokensPrompt
|
|
|
|
|
|
@create_new_process_for_each_test()
|
|
@pytest.mark.parametrize("model", ["JackFram/llama-160m"])
|
|
@pytest.mark.parametrize(
|
|
"prompt_len, max_tokens",
|
|
[
|
|
(2048, 1), # prompt_len = max_model_len
|
|
(2047, 2), # prompt_len = max_model_len - 1
|
|
],
|
|
)
|
|
def test_max_context_length(
|
|
model: str,
|
|
prompt_len: int,
|
|
max_tokens: int,
|
|
) -> None:
|
|
"""Compare vLLM and HuggingFace when the prompt already fills the
|
|
model's maximum context length and we request a single new token.
|
|
|
|
The test ensures vLLM does not raise the "Sampled token IDs exceed the
|
|
max model length" assertion and that both vLLM and HF produce the same
|
|
single token when given the same inputs.
|
|
"""
|
|
|
|
# Construct a prompt of size prompt_len
|
|
prompt_ids = [[43] * prompt_len]
|
|
|
|
# Generate max_tokens new tokens deterministically.
|
|
sampling_params = [
|
|
SamplingParams(max_tokens=max_tokens, temperature=0.0, ignore_eos=True)
|
|
]
|
|
|
|
# --- vLLM generation ---
|
|
llm = LLM(
|
|
model=model,
|
|
tokenizer=model,
|
|
max_model_len=2048,
|
|
max_num_seqs=1,
|
|
tensor_parallel_size=1,
|
|
)
|
|
|
|
vllm_token_prompts = [TokensPrompt(prompt_token_ids=prompt_ids[0])]
|
|
vllm_results = llm.generate(vllm_token_prompts, sampling_params)
|
|
|
|
vllm_output_ids = vllm_results[0].outputs[0].token_ids
|
|
|
|
# --- HuggingFace generation ---
|
|
with torch.no_grad():
|
|
hf_model = AutoModelForCausalLM.from_pretrained(model)
|
|
|
|
# HF expects a tensor of input ids shaped (batch, seq_len).
|
|
hf_input_tokens = torch.tensor(prompt_ids[0]).unsqueeze(0)
|
|
|
|
# Generate max_tokens new tokens deterministically.
|
|
hf_generated = hf_model.generate(
|
|
hf_input_tokens,
|
|
do_sample=False,
|
|
min_new_tokens=max_tokens,
|
|
max_new_tokens=max_tokens,
|
|
)
|
|
|
|
# HF returns the prompt + generated tokens. Slice off the prompt.
|
|
hf_output_ids = hf_generated.cpu().tolist()[0][len(prompt_ids[0]):]
|
|
|
|
# check that exactly max_tokens tokens were generated with vLLM and HF
|
|
assert len(vllm_output_ids) == len(hf_output_ids) == max_tokens
|
|
|
|
# check that vLLM outputs (token ids) match HF outputs
|
|
# Note: for simplicity don't pass detokenized string
|
|
check_outputs_equal(
|
|
outputs_0_lst=[(hf_output_ids, "")],
|
|
outputs_1_lst=[(vllm_output_ids, "")],
|
|
name_0="hf",
|
|
name_1="vllm",
|
|
)
|