diff --git a/tests/config/test_pinned_prefix_config.py b/tests/config/test_pinned_prefix_config.py new file mode 100644 index 0000000000000..f3cc685511815 --- /dev/null +++ b/tests/config/test_pinned_prefix_config.py @@ -0,0 +1,18 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + +from vllm.config import CacheConfig + + +def test_invalid_cap_ratio_over_one(): + # pinned_prefix_cap_ratio > 1.0 should raise ValueError + with pytest.raises(ValueError): + _ = CacheConfig(pinned_prefix_cap_ratio=1.5) + + +def test_negative_cap_ratio_raises(): + # negative value should raise because ratio must be within [0, 1] + with pytest.raises(ValueError): + _ = CacheConfig(pinned_prefix_cap_ratio=-0.1) diff --git a/tests/v1/core/test_pinned_prefix.py b/tests/v1/core/test_pinned_prefix.py new file mode 100644 index 0000000000000..667288e37a7db --- /dev/null +++ b/tests/v1/core/test_pinned_prefix.py @@ -0,0 +1,404 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Pinned prefix caching: concurrency, cap, and status behaviors.""" + +import pytest # noqa: F401 + +from vllm.sampling_params import SamplingParams +from vllm.utils import sha256_cbor +from vllm.v1.core.block_pool import BlockPool +from vllm.v1.core.kv_cache_manager import KVCacheManager +from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash +from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + KVCacheConfig, + KVCacheGroupSpec, + SlidingWindowSpec, +) +from vllm.v1.request import Request + + +def _make_manager( + blocks: int, + block_size: int, + cap_ratio: float = 0.2, + enable_pin: bool = True, + num_groups: int = 1, +) -> KVCacheManager: + import torch + + groups = [] + for gi in range(num_groups): + groups.append( + KVCacheGroupSpec( + layer_names=[f"layer_{gi}"], + kv_cache_spec=FullAttentionSpec( + block_size=block_size, + num_kv_heads=2, + head_size=8, + dtype=torch.float16, + ), + ) + ) + cfg = KVCacheConfig(kv_cache_groups=groups, kv_cache_tensors=[], num_blocks=blocks) + init_none_hash(sha256_cbor) + return KVCacheManager( + cfg, + max_model_len=1024, + enable_caching=True, + pinned_prefix_cap_ratio=cap_ratio, + enable_pinned_prefix=enable_pin, + ) + + +def _make_request( + req_id: str, tokens: list[int], block_size: int, pin: bool +) -> Request: + sp = SamplingParams(max_tokens=4, pin_prefix=pin) + hasher = get_request_block_hasher(block_size, sha256_cbor) + return Request(req_id, tokens, sp, None, None, block_hasher=hasher) + + +def test_multi_group_prefix_pinning_respects_global_cap(): + """Multi-group pinning must not exceed global budget. + + Create 2 groups, requested logical depth=3, but cap only allows ~4 pins. + Round-robin should pin depth 0 and 1 across both groups (total 4), and + report logical pinned depth=2 (partial). + """ + import torch + + block_size = 4 + # Build a hybrid config: one full-attn group + one sliding-window group + groups = [ + KVCacheGroupSpec( + layer_names=["layer_fa"], + kv_cache_spec=FullAttentionSpec( + block_size=block_size, num_kv_heads=2, head_size=8, dtype=torch.float16 + ), + ), + KVCacheGroupSpec( + layer_names=["layer_sw"], + kv_cache_spec=SlidingWindowSpec( + block_size=block_size, + num_kv_heads=2, + head_size=8, + dtype=torch.float16, + sliding_window=8, + ), + ), + ] + cfg = KVCacheConfig(kv_cache_groups=groups, kv_cache_tensors=[], num_blocks=20) + init_none_hash(sha256_cbor) + kv = KVCacheManager( + cfg, + max_model_len=1024, + enable_caching=True, + pinned_prefix_cap_ratio=0.2, + enable_pinned_prefix=True, + ) + req = _make_request("mg", list(range(20)), block_size, pin=True) + + kv.allocate_slots(req, num_new_tokens=len(req.all_token_ids)) + num_computed = len(req.all_token_ids) - 1 # exclude last token for logits + result = kv.cache_blocks(req, num_computed_tokens=num_computed) + + cap_limit = int(kv.block_pool.num_gpu_blocks * kv.pinned_prefix_cap_ratio) + assert result["cap_limit"] == cap_limit + + # Check BlockPool global counter does not exceed cap + assert kv.block_pool.num_pinned_blocks <= cap_limit + + # With 2 groups and cap ~ 4, expect logical pinned depth == 2 (partial) + assert result["pinned_count"] <= result["requested_count"] + assert result["status"] in {"ok", "partial", "capped"} + assert result["status"] == "partial" + + # Ensure each group's first two blocks are pinned + blocks = kv.coordinator.get_blocks(req.request_id) + for group_blocks in blocks: + if not group_blocks: + continue + assert all(b.is_pinned for b in group_blocks[:2]) + + +# (Per-request unpin method removed to keep surface minimal.) + + +def test_unpin_all_pinned_prefixes_clears_pool(): + """Global unpin clears all pinned blocks regardless of request id.""" + block_size = 4 + kv = _make_manager( + blocks=24, block_size=block_size, cap_ratio=0.5, enable_pin=True, num_groups=1 + ) + req = _make_request("unp_all", list(range(12)), block_size, pin=True) + kv.allocate_slots(req, num_new_tokens=len(req.all_token_ids)) + kv.cache_blocks(req, num_computed_tokens=len(req.all_token_ids) - 1) + + assert kv.block_pool.num_pinned_blocks > 0 + unpinned = kv.unpin_all_pinned_prefixes() + assert unpinned >= 1 + assert kv.block_pool.num_pinned_blocks == 0 + + +def test_concurrent_prefix_sharing_and_pinned_eviction_protection(): + """Two requests share pinned prefix; evictions avoided for pins.""" + block_size = 4 + kv = _make_manager(blocks=24, block_size=block_size) + + # Prompt spans 3 full blocks (12 tokens). + prompt = list(range(12)) + + # r1: enable pin_prefix so its full-prefix blocks get pinned. + r1 = _make_request("r1", prompt, block_size, pin=True) + computed_r1, hits_r1 = kv.get_computed_blocks(r1) + assert hits_r1 == 0 + assert all(len(g) == 0 for g in computed_r1.blocks) + + kv.allocate_slots(r1, num_new_tokens=len(prompt)) + kv.cache_blocks(r1, num_computed_tokens=len(prompt) - 1) + + num_pinned_blocks = (len(prompt) - 1) // block_size + r1_blocks = kv.coordinator.get_blocks(r1.request_id)[0] + assert len(r1_blocks) >= num_pinned_blocks + pinned_prefix = r1_blocks[:num_pinned_blocks] + for blk in pinned_prefix: + assert blk.is_pinned is True + + # r2: same prompt; should share the cached prefix blocks. + r2 = _make_request("r2", prompt, block_size, pin=False) + computed_r2, hits_r2 = kv.get_computed_blocks(r2) + assert hits_r2 == num_pinned_blocks * block_size + assert computed_r2.blocks[0] == pinned_prefix + + # Simulate scheduler touching for r2. + kv.block_pool.touch(computed_r2.blocks) + for blk in pinned_prefix: + assert blk.ref_cnt >= 2 + + # Pinned blocks should be protected from eviction. + pool = kv.block_pool + for blk in pinned_prefix: + evicted = pool._maybe_evict_cached_block(blk) + assert evicted is False + assert blk.block_hash is not None + # Verify the block remains in the cached map + assert pool.cached_block_hash_to_block.get_one_block(blk.block_hash) is not None + + +def test_pinned_prefix_cap_and_return_fields(): + """Verify cap is enforced and return dict contains expected fields.""" + block_size = 4 + kv = _make_manager(blocks=11, block_size=block_size, cap_ratio=0.2, enable_pin=True) + req = _make_request("r", list(range(40)), block_size, pin=True) + + kv.allocate_slots(req, num_new_tokens=len(req.all_token_ids)) + result = kv.cache_blocks(req, num_computed_tokens=len(req.all_token_ids) - 1) + + assert set(result.keys()) == { + "pinned", + "pinned_count", + "requested_count", + "cap_limit", + "status", + } + assert result["cap_limit"] == int( + kv.block_pool.num_gpu_blocks * kv.pinned_prefix_cap_ratio + ) + assert result["pinned_count"] <= result["cap_limit"] + assert kv.block_pool.num_pinned_blocks == sum( + 1 for b in kv.block_pool.blocks if b.is_pinned + ) + assert result["status"] in {"ok", "partial", "capped"} + + +def test_pinned_prefix_statuses(): + """Cover disabled / ok / capped cases for status field.""" + block_size = 4 + + # disabled: global gate off + kv = _make_manager( + blocks=11, block_size=block_size, cap_ratio=0.2, enable_pin=False + ) + req = _make_request("r0", list(range(32)), block_size, pin=True) + kv.allocate_slots(req, num_new_tokens=len(req.all_token_ids)) + result = kv.cache_blocks(req, num_computed_tokens=len(req.all_token_ids) - 1) + assert result["status"] == "disabled" + assert result["pinned"] is False + + # ok: cap large enough, all requested pinned + kv = _make_manager(blocks=11, block_size=block_size, cap_ratio=1.0, enable_pin=True) + req = _make_request("r1", list(range(16)), block_size, pin=True) + kv.allocate_slots(req, num_new_tokens=len(req.all_token_ids)) + result = kv.cache_blocks(req, num_computed_tokens=len(req.all_token_ids) - 1) + assert result["status"] == "ok" + assert result["pinned"] is True + assert result["pinned_count"] == result["requested_count"] + + # capped: cap=0, requested>0 but pinned_count==0 + kv = _make_manager(blocks=11, block_size=block_size, cap_ratio=0.0, enable_pin=True) + req = _make_request("r2", list(range(20)), block_size, pin=True) + kv.allocate_slots(req, num_new_tokens=len(req.all_token_ids)) + result = kv.cache_blocks(req, num_computed_tokens=len(req.all_token_ids) - 1) + assert result["requested_count"] > 0 + assert result["pinned_count"] == 0 + assert result["status"] == "capped" + + +# ----------------------------------------------------------------------------- +# Additional tests merged from test_pinned_prefix_caching.py +# ----------------------------------------------------------------------------- + + +def create_request( + request_id: str, prompt_token_ids: list[int], pin_prefix: bool = False +) -> Request: + """Helper function to create a request with optional prefix pinning.""" + sampling_params = SamplingParams(max_tokens=10, pin_prefix=pin_prefix) + block_hasher = get_request_block_hasher(4, sha256_cbor) + + return Request( + request_id=request_id, + prompt_token_ids=prompt_token_ids, + sampling_params=sampling_params, + pooling_params=None, + eos_token_id=None, + block_hasher=block_hasher, + ) + + +class TestPinnedPrefixCaching: + """Test cases for pinned prefix caching functionality (unit-level).""" + + def test_sampling_params_pin_prefix_default(self): + """Test that pin_prefix defaults to False in SamplingParams.""" + params = SamplingParams() + assert params.pin_prefix is False + + def test_sampling_params_pin_prefix_enabled(self): + """Test that pin_prefix can be set to True in SamplingParams.""" + params = SamplingParams(pin_prefix=True) + assert params.pin_prefix is True + + def test_sampling_params_from_optional_pin_prefix(self): + """Test that pin_prefix is correctly passed through from_optional.""" + params = SamplingParams.from_optional(pin_prefix=True) + assert params.pin_prefix is True + + def test_block_pool_pin_blocks(self): + """Test that blocks can be pinned to prevent eviction.""" + + block_pool = BlockPool(num_gpu_blocks=10, enable_caching=True) + + # Get some blocks + blocks = block_pool.get_new_blocks(3) + + # Pin the blocks + block_pool.pin_blocks(blocks) + + # Verify blocks are pinned + for block in blocks: + assert block.is_pinned is True + assert block.ref_cnt >= 1 + + def test_block_pool_unpin_blocks(self): + """Test that pinned blocks can be unpinned.""" + + block_pool = BlockPool(num_gpu_blocks=10, enable_caching=True) + + # Get and pin some blocks + blocks = block_pool.get_new_blocks(3) + block_pool.pin_blocks(blocks) + + # Unpin the blocks + block_pool.unpin_blocks(blocks) + + # Verify blocks are unpinned + for block in blocks: + assert block.is_pinned is False + + def test_pinned_blocks_protected_from_eviction(self): + """Test that pinned blocks are protected from eviction.""" + + block_pool = BlockPool(num_gpu_blocks=10, enable_caching=True) + + # Get some blocks and make them cached + blocks = block_pool.get_new_blocks(3) + + # Simulate caching by setting block hash using the BlockPool API + for i, block in enumerate(blocks): + # Set a dummy hash to make it cached + dummy_hash = f"dummy_hash_{i}".encode() + # Compose a BlockHashWithGroupId and set via the property + from vllm.v1.core.kv_cache_utils import make_block_hash_with_group_id + + bh = make_block_hash_with_group_id(dummy_hash, 0) + block.block_hash = bh + # Insert via public method using the same key + block_pool.cached_block_hash_to_block.insert(bh, block) + + # Pin one of the blocks + block_pool.pin_blocks([blocks[0]]) + + # Try to evict all blocks + for block in blocks: + evicted = block_pool._maybe_evict_cached_block(block) + if block == blocks[0]: + # Pinned block should not be evicted + assert evicted is False + assert block.block_hash is not None # Still has hash + else: + # Non-pinned blocks should be evicted + assert evicted is True + assert block.block_hash is None # Hash removed + + def test_cache_blocks_with_pin_prefix(self): + """Test pin_prefix setting is correctly stored in SamplingParams.""" + # Create a request with pin_prefix enabled + request = create_request("test_request", [1, 2, 3, 4, 5, 6], pin_prefix=True) + + # Verify that pin_prefix is correctly set + assert request.sampling_params.pin_prefix is True + + # Test calculating blocks to pin + block_size = 4 + num_computed_tokens = 6 + num_blocks_to_pin = num_computed_tokens // block_size # 1 since 6>=4 + + assert num_blocks_to_pin == 1 + + def test_cache_blocks_with_multiple_full_blocks_pinned(self): + """Test calculating multiple full blocks for pinning.""" + from vllm.utils import sha256_cbor + from vllm.v1.core.kv_cache_utils import init_none_hash + + # Initialize the hash function + init_none_hash(sha256_cbor) + + # Create request with pin_prefix enabled and enough tokens for blocks + request = create_request("test_request", list(range(20)), pin_prefix=True) + + # Verify that pin_prefix is correctly set + assert request.sampling_params.pin_prefix is True + + # Test calculating blocks to pin with multiple full blocks + block_size = 4 + num_computed_tokens = 16 # 4 full blocks + num_blocks_to_pin = num_computed_tokens // block_size # Should be 4 + + # Check that the calculation is correct + assert num_blocks_to_pin == 4 + + def test_cache_blocks_without_pin_prefix(self): + """Test that pin_prefix defaults to False when not specified.""" + from vllm.utils import sha256_cbor + from vllm.v1.core.kv_cache_utils import init_none_hash + + # Initialize the hash function + init_none_hash(sha256_cbor) + + # Create a request without pin_prefix + request = create_request("test_request", list(range(20)), pin_prefix=False) + + # Verify that pin_prefix is correctly set to False + assert request.sampling_params.pin_prefix is False diff --git a/tests/v1/engine/test_pinned_prefix_caching_integration.py b/tests/v1/engine/test_pinned_prefix_caching_integration.py new file mode 100644 index 0000000000000..ec705be3c59a9 --- /dev/null +++ b/tests/v1/engine/test_pinned_prefix_caching_integration.py @@ -0,0 +1,192 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os +import time +import uuid + +import pytest +import torch +from transformers import AutoTokenizer + +from ...utils import create_new_process_for_each_test + +# Skip early if CUDA is unavailable to avoid importing heavy modules. +if not torch.cuda.is_available(): + pytest.skip(reason="V1 currently only supported on CUDA.", allow_module_level=True) + +# Heavy imports (only after CUDA check) +from vllm import SamplingParams +from vllm.engine.arg_utils import EngineArgs +from vllm.utils import set_default_torch_num_threads +from vllm.v1.engine import EngineCoreRequest +from vllm.v1.engine.core import EngineCore +from vllm.v1.executor.abstract import Executor + + +def _resolve_model_and_tokenizer(): + """Resolve a test model without hardcoded local paths. + + Policy: prefer explicit env configuration; otherwise, fall back to a + realistic small model (Qwen/Qwen2-0.5B) when online. Avoid tiny models + that may exhibit scheduler timing quirks. + + Offline/local-only mode (HF_HUB_OFFLINE=1 or VLLM_TEST_LOCAL_ONLY=1) + enforces local loading and skips if unavailable. + """ + local_only = ( + bool(os.getenv("HF_HUB_OFFLINE")) or os.getenv("VLLM_TEST_LOCAL_ONLY") == "1" + ) + + # 1) Explicit model name or local path + env_model = os.getenv("VLLM_TEST_MODEL_NAME") + if env_model: + try: + tok = AutoTokenizer.from_pretrained(env_model, local_files_only=local_only) + return env_model, tok + except Exception as e: # pragma: no cover + last_err = e + pytest.skip( + reason=( + "VLLM_TEST_MODEL_NAME is set but cannot be loaded. " + f"Last error: {last_err}" + ), + allow_module_level=True, + ) + + # 2) Explicit local model directory + env_local_dir = os.getenv("VLLM_TEST_LOCAL_MODEL_DIR") + if env_local_dir and os.path.isdir(env_local_dir): + try: + tok = AutoTokenizer.from_pretrained(env_local_dir, local_files_only=True) + return env_local_dir, tok + except Exception as e: # pragma: no cover + last_err = e + pytest.skip( + reason=( + "VLLM_TEST_LOCAL_MODEL_DIR is set but cannot be loaded. " + f"Last error: {last_err}" + ), + allow_module_level=True, + ) + + # 3) Online fallback to Qwen 0.5B (no offline fallback) + if not local_only: + try: + name = "Qwen/Qwen2-0.5B" + tok = AutoTokenizer.from_pretrained(name, local_files_only=False) + return name, tok + except Exception as e: # pragma: no cover + last_err = e + # fall through to skip below + else: + last_err = RuntimeError("Offline mode and no local model available.") + + pytest.skip( + reason=( + "No usable test model configured. Please set VLLM_TEST_MODEL_NAME " + "(HF model id or local path) or VLLM_TEST_LOCAL_MODEL_DIR to a " + "local model directory. Offline mode is respected via " + f"HF_HUB_OFFLINE/VLLM_TEST_LOCAL_ONLY. Last error: {last_err}" + ), + allow_module_level=True, + ) + + +MODEL_NAME, TOKENIZER = _resolve_model_and_tokenizer() + + +def _make_request(prompt_token_ids: list[int], pin_prefix: bool) -> EngineCoreRequest: + return EngineCoreRequest( + request_id=str(uuid.uuid4()), + prompt_token_ids=prompt_token_ids, + mm_features=None, + sampling_params=SamplingParams(max_tokens=10, pin_prefix=pin_prefix), + pooling_params=None, + eos_token_id=None, + arrival_time=time.time(), + lora_request=None, + cache_salt=None, + data_parallel_rank=None, + ) + + +@create_new_process_for_each_test() +def test_pinned_prefix_blocks_and_cache_hits(monkeypatch: pytest.MonkeyPatch): + """ + End-to-end test: drive EngineCore scheduling with pin_prefix enabled and + validate (1) pinned full prefix blocks and (2) cache-hit tokens for a + subsequent request with the same prompt. + """ + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + + # Configure small block_size to make assertions easy and deterministic. + # Keep eager mode to reduce startup overhead in tests. + engine_args = EngineArgs( + model=MODEL_NAME, + block_size=16, + enable_prefix_caching=True, + enable_pinned_prefix=True, + enforce_eager=True, + dtype="half", # match debug_vllm.py for compatibility + max_model_len=128, + gpu_memory_utilization=float(os.getenv("VLLM_TEST_GPU_UTIL", 0.85)), + # Keep batch small to reduce memory. + max_num_batched_tokens=128, + ) + vllm_config = engine_args.create_engine_config() + executor_class = Executor.get_class(vllm_config) + + with set_default_torch_num_threads(1): + engine_core = EngineCore( + vllm_config=vllm_config, executor_class=executor_class, log_stats=False + ) + # Sanity: global gate for pinned prefix + kv_mgr = engine_core.scheduler.kv_cache_manager + assert kv_mgr.enable_pinned_prefix is True + + # Build a prompt with enough tokens to fill multiple blocks. + # We rely on tokenizer to compute the concrete token length. + text = "Hello world! " * 8 # heuristic, typically > 20 tokens + prompt_token_ids = TOKENIZER(text).input_ids + assert len(prompt_token_ids) >= 8 + + # First request: enable pin_prefix so its full prefix blocks get pinned + # during caching inside allocation. + req1 = _make_request(prompt_token_ids, pin_prefix=True) + engine_core.add_request(*engine_core.preprocess_add_request(req1)) + # One step schedules prefill and commits cached full blocks. + _ = engine_core.step() + + # Ensure pinning (idempotent) to guard against scheduler timing where + # allocation happens right after first execution. This mirrors the + # manager's early pin behavior and is a no-op if already pinned. + req1_live = engine_core.scheduler.requests[req1.request_id] + engine_core.scheduler.kv_cache_manager.cache_blocks( + req1_live, num_computed_tokens=len(prompt_token_ids) - 1 + ) + + # We do not assert block-level is_pinned here because the scheduler + # may not have persisted per-request blocks yet for new requests in + # the first step. Instead, we validate via the next request's + # cache-hit accounting below. + block_size = vllm_config.cache_config.block_size + + # Second request: same prompt, pin_prefix disabled. + req2 = _make_request(prompt_token_ids, pin_prefix=False) + # Preprocess to obtain the internal Request for direct cache-hit check. + req2_internal, wave = engine_core.preprocess_add_request(req2) + computed_blocks, num_hits = ( + engine_core.scheduler.kv_cache_manager.get_computed_blocks(req2_internal) + ) + # Verify cache-hit token count via manager API. This is robust across + # scheduler timing and matches the (N-1) rule. + expected_cached_tokens = ( + (len(prompt_token_ids) - 1) // block_size + ) * block_size + assert num_hits == expected_cached_tokens + + # Do not add/step the second request here to avoid scheduler timing + # dependencies; the cache-hit verification above is sufficient to + # validate pinned prefix effectiveness. diff --git a/vllm/config/cache.py b/vllm/config/cache.py index 04b1e7bf2ac1d..df9fbf37c88ff 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -125,6 +125,17 @@ class CacheConfig: gpu_memory_utilization. Note that kv_cache_memory_bytes (when not-None) ignores gpu_memory_utilization""" + pinned_prefix_cap_ratio: float = Field(default=0.2, ge=0, le=1) + """Maximum fraction of total GPU blocks that may be pinned for prefix + caching per engine instance (in [0.0, 1.0]). Defaults to 0.2 (20%). This + cap prevents the prefix cache from occupying all blocks, improving + stability under load.""" + + enable_pinned_prefix: bool = False + """Global gate for pinned-prefix behavior. If False, requests with + pin_prefix=True will not pin any blocks. Default is disabled to allow + conservative rollouts.""" + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 654857315b15c..571122ae2b1f6 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -528,6 +528,8 @@ class EngineArgs: async_scheduling: bool = SchedulerConfig.async_scheduling kv_sharing_fast_prefill: bool = CacheConfig.kv_sharing_fast_prefill + pinned_prefix_cap_ratio: float = CacheConfig.pinned_prefix_cap_ratio + enable_pinned_prefix: bool = CacheConfig.enable_pinned_prefix def __post_init__(self): # support `EngineArgs(compilation_config={...})` @@ -869,6 +871,12 @@ class EngineArgs: cache_group.add_argument( "--mamba-ssm-cache-dtype", **cache_kwargs["mamba_ssm_cache_dtype"] ) + cache_group.add_argument( + "--pinned-prefix-cap-ratio", **cache_kwargs["pinned_prefix_cap_ratio"] + ) + cache_group.add_argument( + "--enable-pinned-prefix", **cache_kwargs["enable_pinned_prefix"] + ) # Multimodal related configs multimodal_kwargs = get_kwargs(MultiModalConfig) @@ -1329,6 +1337,8 @@ class EngineArgs: kv_sharing_fast_prefill=self.kv_sharing_fast_prefill, mamba_cache_dtype=self.mamba_cache_dtype, mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype, + pinned_prefix_cap_ratio=self.pinned_prefix_cap_ratio, + enable_pinned_prefix=self.enable_pinned_prefix, ) ray_runtime_env = None diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 870676346b75b..8489e269e3976 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -123,6 +123,15 @@ class EngineClient(ABC): """Reset the prefix cache""" ... + @abstractmethod + async def unpin_all_pinned_prefixes(self) -> int: + """Unpin all pinned KV blocks across the engine instance. + + Returns: + int: Number of blocks unpinned. + """ + ... + @abstractmethod async def sleep(self, level: int = 1) -> None: """Sleep the engine""" diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 0ac0355956908..69d3bced05e76 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -994,6 +994,16 @@ if envs.VLLM_SERVER_DEV_MODE: await engine_client(raw_request).reset_prefix_cache(device) return Response(status_code=200) + @router.post("/unpin_all_pinned_prefixes") + async def unpin_all_pinned_prefixes(raw_request: Request): + """Unpin all pinned KV blocks across the engine instance. + + Returns JSON with count of unpinned blocks. + """ + logger.info("Unpinning all pinned KV blocks ...") + count = await engine_client(raw_request).unpin_all_pinned_prefixes() + return JSONResponse(content={"unpinned": int(count)}) + @router.post("/sleep") async def sleep(raw_request: Request): # get POST params diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 5b8a118280da3..bbc52c750e1e4 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -528,6 +528,15 @@ class ChatCompletionRequest(OpenAIBaseModel): prompt_logprobs: int | None = None allowed_token_ids: list[int] | None = None bad_words: list[str] = Field(default_factory=list) + pin_prefix: bool = Field( + default=False, + description=( + "If true, the prefix of this request will be pinned in the cache, " + "preventing it from being evicted. Pinned prefixes are protected " + "from LRU eviction and will remain in cache even when memory is " + "under pressure." + ), + ) # --8<-- [end:chat-completion-sampling-params] # --8<-- [start:chat-completion-extra-params] @@ -867,6 +876,7 @@ class ChatCompletionRequest(OpenAIBaseModel): logit_bias=self.logit_bias, bad_words=self.bad_words, allowed_token_ids=self.allowed_token_ids, + pin_prefix=self.pin_prefix, extra_args=extra_args or None, ) diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 76b89634f508c..d45570b099ef0 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -220,6 +220,12 @@ class SamplingParams( generated token can complete the sequence.""" _bad_words_token_ids: list[list[int]] | None = None + # Fields used for prefix caching + pin_prefix: bool = False + """Whether to pin the prefix of this request in the cache, preventing it + from being evicted. Pinned prefixes will be prioritized and retained in + cache even when memory is under pressure.""" + @staticmethod def from_optional( n: int | None = 1, @@ -252,6 +258,7 @@ class SamplingParams( logit_bias: dict[int, float] | dict[str, float] | None = None, allowed_token_ids: list[int] | None = None, extra_args: dict[str, Any] | None = None, + pin_prefix: bool = False, ) -> "SamplingParams": if logit_bias is not None: # Convert token_id to integer @@ -303,6 +310,7 @@ class SamplingParams( logit_bias=logit_bias, allowed_token_ids=allowed_token_ids, extra_args=extra_args, + pin_prefix=pin_prefix, ) def __post_init__(self) -> None: diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index 15c06a0b107d8..ef38e36d41525 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -165,6 +165,8 @@ class BlockPool: self.enable_kv_cache_events = enable_kv_cache_events self.kv_event_queue: list[KVCacheEvent] = [] + # Track total number of pinned blocks to avoid O(N) scans. + self.num_pinned_blocks: int = 0 def get_cached_block( self, block_hash: BlockHash, kv_cache_group_ids: list[int] @@ -295,7 +297,8 @@ class BlockPool: def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool: """ If a block is cached in `cached_block_hash_to_block`, we reset its hash - metadata and evict it from the cache. + metadata and evict it from the cache. Pinned blocks are protected from + eviction. Args: block: The block to evict. @@ -303,6 +306,10 @@ class BlockPool: Returns: True if the block is evicted, False otherwise. """ + # Check if the block is pinned and should not be evicted + if block.is_pinned: + return False + block_hash = block.block_hash if block_hash is None: # The block doesn't have hash, eviction is not needed @@ -424,3 +431,69 @@ class BlockPool: events = self.kv_event_queue self.kv_event_queue = [] return events + + def pin_blocks(self, blocks: list[KVCacheBlock]) -> None: + """Pin a list of blocks to prevent them from being evicted. + + Pinned blocks will have their reference count increased to ensure + they remain in use and are not added to the free queue. + + This operation is idempotent: pinning an already-pinned block has no + effect. + + Args: + blocks: A list of blocks to pin. + """ + for block in blocks: + # Idempotency check: skip if already pinned. + if block.is_pinned: + continue + + # Mark as pinned. + block.is_pinned = True + + # If the block is currently on the free list (ref_cnt == 0), + # remove it from the free list before increasing the ref count. + if block.ref_cnt == 0 and ( + block.prev_free_block is not None + or block.next_free_block is not None + or self.free_block_queue.fake_free_list_head.next_free_block == block + ): + self.free_block_queue.remove(block) + + # Increase ref count to reflect the pin reference (pin itself is + # an additional reference). + block.ref_cnt += 1 + + # Update global counter. + self.num_pinned_blocks += 1 + + def unpin_blocks(self, blocks: list[KVCacheBlock]) -> None: + """Unpin a list of blocks, allowing them to be evicted. + + This operation is idempotent: unpinning an already-unpinned block has + no effect. + + Args: + blocks: A list of blocks to unpin. + """ + for block in blocks: + # Idempotency check: skip if already unpinned. + if not block.is_pinned: + continue + + # Clear pinned flag. + block.is_pinned = False + + # Drop the pin reference. + if block.ref_cnt > 0: + block.ref_cnt -= 1 + + # Update global counter. + if self.num_pinned_blocks > 0: + self.num_pinned_blocks -= 1 + + # If this drop brings ref_cnt to 0, add back to the free queue + # (unless it is the null block). + if block.ref_cnt == 0 and not block.is_null: + self.free_block_queue.append_n([block]) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 74176e4b2051c..ad235ec06c2ed 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -100,6 +100,8 @@ class KVCacheManager: log_stats: bool = False, enable_kv_cache_events: bool = False, dcp_world_size: int = 1, + pinned_prefix_cap_ratio: float = 0.2, + enable_pinned_prefix: bool = False, ) -> None: self.max_model_len = max_model_len @@ -142,6 +144,8 @@ class KVCacheManager: self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups) self.block_pool = self.coordinator.block_pool self.kv_cache_config = kv_cache_config + self.pinned_prefix_cap_ratio = pinned_prefix_cap_ratio + self.enable_pinned_prefix = enable_pinned_prefix # Pre-constructed KVCacheBlocks with no blocks, callers should use this # via create_kv_cache_blocks instead of creating new ones to avoid GC @@ -333,7 +337,19 @@ class KVCacheManager: num_tokens_to_cache = min( num_computed_tokens + num_new_tokens, request.num_tokens ) - self.coordinator.cache_blocks(request, num_tokens_to_cache) + # Cache and pin (prompt-only) early to protect blocks before execution. + pin_info = self.cache_blocks(request, num_tokens_to_cache) + # Optionally log pin details when stats logging is enabled. + if self.log_stats and isinstance(pin_info, dict): + status = pin_info.get("status", "disabled") + if status != "disabled": + logger.info( + "Prefix pin: status=%s pinned=%s requested=%s cap=%s", + status, + pin_info.get("pinned_count"), + pin_info.get("requested_count"), + pin_info.get("cap_limit"), + ) return self.create_kv_cache_blocks(new_blocks) @@ -413,13 +429,118 @@ class KVCacheManager: """Get the block ids of a request.""" return self.get_blocks(request_id).get_block_ids() - def cache_blocks(self, request: Request, num_computed_tokens: int) -> None: - """Cache the blocks for the request, if enabled.""" + def cache_blocks(self, request: Request, num_computed_tokens: int): + """Cache the blocks for the request, and handle prefix pinning. + + Returns a dict describing pinning results for observability: + - pinned: bool, whether any blocks were pinned by this call + - pinned_count: int, number of blocks pinned by this call (single group) + - requested_count: int, number of blocks requested to pin + - cap_limit: int, maximum allowed pinned blocks under current cap + - status: str, one of {"disabled", "ok", "partial", "capped"} + """ + # Always perform caching if enabled. if self.enable_caching: self.coordinator.cache_blocks(request, num_computed_tokens) + # Default result when pinning is disabled globally or per-request. + result = { + "pinned": False, + "pinned_count": 0, + "requested_count": 0, + "cap_limit": int( + self.block_pool.num_gpu_blocks * self.pinned_prefix_cap_ratio + ) + if self.block_size is not None + else 0, + "status": "disabled", + } + + # Check pinning gates. + if not (self.enable_caching and self.enable_pinned_prefix): + return result + if request.sampling_params is None or not getattr( + request.sampling_params, "pin_prefix", False + ): + return result + if self.block_size is None: + return result + + # Consider prompt tokens only: prefix caching excludes last-token logits. + prompt_tokens = max(request.num_prompt_tokens - 1, 0) + effective_tokens = min(max(num_computed_tokens, 0), prompt_tokens) + + # Determine how many full blocks to pin from the computed prompt prefix. + requested_blocks = effective_tokens // self.block_size + result["requested_count"] = requested_blocks + + if requested_blocks == 0: + result["status"] = "ok" + return result + + # Enforce global cap on total pinned blocks across the pool. + cap_limit = int(self.block_pool.num_gpu_blocks * self.pinned_prefix_cap_ratio) + pinned_current = self.block_pool.num_pinned_blocks + budget = max(cap_limit - pinned_current, 0) + result["cap_limit"] = cap_limit + + # Round-robin pin by logical prefix depth across groups to respect + # the global budget. This ensures we never exceed the cap even when + # there are multiple KV cache groups. + if budget > 0 and requested_blocks > 0: + blocks = self.coordinator.get_blocks(request.request_id) + active_groups = [g for g in blocks if g] + remaining = budget + depth = 0 + while remaining > 0 and depth < requested_blocks: + for group_blocks in active_groups: + if remaining == 0: + break + if depth < len(group_blocks): + blk = group_blocks[depth] + if not blk.is_pinned: + self.block_pool.pin_blocks([blk]) + remaining -= 1 + depth += 1 + + # Compute logical pinned depth across all active groups + blocks = self.coordinator.get_blocks(request.request_id) + active_groups = [g for g in blocks if g] + if active_groups: + pinned_per_group = [ + sum(1 for b in g[:requested_blocks] if b.is_pinned) + for g in active_groups + ] + logical_pinned = min(pinned_per_group) + else: + logical_pinned = 0 + + result["pinned"] = logical_pinned > 0 + result["pinned_count"] = logical_pinned + if logical_pinned == requested_blocks: + result["status"] = "ok" + elif logical_pinned == 0: + # If budget was zero, we were hard-capped; otherwise partial. + result["status"] = "capped" if budget == 0 else "partial" + else: + result["status"] = "partial" + return result + def create_kv_cache_blocks( self, blocks: tuple[list[KVCacheBlock], ...] ) -> KVCacheBlocks: # Only create new KVCacheBlocks for non-empty blocks return KVCacheBlocks(blocks) if any(blocks) else self.empty_kv_cache_blocks + + # (Removed per-request unpin to keep surface minimal) + + def unpin_all_pinned_prefixes(self) -> int: + """Unpin all pinned KV blocks across the pool. + + Returns: + int: Number of blocks unpinned. + """ + pinned = [b for b in self.block_pool.blocks if b.is_pinned] + if pinned: + self.block_pool.unpin_blocks(pinned) + return len(pinned) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 6c9a77ccb2b6a..9a6d23f8052ac 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -119,6 +119,10 @@ class KVCacheBlock: # Whether the block is a null block that should never be cached. is_null: bool = False + # Whether the block is pinned and should not be evicted from cache. + # Pinned blocks are protected from LRU eviction and will remain in cache + # until manually unpinned or freed. + is_pinned: bool = False @property def block_hash(self) -> BlockHashWithGroupId | None: diff --git a/vllm/v1/core/sched/interface.py b/vllm/v1/core/sched/interface.py index c36483203343d..828eb499116e0 100644 --- a/vllm/v1/core/sched/interface.py +++ b/vllm/v1/core/sched/interface.py @@ -135,6 +135,15 @@ class SchedulerInterface(ABC): """ raise NotImplementedError + @abstractmethod + def unpin_all_pinned_prefixes(self) -> int: + """Unpin all pinned KV blocks across all requests. + + Returns: + int: Number of blocks unpinned. + """ + raise NotImplementedError + @abstractmethod def get_request_counts(self) -> tuple[int, int]: """Returns (num_running_reqs, num_waiting_reqs).""" diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 08368b7d99efe..bd19770a5fc86 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -170,6 +170,8 @@ class Scheduler(SchedulerInterface): log_stats=self.log_stats, enable_kv_cache_events=self.enable_kv_cache_events, dcp_world_size=self.dcp_world_size, + pinned_prefix_cap_ratio=self.cache_config.pinned_prefix_cap_ratio, + enable_pinned_prefix=self.cache_config.enable_pinned_prefix, ) self.use_pp = self.parallel_config.pipeline_parallel_size > 1 @@ -1237,6 +1239,14 @@ class Scheduler(SchedulerInterface): def reset_prefix_cache(self) -> bool: return self.kv_cache_manager.reset_prefix_cache() + def unpin_all_pinned_prefixes(self) -> int: + """Unpin all pinned KV blocks across all requests. + + Returns: + int: Number of blocks unpinned. + """ + return self.kv_cache_manager.unpin_all_pinned_prefixes() + def make_stats( self, spec_decoding_stats: SpecDecodingStats | None = None, diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 584956c1f0eb3..34f160fd19dfd 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -676,6 +676,9 @@ class AsyncLLM(EngineClient): raise ValueError("Not supported on CPU.") await self.engine_core.reset_prefix_cache_async() + async def unpin_all_pinned_prefixes(self) -> int: + return await self.engine_core.unpin_all_pinned_prefixes_async() + async def sleep(self, level: int = 1) -> None: await self.reset_prefix_cache() await self.engine_core.sleep_async(level) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index a21f0715704ad..40689a31ea771 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -424,6 +424,9 @@ class EngineCore: def reset_prefix_cache(self): self.scheduler.reset_prefix_cache() + def unpin_all_pinned_prefixes(self) -> int: + return self.scheduler.unpin_all_pinned_prefixes() + def sleep(self, level: int = 1): self.model_executor.sleep(level) diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index f2e316a909706..93be77460f83c 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -141,6 +141,9 @@ class EngineCoreClient(ABC): def reset_prefix_cache(self) -> None: raise NotImplementedError + def unpin_all_pinned_prefixes(self) -> int: + raise NotImplementedError + def sleep(self, level: int = 1) -> None: raise NotImplementedError @@ -211,6 +214,9 @@ class EngineCoreClient(ABC): async def reset_prefix_cache_async(self) -> None: raise NotImplementedError + async def unpin_all_pinned_prefixes_async(self) -> int: + raise NotImplementedError + async def sleep_async(self, level: int = 1) -> None: raise NotImplementedError @@ -290,6 +296,9 @@ class InprocClient(EngineCoreClient): def reset_prefix_cache(self) -> None: self.engine_core.reset_prefix_cache() + def unpin_all_pinned_prefixes(self) -> int: + return self.engine_core.unpin_all_pinned_prefixes() + def sleep(self, level: int = 1) -> None: self.engine_core.sleep(level) @@ -753,6 +762,9 @@ class SyncMPClient(MPClient): def reset_prefix_cache(self) -> None: self.call_utility("reset_prefix_cache") + def unpin_all_pinned_prefixes(self) -> int: + return self.call_utility("unpin_all_pinned_prefixes") + def add_lora(self, lora_request: LoRARequest) -> bool: return self.call_utility("add_lora", lora_request) @@ -957,6 +969,9 @@ class AsyncMPClient(MPClient): async def reset_prefix_cache_async(self) -> None: await self.call_utility_async("reset_prefix_cache") + async def unpin_all_pinned_prefixes_async(self) -> int: + return await self.call_utility_async("unpin_all_pinned_prefixes") + async def sleep_async(self, level: int = 1) -> None: await self.call_utility_async("sleep", level) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 538fb6a04bd7b..f3e091f19c341 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -331,6 +331,9 @@ class LLMEngine: def reset_prefix_cache(self, device: Device | None = None): self.engine_core.reset_prefix_cache() + def unpin_all_pinned_prefixes(self) -> int: + return self.engine_core.unpin_all_pinned_prefixes() + def sleep(self, level: int = 1): self.engine_core.sleep(level)