[Scheduler] Warning upon preemption and Swapping (#4647)

Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
This commit is contained in:
SangBin Cho 2024-05-13 23:50:44 +09:00 committed by GitHub
parent 350f9e107f
commit e7c46b9527
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 108 additions and 3 deletions

View File

@ -3,6 +3,25 @@
Performance and Tuning
======================
Preemption
----------
Due to the auto-regressive nature of transformer architecture, there are times when KV cache space is insufficient to handle all batched requests.
The vLLM can preempt requests to free up KV cache space for other requests. Preempted requests are recomputed when sufficient KV cache space becomes
available again. When this occurs, the following warning is printed:
```
WARNING 05-09 00:49:33 scheduler.py:1057] Sequence group 0 is preempted by PreemptionMode.SWAP mode because there is not enough KV cache space. This can affect the end-to-end performance. Increase gpu_memory_utilization or tensor_parallel_size to provide more KV cache memory. total_cumulative_preemption_cnt=1
```
While this mechanism ensures system robustness, preemption and recomputation can adversely affect end-to-end latency.
If you frequently encounter preemptions from the vLLM engine, consider the following actions:
- Increase `gpu_memory_utilization`. The vLLM pre-allocates GPU cache by using gpu_memory_utilization% of memory. By increasing this utilization, you can provide more KV cache space.
- Decrease `max_num_seqs` or `max_num_batched_tokens`. This can reduce the number of concurrent requests in a batch, thereby requiring less KV cache space.
- Increase `tensor_parallel_size`. This approach shards model weights, so each GPU has more memory available for KV cache.
You can also monitor the number of preemption requests through Prometheus metrics exposed by the vLLM. Additionally, you can log the cumulative number of preemption requests by setting disable_log_stats=False.
Chunked Prefill
---------------
vLLM supports an experimental feature chunked prefill. Chunked prefill allows to chunk large prefills into smaller chunks and batch them together with decode requests.

View File

@ -6,6 +6,7 @@ Run `VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1
pytest tests/basic_correctness/test_preemption.py`.
"""
import pytest
from prometheus_client import REGISTRY
from vllm import SamplingParams
from vllm.core.scheduler import (ARTIFICIAL_PREEMPTION_MAX_CNT,
@ -71,6 +72,7 @@ def test_chunked_prefill_recompute(
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [96])
def test_preemption(
caplog_vllm,
hf_runner,
vllm_runner,
example_prompts,
@ -87,10 +89,13 @@ def test_preemption(
vllm_model = vllm_runner(
model,
dtype=dtype,
disable_log_stats=False,
)
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt <
ARTIFICIAL_PREEMPTION_MAX_CNT)
total_preemption = (
vllm_model.model.llm_engine.scheduler.num_cumulative_preemption)
del vllm_model
for i in range(len(example_prompts)):
@ -100,6 +105,20 @@ def test_preemption(
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
assert ("is preempted by PreemptionMode.RECOMPUTE mode because there "
"is not enough KV cache space." in caplog_vllm.text)
# Ensure the count bucket of request-level histogram metrics matches
# the number of requests as a simple sanity check to ensure metrics are
# generated
preemption_metrics = None
for m in REGISTRY.collect():
if m.name == "vllm:num_preemptions":
preemption_metrics = m
assert preemption_metrics is not None
total_recorded_preemption = 0
for sample in preemption_metrics.samples:
total_recorded_preemption += sample.value
assert total_preemption == total_recorded_preemption
@pytest.mark.parametrize("model", MODELS)
@ -107,6 +126,7 @@ def test_preemption(
@pytest.mark.parametrize("max_tokens", [96])
@pytest.mark.parametrize("beam_width", [4])
def test_swap(
caplog_vllm,
hf_runner,
vllm_runner,
example_prompts,
@ -122,11 +142,18 @@ def test_swap(
max_tokens)
del hf_model
vllm_model = vllm_runner(model, dtype=dtype, swap_space=10)
vllm_model = vllm_runner(
model,
dtype=dtype,
swap_space=10,
disable_log_stats=False,
)
vllm_outputs = vllm_model.generate_beam_search(example_prompts, beam_width,
max_tokens)
assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt <
ARTIFICIAL_PREEMPTION_MAX_CNT)
total_preemption = (
vllm_model.model.llm_engine.scheduler.num_cumulative_preemption)
del vllm_model
for i in range(len(example_prompts)):
@ -138,6 +165,21 @@ def test_swap(
f"Test{i} output{j}:\nHF: {hf_output_ids}\n"
f"vLLM: {vllm_output_ids}")
assert ("is preempted by PreemptionMode.SWAP mode because there "
"is not enough KV cache space." in caplog_vllm.text)
# Ensure the count bucket of request-level histogram metrics matches
# the number of requests as a simple sanity check to ensure metrics are
# generated
preemption_metrics = None
for m in REGISTRY.collect():
if m.name == "vllm:num_preemptions":
preemption_metrics = m
assert preemption_metrics is not None
total_recorded_preemption = 0
for sample in preemption_metrics.samples:
total_recorded_preemption += sample.value
assert total_preemption == total_recorded_preemption
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])

View File

@ -499,3 +499,19 @@ def get_tokenizer_pool_config(tokenizer_group_type):
pool_type="ray",
extra_config={})
raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}")
@pytest.fixture()
def temporary_enable_log_propagate():
import logging
logger = logging.getLogger("vllm")
logger.propagate = True
yield
logger.propagate = False
@pytest.fixture()
def caplog_vllm(temporary_enable_log_propagate, caplog):
# To capture vllm log, we should enable propagate=True temporarily
# because caplog depends on logs propagated to the root logger.
yield caplog

View File

@ -180,6 +180,7 @@ def test_scheduler_schedule_preempt_abort():
and not out.blocks_to_swap_out)
assert len(seq_group_meta) == 1
assert scheduler.get_num_unfinished_seq_groups() == 2
assert out.preempted == 1
# Abort seq group a. Re-schedule seq group b prompt with recomputation.
scheduler.abort_seq_group("1")

View File

@ -129,6 +129,7 @@ class SchedulerOutputs:
num_lookahead_slots: int
# The number of requests in the running queue
running_queue_size: int
preempted: int
def __post_init__(self):
# Swap in and swap out should never happen at the same time.
@ -310,6 +311,7 @@ class Scheduler:
self.artificial_preempt_cnt = (ARTIFICIAL_PREEMPTION_MAX_CNT
if self.enable_artificial_preemption
else 0)
self.num_cumulative_preemption: int = 0
@property
def lora_enabled(self) -> bool:
@ -785,6 +787,8 @@ class Scheduler:
# Update swapped requests.
self.swapped = remaining_swapped
self.swapped.extend(running_scheduled.swapped_out)
preempted = (len(running_scheduled.preempted) +
len(running_scheduled.swapped_out))
# There should be no prefill from running queue because this policy
# doesn't allow chunked prefills.
@ -804,6 +808,7 @@ class Scheduler:
swapped_in.infeasible_seq_groups,
num_lookahead_slots=running_scheduled.num_lookahead_slots,
running_queue_size=len(self.running),
preempted=preempted,
)
def _schedule_chunked_prefill(self):
@ -891,6 +896,8 @@ class Scheduler:
ignored_seq_groups=prefills.ignored_seq_groups,
num_lookahead_slots=running_scheduled.num_lookahead_slots,
running_queue_size=len(self.running),
preempted=(len(running_scheduled.preempted) +
len(running_scheduled.swapped_out)),
)
def _schedule(self) -> SchedulerOutputs:
@ -1057,6 +1064,17 @@ class Scheduler:
preemption_mode = PreemptionMode.RECOMPUTE
else:
preemption_mode = PreemptionMode.SWAP
if self.num_cumulative_preemption % 50 == 0:
logger.warning(
"Sequence group %s is preempted by %s mode because there is "
"not enough KV cache space. This can affect the end-to-end "
"performance. Increase gpu_memory_utilization or "
"tensor_parallel_size to provide more KV cache memory. "
"total_num_cumulative_preemption=%d", seq_group.request_id,
preemption_mode, self.num_cumulative_preemption + 1)
self.num_cumulative_preemption += 1
if preemption_mode == PreemptionMode.RECOMPUTE:
self._preempt_by_recompute(seq_group)
elif preemption_mode == PreemptionMode.SWAP:

View File

@ -737,6 +737,8 @@ class LLMEngine:
num_generation_tokens_iter = 0
time_to_first_tokens_iter: List[float] = []
time_per_output_tokens_iter: List[float] = []
num_preemption_iter = (0 if scheduler_outputs is None else
scheduler_outputs.preempted)
# Request stats
# Latency
@ -830,7 +832,6 @@ class LLMEngine:
return Stats(
now=now,
# System stats
# Scheduler State
num_running_sys=num_running_sys,
@ -846,6 +847,7 @@ class LLMEngine:
time_to_first_tokens_iter=time_to_first_tokens_iter,
time_per_output_tokens_iter=time_per_output_tokens_iter,
spec_decode_metrics=spec_decode_metrics,
num_preemption_iter=num_preemption_iter,
# Request stats
# Latency

View File

@ -61,6 +61,10 @@ class Metrics:
labelnames=labelnames)
# Iteration stats
self.counter_num_preemption = Counter(
name="vllm:num_preemptions_total",
documentation="Cumulative number of preemption from the engine.",
labelnames=labelnames)
self.counter_prompt_tokens = Counter(
name="vllm:prompt_tokens_total",
documentation="Number of prefill tokens processed.",
@ -181,6 +185,7 @@ class Stats:
num_generation_tokens_iter: int
time_to_first_tokens_iter: List[float]
time_per_output_tokens_iter: List[float]
num_preemption_iter: int
# Request stats (should have _requests suffix)
# Latency
@ -244,6 +249,8 @@ class StatLogger:
stats.cpu_cache_usage_sys)
# Iteration level data
self._log_counter(self.metrics.counter_num_preemption,
stats.num_preemption_iter)
self._log_counter(self.metrics.counter_prompt_tokens,
stats.num_prompt_tokens_iter)
self._log_counter(self.metrics.counter_generation_tokens,
@ -336,7 +343,7 @@ class StatLogger:
"Avg generation throughput: %.1f tokens/s, "
"Running: %d reqs, Swapped: %d reqs, "
"Pending: %d reqs, GPU KV cache usage: %.1f%%, "
"CPU KV cache usage: %.1f%%",
"CPU KV cache usage: %.1f%%.",
prompt_throughput,
generation_throughput,
stats.num_running_sys,