diff --git a/docs/source/models/performance.rst b/docs/source/models/performance.rst index 589fce21056c..d8750ddc34e8 100644 --- a/docs/source/models/performance.rst +++ b/docs/source/models/performance.rst @@ -3,6 +3,25 @@ Performance and Tuning ====================== +Preemption +---------- +Due to the auto-regressive nature of transformer architecture, there are times when KV cache space is insufficient to handle all batched requests. +The vLLM can preempt requests to free up KV cache space for other requests. Preempted requests are recomputed when sufficient KV cache space becomes +available again. When this occurs, the following warning is printed: + +``` +WARNING 05-09 00:49:33 scheduler.py:1057] Sequence group 0 is preempted by PreemptionMode.SWAP mode because there is not enough KV cache space. This can affect the end-to-end performance. Increase gpu_memory_utilization or tensor_parallel_size to provide more KV cache memory. total_cumulative_preemption_cnt=1 +``` + +While this mechanism ensures system robustness, preemption and recomputation can adversely affect end-to-end latency. +If you frequently encounter preemptions from the vLLM engine, consider the following actions: + +- Increase `gpu_memory_utilization`. The vLLM pre-allocates GPU cache by using gpu_memory_utilization% of memory. By increasing this utilization, you can provide more KV cache space. +- Decrease `max_num_seqs` or `max_num_batched_tokens`. This can reduce the number of concurrent requests in a batch, thereby requiring less KV cache space. +- Increase `tensor_parallel_size`. This approach shards model weights, so each GPU has more memory available for KV cache. + +You can also monitor the number of preemption requests through Prometheus metrics exposed by the vLLM. Additionally, you can log the cumulative number of preemption requests by setting disable_log_stats=False. + Chunked Prefill --------------- vLLM supports an experimental feature chunked prefill. Chunked prefill allows to chunk large prefills into smaller chunks and batch them together with decode requests. diff --git a/tests/basic_correctness/test_preemption.py b/tests/basic_correctness/test_preemption.py index ffb0717b3bfd..29a4c39cd25a 100644 --- a/tests/basic_correctness/test_preemption.py +++ b/tests/basic_correctness/test_preemption.py @@ -6,6 +6,7 @@ Run `VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest tests/basic_correctness/test_preemption.py`. """ import pytest +from prometheus_client import REGISTRY from vllm import SamplingParams from vllm.core.scheduler import (ARTIFICIAL_PREEMPTION_MAX_CNT, @@ -71,6 +72,7 @@ def test_chunked_prefill_recompute( @pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("max_tokens", [96]) def test_preemption( + caplog_vllm, hf_runner, vllm_runner, example_prompts, @@ -87,10 +89,13 @@ def test_preemption( vllm_model = vllm_runner( model, dtype=dtype, + disable_log_stats=False, ) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt < ARTIFICIAL_PREEMPTION_MAX_CNT) + total_preemption = ( + vllm_model.model.llm_engine.scheduler.num_cumulative_preemption) del vllm_model for i in range(len(example_prompts)): @@ -100,6 +105,20 @@ def test_preemption( f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") assert hf_output_ids == vllm_output_ids, ( f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + assert ("is preempted by PreemptionMode.RECOMPUTE mode because there " + "is not enough KV cache space." in caplog_vllm.text) + # Ensure the count bucket of request-level histogram metrics matches + # the number of requests as a simple sanity check to ensure metrics are + # generated + preemption_metrics = None + for m in REGISTRY.collect(): + if m.name == "vllm:num_preemptions": + preemption_metrics = m + assert preemption_metrics is not None + total_recorded_preemption = 0 + for sample in preemption_metrics.samples: + total_recorded_preemption += sample.value + assert total_preemption == total_recorded_preemption @pytest.mark.parametrize("model", MODELS) @@ -107,6 +126,7 @@ def test_preemption( @pytest.mark.parametrize("max_tokens", [96]) @pytest.mark.parametrize("beam_width", [4]) def test_swap( + caplog_vllm, hf_runner, vllm_runner, example_prompts, @@ -122,11 +142,18 @@ def test_swap( max_tokens) del hf_model - vllm_model = vllm_runner(model, dtype=dtype, swap_space=10) + vllm_model = vllm_runner( + model, + dtype=dtype, + swap_space=10, + disable_log_stats=False, + ) vllm_outputs = vllm_model.generate_beam_search(example_prompts, beam_width, max_tokens) assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt < ARTIFICIAL_PREEMPTION_MAX_CNT) + total_preemption = ( + vllm_model.model.llm_engine.scheduler.num_cumulative_preemption) del vllm_model for i in range(len(example_prompts)): @@ -138,6 +165,21 @@ def test_swap( f"Test{i} output{j}:\nHF: {hf_output_ids}\n" f"vLLM: {vllm_output_ids}") + assert ("is preempted by PreemptionMode.SWAP mode because there " + "is not enough KV cache space." in caplog_vllm.text) + # Ensure the count bucket of request-level histogram metrics matches + # the number of requests as a simple sanity check to ensure metrics are + # generated + preemption_metrics = None + for m in REGISTRY.collect(): + if m.name == "vllm:num_preemptions": + preemption_metrics = m + assert preemption_metrics is not None + total_recorded_preemption = 0 + for sample in preemption_metrics.samples: + total_recorded_preemption += sample.value + assert total_preemption == total_recorded_preemption + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float"]) diff --git a/tests/conftest.py b/tests/conftest.py index b8117a19c75d..999ace2c3c69 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -499,3 +499,19 @@ def get_tokenizer_pool_config(tokenizer_group_type): pool_type="ray", extra_config={}) raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}") + + +@pytest.fixture() +def temporary_enable_log_propagate(): + import logging + logger = logging.getLogger("vllm") + logger.propagate = True + yield + logger.propagate = False + + +@pytest.fixture() +def caplog_vllm(temporary_enable_log_propagate, caplog): + # To capture vllm log, we should enable propagate=True temporarily + # because caplog depends on logs propagated to the root logger. + yield caplog diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index 6bcabc4f95fa..07fc8731e184 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -180,6 +180,7 @@ def test_scheduler_schedule_preempt_abort(): and not out.blocks_to_swap_out) assert len(seq_group_meta) == 1 assert scheduler.get_num_unfinished_seq_groups() == 2 + assert out.preempted == 1 # Abort seq group a. Re-schedule seq group b prompt with recomputation. scheduler.abort_seq_group("1") diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index fb6e985b2f31..fbde27f99823 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -129,6 +129,7 @@ class SchedulerOutputs: num_lookahead_slots: int # The number of requests in the running queue running_queue_size: int + preempted: int def __post_init__(self): # Swap in and swap out should never happen at the same time. @@ -310,6 +311,7 @@ class Scheduler: self.artificial_preempt_cnt = (ARTIFICIAL_PREEMPTION_MAX_CNT if self.enable_artificial_preemption else 0) + self.num_cumulative_preemption: int = 0 @property def lora_enabled(self) -> bool: @@ -785,6 +787,8 @@ class Scheduler: # Update swapped requests. self.swapped = remaining_swapped self.swapped.extend(running_scheduled.swapped_out) + preempted = (len(running_scheduled.preempted) + + len(running_scheduled.swapped_out)) # There should be no prefill from running queue because this policy # doesn't allow chunked prefills. @@ -804,6 +808,7 @@ class Scheduler: swapped_in.infeasible_seq_groups, num_lookahead_slots=running_scheduled.num_lookahead_slots, running_queue_size=len(self.running), + preempted=preempted, ) def _schedule_chunked_prefill(self): @@ -891,6 +896,8 @@ class Scheduler: ignored_seq_groups=prefills.ignored_seq_groups, num_lookahead_slots=running_scheduled.num_lookahead_slots, running_queue_size=len(self.running), + preempted=(len(running_scheduled.preempted) + + len(running_scheduled.swapped_out)), ) def _schedule(self) -> SchedulerOutputs: @@ -1057,6 +1064,17 @@ class Scheduler: preemption_mode = PreemptionMode.RECOMPUTE else: preemption_mode = PreemptionMode.SWAP + + if self.num_cumulative_preemption % 50 == 0: + logger.warning( + "Sequence group %s is preempted by %s mode because there is " + "not enough KV cache space. This can affect the end-to-end " + "performance. Increase gpu_memory_utilization or " + "tensor_parallel_size to provide more KV cache memory. " + "total_num_cumulative_preemption=%d", seq_group.request_id, + preemption_mode, self.num_cumulative_preemption + 1) + self.num_cumulative_preemption += 1 + if preemption_mode == PreemptionMode.RECOMPUTE: self._preempt_by_recompute(seq_group) elif preemption_mode == PreemptionMode.SWAP: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 46fa41030b4a..e258a3f4afd5 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -737,6 +737,8 @@ class LLMEngine: num_generation_tokens_iter = 0 time_to_first_tokens_iter: List[float] = [] time_per_output_tokens_iter: List[float] = [] + num_preemption_iter = (0 if scheduler_outputs is None else + scheduler_outputs.preempted) # Request stats # Latency @@ -830,7 +832,6 @@ class LLMEngine: return Stats( now=now, - # System stats # Scheduler State num_running_sys=num_running_sys, @@ -846,6 +847,7 @@ class LLMEngine: time_to_first_tokens_iter=time_to_first_tokens_iter, time_per_output_tokens_iter=time_per_output_tokens_iter, spec_decode_metrics=spec_decode_metrics, + num_preemption_iter=num_preemption_iter, # Request stats # Latency diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index 3c4aac91549a..ae7ae144bc04 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -61,6 +61,10 @@ class Metrics: labelnames=labelnames) # Iteration stats + self.counter_num_preemption = Counter( + name="vllm:num_preemptions_total", + documentation="Cumulative number of preemption from the engine.", + labelnames=labelnames) self.counter_prompt_tokens = Counter( name="vllm:prompt_tokens_total", documentation="Number of prefill tokens processed.", @@ -181,6 +185,7 @@ class Stats: num_generation_tokens_iter: int time_to_first_tokens_iter: List[float] time_per_output_tokens_iter: List[float] + num_preemption_iter: int # Request stats (should have _requests suffix) # Latency @@ -244,6 +249,8 @@ class StatLogger: stats.cpu_cache_usage_sys) # Iteration level data + self._log_counter(self.metrics.counter_num_preemption, + stats.num_preemption_iter) self._log_counter(self.metrics.counter_prompt_tokens, stats.num_prompt_tokens_iter) self._log_counter(self.metrics.counter_generation_tokens, @@ -336,7 +343,7 @@ class StatLogger: "Avg generation throughput: %.1f tokens/s, " "Running: %d reqs, Swapped: %d reqs, " "Pending: %d reqs, GPU KV cache usage: %.1f%%, " - "CPU KV cache usage: %.1f%%", + "CPU KV cache usage: %.1f%%.", prompt_throughput, generation_throughput, stats.num_running_sys,