diff --git a/tests/entrypoints/llm/test_generate.py b/tests/entrypoints/llm/test_generate.py index 5af4327b65d0..3be4530452fa 100644 --- a/tests/entrypoints/llm/test_generate.py +++ b/tests/entrypoints/llm/test_generate.py @@ -82,10 +82,11 @@ def test_max_model_len(): for output in outputs: num_total_tokens = len(output.prompt_token_ids) + len( output.outputs[0].token_ids) - # Total tokens must not exceed max_model_len. + # Total tokens must not exceed max_model_len + 1 (the last token can be + # generated with the context length equal to the max model length) # It can be less if generation finishes due to other reasons (e.g., EOS) # before reaching the absolute model length limit. - assert num_total_tokens <= max_model_len + assert num_total_tokens <= max_model_len + 1 def test_log_stats(): diff --git a/tests/v1/e2e/test_context_length.py b/tests/v1/e2e/test_context_length.py index 67a6c7be4432..b8891d961906 100644 --- a/tests/v1/e2e/test_context_length.py +++ b/tests/v1/e2e/test_context_length.py @@ -4,15 +4,22 @@ end-to-end tests for context length corner cases of vLLM v1 model runner versus HuggingFace's transformers. -This test verifies the following behavior: allow a prefill that fills the -model's maximum context length and then request a single new token. +This test verifies the following behavior: allow prefill and decodes on the +model's maximum context length ``max_model_len`` and get one more token. Test strategy -- Build a textual prompt that tokenizes to exactly ``max_model_len`` tokens. -- Run vLLM generation requesting a single new token (max_tokens=1). -- Run HF generation on the same prompt requesting a single token too. +- Build a prompt consisting of exactly ``prompt_len`` tokens. +- Run vLLM generation requesting ``max_tokens`` new tokens. +- Run HF generation on the same prompt requesting the same number of tokens. - Assert both return the same number of generated tokens and the same ids. +Test cases +- Prefill a prompt of ``max_model_len`` (2048) and request a single token which +will be sampled after the prefill (context length ``max_model_len``). +- Prefill a prompt of ``max_model_len`` - 1 (2047) and request two tokens where +the 1st will be sampled after the prefill and the 2nd after the first decode +(context length ``max_model_len``). + """ import pytest @@ -27,11 +34,16 @@ from vllm.inputs import TokensPrompt @create_new_process_for_each_test() @pytest.mark.parametrize("model", ["JackFram/llama-160m"]) -@pytest.mark.parametrize("max_model_len", [2048]) -@pytest.mark.parametrize("max_tokens", [1]) -def test_prefill_max_context_length( +@pytest.mark.parametrize( + "prompt_len, max_tokens", + [ + (2048, 1), # prompt_len = max_model_len + (2047, 2), # prompt_len = max_model_len - 1 + ], +) +def test_max_context_length( model: str, - max_model_len: int, + prompt_len: int, max_tokens: int, ) -> None: """Compare vLLM and HuggingFace when the prompt already fills the @@ -42,8 +54,8 @@ def test_prefill_max_context_length( single token when given the same inputs. """ - # Construct a prompt of size max_model_len - prompt_ids = [[43] * max_model_len] + # Construct a prompt of size prompt_len + prompt_ids = [[43] * prompt_len] # Generate max_tokens new tokens deterministically. sampling_params = [ @@ -54,6 +66,7 @@ def test_prefill_max_context_length( llm = LLM( model=model, tokenizer=model, + max_model_len=2048, max_num_seqs=1, tensor_parallel_size=1, ) @@ -81,6 +94,9 @@ def test_prefill_max_context_length( # HF returns the prompt + generated tokens. Slice off the prompt. hf_output_ids = hf_generated.cpu().tolist()[0][len(prompt_ids[0]):] + # check that exactly max_tokens tokens were generated with vLLM and HF + assert len(vllm_output_ids) == len(hf_output_ids) == max_tokens + # check that vLLM outputs (token ids) match HF outputs # Note: for simplicity don't pass detokenized string check_outputs_equal( diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index d4be1b06b3b2..6983ccca51f4 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -224,7 +224,7 @@ class Scheduler(SchedulerInterface): # This is necessary when using spec decoding. num_new_tokens = min( num_new_tokens, - self.max_model_len - 1 - request.num_computed_tokens) + self.max_model_len - request.num_computed_tokens) # Schedule encoder inputs. encoder_inputs_to_schedule = None diff --git a/vllm/v1/core/sched/utils.py b/vllm/v1/core/sched/utils.py index c431843de6ba..6b321f4ebbef 100644 --- a/vllm/v1/core/sched/utils.py +++ b/vllm/v1/core/sched/utils.py @@ -43,7 +43,7 @@ def remove_all(lst: list, items_to_remove: set) -> list: def check_stop(request: Request, max_model_len: int, pooler_output: Optional[torch.Tensor] = None) -> bool: - if (request.num_tokens >= max_model_len + if (request.num_tokens > max_model_len or request.num_output_tokens >= request.max_tokens): request.status = RequestStatus.FINISHED_LENGTH_CAPPED return True