mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-12 16:16:22 +08:00
[Core] Reduce TTFT with concurrent partial prefills (#10235)
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com> Signed-off-by: Prashant Gupta <prashantgupta@us.ibm.com> Co-authored-by: Prashant Gupta <prashantgupta@us.ibm.com> Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
parent
5e5c8e091e
commit
3bcb8c75da
@ -8,7 +8,6 @@ prefill requests are chunked.
|
|||||||
Run `pytest tests/models/test_chunked_prefill.py`.
|
Run `pytest tests/models/test_chunked_prefill.py`.
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
from contextlib import nullcontext
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -233,7 +232,6 @@ def test_with_prefix_caching(
|
|||||||
|
|
||||||
max_num_batched_tokens = max_num_seqs = chunk_size
|
max_num_batched_tokens = max_num_seqs = chunk_size
|
||||||
outputs = {} # type: ignore
|
outputs = {} # type: ignore
|
||||||
check_result = True
|
|
||||||
for enable in (True, False):
|
for enable in (True, False):
|
||||||
with vllm_runner(
|
with vllm_runner(
|
||||||
model,
|
model,
|
||||||
@ -245,25 +243,17 @@ def test_with_prefix_caching(
|
|||||||
enforce_eager=enforce_eager,
|
enforce_eager=enforce_eager,
|
||||||
max_num_seqs=max_num_seqs,
|
max_num_seqs=max_num_seqs,
|
||||||
) as vllm_model:
|
) as vllm_model:
|
||||||
# It should fail when prefix caching is enable and chunk
|
|
||||||
# size is not a multiple of block size (16).
|
|
||||||
should_fail = chunk_size % 16 != 0 and enable
|
|
||||||
check_result &= not should_fail
|
|
||||||
outputs[enable] = []
|
outputs[enable] = []
|
||||||
# Send the request one-by-one to ensure the cache is populated.
|
for prompt in full_prompts:
|
||||||
with pytest.raises(ValueError) if should_fail else nullcontext():
|
outputs[enable] += vllm_model.generate_greedy([prompt],
|
||||||
for prompt in full_prompts:
|
max_tokens)
|
||||||
outputs[enable] += vllm_model.generate_greedy([prompt],
|
|
||||||
max_tokens)
|
|
||||||
|
|
||||||
# Check results only if we did not expect a failure.
|
check_outputs_equal(
|
||||||
if check_result:
|
outputs_0_lst=outputs[False],
|
||||||
check_outputs_equal(
|
outputs_1_lst=outputs[True],
|
||||||
outputs_0_lst=outputs[False],
|
name_0="w/o prefix caching",
|
||||||
outputs_1_lst=outputs[True],
|
name_1="with prefix caching",
|
||||||
name_0="w/o prefix caching",
|
)
|
||||||
name_1="with prefix caching",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
|
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
|
||||||
|
|||||||
@ -7,6 +7,9 @@ import pytest # noqa
|
|||||||
|
|
||||||
from vllm.config import CacheConfig, SchedulerConfig
|
from vllm.config import CacheConfig, SchedulerConfig
|
||||||
from vllm.core.scheduler import Scheduler
|
from vllm.core.scheduler import Scheduler
|
||||||
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
|
from vllm.engine.llm_engine import LLMEngine
|
||||||
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.sequence import Logprob, SequenceGroup
|
from vllm.sequence import Logprob, SequenceGroup
|
||||||
|
|
||||||
from .utils import create_dummy_prompt
|
from .utils import create_dummy_prompt
|
||||||
@ -16,7 +19,7 @@ def get_sequence_groups(scheduler_output):
|
|||||||
return [s.seq_group for s in scheduler_output.scheduled_seq_groups]
|
return [s.seq_group for s in scheduler_output.scheduled_seq_groups]
|
||||||
|
|
||||||
|
|
||||||
def append_new_token(seq_group, token_id: int):
|
def append_new_token(seq_group: SequenceGroup, token_id: int):
|
||||||
for seq in seq_group.get_seqs():
|
for seq in seq_group.get_seqs():
|
||||||
seq.append_token_id(token_id, {token_id: Logprob(token_id)})
|
seq.append_token_id(token_id, {token_id: Logprob(token_id)})
|
||||||
|
|
||||||
@ -123,6 +126,232 @@ def test_chunk():
|
|||||||
assert out.num_batched_tokens == 57
|
assert out.num_batched_tokens == 57
|
||||||
|
|
||||||
|
|
||||||
|
def test_concurrent_chunking():
|
||||||
|
"""Verify prefills are chunked properly when
|
||||||
|
--max-num-partial-prefills is > 1"""
|
||||||
|
block_size = 4
|
||||||
|
max_seqs = 60
|
||||||
|
max_model_len = 2000
|
||||||
|
max_num_batched_tokens = 64
|
||||||
|
scheduler_config = SchedulerConfig(
|
||||||
|
"generate",
|
||||||
|
max_num_batched_tokens,
|
||||||
|
max_seqs,
|
||||||
|
max_model_len,
|
||||||
|
enable_chunked_prefill=True,
|
||||||
|
max_num_partial_prefills=2, # Up to 2 partial prefills at a time
|
||||||
|
)
|
||||||
|
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||||
|
cache_config.num_cpu_blocks = 32
|
||||||
|
cache_config.num_gpu_blocks = 32
|
||||||
|
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||||
|
running: List[SequenceGroup] = []
|
||||||
|
|
||||||
|
# Add seq groups to scheduler.
|
||||||
|
for i in range(2):
|
||||||
|
_, seq_group = create_dummy_prompt(str(i),
|
||||||
|
prompt_length=60,
|
||||||
|
block_size=block_size)
|
||||||
|
scheduler.add_seq_group(seq_group)
|
||||||
|
running.append(seq_group)
|
||||||
|
|
||||||
|
# Verify both requests are chunked with half of max_num_batched_tokens each
|
||||||
|
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||||
|
assert set(get_sequence_groups(out)) == set(running)
|
||||||
|
assert seq_group_meta[0].token_chunk_size == 32
|
||||||
|
assert seq_group_meta[1].token_chunk_size == 32
|
||||||
|
assert out.num_prefill_groups == 2
|
||||||
|
assert out.num_batched_tokens == 64
|
||||||
|
|
||||||
|
# After one iteration, both should have 60 - 32 = 28 tokens left to prefill
|
||||||
|
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||||
|
assert set(get_sequence_groups(out)) == set(running)
|
||||||
|
assert seq_group_meta[0].token_chunk_size == 28
|
||||||
|
assert seq_group_meta[1].token_chunk_size == 28
|
||||||
|
assert out.num_prefill_groups == 2
|
||||||
|
assert out.num_batched_tokens == 56
|
||||||
|
|
||||||
|
|
||||||
|
def test_concurrent_chunking_large_requests():
|
||||||
|
"""Verify large prefill requests are run one at a time"""
|
||||||
|
block_size = 4
|
||||||
|
max_seqs = 60
|
||||||
|
max_model_len = 2000
|
||||||
|
max_num_batched_tokens = 64
|
||||||
|
scheduler_config = SchedulerConfig(
|
||||||
|
"generate",
|
||||||
|
max_num_batched_tokens,
|
||||||
|
max_seqs,
|
||||||
|
max_model_len,
|
||||||
|
enable_chunked_prefill=True,
|
||||||
|
max_num_partial_prefills=2, # Up to 2 partial prefills at a time
|
||||||
|
)
|
||||||
|
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||||
|
cache_config.num_cpu_blocks = 3200 # large KV cache size for large requests
|
||||||
|
cache_config.num_gpu_blocks = 3200
|
||||||
|
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||||
|
|
||||||
|
# Add seq groups to scheduler.
|
||||||
|
for i in range(2):
|
||||||
|
_, seq_group = create_dummy_prompt(
|
||||||
|
str(i),
|
||||||
|
prompt_length=1200, # Very large prompt
|
||||||
|
block_size=block_size)
|
||||||
|
scheduler.add_seq_group(seq_group)
|
||||||
|
|
||||||
|
# Verify only a single request is chunked, and it gets all 64 tokens
|
||||||
|
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||||
|
assert len(get_sequence_groups(out)) == 1
|
||||||
|
assert seq_group_meta[0].token_chunk_size == 64
|
||||||
|
assert out.num_prefill_groups == 1
|
||||||
|
assert out.num_batched_tokens == 64
|
||||||
|
|
||||||
|
|
||||||
|
def test_short_prompts_jump_long_prompts_in_queue():
|
||||||
|
"""Verify large prefill requests are punted behind smaller ones if
|
||||||
|
another large prefill request is already running"""
|
||||||
|
block_size = 4
|
||||||
|
max_seqs = 60
|
||||||
|
max_model_len = 2000
|
||||||
|
max_num_batched_tokens = 64
|
||||||
|
scheduler_config = SchedulerConfig(
|
||||||
|
"generate",
|
||||||
|
max_num_batched_tokens,
|
||||||
|
max_seqs,
|
||||||
|
max_model_len,
|
||||||
|
enable_chunked_prefill=True,
|
||||||
|
max_num_partial_prefills=2, # Up to 2 partial prefills at a time
|
||||||
|
)
|
||||||
|
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||||
|
cache_config.num_cpu_blocks = 3200 # large KV cache size for large requests
|
||||||
|
cache_config.num_gpu_blocks = 3200
|
||||||
|
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||||
|
long_seqs: List[SequenceGroup] = []
|
||||||
|
short_seqs: List[SequenceGroup] = []
|
||||||
|
|
||||||
|
# Add 2 large seq groups to scheduler.
|
||||||
|
for i in range(2):
|
||||||
|
_, seq_group = create_dummy_prompt(
|
||||||
|
str(i),
|
||||||
|
prompt_length=1200, # Very large prompt
|
||||||
|
block_size=block_size)
|
||||||
|
scheduler.add_seq_group(seq_group)
|
||||||
|
long_seqs.append(seq_group)
|
||||||
|
assert seq_group.is_prefill()
|
||||||
|
|
||||||
|
# Add 2 small seq groups behind them
|
||||||
|
for i in range(2):
|
||||||
|
_, seq_group = create_dummy_prompt(
|
||||||
|
str(i + 2),
|
||||||
|
prompt_length=40, # Very small prompt
|
||||||
|
block_size=block_size)
|
||||||
|
scheduler.add_seq_group(seq_group)
|
||||||
|
short_seqs.append(seq_group)
|
||||||
|
assert seq_group.is_prefill()
|
||||||
|
|
||||||
|
# Verify one large req and 1 small req chunked
|
||||||
|
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||||
|
assert seq_group_meta[0].token_chunk_size == 32 # large req gets 32 tokens
|
||||||
|
assert seq_group_meta[1].token_chunk_size == 32 # small req gets 32 tokens
|
||||||
|
|
||||||
|
# all 4 are prefilling
|
||||||
|
assert long_seqs[0].is_prefill()
|
||||||
|
assert long_seqs[1].is_prefill()
|
||||||
|
assert short_seqs[0].is_prefill()
|
||||||
|
assert short_seqs[1].is_prefill()
|
||||||
|
# First short and first long sequences have been scheduled
|
||||||
|
assert long_seqs[0].first_seq.get_num_computed_tokens() == 32
|
||||||
|
assert long_seqs[1].first_seq.get_num_computed_tokens() == 0
|
||||||
|
assert short_seqs[0].first_seq.get_num_computed_tokens() == 32
|
||||||
|
assert short_seqs[1].first_seq.get_num_computed_tokens() == 0
|
||||||
|
|
||||||
|
assert out.num_prefill_groups == 2
|
||||||
|
assert out.num_batched_tokens == 64
|
||||||
|
|
||||||
|
# in the second iteration,
|
||||||
|
# the first small request had only 8 tokens left
|
||||||
|
# so it went to decode
|
||||||
|
# The other small req is scheduled
|
||||||
|
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||||
|
# the new small req got 64 - (32+8) tokens
|
||||||
|
assert seq_group_meta[0].token_chunk_size == 24
|
||||||
|
assert seq_group_meta[1].token_chunk_size == 32 # large req still got 32
|
||||||
|
# the other small request had only 8 tokens left
|
||||||
|
assert seq_group_meta[2].token_chunk_size == 8 # 40-32
|
||||||
|
|
||||||
|
# The first small request got to decode now
|
||||||
|
assert long_seqs[0].is_prefill()
|
||||||
|
assert long_seqs[1].is_prefill()
|
||||||
|
assert not short_seqs[0].is_prefill()
|
||||||
|
assert short_seqs[1].is_prefill()
|
||||||
|
# Both small requests have started in front of the second long request
|
||||||
|
assert long_seqs[0].first_seq.get_num_computed_tokens() == 64
|
||||||
|
assert long_seqs[1].first_seq.get_num_computed_tokens() == 0
|
||||||
|
assert short_seqs[0].first_seq.get_num_computed_tokens() == 40
|
||||||
|
assert short_seqs[1].first_seq.get_num_computed_tokens() == 24
|
||||||
|
|
||||||
|
assert out.num_prefill_groups == 3
|
||||||
|
assert out.num_batched_tokens == 64
|
||||||
|
# the first small seq group has a new token appended.
|
||||||
|
append_new_token(short_seqs[0], 1)
|
||||||
|
|
||||||
|
# in the third iteration,
|
||||||
|
# the first small request is already decoding
|
||||||
|
# the second small request only has 16 tokens left and will enter decoding
|
||||||
|
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||||
|
assert seq_group_meta[0].token_chunk_size == 32 # large still got 32
|
||||||
|
# small req finished prefilling 40-24=16 tokens
|
||||||
|
assert seq_group_meta[1].token_chunk_size == 16
|
||||||
|
assert seq_group_meta[2].token_chunk_size == 1 # decode
|
||||||
|
assert out.num_prefill_groups == 2
|
||||||
|
assert out.num_batched_tokens == 49 # (32+16+1 decode)
|
||||||
|
|
||||||
|
# both small requests have now reached decode
|
||||||
|
assert long_seqs[0].is_prefill()
|
||||||
|
assert long_seqs[1].is_prefill()
|
||||||
|
assert not short_seqs[0].is_prefill()
|
||||||
|
assert not short_seqs[1].is_prefill()
|
||||||
|
assert long_seqs[0].first_seq.get_num_computed_tokens() == 96
|
||||||
|
assert long_seqs[1].first_seq.get_num_computed_tokens() == 0
|
||||||
|
assert short_seqs[0].first_seq.get_num_computed_tokens() == 41
|
||||||
|
assert short_seqs[1].first_seq.get_num_computed_tokens() == 40
|
||||||
|
|
||||||
|
# both the small seq groups have a new token appended
|
||||||
|
append_new_token(short_seqs[0], 1)
|
||||||
|
append_new_token(short_seqs[1], 1)
|
||||||
|
|
||||||
|
# in the fourth iteration, both small requests are decoding
|
||||||
|
# so large request gets all the budget
|
||||||
|
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||||
|
|
||||||
|
# large req gets 62 tokens (minus 2 for decode)
|
||||||
|
assert seq_group_meta[0].token_chunk_size == 62
|
||||||
|
assert seq_group_meta[1].token_chunk_size == 1 # decode
|
||||||
|
assert seq_group_meta[2].token_chunk_size == 1 # decode
|
||||||
|
assert out.num_prefill_groups == 1
|
||||||
|
assert out.num_batched_tokens == 64
|
||||||
|
|
||||||
|
assert long_seqs[0].first_seq.get_num_computed_tokens() == 158
|
||||||
|
|
||||||
|
# assert long_seqs[0].is_prefill()
|
||||||
|
# assert long_seqs[1].is_prefill()
|
||||||
|
# assert not short_seqs[0].is_prefill()
|
||||||
|
# assert not short_seqs[1].is_prefill()
|
||||||
|
|
||||||
|
# # both the small seq groups have a new token appended
|
||||||
|
# append_new_token(short_seqs[0], 1)
|
||||||
|
# append_new_token(short_seqs[1], 1)
|
||||||
|
|
||||||
|
# # in the fifth iteration, large request gets all the budget
|
||||||
|
# # while both small requests are decoding
|
||||||
|
# seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||||
|
# assert seq_group_meta[0].token_chunk_size == 62
|
||||||
|
# assert seq_group_meta[1].token_chunk_size == 1 # decode
|
||||||
|
# assert seq_group_meta[2].token_chunk_size == 1 # decode
|
||||||
|
# assert out.num_prefill_groups == 1
|
||||||
|
# assert out.num_batched_tokens == 64
|
||||||
|
|
||||||
|
|
||||||
def test_complex():
|
def test_complex():
|
||||||
block_size = 4
|
block_size = 4
|
||||||
max_seqs = 60
|
max_seqs = 60
|
||||||
@ -508,7 +737,7 @@ def test_chunked_prefill_max_seqs():
|
|||||||
assert not running[1].is_prefill()
|
assert not running[1].is_prefill()
|
||||||
|
|
||||||
|
|
||||||
def test_perfix_caching():
|
def test_prefix_caching():
|
||||||
"""Verify allocating full blocks when prefix caching is enabled."""
|
"""Verify allocating full blocks when prefix caching is enabled."""
|
||||||
block_size = 4
|
block_size = 4
|
||||||
max_seqs = 10
|
max_seqs = 10
|
||||||
@ -548,3 +777,86 @@ def test_perfix_caching():
|
|||||||
assert seq_group_meta[1].token_chunk_size == 12
|
assert seq_group_meta[1].token_chunk_size == 12
|
||||||
assert out.num_prefill_groups == 2
|
assert out.num_prefill_groups == 2
|
||||||
assert out.num_batched_tokens == 62
|
assert out.num_batched_tokens == 62
|
||||||
|
|
||||||
|
|
||||||
|
def test_prefix_caching_with_concurrent_partial_prefills():
|
||||||
|
"""Verify allocating full blocks when prefix caching is enabled with
|
||||||
|
--max-num-partial-prefills > 1."""
|
||||||
|
block_size = 4
|
||||||
|
max_seqs = 10
|
||||||
|
max_model_len = 8000
|
||||||
|
max_num_batched_tokens = 60 # With two slots, each slot will get 30 tokens
|
||||||
|
scheduler_config = SchedulerConfig("generate",
|
||||||
|
max_num_batched_tokens,
|
||||||
|
max_seqs,
|
||||||
|
max_model_len,
|
||||||
|
enable_chunked_prefill=True,
|
||||||
|
max_num_partial_prefills=2)
|
||||||
|
cache_config = CacheConfig(block_size,
|
||||||
|
1.0,
|
||||||
|
1,
|
||||||
|
"auto",
|
||||||
|
enable_prefix_caching=True)
|
||||||
|
cache_config.num_cpu_blocks = 0
|
||||||
|
cache_config.num_gpu_blocks = 32
|
||||||
|
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||||
|
running: List[SequenceGroup] = []
|
||||||
|
|
||||||
|
# Add seq groups to scheduler.
|
||||||
|
for i in range(2):
|
||||||
|
_, seq_group = create_dummy_prompt(str(i),
|
||||||
|
block_size=block_size,
|
||||||
|
prompt_length=50)
|
||||||
|
scheduler.add_seq_group(seq_group)
|
||||||
|
running.append(seq_group)
|
||||||
|
|
||||||
|
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||||
|
assert set(get_sequence_groups(out)) == set(running)
|
||||||
|
# To partially prefill both sequences, both can chunk up to 30 tokens
|
||||||
|
# But the next lowest multiple of the block size (4) is 28
|
||||||
|
assert seq_group_meta[0].token_chunk_size == 28
|
||||||
|
assert seq_group_meta[1].token_chunk_size == 28
|
||||||
|
assert out.num_prefill_groups == 2
|
||||||
|
assert out.num_batched_tokens == 56
|
||||||
|
|
||||||
|
# On the next iteration, both sequences should finish prefill
|
||||||
|
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||||
|
assert set(get_sequence_groups(out)) == set(running)
|
||||||
|
# Both sequences have 50 - 28 = 22 tokens left to prefill.
|
||||||
|
# This is not a multiple of the block size, but we don't care since we don't
|
||||||
|
# cache the final partial block of prefix sequences
|
||||||
|
assert seq_group_meta[0].token_chunk_size == 22
|
||||||
|
assert seq_group_meta[1].token_chunk_size == 22
|
||||||
|
assert out.num_prefill_groups == 2
|
||||||
|
assert out.num_batched_tokens == 44
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
|
||||||
|
@pytest.mark.parametrize("max_num_partial_prefills", [2, 4, 8])
|
||||||
|
def test_chunked_prefill_with_actual_engine(model: str,
|
||||||
|
max_num_partial_prefills: int):
|
||||||
|
"""Make sure the model can actually sample with concurrent
|
||||||
|
partial prefills
|
||||||
|
"""
|
||||||
|
|
||||||
|
prompt = "hello" * 40
|
||||||
|
|
||||||
|
engine_args = EngineArgs(
|
||||||
|
model=model,
|
||||||
|
max_num_partial_prefills=max_num_partial_prefills,
|
||||||
|
max_num_batched_tokens=40,
|
||||||
|
max_num_seqs=8,
|
||||||
|
enable_chunked_prefill=True,
|
||||||
|
gpu_memory_utilization=0.8,
|
||||||
|
)
|
||||||
|
|
||||||
|
engine = LLMEngine.from_engine_args(engine_args)
|
||||||
|
sampling_params = SamplingParams(temperature=0)
|
||||||
|
|
||||||
|
for req_num in range(max_num_partial_prefills):
|
||||||
|
engine.add_request(f"{req_num}", prompt, sampling_params)
|
||||||
|
# first step
|
||||||
|
request_outputs = engine.step()
|
||||||
|
# means all are prefilling
|
||||||
|
assert len(request_outputs) == 0
|
||||||
|
assert len(engine.scheduler[0].running) == max_num_partial_prefills
|
||||||
|
|||||||
@ -1430,6 +1430,17 @@ class SchedulerConfig:
|
|||||||
# Maximum length of a sequence (including prompt and generated text).
|
# Maximum length of a sequence (including prompt and generated text).
|
||||||
max_model_len: int = 8192
|
max_model_len: int = 8192
|
||||||
|
|
||||||
|
# Maximum number of sequences that can be partially prefilled concurrently
|
||||||
|
max_num_partial_prefills: int = 1
|
||||||
|
|
||||||
|
# Maximum number of "very long prompt" sequences that can be prefilled
|
||||||
|
# concurrently (long is defined by long_prefill_threshold)
|
||||||
|
max_long_partial_prefills: int = 1
|
||||||
|
|
||||||
|
# calculate context length that determines which sequences are
|
||||||
|
# considered "long"
|
||||||
|
long_prefill_token_threshold: int = 0
|
||||||
|
|
||||||
# The number of slots to allocate per sequence per
|
# The number of slots to allocate per sequence per
|
||||||
# step, beyond the known token ids. This is used in speculative
|
# step, beyond the known token ids. This is used in speculative
|
||||||
# decoding to store KV activations of tokens which may or may not be
|
# decoding to store KV activations of tokens which may or may not be
|
||||||
@ -1537,6 +1548,18 @@ class SchedulerConfig:
|
|||||||
self.max_num_batched_tokens)
|
self.max_num_batched_tokens)
|
||||||
|
|
||||||
self.chunked_prefill_enabled = self.enable_chunked_prefill
|
self.chunked_prefill_enabled = self.enable_chunked_prefill
|
||||||
|
if self.max_num_partial_prefills > 1:
|
||||||
|
if self.long_prefill_token_threshold == 0:
|
||||||
|
self.long_prefill_token_threshold = int(self.max_model_len *
|
||||||
|
0.04)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Concurrent partial prefills enabled with "
|
||||||
|
"max_num_partial_prefills=%d, max_long_partial_prefills=%d, "
|
||||||
|
"long_prefill_token_threshold=%d",
|
||||||
|
self.max_num_partial_prefills, self.max_long_partial_prefills,
|
||||||
|
self.long_prefill_token_threshold)
|
||||||
|
|
||||||
self._verify_args()
|
self._verify_args()
|
||||||
|
|
||||||
def _verify_args(self) -> None:
|
def _verify_args(self) -> None:
|
||||||
@ -1568,6 +1591,29 @@ class SchedulerConfig:
|
|||||||
f"({self.num_scheduler_steps}) must be greater than or "
|
f"({self.num_scheduler_steps}) must be greater than or "
|
||||||
"equal to 1.")
|
"equal to 1.")
|
||||||
|
|
||||||
|
if self.max_num_partial_prefills < 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"max_num_partial_prefills ({self.max_num_partial_prefills}) "
|
||||||
|
"must be greater than or equal to 1.")
|
||||||
|
elif self.max_num_partial_prefills > 1:
|
||||||
|
if not self.chunked_prefill_enabled:
|
||||||
|
raise ValueError("Chunked prefill must be enabled to set "
|
||||||
|
"max_num_partial_prefills > 1.")
|
||||||
|
|
||||||
|
if self.long_prefill_token_threshold > self.max_model_len:
|
||||||
|
raise ValueError(
|
||||||
|
"long_prefill_token_threshold "
|
||||||
|
f"({self.long_prefill_token_threshold}) cannot be greater "
|
||||||
|
f"than the max_model_len ({self.max_model_len}).")
|
||||||
|
|
||||||
|
if (self.max_long_partial_prefills
|
||||||
|
< 1) or (self.max_long_partial_prefills
|
||||||
|
> self.max_num_partial_prefills):
|
||||||
|
raise ValueError(
|
||||||
|
f"max_long_partial_prefills ({self.max_long_partial_prefills}) "
|
||||||
|
"must be greater than or equal to 1 and less than or equal to "
|
||||||
|
f"max_num_partial_prefills ({self.max_num_partial_prefills}).")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_multi_step(self) -> bool:
|
def is_multi_step(self) -> bool:
|
||||||
return self.num_scheduler_steps > 1
|
return self.num_scheduler_steps > 1
|
||||||
|
|||||||
@ -17,7 +17,7 @@ from vllm.lora.request import LoRARequest
|
|||||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||||
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
|
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
|
||||||
SequenceGroupMetadata, SequenceGroupMetadataDelta,
|
SequenceGroupMetadata, SequenceGroupMetadataDelta,
|
||||||
SequenceStatus)
|
SequenceStage, SequenceStatus)
|
||||||
from vllm.utils import Device, PyObjectCache
|
from vllm.utils import Device, PyObjectCache
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -39,6 +39,7 @@ class PreemptionMode(enum.Enum):
|
|||||||
recompute them when the sequences are resumed, treating the sequences as
|
recompute them when the sequences are resumed, treating the sequences as
|
||||||
new prompts.
|
new prompts.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
SWAP = enum.auto()
|
SWAP = enum.auto()
|
||||||
RECOMPUTE = enum.auto()
|
RECOMPUTE = enum.auto()
|
||||||
|
|
||||||
@ -54,6 +55,7 @@ class SchedulingBudget:
|
|||||||
happen if we only have chunked prefill scheduling, we can remove this
|
happen if we only have chunked prefill scheduling, we can remove this
|
||||||
feature from the API when chunked prefill is enabled by default.
|
feature from the API when chunked prefill is enabled by default.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
token_budget: int
|
token_budget: int
|
||||||
max_num_seqs: int
|
max_num_seqs: int
|
||||||
_request_ids_num_batched_tokens: Set[str] = field(default_factory=set)
|
_request_ids_num_batched_tokens: Set[str] = field(default_factory=set)
|
||||||
@ -132,6 +134,7 @@ class ScheduledSequenceGroup:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class SchedulerOutputs:
|
class SchedulerOutputs:
|
||||||
"""The scheduling decision made from a scheduler."""
|
"""The scheduling decision made from a scheduler."""
|
||||||
|
|
||||||
# Scheduled sequence groups.
|
# Scheduled sequence groups.
|
||||||
scheduled_seq_groups: GenericSequence[ScheduledSequenceGroup]
|
scheduled_seq_groups: GenericSequence[ScheduledSequenceGroup]
|
||||||
# Number of prefill groups scheduled.
|
# Number of prefill groups scheduled.
|
||||||
@ -205,6 +208,7 @@ class SchedulerRunningOutputs:
|
|||||||
Could contain prefill (prefill that's chunked) or decodes. If there's not
|
Could contain prefill (prefill that's chunked) or decodes. If there's not
|
||||||
enough memory, it can be preempted (for recompute) or swapped out.
|
enough memory, it can be preempted (for recompute) or swapped out.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Selected sequences that are running and in a decoding phase.
|
# Selected sequences that are running and in a decoding phase.
|
||||||
decode_seq_groups: List[ScheduledSequenceGroup]
|
decode_seq_groups: List[ScheduledSequenceGroup]
|
||||||
# Selected sequences that are running and in a prefill phase.
|
# Selected sequences that are running and in a prefill phase.
|
||||||
@ -246,6 +250,7 @@ class SchedulerSwappedInOutputs:
|
|||||||
|
|
||||||
Could contain prefill (prefill that's chunked) or decodes.
|
Could contain prefill (prefill that's chunked) or decodes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Selected sequences that are going to be swapped in and is in a
|
# Selected sequences that are going to be swapped in and is in a
|
||||||
# decoding phase.
|
# decoding phase.
|
||||||
decode_seq_groups: List[ScheduledSequenceGroup]
|
decode_seq_groups: List[ScheduledSequenceGroup]
|
||||||
@ -280,6 +285,7 @@ class SchedulerPrefillOutputs:
|
|||||||
Could contain a fresh prefill requests or preempted requests that need
|
Could contain a fresh prefill requests or preempted requests that need
|
||||||
to be recomputed from scratch.
|
to be recomputed from scratch.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Selected sequences for prefill.
|
# Selected sequences for prefill.
|
||||||
seq_groups: List[ScheduledSequenceGroup]
|
seq_groups: List[ScheduledSequenceGroup]
|
||||||
# Ignored sequence groups.
|
# Ignored sequence groups.
|
||||||
@ -321,6 +327,100 @@ def scheduled_seq_group_builder():
|
|||||||
# return ScheduledSequenceGroup(seq_group=None, token_chunk_size=0)
|
# return ScheduledSequenceGroup(seq_group=None, token_chunk_size=0)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PartialPrefillMetadata:
|
||||||
|
"""Holds information about the partial prefills that are currently running
|
||||||
|
during a single iteration of the Scheduler.
|
||||||
|
When chunked prefill is enabled, we allow a certain number of seqs to be
|
||||||
|
partially prefilled during each iteration. Having multiple partial prefills
|
||||||
|
in flight allows us to minimize TTFT and avoid decode starvation in cases
|
||||||
|
where a single sequence group with a very large prompt blocks the queue for
|
||||||
|
too many iterations.
|
||||||
|
The number of long prefill requests is limited so that smaller
|
||||||
|
requests may jump the queue in front of them and get to the decode
|
||||||
|
phase faster.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# A minimum bound on the total number of prefills to be scheduled during
|
||||||
|
# this iteration
|
||||||
|
schedulable_prefills: int
|
||||||
|
|
||||||
|
# The number of long prefill requests currently running
|
||||||
|
long_prefills: int
|
||||||
|
|
||||||
|
scheduler_config: SchedulerConfig
|
||||||
|
|
||||||
|
def can_schedule(self, seq_group: SequenceGroup) -> bool:
|
||||||
|
"""When concurrent partial prefills are enabled,
|
||||||
|
we limit the number of long requests and only accept
|
||||||
|
shorter requests from the queue while running them
|
||||||
|
concurrently"""
|
||||||
|
return not (seq_group.first_seq.get_num_new_tokens()
|
||||||
|
> self.scheduler_config.long_prefill_token_threshold
|
||||||
|
and self.long_prefills
|
||||||
|
>= self.scheduler_config.max_long_partial_prefills
|
||||||
|
and self.scheduler_config.max_num_partial_prefills > 1)
|
||||||
|
|
||||||
|
def maybe_increment_partial_prefills(self,
|
||||||
|
seq_group: SequenceGroup) -> None:
|
||||||
|
# When a new prefill is scheduled, we need to know if it is a
|
||||||
|
# long request
|
||||||
|
if (seq_group.first_seq.get_num_new_tokens()
|
||||||
|
> self.scheduler_config.long_prefill_token_threshold):
|
||||||
|
self.long_prefills += 1
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_queues(
|
||||||
|
cls,
|
||||||
|
running: Deque[SequenceGroup],
|
||||||
|
waiting: Deque[SequenceGroup],
|
||||||
|
scheduler_config: SchedulerConfig,
|
||||||
|
) -> "PartialPrefillMetadata":
|
||||||
|
"""Create a PartialPrefillMetadata object from the current state of
|
||||||
|
the scheduler's queues.
|
||||||
|
This accounts for the currently running prefill requests, and peeks into
|
||||||
|
the waiting queue to see if there are more prefills to potentially be
|
||||||
|
scheduled during this iteration."""
|
||||||
|
prefills = 0
|
||||||
|
long_prefills = 0
|
||||||
|
|
||||||
|
waiting_long_prefills = 0
|
||||||
|
|
||||||
|
for sg in running:
|
||||||
|
if sg.first_seq.data.stage == SequenceStage.PREFILL:
|
||||||
|
prefills += 1
|
||||||
|
if (sg.first_seq.get_num_new_tokens()
|
||||||
|
> scheduler_config.long_prefill_token_threshold):
|
||||||
|
long_prefills += 1
|
||||||
|
|
||||||
|
for sg in waiting:
|
||||||
|
# Don't bother looping through the rest of the queue if we know
|
||||||
|
# there are already at
|
||||||
|
# least max_partial_prefills requests to fill
|
||||||
|
if prefills >= scheduler_config.max_num_partial_prefills:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Don't count long requests from the waiting queue if we aren't
|
||||||
|
# going to schedule them anyway
|
||||||
|
if (sg.first_seq.get_num_new_tokens()
|
||||||
|
> scheduler_config.long_prefill_token_threshold):
|
||||||
|
if (long_prefills + waiting_long_prefills
|
||||||
|
>= scheduler_config.max_long_partial_prefills):
|
||||||
|
continue
|
||||||
|
waiting_long_prefills += 1
|
||||||
|
prefills += 1
|
||||||
|
|
||||||
|
# NB: long_prefills and waiting_long_prefills are tracked separately.
|
||||||
|
# We don't account for the waiting requests here because we need to use
|
||||||
|
# this metadata to track how many have actually been scheduled.
|
||||||
|
return PartialPrefillMetadata(
|
||||||
|
schedulable_prefills=min(
|
||||||
|
prefills, scheduler_config.max_num_partial_prefills),
|
||||||
|
long_prefills=long_prefills,
|
||||||
|
scheduler_config=scheduler_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Scheduler:
|
class Scheduler:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -360,7 +460,8 @@ class Scheduler:
|
|||||||
num_gpu_blocks=num_gpu_blocks,
|
num_gpu_blocks=num_gpu_blocks,
|
||||||
num_cpu_blocks=num_cpu_blocks,
|
num_cpu_blocks=num_cpu_blocks,
|
||||||
sliding_window=self.cache_config.sliding_window,
|
sliding_window=self.cache_config.sliding_window,
|
||||||
enable_caching=self.cache_config.enable_prefix_caching)
|
enable_caching=self.cache_config.enable_prefix_caching,
|
||||||
|
)
|
||||||
|
|
||||||
# Sequence groups in the WAITING state.
|
# Sequence groups in the WAITING state.
|
||||||
# Contain new prefill or preempted requests.
|
# Contain new prefill or preempted requests.
|
||||||
@ -421,6 +522,18 @@ class Scheduler:
|
|||||||
# for processing and deallocation by the free_finished_seq_groups()
|
# for processing and deallocation by the free_finished_seq_groups()
|
||||||
self._async_stopped: List[SequenceGroup] = []
|
self._async_stopped: List[SequenceGroup] = []
|
||||||
|
|
||||||
|
# List with the chunk sizes to hand out to each sequence depending
|
||||||
|
# on how many partial prefills are running. This is slightly faster than
|
||||||
|
# running an integer division every time a prefill is scheduled.
|
||||||
|
# This splits the budget evenly among all prefills.
|
||||||
|
self.partial_prefill_budget_lookup_list = [0] * (
|
||||||
|
self.scheduler_config.max_num_partial_prefills + 1)
|
||||||
|
self.partial_prefill_budget_lookup_list[0] = (
|
||||||
|
scheduler_config.max_num_batched_tokens)
|
||||||
|
for i in range(1, self.scheduler_config.max_num_partial_prefills + 1):
|
||||||
|
self.partial_prefill_budget_lookup_list[i] = (
|
||||||
|
scheduler_config.max_num_batched_tokens // i)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def next_cache_id(self):
|
def next_cache_id(self):
|
||||||
return (self.cache_id + 1) % self.num_cache_iters
|
return (self.cache_id + 1) % self.num_cache_iters
|
||||||
@ -500,8 +613,8 @@ class Scheduler:
|
|||||||
self.block_manager.free_cross(seq_group)
|
self.block_manager.free_cross(seq_group)
|
||||||
|
|
||||||
def has_unfinished_seqs(self) -> bool:
|
def has_unfinished_seqs(self) -> bool:
|
||||||
return len(self.waiting) != 0 or len(self.running) != 0 or len(
|
return (len(self.waiting) != 0 or len(self.running) != 0
|
||||||
self.swapped) != 0
|
or len(self.swapped) != 0)
|
||||||
|
|
||||||
def get_prefix_cache_hit_rate(self, device: Device) -> float:
|
def get_prefix_cache_hit_rate(self, device: Device) -> float:
|
||||||
return self.block_manager.get_prefix_cache_hit_rate(device)
|
return self.block_manager.get_prefix_cache_hit_rate(device)
|
||||||
@ -523,6 +636,7 @@ class Scheduler:
|
|||||||
budget: SchedulingBudget,
|
budget: SchedulingBudget,
|
||||||
curr_loras: Optional[Set[int]],
|
curr_loras: Optional[Set[int]],
|
||||||
enable_chunking: bool = False,
|
enable_chunking: bool = False,
|
||||||
|
partial_prefill_metadata: Optional[PartialPrefillMetadata] = None,
|
||||||
) -> SchedulerRunningOutputs:
|
) -> SchedulerRunningOutputs:
|
||||||
"""Schedule sequence groups that are running.
|
"""Schedule sequence groups that are running.
|
||||||
|
|
||||||
@ -537,12 +651,14 @@ class Scheduler:
|
|||||||
chunked number of tokens are scheduled if
|
chunked number of tokens are scheduled if
|
||||||
`budget.num_batched_tokens` has not enough capacity to schedule
|
`budget.num_batched_tokens` has not enough capacity to schedule
|
||||||
all tokens.
|
all tokens.
|
||||||
|
partial_prefill_metadata: information about the partial prefills
|
||||||
|
that are currently running
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
SchedulerRunningOutputs.
|
SchedulerRunningOutputs.
|
||||||
"""
|
"""
|
||||||
ret: SchedulerRunningOutputs = \
|
ret: SchedulerRunningOutputs = self._scheduler_running_outputs_cache[
|
||||||
self._scheduler_running_outputs_cache[self.cache_id].get_object()
|
self.cache_id].get_object()
|
||||||
ret.blocks_to_swap_out.clear()
|
ret.blocks_to_swap_out.clear()
|
||||||
ret.blocks_to_copy.clear()
|
ret.blocks_to_copy.clear()
|
||||||
ret.decode_seq_groups.clear()
|
ret.decode_seq_groups.clear()
|
||||||
@ -577,10 +693,14 @@ class Scheduler:
|
|||||||
# 2. If a sequence is running with non-chunked prefill, then
|
# 2. If a sequence is running with non-chunked prefill, then
|
||||||
# there it's a decoding sequence, and the cached tokens info is
|
# there it's a decoding sequence, and the cached tokens info is
|
||||||
# irrelevant.
|
# irrelevant.
|
||||||
num_uncached_new_tokens, _ = (
|
num_uncached_new_tokens, _ = \
|
||||||
self._get_num_new_uncached_and_cached_tokens(
|
self._get_num_new_uncached_and_cached_tokens(
|
||||||
seq_group, SequenceStatus.RUNNING, enable_chunking,
|
seq_group,
|
||||||
budget))
|
SequenceStatus.RUNNING,
|
||||||
|
enable_chunking,
|
||||||
|
budget,
|
||||||
|
partial_prefill_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
num_running_tokens = num_uncached_new_tokens
|
num_running_tokens = num_uncached_new_tokens
|
||||||
if num_running_tokens == 0:
|
if num_running_tokens == 0:
|
||||||
@ -593,8 +713,8 @@ class Scheduler:
|
|||||||
# to process the final tokens. The check below avoids this extra
|
# to process the final tokens. The check below avoids this extra
|
||||||
# decode run when the model max len is reached, in order to avoid
|
# decode run when the model max len is reached, in order to avoid
|
||||||
# a memory overflow.
|
# a memory overflow.
|
||||||
if self.use_async_output_proc and seq_group.seqs[0].get_len(
|
if (self.use_async_output_proc and seq_group.seqs[0].get_len()
|
||||||
) > self.scheduler_config.max_model_len:
|
> self.scheduler_config.max_model_len):
|
||||||
self._async_stopped.append(seq_group)
|
self._async_stopped.append(seq_group)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -653,8 +773,9 @@ class Scheduler:
|
|||||||
self._append_slots(seq_group, blocks_to_copy, enable_chunking)
|
self._append_slots(seq_group, blocks_to_copy, enable_chunking)
|
||||||
is_prefill = seq_group.is_prefill()
|
is_prefill = seq_group.is_prefill()
|
||||||
|
|
||||||
scheduled_seq_group: ScheduledSequenceGroup = \
|
scheduled_seq_group: ScheduledSequenceGroup = (
|
||||||
self._scheduled_seq_group_cache[self.cache_id].get_object()
|
self._scheduled_seq_group_cache[
|
||||||
|
self.cache_id].get_object())
|
||||||
scheduled_seq_group.seq_group = seq_group
|
scheduled_seq_group.seq_group = seq_group
|
||||||
if is_prefill:
|
if is_prefill:
|
||||||
scheduled_seq_group.token_chunk_size = num_running_tokens
|
scheduled_seq_group.token_chunk_size = num_running_tokens
|
||||||
@ -731,7 +852,8 @@ class Scheduler:
|
|||||||
logger.warning(
|
logger.warning(
|
||||||
"Failing the request %s because there's not enough kv "
|
"Failing the request %s because there's not enough kv "
|
||||||
"cache blocks to run the entire sequence.",
|
"cache blocks to run the entire sequence.",
|
||||||
seq_group.request_id)
|
seq_group.request_id,
|
||||||
|
)
|
||||||
for seq in seq_group.get_seqs():
|
for seq in seq_group.get_seqs():
|
||||||
seq.status = SequenceStatus.FINISHED_IGNORED
|
seq.status = SequenceStatus.FINISHED_IGNORED
|
||||||
infeasible_seq_groups.append(seq_group)
|
infeasible_seq_groups.append(seq_group)
|
||||||
@ -800,16 +922,17 @@ class Scheduler:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _get_prompt_limit(self, seq_group: SequenceGroup) -> int:
|
def _get_prompt_limit(self, seq_group: SequenceGroup) -> int:
|
||||||
if self.scheduler_config.chunked_prefill_enabled and \
|
if (self.scheduler_config.chunked_prefill_enabled
|
||||||
not self.scheduler_config.is_multi_step:
|
and not self.scheduler_config.is_multi_step):
|
||||||
prompt_limit = self.scheduler_config.max_model_len
|
prompt_limit = self.scheduler_config.max_model_len
|
||||||
else:
|
else:
|
||||||
prompt_limit = min(self.scheduler_config.max_model_len,
|
prompt_limit = min(
|
||||||
self.scheduler_config.max_num_batched_tokens)
|
self.scheduler_config.max_model_len,
|
||||||
|
self.scheduler_config.max_num_batched_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
# Model is fine tuned with long context. Return the fine tuned max_len.
|
# Model is fine tuned with long context. Return the fine tuned max_len.
|
||||||
if (seq_group.lora_request
|
if seq_group.lora_request and seq_group.lora_request.long_lora_max_len:
|
||||||
and seq_group.lora_request.long_lora_max_len):
|
|
||||||
assert prompt_limit <= seq_group.lora_request.long_lora_max_len
|
assert prompt_limit <= seq_group.lora_request.long_lora_max_len
|
||||||
return seq_group.lora_request.long_lora_max_len
|
return seq_group.lora_request.long_lora_max_len
|
||||||
else:
|
else:
|
||||||
@ -817,7 +940,7 @@ class Scheduler:
|
|||||||
|
|
||||||
def _get_priority(self,
|
def _get_priority(self,
|
||||||
seq_group: SequenceGroup) -> Tuple[Optional[int], float]:
|
seq_group: SequenceGroup) -> Tuple[Optional[int], float]:
|
||||||
""" Get the priority of the sequence group.
|
"""Get the priority of the sequence group.
|
||||||
Highest preference to user-defined priority, followed by arrival time.
|
Highest preference to user-defined priority, followed by arrival time.
|
||||||
Args:
|
Args:
|
||||||
seq_group: The sequence group input.
|
seq_group: The sequence group input.
|
||||||
@ -850,14 +973,14 @@ class Scheduler:
|
|||||||
if waiting_queue:
|
if waiting_queue:
|
||||||
seq_group = waiting_queue.popleft()
|
seq_group = waiting_queue.popleft()
|
||||||
num_new_seqs = seq_group.get_max_num_running_seqs()
|
num_new_seqs = seq_group.get_max_num_running_seqs()
|
||||||
num_new_tokens_uncached, _ = (
|
num_new_tokens_uncached, _ = \
|
||||||
self._get_num_new_uncached_and_cached_tokens(
|
self._get_num_new_uncached_and_cached_tokens(
|
||||||
seq_group, SequenceStatus.WAITING, False, budget))
|
seq_group, SequenceStatus.WAITING, False, budget)
|
||||||
|
|
||||||
#Only preempt if priority inversion exists
|
# Only preempt if priority inversion exists
|
||||||
while running_queue and self._get_priority(
|
while running_queue and self._get_priority(
|
||||||
running_queue[-1]) > self._get_priority(seq_group):
|
running_queue[-1]) > self._get_priority(seq_group):
|
||||||
#Only preempt if waiting sequence cannot be allocated
|
# Only preempt if waiting sequence cannot be allocated
|
||||||
can_allocate = self.block_manager.can_allocate(seq_group)
|
can_allocate = self.block_manager.can_allocate(seq_group)
|
||||||
if (num_new_tokens_uncached > 0
|
if (num_new_tokens_uncached > 0
|
||||||
and can_allocate == AllocStatus.OK
|
and can_allocate == AllocStatus.OK
|
||||||
@ -867,7 +990,7 @@ class Scheduler:
|
|||||||
)):
|
)):
|
||||||
break
|
break
|
||||||
|
|
||||||
#Adjust budget to remove the victim sequence group
|
# Adjust budget to remove the victim sequence group
|
||||||
vseq_group = running_queue.pop()
|
vseq_group = running_queue.pop()
|
||||||
num_running_tokens_uncached, _ = (
|
num_running_tokens_uncached, _ = (
|
||||||
self._get_num_new_uncached_and_cached_tokens(
|
self._get_num_new_uncached_and_cached_tokens(
|
||||||
@ -878,11 +1001,11 @@ class Scheduler:
|
|||||||
budget.subtract_num_seqs(vseq_group.request_id,
|
budget.subtract_num_seqs(vseq_group.request_id,
|
||||||
num_running_seqs)
|
num_running_seqs)
|
||||||
|
|
||||||
#Preempt out the victim sequence group
|
# Preempt out the victim sequence group
|
||||||
self._preempt(vseq_group, blocks_to_swap_out)
|
self._preempt(vseq_group, blocks_to_swap_out)
|
||||||
waiting_queue.appendleft(vseq_group)
|
waiting_queue.appendleft(vseq_group)
|
||||||
force_preemption_count += 1
|
force_preemption_count += 1
|
||||||
#Put the sequence back into the waiting queue
|
# Put the sequence back into the waiting queue
|
||||||
waiting_queue.appendleft(seq_group)
|
waiting_queue.appendleft(seq_group)
|
||||||
|
|
||||||
waiting_queue = deque(sorted(waiting_queue, key=self._get_priority))
|
waiting_queue = deque(sorted(waiting_queue, key=self._get_priority))
|
||||||
@ -896,6 +1019,7 @@ class Scheduler:
|
|||||||
budget: SchedulingBudget,
|
budget: SchedulingBudget,
|
||||||
curr_loras: Optional[Set[int]],
|
curr_loras: Optional[Set[int]],
|
||||||
enable_chunking: bool = False,
|
enable_chunking: bool = False,
|
||||||
|
partial_prefill_metadata: Optional[PartialPrefillMetadata] = None,
|
||||||
) -> SchedulerPrefillOutputs:
|
) -> SchedulerPrefillOutputs:
|
||||||
"""Schedule sequence groups that are in prefill stage.
|
"""Schedule sequence groups that are in prefill stage.
|
||||||
|
|
||||||
@ -916,10 +1040,20 @@ class Scheduler:
|
|||||||
chunked number of tokens are scheduled if
|
chunked number of tokens are scheduled if
|
||||||
`budget.num_batched_tokens` has not enough capacity to schedule
|
`budget.num_batched_tokens` has not enough capacity to schedule
|
||||||
all tokens.
|
all tokens.
|
||||||
|
partial_prefill_metadata: information about the partial prefills
|
||||||
|
that are currently running
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
SchedulerPrefillOutputs.
|
SchedulerPrefillOutputs.
|
||||||
"""
|
"""
|
||||||
|
if budget.remaining_token_budget() == 0:
|
||||||
|
# Do nothing: Can't add any more prefill anyway
|
||||||
|
return SchedulerPrefillOutputs(
|
||||||
|
seq_groups=[],
|
||||||
|
ignored_seq_groups=[],
|
||||||
|
num_lookahead_slots=self._get_num_lookahead_slots(
|
||||||
|
is_prefill=True, enable_chunking=enable_chunking),
|
||||||
|
)
|
||||||
ignored_seq_groups: List[SequenceGroup] = []
|
ignored_seq_groups: List[SequenceGroup] = []
|
||||||
seq_groups: List[ScheduledSequenceGroup] = []
|
seq_groups: List[ScheduledSequenceGroup] = []
|
||||||
|
|
||||||
@ -933,10 +1067,19 @@ class Scheduler:
|
|||||||
assert len(waiting_seqs) == 1, (
|
assert len(waiting_seqs) == 1, (
|
||||||
"Waiting sequence group should have only one prompt "
|
"Waiting sequence group should have only one prompt "
|
||||||
"sequence.")
|
"sequence.")
|
||||||
|
if (partial_prefill_metadata is not None
|
||||||
|
and not partial_prefill_metadata.can_schedule(seq_group)):
|
||||||
|
leftover_waiting_sequences.appendleft(seq_group)
|
||||||
|
waiting_queue.popleft()
|
||||||
|
continue
|
||||||
num_new_tokens_uncached, num_new_tokens_cached = (
|
num_new_tokens_uncached, num_new_tokens_cached = (
|
||||||
self._get_num_new_uncached_and_cached_tokens(
|
self._get_num_new_uncached_and_cached_tokens(
|
||||||
seq_group, SequenceStatus.WAITING, enable_chunking,
|
seq_group,
|
||||||
budget))
|
SequenceStatus.WAITING,
|
||||||
|
enable_chunking,
|
||||||
|
budget,
|
||||||
|
partial_prefill_metadata=partial_prefill_metadata,
|
||||||
|
))
|
||||||
num_new_tokens = num_new_tokens_uncached + num_new_tokens_cached
|
num_new_tokens = num_new_tokens_uncached + num_new_tokens_cached
|
||||||
|
|
||||||
if not enable_chunking:
|
if not enable_chunking:
|
||||||
@ -947,7 +1090,10 @@ class Scheduler:
|
|||||||
if num_new_tokens > prompt_limit:
|
if num_new_tokens > prompt_limit:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Input prompt (%d tokens) is too long"
|
"Input prompt (%d tokens) is too long"
|
||||||
" and exceeds limit of %d", num_new_tokens, prompt_limit)
|
" and exceeds limit of %d",
|
||||||
|
num_new_tokens,
|
||||||
|
prompt_limit,
|
||||||
|
)
|
||||||
for seq in waiting_seqs:
|
for seq in waiting_seqs:
|
||||||
seq.status = SequenceStatus.FINISHED_IGNORED
|
seq.status = SequenceStatus.FINISHED_IGNORED
|
||||||
ignored_seq_groups.append(seq_group)
|
ignored_seq_groups.append(seq_group)
|
||||||
@ -968,7 +1114,9 @@ class Scheduler:
|
|||||||
logger.warning(
|
logger.warning(
|
||||||
"Input prompt (%d tokens) + lookahead slots (%d) is "
|
"Input prompt (%d tokens) + lookahead slots (%d) is "
|
||||||
"too long and exceeds the capacity of block_manager",
|
"too long and exceeds the capacity of block_manager",
|
||||||
num_new_tokens, num_lookahead_slots)
|
num_new_tokens,
|
||||||
|
num_lookahead_slots,
|
||||||
|
)
|
||||||
for seq in waiting_seqs:
|
for seq in waiting_seqs:
|
||||||
seq.status = SequenceStatus.FINISHED_IGNORED
|
seq.status = SequenceStatus.FINISHED_IGNORED
|
||||||
ignored_seq_groups.append(seq_group)
|
ignored_seq_groups.append(seq_group)
|
||||||
@ -1009,6 +1157,10 @@ class Scheduler:
|
|||||||
waiting_queue.popleft()
|
waiting_queue.popleft()
|
||||||
self._allocate_and_set_running(seq_group)
|
self._allocate_and_set_running(seq_group)
|
||||||
|
|
||||||
|
if partial_prefill_metadata is not None:
|
||||||
|
partial_prefill_metadata.maybe_increment_partial_prefills(
|
||||||
|
seq_group)
|
||||||
|
|
||||||
if enable_chunking and self.scheduler_config.is_multi_step:
|
if enable_chunking and self.scheduler_config.is_multi_step:
|
||||||
blocks_to_copy: List[Tuple[int, int]] = []
|
blocks_to_copy: List[Tuple[int, int]] = []
|
||||||
# init_multi_step_from_lookahead_slots happens in append_slots
|
# init_multi_step_from_lookahead_slots happens in append_slots
|
||||||
@ -1024,7 +1176,8 @@ class Scheduler:
|
|||||||
num_scheduler_steps=self.scheduler_config.
|
num_scheduler_steps=self.scheduler_config.
|
||||||
num_scheduler_steps,
|
num_scheduler_steps,
|
||||||
is_multi_step=self.scheduler_config.is_multi_step,
|
is_multi_step=self.scheduler_config.is_multi_step,
|
||||||
enable_chunking=enable_chunking)
|
enable_chunking=enable_chunking,
|
||||||
|
)
|
||||||
|
|
||||||
seq_groups.append(
|
seq_groups.append(
|
||||||
ScheduledSequenceGroup(seq_group=seq_group,
|
ScheduledSequenceGroup(seq_group=seq_group,
|
||||||
@ -1045,11 +1198,12 @@ class Scheduler:
|
|||||||
seq_groups=seq_groups,
|
seq_groups=seq_groups,
|
||||||
ignored_seq_groups=ignored_seq_groups,
|
ignored_seq_groups=ignored_seq_groups,
|
||||||
num_lookahead_slots=self._get_num_lookahead_slots(
|
num_lookahead_slots=self._get_num_lookahead_slots(
|
||||||
is_prefill=True, enable_chunking=enable_chunking))
|
is_prefill=True, enable_chunking=enable_chunking),
|
||||||
|
)
|
||||||
|
|
||||||
def _schedule_default(self) -> SchedulerOutputs:
|
def _schedule_default(self) -> SchedulerOutputs:
|
||||||
"""Schedule queued requests.
|
"""Schedule queued requests.
|
||||||
|
|
||||||
The current policy is designed to optimize the throughput. First,
|
The current policy is designed to optimize the throughput. First,
|
||||||
it batches as many prefill requests as possible. And it schedules
|
it batches as many prefill requests as possible. And it schedules
|
||||||
decodes. If there's a pressure on GPU memory, decode requests can
|
decodes. If there's a pressure on GPU memory, decode requests can
|
||||||
@ -1065,9 +1219,9 @@ class Scheduler:
|
|||||||
for seq_group in self.running:
|
for seq_group in self.running:
|
||||||
budget.add_num_seqs(seq_group.request_id,
|
budget.add_num_seqs(seq_group.request_id,
|
||||||
seq_group.get_max_num_running_seqs())
|
seq_group.get_max_num_running_seqs())
|
||||||
curr_loras = set(
|
curr_loras = (set(
|
||||||
seq_group.lora_int_id for seq_group in self.running
|
seq_group.lora_int_id for seq_group in self.running
|
||||||
if seq_group.lora_int_id > 0) if self.lora_enabled else None
|
if seq_group.lora_int_id > 0) if self.lora_enabled else None)
|
||||||
|
|
||||||
prefills = SchedulerPrefillOutputs.create_empty()
|
prefills = SchedulerPrefillOutputs.create_empty()
|
||||||
running_scheduled = SchedulerRunningOutputs.create_empty()
|
running_scheduled = SchedulerRunningOutputs.create_empty()
|
||||||
@ -1093,9 +1247,10 @@ class Scheduler:
|
|||||||
|
|
||||||
# If any sequence group is preempted, do not swap in any sequence
|
# If any sequence group is preempted, do not swap in any sequence
|
||||||
# group. because it means there's no slot for new running requests.
|
# group. because it means there's no slot for new running requests.
|
||||||
if len(running_scheduled.preempted) + len(
|
if (len(running_scheduled.preempted) +
|
||||||
running_scheduled.swapped_out) == 0:
|
len(running_scheduled.swapped_out) == 0):
|
||||||
swapped_in = self._schedule_swapped(budget, curr_loras)
|
swapped_in = \
|
||||||
|
self._schedule_swapped(budget, curr_loras)
|
||||||
|
|
||||||
assert (budget.num_batched_tokens
|
assert (budget.num_batched_tokens
|
||||||
<= self.scheduler_config.max_num_batched_tokens)
|
<= self.scheduler_config.max_num_batched_tokens)
|
||||||
@ -1115,8 +1270,8 @@ class Scheduler:
|
|||||||
|
|
||||||
# Update swapped requests.
|
# Update swapped requests.
|
||||||
self.swapped.extend(running_scheduled.swapped_out)
|
self.swapped.extend(running_scheduled.swapped_out)
|
||||||
preempted = (len(running_scheduled.preempted) +
|
preempted = len(running_scheduled.preempted) + len(
|
||||||
len(running_scheduled.swapped_out))
|
running_scheduled.swapped_out)
|
||||||
|
|
||||||
# There should be no prefill from running queue because this policy
|
# There should be no prefill from running queue because this policy
|
||||||
# doesn't allow chunked prefills.
|
# doesn't allow chunked prefills.
|
||||||
@ -1154,7 +1309,7 @@ class Scheduler:
|
|||||||
|
|
||||||
def _schedule_chunked_prefill(self) -> SchedulerOutputs:
|
def _schedule_chunked_prefill(self) -> SchedulerOutputs:
|
||||||
"""Schedule queued requests.
|
"""Schedule queued requests.
|
||||||
|
|
||||||
Chunked prefill allows to chunk prefill requests, batch them together
|
Chunked prefill allows to chunk prefill requests, batch them together
|
||||||
with decode requests. This policy 1. schedule as many decoding requests
|
with decode requests. This policy 1. schedule as many decoding requests
|
||||||
as possible. 2. schedule chunked prefill requests that are not
|
as possible. 2. schedule chunked prefill requests that are not
|
||||||
@ -1175,10 +1330,20 @@ class Scheduler:
|
|||||||
prefills = SchedulerPrefillOutputs.create_empty()
|
prefills = SchedulerPrefillOutputs.create_empty()
|
||||||
swapped_in = SchedulerSwappedInOutputs.create_empty()
|
swapped_in = SchedulerSwappedInOutputs.create_empty()
|
||||||
|
|
||||||
|
# Create partial prefill metadata
|
||||||
|
partial_prefill_metadata = PartialPrefillMetadata.from_queues(
|
||||||
|
running=self.running,
|
||||||
|
waiting=self.waiting,
|
||||||
|
scheduler_config=self.scheduler_config,
|
||||||
|
)
|
||||||
|
|
||||||
# Decoding should be always scheduled first by fcfs.
|
# Decoding should be always scheduled first by fcfs.
|
||||||
running_scheduled = self._schedule_running(budget,
|
running_scheduled = self._schedule_running(
|
||||||
curr_loras,
|
budget,
|
||||||
enable_chunking=True)
|
curr_loras,
|
||||||
|
enable_chunking=True,
|
||||||
|
partial_prefill_metadata=partial_prefill_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
# Schedule swapped out requests.
|
# Schedule swapped out requests.
|
||||||
# If preemption happens, it means we don't have space for swap-in.
|
# If preemption happens, it means we don't have space for swap-in.
|
||||||
@ -1186,9 +1351,12 @@ class Scheduler:
|
|||||||
running_scheduled.swapped_out) == 0:
|
running_scheduled.swapped_out) == 0:
|
||||||
swapped_in = self._schedule_swapped(budget, curr_loras)
|
swapped_in = self._schedule_swapped(budget, curr_loras)
|
||||||
|
|
||||||
prefills = self._schedule_prefills(budget,
|
prefills = self._schedule_prefills(
|
||||||
curr_loras,
|
budget,
|
||||||
enable_chunking=True)
|
curr_loras,
|
||||||
|
enable_chunking=True,
|
||||||
|
partial_prefill_metadata=partial_prefill_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
assert (budget.num_batched_tokens
|
assert (budget.num_batched_tokens
|
||||||
<= self.scheduler_config.max_num_batched_tokens)
|
<= self.scheduler_config.max_num_batched_tokens)
|
||||||
@ -1207,8 +1375,15 @@ class Scheduler:
|
|||||||
[s.seq_group for s in swapped_in.prefill_seq_groups])
|
[s.seq_group for s in swapped_in.prefill_seq_groups])
|
||||||
self.running.extend(
|
self.running.extend(
|
||||||
[s.seq_group for s in running_scheduled.decode_seq_groups])
|
[s.seq_group for s in running_scheduled.decode_seq_groups])
|
||||||
|
# Because multiple prefills may be running concurrently, we need to
|
||||||
|
# make sure that prefills which are scheduled to finish are listed
|
||||||
|
# before those that won't. This is so that on the next scheduling
|
||||||
|
# iteration when they have transitioned to the decode stage, they are
|
||||||
|
# properly prioritized over sequences that are still in the prefill
|
||||||
|
# stage.
|
||||||
self.running.extend(
|
self.running.extend(
|
||||||
[s.seq_group for s in running_scheduled.prefill_seq_groups])
|
self._order_finishing_prefills_first(
|
||||||
|
running_scheduled.prefill_seq_groups))
|
||||||
self.running.extend([s.seq_group for s in prefills.seq_groups])
|
self.running.extend([s.seq_group for s in prefills.seq_groups])
|
||||||
|
|
||||||
# Update swapped requests.
|
# Update swapped requests.
|
||||||
@ -1225,7 +1400,7 @@ class Scheduler:
|
|||||||
# If all prompts, then we set num_lookahead_slots to 0
|
# If all prompts, then we set num_lookahead_slots to 0
|
||||||
# this allows us to go through the `no_spec` path in
|
# this allows us to go through the `no_spec` path in
|
||||||
# `spec_decode_worker.py`
|
# `spec_decode_worker.py`
|
||||||
all_prefills = (len(scheduled_seq_groups) == num_prefill_groups)
|
all_prefills = len(scheduled_seq_groups) == num_prefill_groups
|
||||||
num_lookahead_slots = (0 if
|
num_lookahead_slots = (0 if
|
||||||
(all_prefills
|
(all_prefills
|
||||||
and not self.scheduler_config.is_multi_step)
|
and not self.scheduler_config.is_multi_step)
|
||||||
@ -1247,6 +1422,21 @@ class Scheduler:
|
|||||||
len(running_scheduled.swapped_out)),
|
len(running_scheduled.swapped_out)),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _order_finishing_prefills_first(
|
||||||
|
self, scheduled_prefill_seqs: List[ScheduledSequenceGroup]
|
||||||
|
) -> List[SequenceGroup]:
|
||||||
|
"""Returns a list of prefilling SequenceGroups where sequences that are
|
||||||
|
scheduled to finish prefilling are listed first"""
|
||||||
|
finishing = [
|
||||||
|
s.seq_group for s in scheduled_prefill_seqs
|
||||||
|
if s.seq_group.get_num_uncomputed_tokens() == s.token_chunk_size
|
||||||
|
]
|
||||||
|
not_finishing = [
|
||||||
|
s.seq_group for s in scheduled_prefill_seqs
|
||||||
|
if s.seq_group.get_num_uncomputed_tokens() != s.token_chunk_size
|
||||||
|
]
|
||||||
|
return finishing + not_finishing
|
||||||
|
|
||||||
def _schedule(self) -> SchedulerOutputs:
|
def _schedule(self) -> SchedulerOutputs:
|
||||||
"""Schedule queued requests."""
|
"""Schedule queued requests."""
|
||||||
if self.scheduler_config.chunked_prefill_enabled:
|
if self.scheduler_config.chunked_prefill_enabled:
|
||||||
@ -1385,10 +1575,12 @@ class Scheduler:
|
|||||||
# between engine and worker.
|
# between engine and worker.
|
||||||
# the subsequent comms can still use delta, but
|
# the subsequent comms can still use delta, but
|
||||||
# `multi_modal_data` will be None.
|
# `multi_modal_data` will be None.
|
||||||
multi_modal_data=seq_group.multi_modal_data
|
multi_modal_data=(seq_group.multi_modal_data
|
||||||
if scheduler_outputs.num_prefill_groups > 0 else None,
|
if scheduler_outputs.num_prefill_groups
|
||||||
multi_modal_placeholders=seq_group.multi_modal_placeholders
|
> 0 else None),
|
||||||
if scheduler_outputs.num_prefill_groups > 0 else None,
|
multi_modal_placeholders=(
|
||||||
|
seq_group.multi_modal_placeholders
|
||||||
|
if scheduler_outputs.num_prefill_groups > 0 else None),
|
||||||
mm_processor_kwargs=seq_group.mm_processor_kwargs,
|
mm_processor_kwargs=seq_group.mm_processor_kwargs,
|
||||||
prompt_adapter_request=seq_group.prompt_adapter_request,
|
prompt_adapter_request=seq_group.prompt_adapter_request,
|
||||||
)
|
)
|
||||||
@ -1494,10 +1686,12 @@ class Scheduler:
|
|||||||
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
|
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
|
||||||
seq.status = SequenceStatus.RUNNING
|
seq.status = SequenceStatus.RUNNING
|
||||||
|
|
||||||
def _append_slots(self,
|
def _append_slots(
|
||||||
seq_group: SequenceGroup,
|
self,
|
||||||
blocks_to_copy: List[Tuple[int, int]],
|
seq_group: SequenceGroup,
|
||||||
enable_chunking: bool = False) -> None:
|
blocks_to_copy: List[Tuple[int, int]],
|
||||||
|
enable_chunking: bool = False,
|
||||||
|
) -> None:
|
||||||
"""Appends new slots to the sequences in the given sequence group.
|
"""Appends new slots to the sequences in the given sequence group.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1518,7 +1712,8 @@ class Scheduler:
|
|||||||
num_lookahead_slots,
|
num_lookahead_slots,
|
||||||
num_scheduler_steps=self.scheduler_config.num_scheduler_steps,
|
num_scheduler_steps=self.scheduler_config.num_scheduler_steps,
|
||||||
is_multi_step=self.scheduler_config.is_multi_step,
|
is_multi_step=self.scheduler_config.is_multi_step,
|
||||||
enable_chunking=enable_chunking)
|
enable_chunking=enable_chunking,
|
||||||
|
)
|
||||||
|
|
||||||
seq_status: Optional[SequenceStatus] = SequenceStatus.RUNNING
|
seq_status: Optional[SequenceStatus] = SequenceStatus.RUNNING
|
||||||
if self.scheduler_config.is_multi_step and enable_chunking:
|
if self.scheduler_config.is_multi_step and enable_chunking:
|
||||||
@ -1561,8 +1756,11 @@ class Scheduler:
|
|||||||
"not enough KV cache space. This can affect the end-to-end "
|
"not enough KV cache space. This can affect the end-to-end "
|
||||||
"performance. Increase gpu_memory_utilization or "
|
"performance. Increase gpu_memory_utilization or "
|
||||||
"tensor_parallel_size to provide more KV cache memory. "
|
"tensor_parallel_size to provide more KV cache memory. "
|
||||||
"total_num_cumulative_preemption=%d", seq_group.request_id,
|
"total_num_cumulative_preemption=%d",
|
||||||
preemption_mode, self.num_cumulative_preemption + 1)
|
seq_group.request_id,
|
||||||
|
preemption_mode,
|
||||||
|
self.num_cumulative_preemption + 1,
|
||||||
|
)
|
||||||
self.num_cumulative_preemption += 1
|
self.num_cumulative_preemption += 1
|
||||||
|
|
||||||
if preemption_mode == PreemptionMode.RECOMPUTE:
|
if preemption_mode == PreemptionMode.RECOMPUTE:
|
||||||
@ -1668,6 +1866,7 @@ class Scheduler:
|
|||||||
status: SequenceStatus,
|
status: SequenceStatus,
|
||||||
enable_chunking: bool,
|
enable_chunking: bool,
|
||||||
budget: SchedulingBudget,
|
budget: SchedulingBudget,
|
||||||
|
partial_prefill_metadata: Optional[PartialPrefillMetadata] = None,
|
||||||
) -> Tuple[int, int]:
|
) -> Tuple[int, int]:
|
||||||
"""
|
"""
|
||||||
Returns the number of new uncached and cached tokens to schedule for a
|
Returns the number of new uncached and cached tokens to schedule for a
|
||||||
@ -1691,6 +1890,8 @@ class Scheduler:
|
|||||||
to schedule.
|
to schedule.
|
||||||
enable_chunking: Whether to chunk the number of tokens to compute.
|
enable_chunking: Whether to chunk the number of tokens to compute.
|
||||||
budget: The budget to chunk the number of tokens to compute.
|
budget: The budget to chunk the number of tokens to compute.
|
||||||
|
partial_prefill_metadata: information about the partial prefills
|
||||||
|
that are currently running
|
||||||
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -1768,6 +1969,8 @@ class Scheduler:
|
|||||||
budget,
|
budget,
|
||||||
self._get_prompt_limit(seq_group),
|
self._get_prompt_limit(seq_group),
|
||||||
num_uncached_new_tokens,
|
num_uncached_new_tokens,
|
||||||
|
self.partial_prefill_budget_lookup_list,
|
||||||
|
partial_prefill_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
return num_uncached_new_tokens, num_cached_new_tokens
|
return num_uncached_new_tokens, num_cached_new_tokens
|
||||||
@ -1779,6 +1982,8 @@ class Scheduler:
|
|||||||
budget: SchedulingBudget,
|
budget: SchedulingBudget,
|
||||||
prompt_limit: int,
|
prompt_limit: int,
|
||||||
num_new_tokens: int,
|
num_new_tokens: int,
|
||||||
|
partial_prefill_budget_lookup_list: List[int],
|
||||||
|
partial_prefill_metadata: Optional[PartialPrefillMetadata] = None,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Chunks the number of new tokens to schedule based on the budget when
|
Chunks the number of new tokens to schedule based on the budget when
|
||||||
@ -1811,29 +2016,31 @@ class Scheduler:
|
|||||||
# the sequence.
|
# the sequence.
|
||||||
return num_new_tokens
|
return num_new_tokens
|
||||||
|
|
||||||
return (0 if num_new_tokens > remaining_token_budget else
|
return 0 if num_new_tokens > \
|
||||||
num_new_tokens)
|
remaining_token_budget else num_new_tokens
|
||||||
|
|
||||||
|
# Get the number of tokens to allocate to this prefill slot
|
||||||
|
prefill_slot_budget = (
|
||||||
|
remaining_token_budget if partial_prefill_metadata is None else
|
||||||
|
partial_prefill_budget_lookup_list[
|
||||||
|
partial_prefill_metadata.schedulable_prefills])
|
||||||
|
|
||||||
if cache_config.enable_prefix_caching:
|
if cache_config.enable_prefix_caching:
|
||||||
# Adjust the remaining token budget to be divisible by the block
|
# When prefix caching is enabled and we're partially prefilling
|
||||||
# size when prefix caching is enabled.
|
# a sequence, we always allocate a number of new tokens that is
|
||||||
|
# divisible by the block size to avoid partial block matching.
|
||||||
# When prefix caching is enabled, we always allocate
|
|
||||||
# the number of new tokens that is dividable by the block
|
|
||||||
# size to avoid partial block matching.
|
|
||||||
block_size = cache_config.block_size
|
block_size = cache_config.block_size
|
||||||
remainder = budget.token_budget % block_size
|
# Don't exceed either the total budget or slot budget.
|
||||||
if remainder != 0:
|
# Take min of those and get the next lowest multiple of the
|
||||||
raise ValueError("When enabling chunked prefill and "
|
# block size:
|
||||||
"prefix caching, max_num_batched_tokens "
|
remaining_token_budget = (
|
||||||
"(chunk size) must be dividable by "
|
min(remaining_token_budget, prefill_slot_budget) //
|
||||||
"block size, but got chunk_size "
|
block_size) * block_size
|
||||||
f"({budget.token_budget}) % block_size "
|
# NB: In the case where num_new_tokens < budget, we are
|
||||||
f"({block_size}) = {remainder}")
|
# finishing prefill for this sequence, so we do not need to
|
||||||
# Round down to block size.
|
# allocate a full block.
|
||||||
remaining_token_budget = (remaining_token_budget // block_size *
|
|
||||||
block_size)
|
|
||||||
|
|
||||||
num_new_tokens = min(num_new_tokens, remaining_token_budget)
|
num_new_tokens = min(num_new_tokens, remaining_token_budget,
|
||||||
|
prefill_slot_budget)
|
||||||
|
|
||||||
return num_new_tokens
|
return num_new_tokens
|
||||||
|
|||||||
@ -120,6 +120,9 @@ class EngineArgs:
|
|||||||
cpu_offload_gb: float = 0 # GiB
|
cpu_offload_gb: float = 0 # GiB
|
||||||
gpu_memory_utilization: float = 0.90
|
gpu_memory_utilization: float = 0.90
|
||||||
max_num_batched_tokens: Optional[int] = None
|
max_num_batched_tokens: Optional[int] = None
|
||||||
|
max_num_partial_prefills: Optional[int] = 1
|
||||||
|
max_long_partial_prefills: Optional[int] = 1
|
||||||
|
long_prefill_token_threshold: Optional[int] = 0
|
||||||
max_num_seqs: Optional[int] = None
|
max_num_seqs: Optional[int] = None
|
||||||
max_logprobs: int = 20 # Default value for OpenAI Chat Completions API
|
max_logprobs: int = 20 # Default value for OpenAI Chat Completions API
|
||||||
disable_log_stats: bool = False
|
disable_log_stats: bool = False
|
||||||
@ -515,6 +518,31 @@ class EngineArgs:
|
|||||||
default=EngineArgs.max_num_batched_tokens,
|
default=EngineArgs.max_num_batched_tokens,
|
||||||
help='Maximum number of batched tokens per '
|
help='Maximum number of batched tokens per '
|
||||||
'iteration.')
|
'iteration.')
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-num-partial-prefills",
|
||||||
|
type=int,
|
||||||
|
default=EngineArgs.max_num_partial_prefills,
|
||||||
|
help="For chunked prefill, the max number of concurrent \
|
||||||
|
partial prefills."
|
||||||
|
"Defaults to 1",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-long-partial-prefills",
|
||||||
|
type=int,
|
||||||
|
default=EngineArgs.max_long_partial_prefills,
|
||||||
|
help="For chunked prefill, the maximum number of prompts longer "
|
||||||
|
"than --long-prefill-token-threshold that will be prefilled "
|
||||||
|
"concurrently. Setting this less than --max-num-partial-prefills "
|
||||||
|
"will allow shorter prompts to jump the queue in front of longer "
|
||||||
|
"prompts in some cases, improving latency. Defaults to 1.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--long-prefill-token-threshold",
|
||||||
|
type=float,
|
||||||
|
default=EngineArgs.long_prefill_token_threshold,
|
||||||
|
help="For chunked prefill, a request is considered long if the "
|
||||||
|
"prompt is longer than this number of tokens. Defaults to 4%% of "
|
||||||
|
"the model's context length.",
|
||||||
|
)
|
||||||
parser.add_argument('--max-num-seqs',
|
parser.add_argument('--max-num-seqs',
|
||||||
type=int,
|
type=int,
|
||||||
default=EngineArgs.max_num_seqs,
|
default=EngineArgs.max_num_seqs,
|
||||||
@ -1244,7 +1272,11 @@ class EngineArgs:
|
|||||||
multi_step_stream_outputs=self.multi_step_stream_outputs,
|
multi_step_stream_outputs=self.multi_step_stream_outputs,
|
||||||
send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
|
send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
|
||||||
and parallel_config.use_ray),
|
and parallel_config.use_ray),
|
||||||
policy=self.scheduling_policy)
|
policy=self.scheduling_policy,
|
||||||
|
max_num_partial_prefills=self.max_num_partial_prefills,
|
||||||
|
max_long_partial_prefills=self.max_long_partial_prefills,
|
||||||
|
long_prefill_token_threshold=self.long_prefill_token_threshold,
|
||||||
|
)
|
||||||
lora_config = LoRAConfig(
|
lora_config = LoRAConfig(
|
||||||
bias_enabled=self.enable_lora_bias,
|
bias_enabled=self.enable_lora_bias,
|
||||||
max_lora_rank=self.max_lora_rank,
|
max_lora_rank=self.max_lora_rank,
|
||||||
|
|||||||
@ -958,7 +958,9 @@ def get_logprobs(
|
|||||||
if len(query_indices) == 0:
|
if len(query_indices) == 0:
|
||||||
empty_sampled_logprob: SampleLogprobs = []
|
empty_sampled_logprob: SampleLogprobs = []
|
||||||
empty_prompt_logprob: Optional[PromptLogprobs] = None
|
empty_prompt_logprob: Optional[PromptLogprobs] = None
|
||||||
return [empty_prompt_logprob], [empty_sampled_logprob]
|
num_seq_groups = len(sampling_metadata.seq_groups)
|
||||||
|
return [empty_prompt_logprob
|
||||||
|
] * num_seq_groups, [empty_sampled_logprob] * num_seq_groups
|
||||||
|
|
||||||
selected_logprobs, ranks = None, None
|
selected_logprobs, ranks = None, None
|
||||||
top_logprobs, top_token_ids = None, None
|
top_logprobs, top_token_ids = None, None
|
||||||
@ -1225,6 +1227,10 @@ def _build_sampler_output(
|
|||||||
assert sample_logprobs is not None
|
assert sample_logprobs is not None
|
||||||
assert not isinstance(maybe_deferred_sample_results,
|
assert not isinstance(maybe_deferred_sample_results,
|
||||||
SampleResultArgsType)
|
SampleResultArgsType)
|
||||||
|
assert len(sampling_metadata.seq_groups) \
|
||||||
|
== len(maybe_deferred_sample_results) \
|
||||||
|
== len(prompt_logprobs) \
|
||||||
|
== len(sample_logprobs)
|
||||||
deferred_sample_results_args = None
|
deferred_sample_results_args = None
|
||||||
|
|
||||||
for (seq_group, sample_result, group_prompt_logprobs,
|
for (seq_group, sample_result, group_prompt_logprobs,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user