mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 00:34:58 +08:00
[Core] Enable decode of context length equal to max model length (#26168)
Signed-off-by: Yannick Schnider <yannick.schnider1@ibm.com>
This commit is contained in:
parent
d0df145c2a
commit
f05fea1f5e
@ -82,10 +82,11 @@ def test_max_model_len():
|
|||||||
for output in outputs:
|
for output in outputs:
|
||||||
num_total_tokens = len(output.prompt_token_ids) + len(
|
num_total_tokens = len(output.prompt_token_ids) + len(
|
||||||
output.outputs[0].token_ids)
|
output.outputs[0].token_ids)
|
||||||
# Total tokens must not exceed max_model_len.
|
# Total tokens must not exceed max_model_len + 1 (the last token can be
|
||||||
|
# generated with the context length equal to the max model length)
|
||||||
# It can be less if generation finishes due to other reasons (e.g., EOS)
|
# It can be less if generation finishes due to other reasons (e.g., EOS)
|
||||||
# before reaching the absolute model length limit.
|
# before reaching the absolute model length limit.
|
||||||
assert num_total_tokens <= max_model_len
|
assert num_total_tokens <= max_model_len + 1
|
||||||
|
|
||||||
|
|
||||||
def test_log_stats():
|
def test_log_stats():
|
||||||
|
|||||||
@ -4,15 +4,22 @@
|
|||||||
end-to-end tests for context length corner cases of vLLM v1 model runner
|
end-to-end tests for context length corner cases of vLLM v1 model runner
|
||||||
versus HuggingFace's transformers.
|
versus HuggingFace's transformers.
|
||||||
|
|
||||||
This test verifies the following behavior: allow a prefill that fills the
|
This test verifies the following behavior: allow prefill and decodes on the
|
||||||
model's maximum context length and then request a single new token.
|
model's maximum context length ``max_model_len`` and get one more token.
|
||||||
|
|
||||||
Test strategy
|
Test strategy
|
||||||
- Build a textual prompt that tokenizes to exactly ``max_model_len`` tokens.
|
- Build a prompt consisting of exactly ``prompt_len`` tokens.
|
||||||
- Run vLLM generation requesting a single new token (max_tokens=1).
|
- Run vLLM generation requesting ``max_tokens`` new tokens.
|
||||||
- Run HF generation on the same prompt requesting a single token too.
|
- Run HF generation on the same prompt requesting the same number of tokens.
|
||||||
- Assert both return the same number of generated tokens and the same ids.
|
- Assert both return the same number of generated tokens and the same ids.
|
||||||
|
|
||||||
|
Test cases
|
||||||
|
- Prefill a prompt of ``max_model_len`` (2048) and request a single token which
|
||||||
|
will be sampled after the prefill (context length ``max_model_len``).
|
||||||
|
- Prefill a prompt of ``max_model_len`` - 1 (2047) and request two tokens where
|
||||||
|
the 1st will be sampled after the prefill and the 2nd after the first decode
|
||||||
|
(context length ``max_model_len``).
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -27,11 +34,16 @@ from vllm.inputs import TokensPrompt
|
|||||||
|
|
||||||
@create_new_process_for_each_test()
|
@create_new_process_for_each_test()
|
||||||
@pytest.mark.parametrize("model", ["JackFram/llama-160m"])
|
@pytest.mark.parametrize("model", ["JackFram/llama-160m"])
|
||||||
@pytest.mark.parametrize("max_model_len", [2048])
|
@pytest.mark.parametrize(
|
||||||
@pytest.mark.parametrize("max_tokens", [1])
|
"prompt_len, max_tokens",
|
||||||
def test_prefill_max_context_length(
|
[
|
||||||
|
(2048, 1), # prompt_len = max_model_len
|
||||||
|
(2047, 2), # prompt_len = max_model_len - 1
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_max_context_length(
|
||||||
model: str,
|
model: str,
|
||||||
max_model_len: int,
|
prompt_len: int,
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Compare vLLM and HuggingFace when the prompt already fills the
|
"""Compare vLLM and HuggingFace when the prompt already fills the
|
||||||
@ -42,8 +54,8 @@ def test_prefill_max_context_length(
|
|||||||
single token when given the same inputs.
|
single token when given the same inputs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Construct a prompt of size max_model_len
|
# Construct a prompt of size prompt_len
|
||||||
prompt_ids = [[43] * max_model_len]
|
prompt_ids = [[43] * prompt_len]
|
||||||
|
|
||||||
# Generate max_tokens new tokens deterministically.
|
# Generate max_tokens new tokens deterministically.
|
||||||
sampling_params = [
|
sampling_params = [
|
||||||
@ -54,6 +66,7 @@ def test_prefill_max_context_length(
|
|||||||
llm = LLM(
|
llm = LLM(
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=model,
|
tokenizer=model,
|
||||||
|
max_model_len=2048,
|
||||||
max_num_seqs=1,
|
max_num_seqs=1,
|
||||||
tensor_parallel_size=1,
|
tensor_parallel_size=1,
|
||||||
)
|
)
|
||||||
@ -81,6 +94,9 @@ def test_prefill_max_context_length(
|
|||||||
# HF returns the prompt + generated tokens. Slice off the prompt.
|
# HF returns the prompt + generated tokens. Slice off the prompt.
|
||||||
hf_output_ids = hf_generated.cpu().tolist()[0][len(prompt_ids[0]):]
|
hf_output_ids = hf_generated.cpu().tolist()[0][len(prompt_ids[0]):]
|
||||||
|
|
||||||
|
# check that exactly max_tokens tokens were generated with vLLM and HF
|
||||||
|
assert len(vllm_output_ids) == len(hf_output_ids) == max_tokens
|
||||||
|
|
||||||
# check that vLLM outputs (token ids) match HF outputs
|
# check that vLLM outputs (token ids) match HF outputs
|
||||||
# Note: for simplicity don't pass detokenized string
|
# Note: for simplicity don't pass detokenized string
|
||||||
check_outputs_equal(
|
check_outputs_equal(
|
||||||
|
|||||||
@ -224,7 +224,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
# This is necessary when using spec decoding.
|
# This is necessary when using spec decoding.
|
||||||
num_new_tokens = min(
|
num_new_tokens = min(
|
||||||
num_new_tokens,
|
num_new_tokens,
|
||||||
self.max_model_len - 1 - request.num_computed_tokens)
|
self.max_model_len - request.num_computed_tokens)
|
||||||
|
|
||||||
# Schedule encoder inputs.
|
# Schedule encoder inputs.
|
||||||
encoder_inputs_to_schedule = None
|
encoder_inputs_to_schedule = None
|
||||||
|
|||||||
@ -43,7 +43,7 @@ def remove_all(lst: list, items_to_remove: set) -> list:
|
|||||||
def check_stop(request: Request,
|
def check_stop(request: Request,
|
||||||
max_model_len: int,
|
max_model_len: int,
|
||||||
pooler_output: Optional[torch.Tensor] = None) -> bool:
|
pooler_output: Optional[torch.Tensor] = None) -> bool:
|
||||||
if (request.num_tokens >= max_model_len
|
if (request.num_tokens > max_model_len
|
||||||
or request.num_output_tokens >= request.max_tokens):
|
or request.num_output_tokens >= request.max_tokens):
|
||||||
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
|
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
|
||||||
return True
|
return True
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user