mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 06:35:00 +08:00
[Scheduler] Warning upon preemption and Swapping (#4647)
Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
This commit is contained in:
parent
350f9e107f
commit
e7c46b9527
@ -3,6 +3,25 @@
|
|||||||
Performance and Tuning
|
Performance and Tuning
|
||||||
======================
|
======================
|
||||||
|
|
||||||
|
Preemption
|
||||||
|
----------
|
||||||
|
Due to the auto-regressive nature of transformer architecture, there are times when KV cache space is insufficient to handle all batched requests.
|
||||||
|
The vLLM can preempt requests to free up KV cache space for other requests. Preempted requests are recomputed when sufficient KV cache space becomes
|
||||||
|
available again. When this occurs, the following warning is printed:
|
||||||
|
|
||||||
|
```
|
||||||
|
WARNING 05-09 00:49:33 scheduler.py:1057] Sequence group 0 is preempted by PreemptionMode.SWAP mode because there is not enough KV cache space. This can affect the end-to-end performance. Increase gpu_memory_utilization or tensor_parallel_size to provide more KV cache memory. total_cumulative_preemption_cnt=1
|
||||||
|
```
|
||||||
|
|
||||||
|
While this mechanism ensures system robustness, preemption and recomputation can adversely affect end-to-end latency.
|
||||||
|
If you frequently encounter preemptions from the vLLM engine, consider the following actions:
|
||||||
|
|
||||||
|
- Increase `gpu_memory_utilization`. The vLLM pre-allocates GPU cache by using gpu_memory_utilization% of memory. By increasing this utilization, you can provide more KV cache space.
|
||||||
|
- Decrease `max_num_seqs` or `max_num_batched_tokens`. This can reduce the number of concurrent requests in a batch, thereby requiring less KV cache space.
|
||||||
|
- Increase `tensor_parallel_size`. This approach shards model weights, so each GPU has more memory available for KV cache.
|
||||||
|
|
||||||
|
You can also monitor the number of preemption requests through Prometheus metrics exposed by the vLLM. Additionally, you can log the cumulative number of preemption requests by setting disable_log_stats=False.
|
||||||
|
|
||||||
Chunked Prefill
|
Chunked Prefill
|
||||||
---------------
|
---------------
|
||||||
vLLM supports an experimental feature chunked prefill. Chunked prefill allows to chunk large prefills into smaller chunks and batch them together with decode requests.
|
vLLM supports an experimental feature chunked prefill. Chunked prefill allows to chunk large prefills into smaller chunks and batch them together with decode requests.
|
||||||
|
|||||||
@ -6,6 +6,7 @@ Run `VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1
|
|||||||
pytest tests/basic_correctness/test_preemption.py`.
|
pytest tests/basic_correctness/test_preemption.py`.
|
||||||
"""
|
"""
|
||||||
import pytest
|
import pytest
|
||||||
|
from prometheus_client import REGISTRY
|
||||||
|
|
||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
from vllm.core.scheduler import (ARTIFICIAL_PREEMPTION_MAX_CNT,
|
from vllm.core.scheduler import (ARTIFICIAL_PREEMPTION_MAX_CNT,
|
||||||
@ -71,6 +72,7 @@ def test_chunked_prefill_recompute(
|
|||||||
@pytest.mark.parametrize("dtype", ["float"])
|
@pytest.mark.parametrize("dtype", ["float"])
|
||||||
@pytest.mark.parametrize("max_tokens", [96])
|
@pytest.mark.parametrize("max_tokens", [96])
|
||||||
def test_preemption(
|
def test_preemption(
|
||||||
|
caplog_vllm,
|
||||||
hf_runner,
|
hf_runner,
|
||||||
vllm_runner,
|
vllm_runner,
|
||||||
example_prompts,
|
example_prompts,
|
||||||
@ -87,10 +89,13 @@ def test_preemption(
|
|||||||
vllm_model = vllm_runner(
|
vllm_model = vllm_runner(
|
||||||
model,
|
model,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
disable_log_stats=False,
|
||||||
)
|
)
|
||||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||||
assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt <
|
assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt <
|
||||||
ARTIFICIAL_PREEMPTION_MAX_CNT)
|
ARTIFICIAL_PREEMPTION_MAX_CNT)
|
||||||
|
total_preemption = (
|
||||||
|
vllm_model.model.llm_engine.scheduler.num_cumulative_preemption)
|
||||||
del vllm_model
|
del vllm_model
|
||||||
|
|
||||||
for i in range(len(example_prompts)):
|
for i in range(len(example_prompts)):
|
||||||
@ -100,6 +105,20 @@ def test_preemption(
|
|||||||
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
|
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
|
||||||
assert hf_output_ids == vllm_output_ids, (
|
assert hf_output_ids == vllm_output_ids, (
|
||||||
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
|
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
|
||||||
|
assert ("is preempted by PreemptionMode.RECOMPUTE mode because there "
|
||||||
|
"is not enough KV cache space." in caplog_vllm.text)
|
||||||
|
# Ensure the count bucket of request-level histogram metrics matches
|
||||||
|
# the number of requests as a simple sanity check to ensure metrics are
|
||||||
|
# generated
|
||||||
|
preemption_metrics = None
|
||||||
|
for m in REGISTRY.collect():
|
||||||
|
if m.name == "vllm:num_preemptions":
|
||||||
|
preemption_metrics = m
|
||||||
|
assert preemption_metrics is not None
|
||||||
|
total_recorded_preemption = 0
|
||||||
|
for sample in preemption_metrics.samples:
|
||||||
|
total_recorded_preemption += sample.value
|
||||||
|
assert total_preemption == total_recorded_preemption
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
@ -107,6 +126,7 @@ def test_preemption(
|
|||||||
@pytest.mark.parametrize("max_tokens", [96])
|
@pytest.mark.parametrize("max_tokens", [96])
|
||||||
@pytest.mark.parametrize("beam_width", [4])
|
@pytest.mark.parametrize("beam_width", [4])
|
||||||
def test_swap(
|
def test_swap(
|
||||||
|
caplog_vllm,
|
||||||
hf_runner,
|
hf_runner,
|
||||||
vllm_runner,
|
vllm_runner,
|
||||||
example_prompts,
|
example_prompts,
|
||||||
@ -122,11 +142,18 @@ def test_swap(
|
|||||||
max_tokens)
|
max_tokens)
|
||||||
del hf_model
|
del hf_model
|
||||||
|
|
||||||
vllm_model = vllm_runner(model, dtype=dtype, swap_space=10)
|
vllm_model = vllm_runner(
|
||||||
|
model,
|
||||||
|
dtype=dtype,
|
||||||
|
swap_space=10,
|
||||||
|
disable_log_stats=False,
|
||||||
|
)
|
||||||
vllm_outputs = vllm_model.generate_beam_search(example_prompts, beam_width,
|
vllm_outputs = vllm_model.generate_beam_search(example_prompts, beam_width,
|
||||||
max_tokens)
|
max_tokens)
|
||||||
assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt <
|
assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt <
|
||||||
ARTIFICIAL_PREEMPTION_MAX_CNT)
|
ARTIFICIAL_PREEMPTION_MAX_CNT)
|
||||||
|
total_preemption = (
|
||||||
|
vllm_model.model.llm_engine.scheduler.num_cumulative_preemption)
|
||||||
del vllm_model
|
del vllm_model
|
||||||
|
|
||||||
for i in range(len(example_prompts)):
|
for i in range(len(example_prompts)):
|
||||||
@ -138,6 +165,21 @@ def test_swap(
|
|||||||
f"Test{i} output{j}:\nHF: {hf_output_ids}\n"
|
f"Test{i} output{j}:\nHF: {hf_output_ids}\n"
|
||||||
f"vLLM: {vllm_output_ids}")
|
f"vLLM: {vllm_output_ids}")
|
||||||
|
|
||||||
|
assert ("is preempted by PreemptionMode.SWAP mode because there "
|
||||||
|
"is not enough KV cache space." in caplog_vllm.text)
|
||||||
|
# Ensure the count bucket of request-level histogram metrics matches
|
||||||
|
# the number of requests as a simple sanity check to ensure metrics are
|
||||||
|
# generated
|
||||||
|
preemption_metrics = None
|
||||||
|
for m in REGISTRY.collect():
|
||||||
|
if m.name == "vllm:num_preemptions":
|
||||||
|
preemption_metrics = m
|
||||||
|
assert preemption_metrics is not None
|
||||||
|
total_recorded_preemption = 0
|
||||||
|
for sample in preemption_metrics.samples:
|
||||||
|
total_recorded_preemption += sample.value
|
||||||
|
assert total_preemption == total_recorded_preemption
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
@pytest.mark.parametrize("dtype", ["float"])
|
@pytest.mark.parametrize("dtype", ["float"])
|
||||||
|
|||||||
@ -499,3 +499,19 @@ def get_tokenizer_pool_config(tokenizer_group_type):
|
|||||||
pool_type="ray",
|
pool_type="ray",
|
||||||
extra_config={})
|
extra_config={})
|
||||||
raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}")
|
raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def temporary_enable_log_propagate():
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger("vllm")
|
||||||
|
logger.propagate = True
|
||||||
|
yield
|
||||||
|
logger.propagate = False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def caplog_vllm(temporary_enable_log_propagate, caplog):
|
||||||
|
# To capture vllm log, we should enable propagate=True temporarily
|
||||||
|
# because caplog depends on logs propagated to the root logger.
|
||||||
|
yield caplog
|
||||||
|
|||||||
@ -180,6 +180,7 @@ def test_scheduler_schedule_preempt_abort():
|
|||||||
and not out.blocks_to_swap_out)
|
and not out.blocks_to_swap_out)
|
||||||
assert len(seq_group_meta) == 1
|
assert len(seq_group_meta) == 1
|
||||||
assert scheduler.get_num_unfinished_seq_groups() == 2
|
assert scheduler.get_num_unfinished_seq_groups() == 2
|
||||||
|
assert out.preempted == 1
|
||||||
|
|
||||||
# Abort seq group a. Re-schedule seq group b prompt with recomputation.
|
# Abort seq group a. Re-schedule seq group b prompt with recomputation.
|
||||||
scheduler.abort_seq_group("1")
|
scheduler.abort_seq_group("1")
|
||||||
|
|||||||
@ -129,6 +129,7 @@ class SchedulerOutputs:
|
|||||||
num_lookahead_slots: int
|
num_lookahead_slots: int
|
||||||
# The number of requests in the running queue
|
# The number of requests in the running queue
|
||||||
running_queue_size: int
|
running_queue_size: int
|
||||||
|
preempted: int
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# Swap in and swap out should never happen at the same time.
|
# Swap in and swap out should never happen at the same time.
|
||||||
@ -310,6 +311,7 @@ class Scheduler:
|
|||||||
self.artificial_preempt_cnt = (ARTIFICIAL_PREEMPTION_MAX_CNT
|
self.artificial_preempt_cnt = (ARTIFICIAL_PREEMPTION_MAX_CNT
|
||||||
if self.enable_artificial_preemption
|
if self.enable_artificial_preemption
|
||||||
else 0)
|
else 0)
|
||||||
|
self.num_cumulative_preemption: int = 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def lora_enabled(self) -> bool:
|
def lora_enabled(self) -> bool:
|
||||||
@ -785,6 +787,8 @@ class Scheduler:
|
|||||||
# Update swapped requests.
|
# Update swapped requests.
|
||||||
self.swapped = remaining_swapped
|
self.swapped = remaining_swapped
|
||||||
self.swapped.extend(running_scheduled.swapped_out)
|
self.swapped.extend(running_scheduled.swapped_out)
|
||||||
|
preempted = (len(running_scheduled.preempted) +
|
||||||
|
len(running_scheduled.swapped_out))
|
||||||
|
|
||||||
# There should be no prefill from running queue because this policy
|
# There should be no prefill from running queue because this policy
|
||||||
# doesn't allow chunked prefills.
|
# doesn't allow chunked prefills.
|
||||||
@ -804,6 +808,7 @@ class Scheduler:
|
|||||||
swapped_in.infeasible_seq_groups,
|
swapped_in.infeasible_seq_groups,
|
||||||
num_lookahead_slots=running_scheduled.num_lookahead_slots,
|
num_lookahead_slots=running_scheduled.num_lookahead_slots,
|
||||||
running_queue_size=len(self.running),
|
running_queue_size=len(self.running),
|
||||||
|
preempted=preempted,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _schedule_chunked_prefill(self):
|
def _schedule_chunked_prefill(self):
|
||||||
@ -891,6 +896,8 @@ class Scheduler:
|
|||||||
ignored_seq_groups=prefills.ignored_seq_groups,
|
ignored_seq_groups=prefills.ignored_seq_groups,
|
||||||
num_lookahead_slots=running_scheduled.num_lookahead_slots,
|
num_lookahead_slots=running_scheduled.num_lookahead_slots,
|
||||||
running_queue_size=len(self.running),
|
running_queue_size=len(self.running),
|
||||||
|
preempted=(len(running_scheduled.preempted) +
|
||||||
|
len(running_scheduled.swapped_out)),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _schedule(self) -> SchedulerOutputs:
|
def _schedule(self) -> SchedulerOutputs:
|
||||||
@ -1057,6 +1064,17 @@ class Scheduler:
|
|||||||
preemption_mode = PreemptionMode.RECOMPUTE
|
preemption_mode = PreemptionMode.RECOMPUTE
|
||||||
else:
|
else:
|
||||||
preemption_mode = PreemptionMode.SWAP
|
preemption_mode = PreemptionMode.SWAP
|
||||||
|
|
||||||
|
if self.num_cumulative_preemption % 50 == 0:
|
||||||
|
logger.warning(
|
||||||
|
"Sequence group %s is preempted by %s mode because there is "
|
||||||
|
"not enough KV cache space. This can affect the end-to-end "
|
||||||
|
"performance. Increase gpu_memory_utilization or "
|
||||||
|
"tensor_parallel_size to provide more KV cache memory. "
|
||||||
|
"total_num_cumulative_preemption=%d", seq_group.request_id,
|
||||||
|
preemption_mode, self.num_cumulative_preemption + 1)
|
||||||
|
self.num_cumulative_preemption += 1
|
||||||
|
|
||||||
if preemption_mode == PreemptionMode.RECOMPUTE:
|
if preemption_mode == PreemptionMode.RECOMPUTE:
|
||||||
self._preempt_by_recompute(seq_group)
|
self._preempt_by_recompute(seq_group)
|
||||||
elif preemption_mode == PreemptionMode.SWAP:
|
elif preemption_mode == PreemptionMode.SWAP:
|
||||||
|
|||||||
@ -737,6 +737,8 @@ class LLMEngine:
|
|||||||
num_generation_tokens_iter = 0
|
num_generation_tokens_iter = 0
|
||||||
time_to_first_tokens_iter: List[float] = []
|
time_to_first_tokens_iter: List[float] = []
|
||||||
time_per_output_tokens_iter: List[float] = []
|
time_per_output_tokens_iter: List[float] = []
|
||||||
|
num_preemption_iter = (0 if scheduler_outputs is None else
|
||||||
|
scheduler_outputs.preempted)
|
||||||
|
|
||||||
# Request stats
|
# Request stats
|
||||||
# Latency
|
# Latency
|
||||||
@ -830,7 +832,6 @@ class LLMEngine:
|
|||||||
|
|
||||||
return Stats(
|
return Stats(
|
||||||
now=now,
|
now=now,
|
||||||
|
|
||||||
# System stats
|
# System stats
|
||||||
# Scheduler State
|
# Scheduler State
|
||||||
num_running_sys=num_running_sys,
|
num_running_sys=num_running_sys,
|
||||||
@ -846,6 +847,7 @@ class LLMEngine:
|
|||||||
time_to_first_tokens_iter=time_to_first_tokens_iter,
|
time_to_first_tokens_iter=time_to_first_tokens_iter,
|
||||||
time_per_output_tokens_iter=time_per_output_tokens_iter,
|
time_per_output_tokens_iter=time_per_output_tokens_iter,
|
||||||
spec_decode_metrics=spec_decode_metrics,
|
spec_decode_metrics=spec_decode_metrics,
|
||||||
|
num_preemption_iter=num_preemption_iter,
|
||||||
|
|
||||||
# Request stats
|
# Request stats
|
||||||
# Latency
|
# Latency
|
||||||
|
|||||||
@ -61,6 +61,10 @@ class Metrics:
|
|||||||
labelnames=labelnames)
|
labelnames=labelnames)
|
||||||
|
|
||||||
# Iteration stats
|
# Iteration stats
|
||||||
|
self.counter_num_preemption = Counter(
|
||||||
|
name="vllm:num_preemptions_total",
|
||||||
|
documentation="Cumulative number of preemption from the engine.",
|
||||||
|
labelnames=labelnames)
|
||||||
self.counter_prompt_tokens = Counter(
|
self.counter_prompt_tokens = Counter(
|
||||||
name="vllm:prompt_tokens_total",
|
name="vllm:prompt_tokens_total",
|
||||||
documentation="Number of prefill tokens processed.",
|
documentation="Number of prefill tokens processed.",
|
||||||
@ -181,6 +185,7 @@ class Stats:
|
|||||||
num_generation_tokens_iter: int
|
num_generation_tokens_iter: int
|
||||||
time_to_first_tokens_iter: List[float]
|
time_to_first_tokens_iter: List[float]
|
||||||
time_per_output_tokens_iter: List[float]
|
time_per_output_tokens_iter: List[float]
|
||||||
|
num_preemption_iter: int
|
||||||
|
|
||||||
# Request stats (should have _requests suffix)
|
# Request stats (should have _requests suffix)
|
||||||
# Latency
|
# Latency
|
||||||
@ -244,6 +249,8 @@ class StatLogger:
|
|||||||
stats.cpu_cache_usage_sys)
|
stats.cpu_cache_usage_sys)
|
||||||
|
|
||||||
# Iteration level data
|
# Iteration level data
|
||||||
|
self._log_counter(self.metrics.counter_num_preemption,
|
||||||
|
stats.num_preemption_iter)
|
||||||
self._log_counter(self.metrics.counter_prompt_tokens,
|
self._log_counter(self.metrics.counter_prompt_tokens,
|
||||||
stats.num_prompt_tokens_iter)
|
stats.num_prompt_tokens_iter)
|
||||||
self._log_counter(self.metrics.counter_generation_tokens,
|
self._log_counter(self.metrics.counter_generation_tokens,
|
||||||
@ -336,7 +343,7 @@ class StatLogger:
|
|||||||
"Avg generation throughput: %.1f tokens/s, "
|
"Avg generation throughput: %.1f tokens/s, "
|
||||||
"Running: %d reqs, Swapped: %d reqs, "
|
"Running: %d reqs, Swapped: %d reqs, "
|
||||||
"Pending: %d reqs, GPU KV cache usage: %.1f%%, "
|
"Pending: %d reqs, GPU KV cache usage: %.1f%%, "
|
||||||
"CPU KV cache usage: %.1f%%",
|
"CPU KV cache usage: %.1f%%.",
|
||||||
prompt_throughput,
|
prompt_throughput,
|
||||||
generation_throughput,
|
generation_throughput,
|
||||||
stats.num_running_sys,
|
stats.num_running_sys,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user