mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:25:01 +08:00
[Core] Combined support for multi-step scheduling, chunked prefill & prefix caching (#8804)
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com> Co-authored-by: Andrew Feldman <afeld2012@gmail.com>
This commit is contained in:
parent
1570203864
commit
563649aafe
@ -1,5 +1,6 @@
|
||||
# Test the LLMEngine with multi-step-decoding
|
||||
|
||||
import copy
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
@ -196,3 +197,160 @@ def test_multi_step_llm_w_prompt_logprobs(
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("tp_size", [1])
|
||||
@pytest.mark.parametrize("max_tokens", [5])
|
||||
@pytest.mark.parametrize("enforce_eager", [True])
|
||||
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
|
||||
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
|
||||
@pytest.mark.parametrize("num_logprobs", [None, 5])
|
||||
def test_multi_step_llm_chunked_prefill_prefix_cache(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
tp_size: int,
|
||||
max_tokens: int,
|
||||
enforce_eager: int,
|
||||
num_scheduler_steps: int,
|
||||
num_prompts: int,
|
||||
num_logprobs: Optional[int],
|
||||
) -> None:
|
||||
"""Test vLLM engine with multi-step+"single-step chunked prefill"+APC.
|
||||
|
||||
Set up contrived scenario which tests for a possible failure mode of
|
||||
scheduling with multi-step+"single-step chunked prefill"+APC
|
||||
|
||||
"single-step chunked prefill" here refers to the current vLLM multi-step+
|
||||
chunked-prefill implementation, which requires that a prefill may only
|
||||
be scheduled in the same step as decodes if the prefill prompt fits in a
|
||||
single chunk (note that "complete" multi-step+chunked-prefill would allow
|
||||
a prefill to span multiple chunks & multiple steps but that is not yet
|
||||
the case.)
|
||||
|
||||
"APC" is short for "automatic prefix caching".
|
||||
|
||||
This test creates a scenario where the scheduler must decide whether/how
|
||||
to schedule a prefill with a prompt that exceeds the available token budget.
|
||||
The correct behavior for multi-step+"single-step chunked prefill"+APC is to
|
||||
put off scheduling the prefill until a future step.
|
||||
|
||||
Validate that:
|
||||
* Multi-step kernels do not raise an exception due to incorrect scheduler
|
||||
behavior
|
||||
* Generated tokens match between
|
||||
multi-step+"single-step chunked prefill"+APC and
|
||||
single-step scheduling.
|
||||
* (If logprobs are enabled) check logprobs are close enough
|
||||
|
||||
Args:
|
||||
vllm_runner: vLLM model runner fixture
|
||||
example_prompts: test fixture providing example prompts
|
||||
model: model under test (same for single- and multi-step engines)
|
||||
dtype: tensor datatype for engine to utilize
|
||||
tp_size: degree of tensor-parallelism
|
||||
max_tokens: the maximum number of tokens to generate
|
||||
enforce_eager
|
||||
num_scheduler_steps: for multi-step scheduling, GPU-side steps per
|
||||
GPU -> CPU output transfer
|
||||
num_prompts: number of example prompts under test
|
||||
num_logprobs: corresponds to the `logprobs` argument to the OpenAI
|
||||
completions endpoint; `None` -> 1 logprob returned.
|
||||
"""
|
||||
|
||||
# Set up contrived test for correct scheduling behavior with
|
||||
# multi-step+"single-step chunked prefill"+APC.
|
||||
#
|
||||
# Assume block_size=16
|
||||
#
|
||||
# Assume max_num_batched_tokens=48
|
||||
# => Per-step token budget=48
|
||||
#
|
||||
# 1. Scheduler schedules 0th prompt (24 tokens)
|
||||
# => Remaining token budget=24
|
||||
# 2. Scheduler attempts to schedule 1st prompt (30 tokens)
|
||||
# * 30 tokens exceeds 24 token remaining budget
|
||||
# * Correct behavior: do not schedule this prompt in this step
|
||||
# * Incorrect behavior: schedule prompt chunk
|
||||
# * `do_sample=False` for this prompt in this step
|
||||
# * Chunk size = (remaining tokens // block size) * block size
|
||||
#
|
||||
# The Incorrect scheduling behavior - if it occurs - will cause an exception
|
||||
# in the model runner resulting from `do_sample=False`.
|
||||
assert len(example_prompts) >= 2
|
||||
challenge_prompts = copy.deepcopy(example_prompts)
|
||||
challenge_prompts[0] = ('vLLM is a high-throughput and memory-efficient '
|
||||
'inference and serving engine for LLMs.\n'
|
||||
) # 24 tok
|
||||
challenge_prompts[1] = (
|
||||
'Briefly describe the major milestones in the '
|
||||
'development of artificial intelligence from 1950 to 2020.\n'
|
||||
) # 30 tok
|
||||
|
||||
# If necessary, adjust the length of `challenge_prompts` to match
|
||||
# `num_prompts`
|
||||
if len(challenge_prompts) < num_prompts:
|
||||
challenge_prompts = (challenge_prompts *
|
||||
((num_prompts // len(challenge_prompts)) + 1))
|
||||
challenge_prompts = challenge_prompts[:num_prompts]
|
||||
assert len(challenge_prompts) == num_prompts
|
||||
|
||||
# Single-step scheduler baseline
|
||||
with vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
enforce_eager=enforce_eager,
|
||||
gpu_memory_utilization=0.7,
|
||||
tensor_parallel_size=tp_size,
|
||||
use_v2_block_manager=True,
|
||||
num_scheduler_steps=num_scheduler_steps,
|
||||
max_model_len=48,
|
||||
max_num_batched_tokens=48,
|
||||
max_num_seqs=4,
|
||||
block_size=16,
|
||||
) as vllm_model:
|
||||
outputs_baseline = (vllm_model.generate_greedy(
|
||||
challenge_prompts, max_tokens) if num_logprobs is None else
|
||||
vllm_model.generate_greedy_logprobs(
|
||||
challenge_prompts, max_tokens, num_logprobs))
|
||||
|
||||
# multi-step+"single-step chunked prefill"+APC
|
||||
with vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
enforce_eager=enforce_eager,
|
||||
gpu_memory_utilization=0.7,
|
||||
tensor_parallel_size=tp_size,
|
||||
use_v2_block_manager=True,
|
||||
enable_chunked_prefill=True,
|
||||
enable_prefix_caching=True,
|
||||
num_scheduler_steps=num_scheduler_steps,
|
||||
max_model_len=48,
|
||||
max_num_batched_tokens=48,
|
||||
max_num_seqs=4,
|
||||
block_size=16,
|
||||
) as vllm_model:
|
||||
outputs_w_features = (vllm_model.generate_greedy(
|
||||
challenge_prompts, max_tokens) if num_logprobs is None else
|
||||
vllm_model.generate_greedy_logprobs(
|
||||
challenge_prompts, max_tokens, num_logprobs))
|
||||
|
||||
if num_logprobs is None:
|
||||
# No-logprobs test
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=outputs_baseline,
|
||||
outputs_1_lst=outputs_w_features,
|
||||
name_0="multi-step",
|
||||
name_1="multi-step+features",
|
||||
)
|
||||
else:
|
||||
# Yes-logprobs test
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=outputs_baseline,
|
||||
outputs_1_lst=outputs_w_features,
|
||||
name_0="multi-step",
|
||||
name_1="multi-step+features",
|
||||
)
|
||||
|
||||
@ -1607,10 +1607,29 @@ class Scheduler:
|
||||
# in a decode phase. Do not chunk.
|
||||
if enable_chunking and len(seqs) == 1:
|
||||
remaining_token_budget = budget.remaining_token_budget()
|
||||
if self.cache_config.enable_prefix_caching:
|
||||
if self.scheduler_config.is_multi_step:
|
||||
# The current multi-step + chunked prefill capability does
|
||||
# not actually support chunking prompts.
|
||||
#
|
||||
# Therefore, `num_new_tokens` is computed in the same fashion
|
||||
# for both multi-step+chunked-prefill &
|
||||
# multi-step+chunked-prefill+APC
|
||||
#
|
||||
# Prompts with more tokens than the current remaining budget
|
||||
# are postponed to future scheduler steps
|
||||
if num_new_tokens > self._get_prompt_limit(seq_group):
|
||||
# If the seq_group is in prompt-stage, pass the
|
||||
# num_new_tokens as-is so the caller can ignore
|
||||
# the sequence.
|
||||
pass
|
||||
else:
|
||||
num_new_tokens = 0 \
|
||||
if num_new_tokens > remaining_token_budget \
|
||||
else num_new_tokens
|
||||
elif self.cache_config.enable_prefix_caching:
|
||||
# When prefix caching is enabled, we always allocate
|
||||
# the number of new tokens that is dividable by the block size
|
||||
# to avoid partial block matching.
|
||||
# the number of new tokens that is dividable by the block
|
||||
# size to avoid partial block matching.
|
||||
block_size = self.cache_config.block_size
|
||||
remainder = budget.token_budget % block_size
|
||||
if remainder != 0:
|
||||
@ -1623,16 +1642,6 @@ class Scheduler:
|
||||
if remaining_token_budget < num_new_tokens:
|
||||
num_new_tokens = (remaining_token_budget //
|
||||
block_size) * block_size
|
||||
elif self.scheduler_config.is_multi_step:
|
||||
if num_new_tokens > self._get_prompt_limit(seq_group):
|
||||
# If the seq_group is in prompt-stage, pass the
|
||||
# num_new_tokens as-is so the caller can ignore
|
||||
# the sequence.
|
||||
pass
|
||||
else:
|
||||
num_new_tokens = 0 \
|
||||
if num_new_tokens > remaining_token_budget \
|
||||
else num_new_tokens
|
||||
else:
|
||||
num_new_tokens = min(num_new_tokens, remaining_token_budget)
|
||||
return num_new_tokens
|
||||
|
||||
@ -999,10 +999,6 @@ class EngineArgs:
|
||||
if speculative_config is not None:
|
||||
raise ValueError("Speculative decoding is not supported with "
|
||||
"multi-step (--num-scheduler-steps > 1)")
|
||||
if self.enable_chunked_prefill and self.enable_prefix_caching:
|
||||
raise ValueError("Multi-Step is not supported with "
|
||||
"both Chunked-Prefill and Prefix-Caching "
|
||||
"enabled together.")
|
||||
if self.enable_chunked_prefill and self.pipeline_parallel_size > 1:
|
||||
raise ValueError("Multi-Step Chunked-Prefill is not supported "
|
||||
"for pipeline-parallel-size > 1")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user