mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 20:28:42 +08:00
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
parent
caf8b1c084
commit
31a4b3e6c4
@ -85,11 +85,10 @@ def test_max_model_len():
|
||||
num_total_tokens = len(output.prompt_token_ids) + len(
|
||||
output.outputs[0].token_ids
|
||||
)
|
||||
# Total tokens must not exceed max_model_len + 1 (the last token can be
|
||||
# generated with the context length equal to the max model length)
|
||||
# Total tokens must not exceed max_model_len.
|
||||
# It can be less if generation finishes due to other reasons (e.g., EOS)
|
||||
# before reaching the absolute model length limit.
|
||||
assert num_total_tokens <= max_model_len + 1
|
||||
assert num_total_tokens <= max_model_len
|
||||
|
||||
|
||||
def test_log_stats():
|
||||
|
||||
@ -1,90 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
end-to-end tests for context length corner cases of vLLM v1 model runner
|
||||
versus HuggingFace's transformers.
|
||||
|
||||
This test verifies the following behavior: allow prefill and decodes on the
|
||||
model's maximum context length ``max_model_len`` and get one more token.
|
||||
|
||||
Test strategy
|
||||
- Build a prompt consisting of exactly ``prompt_len`` tokens.
|
||||
- Run vLLM generation requesting ``max_tokens`` new tokens.
|
||||
- Run HF generation on the same prompt requesting the same number of tokens.
|
||||
- Assert both return the same number of generated tokens and the same ids.
|
||||
|
||||
Test cases
|
||||
- Prefill a prompt of ``max_model_len`` (2048) and request a single token which
|
||||
will be sampled after the prefill (context length ``max_model_len``).
|
||||
- Prefill a prompt of ``max_model_len`` - 1 (2047) and request two tokens where
|
||||
the 1st will be sampled after the prefill and the 2nd after the first decode
|
||||
(context length ``max_model_len``).
|
||||
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.conftest import HfRunner, VllmRunner
|
||||
from tests.models.utils import check_outputs_equal
|
||||
from tests.utils import create_new_process_for_each_test
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
@pytest.mark.parametrize("model", ["JackFram/llama-160m"])
|
||||
@pytest.mark.parametrize(
|
||||
"prompt_len, max_tokens",
|
||||
[
|
||||
(2048, 1), # prompt_len = max_model_len
|
||||
(2047, 2), # prompt_len = max_model_len - 1
|
||||
],
|
||||
)
|
||||
def test_max_context_length(
|
||||
model: str,
|
||||
vllm_runner: type[VllmRunner],
|
||||
hf_runner: type[HfRunner],
|
||||
prompt_len: int,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
"""Compare vLLM and HuggingFace when the prompt already fills the
|
||||
model's maximum context length and we request a single new token.
|
||||
|
||||
The test ensures vLLM does not raise the "Sampled token IDs exceed the
|
||||
max model length" assertion and that both vLLM and HF produce the same
|
||||
single token when given the same inputs.
|
||||
"""
|
||||
|
||||
# Construct a prompt of size prompt_len
|
||||
prompt_ids = [[43] * prompt_len]
|
||||
|
||||
# --- vLLM generation ---
|
||||
with vllm_runner(
|
||||
model_name=model,
|
||||
tokenizer_name=model,
|
||||
max_model_len=2048,
|
||||
max_num_seqs=1,
|
||||
tensor_parallel_size=1,
|
||||
) as vllm_model:
|
||||
# Generate max_tokens new tokens deterministically.
|
||||
vllm_outputs = vllm_model.generate_greedy(prompt_ids, max_tokens)
|
||||
|
||||
# --- HuggingFace generation ---
|
||||
with hf_runner(
|
||||
model_name=model,
|
||||
) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy(prompt_ids, max_tokens)
|
||||
|
||||
# vLLM and HF runners return prompt + generated tokens. Slice off the prompt.
|
||||
vllm_output_ids = vllm_outputs[0][0][prompt_len:]
|
||||
hf_output_ids = hf_outputs[0][0][prompt_len:]
|
||||
|
||||
# check that exactly max_tokens tokens were generated with vLLM and HF
|
||||
assert len(vllm_output_ids) == len(hf_output_ids) == max_tokens
|
||||
|
||||
# check that vLLM outputs (token ids) match HF outputs
|
||||
# Note: for simplicity don't pass detokenized string
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=[(hf_output_ids, "")],
|
||||
outputs_1_lst=[(vllm_output_ids, "")],
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
@ -223,7 +223,7 @@ class Scheduler(SchedulerInterface):
|
||||
# Make sure the input position does not exceed the max model len.
|
||||
# This is necessary when using spec decoding.
|
||||
num_new_tokens = min(
|
||||
num_new_tokens, self.max_model_len - request.num_computed_tokens
|
||||
num_new_tokens, self.max_model_len - 1 - request.num_computed_tokens
|
||||
)
|
||||
|
||||
# Schedule encoder inputs.
|
||||
|
||||
@ -44,7 +44,7 @@ def check_stop(
|
||||
request: Request, max_model_len: int, pooler_output: Optional[torch.Tensor] = None
|
||||
) -> bool:
|
||||
if (
|
||||
request.num_tokens > max_model_len
|
||||
request.num_tokens >= max_model_len
|
||||
or request.num_output_tokens >= request.max_tokens
|
||||
):
|
||||
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
|
||||
|
||||
@ -2317,30 +2317,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
start_idx = self.input_batch.num_tokens_no_spec[req_idx]
|
||||
end_idx = start_idx + len(sampled_ids)
|
||||
assert end_idx <= self.max_model_len + 1, (
|
||||
"Sampled token IDs exceed the max model length + 1. "
|
||||
f"Total number of tokens: {end_idx} > max_model_len + 1: "
|
||||
f"{self.max_model_len + 1}"
|
||||
assert end_idx <= self.max_model_len, (
|
||||
"Sampled token IDs exceed the max model length. "
|
||||
f"Total number of tokens: {end_idx} > max_model_len: "
|
||||
f"{self.max_model_len}"
|
||||
)
|
||||
|
||||
n_tokens_cache = len(sampled_ids)
|
||||
|
||||
# Sampled token IDs exceed the max model length by 1. This is
|
||||
# legitimate as we can still sample 1 last token when the context
|
||||
# length equals the max model length. Note that we do not need to
|
||||
# cache this token ID as the sequence finishes after this step.
|
||||
# Additionally, the buffers token_ids_cpu and is_token_ids are of
|
||||
# size max model length only.
|
||||
if end_idx == self.max_model_len + 1:
|
||||
n_tokens_cache -= 1
|
||||
|
||||
self.input_batch.token_ids_cpu[
|
||||
req_idx, start_idx : (start_idx + n_tokens_cache)
|
||||
] = sampled_ids[:n_tokens_cache]
|
||||
self.input_batch.is_token_ids[
|
||||
req_idx, start_idx : (start_idx + n_tokens_cache)
|
||||
] = True
|
||||
|
||||
self.input_batch.token_ids_cpu[req_idx, start_idx:end_idx] = sampled_ids
|
||||
self.input_batch.is_token_ids[req_idx, start_idx:end_idx] = True
|
||||
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
|
||||
self.input_batch.num_tokens[req_idx] = end_idx
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user