Compare commits

...

2 Commits

Author SHA1 Message Date
dongbo910220
1efe104097
Merge 3e65806b30faa70248d7afbe8398468f03d6bfd3 into 6c9fdbf7258146a9e335c50aab12969cd95e9227 2025-10-17 11:08:34 +00:00
dongbo910220
3e65806b30 feat(v1): Implement pinned prefix caching with global unpin API
Core Features:
- Add pin_prefix parameter to SamplingParams for per-request prefix pinning
- Implement pinned prefix caching in V1 engine KVCacheManager
- Add pinned_prefix_cap_ratio (default 0.2) to control memory usage
- Add enable_pinned_prefix global gate for conservative rollouts
- Protect pinned blocks from LRU eviction in BlockPool

Bug Fixes:
- Fix multi-group budget bug with round-robin pinning strategy
- Ensure global cap is never exceeded even with multiple KV cache groups
- Use logical pinned depth (min across groups) for accurate reporting

Management APIs:
- Add HTTP endpoint POST /unpin_all_pinned_prefixes for memory reclamation
- Implement complete call chain: API -> AsyncLLM -> EngineCore -> Scheduler -> KVCacheManager
- Remove per-request unpin to keep API surface minimal

Code Quality:
- Replace manual @field_validator with Field(ge=0, le=1) for cleaner validation
- Add comprehensive test coverage (unit + integration + E2E)
- Add test_multi_group_prefix_pinning_respects_global_cap() for multi-group validation
- Add test_unpin_all_pinned_prefixes_clears_pool() for unpin API validation

Resolves: #23083
Signed-off-by: dongbo910220 <1275604947@qq.com>
2025-10-17 19:02:56 +08:00
19 changed files with 917 additions and 5 deletions

View File

@ -1224,7 +1224,6 @@ typeshed-client==2.8.2
# via jsonargparse
typing-extensions==4.15.0
# via
# aiosignal
# albumentations
# alembic
# chz

View File

@ -0,0 +1,18 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm.config import CacheConfig
def test_invalid_cap_ratio_over_one():
# pinned_prefix_cap_ratio > 1.0 should raise ValueError
with pytest.raises(ValueError):
_ = CacheConfig(pinned_prefix_cap_ratio=1.5)
def test_negative_cap_ratio_raises():
# negative value should raise because ratio must be within [0, 1]
with pytest.raises(ValueError):
_ = CacheConfig(pinned_prefix_cap_ratio=-0.1)

View File

@ -0,0 +1,404 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Pinned prefix caching: concurrency, cap, and status behaviors."""
import pytest # noqa: F401
from vllm.sampling_params import SamplingParams
from vllm.utils import sha256_cbor
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_manager import KVCacheManager
from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash
from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
KVCacheConfig,
KVCacheGroupSpec,
SlidingWindowSpec,
)
from vllm.v1.request import Request
def _make_manager(
blocks: int,
block_size: int,
cap_ratio: float = 0.2,
enable_pin: bool = True,
num_groups: int = 1,
) -> KVCacheManager:
import torch
groups = []
for gi in range(num_groups):
groups.append(
KVCacheGroupSpec(
layer_names=[f"layer_{gi}"],
kv_cache_spec=FullAttentionSpec(
block_size=block_size,
num_kv_heads=2,
head_size=8,
dtype=torch.float16,
),
)
)
cfg = KVCacheConfig(kv_cache_groups=groups, kv_cache_tensors=[], num_blocks=blocks)
init_none_hash(sha256_cbor)
return KVCacheManager(
cfg,
max_model_len=1024,
enable_caching=True,
pinned_prefix_cap_ratio=cap_ratio,
enable_pinned_prefix=enable_pin,
)
def _make_request(
req_id: str, tokens: list[int], block_size: int, pin: bool
) -> Request:
sp = SamplingParams(max_tokens=4, pin_prefix=pin)
hasher = get_request_block_hasher(block_size, sha256_cbor)
return Request(req_id, tokens, sp, None, None, block_hasher=hasher)
def test_multi_group_prefix_pinning_respects_global_cap():
"""Multi-group pinning must not exceed global budget.
Create 2 groups, requested logical depth=3, but cap only allows ~4 pins.
Round-robin should pin depth 0 and 1 across both groups (total 4), and
report logical pinned depth=2 (partial).
"""
import torch
block_size = 4
# Build a hybrid config: one full-attn group + one sliding-window group
groups = [
KVCacheGroupSpec(
layer_names=["layer_fa"],
kv_cache_spec=FullAttentionSpec(
block_size=block_size, num_kv_heads=2, head_size=8, dtype=torch.float16
),
),
KVCacheGroupSpec(
layer_names=["layer_sw"],
kv_cache_spec=SlidingWindowSpec(
block_size=block_size,
num_kv_heads=2,
head_size=8,
dtype=torch.float16,
sliding_window=8,
),
),
]
cfg = KVCacheConfig(kv_cache_groups=groups, kv_cache_tensors=[], num_blocks=20)
init_none_hash(sha256_cbor)
kv = KVCacheManager(
cfg,
max_model_len=1024,
enable_caching=True,
pinned_prefix_cap_ratio=0.2,
enable_pinned_prefix=True,
)
req = _make_request("mg", list(range(20)), block_size, pin=True)
kv.allocate_slots(req, num_new_tokens=len(req.all_token_ids))
num_computed = len(req.all_token_ids) - 1 # exclude last token for logits
result = kv.cache_blocks(req, num_computed_tokens=num_computed)
cap_limit = int(kv.block_pool.num_gpu_blocks * kv.pinned_prefix_cap_ratio)
assert result["cap_limit"] == cap_limit
# Check BlockPool global counter does not exceed cap
assert kv.block_pool.num_pinned_blocks <= cap_limit
# With 2 groups and cap ~ 4, expect logical pinned depth == 2 (partial)
assert result["pinned_count"] <= result["requested_count"]
assert result["status"] in {"ok", "partial", "capped"}
assert result["status"] == "partial"
# Ensure each group's first two blocks are pinned
blocks = kv.coordinator.get_blocks(req.request_id)
for group_blocks in blocks:
if not group_blocks:
continue
assert all(b.is_pinned for b in group_blocks[:2])
# (Per-request unpin method removed to keep surface minimal.)
def test_unpin_all_pinned_prefixes_clears_pool():
"""Global unpin clears all pinned blocks regardless of request id."""
block_size = 4
kv = _make_manager(
blocks=24, block_size=block_size, cap_ratio=0.5, enable_pin=True, num_groups=1
)
req = _make_request("unp_all", list(range(12)), block_size, pin=True)
kv.allocate_slots(req, num_new_tokens=len(req.all_token_ids))
kv.cache_blocks(req, num_computed_tokens=len(req.all_token_ids) - 1)
assert kv.block_pool.num_pinned_blocks > 0
unpinned = kv.unpin_all_pinned_prefixes()
assert unpinned >= 1
assert kv.block_pool.num_pinned_blocks == 0
def test_concurrent_prefix_sharing_and_pinned_eviction_protection():
"""Two requests share pinned prefix; evictions avoided for pins."""
block_size = 4
kv = _make_manager(blocks=24, block_size=block_size)
# Prompt spans 3 full blocks (12 tokens).
prompt = list(range(12))
# r1: enable pin_prefix so its full-prefix blocks get pinned.
r1 = _make_request("r1", prompt, block_size, pin=True)
computed_r1, hits_r1 = kv.get_computed_blocks(r1)
assert hits_r1 == 0
assert all(len(g) == 0 for g in computed_r1.blocks)
kv.allocate_slots(r1, num_new_tokens=len(prompt))
kv.cache_blocks(r1, num_computed_tokens=len(prompt) - 1)
num_pinned_blocks = (len(prompt) - 1) // block_size
r1_blocks = kv.coordinator.get_blocks(r1.request_id)[0]
assert len(r1_blocks) >= num_pinned_blocks
pinned_prefix = r1_blocks[:num_pinned_blocks]
for blk in pinned_prefix:
assert blk.is_pinned is True
# r2: same prompt; should share the cached prefix blocks.
r2 = _make_request("r2", prompt, block_size, pin=False)
computed_r2, hits_r2 = kv.get_computed_blocks(r2)
assert hits_r2 == num_pinned_blocks * block_size
assert computed_r2.blocks[0] == pinned_prefix
# Simulate scheduler touching for r2.
kv.block_pool.touch(computed_r2.blocks)
for blk in pinned_prefix:
assert blk.ref_cnt >= 2
# Pinned blocks should be protected from eviction.
pool = kv.block_pool
for blk in pinned_prefix:
evicted = pool._maybe_evict_cached_block(blk)
assert evicted is False
assert blk.block_hash is not None
# Verify the block remains in the cached map
assert pool.cached_block_hash_to_block.get_one_block(blk.block_hash) is not None
def test_pinned_prefix_cap_and_return_fields():
"""Verify cap is enforced and return dict contains expected fields."""
block_size = 4
kv = _make_manager(blocks=11, block_size=block_size, cap_ratio=0.2, enable_pin=True)
req = _make_request("r", list(range(40)), block_size, pin=True)
kv.allocate_slots(req, num_new_tokens=len(req.all_token_ids))
result = kv.cache_blocks(req, num_computed_tokens=len(req.all_token_ids) - 1)
assert set(result.keys()) == {
"pinned",
"pinned_count",
"requested_count",
"cap_limit",
"status",
}
assert result["cap_limit"] == int(
kv.block_pool.num_gpu_blocks * kv.pinned_prefix_cap_ratio
)
assert result["pinned_count"] <= result["cap_limit"]
assert kv.block_pool.num_pinned_blocks == sum(
1 for b in kv.block_pool.blocks if b.is_pinned
)
assert result["status"] in {"ok", "partial", "capped"}
def test_pinned_prefix_statuses():
"""Cover disabled / ok / capped cases for status field."""
block_size = 4
# disabled: global gate off
kv = _make_manager(
blocks=11, block_size=block_size, cap_ratio=0.2, enable_pin=False
)
req = _make_request("r0", list(range(32)), block_size, pin=True)
kv.allocate_slots(req, num_new_tokens=len(req.all_token_ids))
result = kv.cache_blocks(req, num_computed_tokens=len(req.all_token_ids) - 1)
assert result["status"] == "disabled"
assert result["pinned"] is False
# ok: cap large enough, all requested pinned
kv = _make_manager(blocks=11, block_size=block_size, cap_ratio=1.0, enable_pin=True)
req = _make_request("r1", list(range(16)), block_size, pin=True)
kv.allocate_slots(req, num_new_tokens=len(req.all_token_ids))
result = kv.cache_blocks(req, num_computed_tokens=len(req.all_token_ids) - 1)
assert result["status"] == "ok"
assert result["pinned"] is True
assert result["pinned_count"] == result["requested_count"]
# capped: cap=0, requested>0 but pinned_count==0
kv = _make_manager(blocks=11, block_size=block_size, cap_ratio=0.0, enable_pin=True)
req = _make_request("r2", list(range(20)), block_size, pin=True)
kv.allocate_slots(req, num_new_tokens=len(req.all_token_ids))
result = kv.cache_blocks(req, num_computed_tokens=len(req.all_token_ids) - 1)
assert result["requested_count"] > 0
assert result["pinned_count"] == 0
assert result["status"] == "capped"
# -----------------------------------------------------------------------------
# Additional tests merged from test_pinned_prefix_caching.py
# -----------------------------------------------------------------------------
def create_request(
request_id: str, prompt_token_ids: list[int], pin_prefix: bool = False
) -> Request:
"""Helper function to create a request with optional prefix pinning."""
sampling_params = SamplingParams(max_tokens=10, pin_prefix=pin_prefix)
block_hasher = get_request_block_hasher(4, sha256_cbor)
return Request(
request_id=request_id,
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params,
pooling_params=None,
eos_token_id=None,
block_hasher=block_hasher,
)
class TestPinnedPrefixCaching:
"""Test cases for pinned prefix caching functionality (unit-level)."""
def test_sampling_params_pin_prefix_default(self):
"""Test that pin_prefix defaults to False in SamplingParams."""
params = SamplingParams()
assert params.pin_prefix is False
def test_sampling_params_pin_prefix_enabled(self):
"""Test that pin_prefix can be set to True in SamplingParams."""
params = SamplingParams(pin_prefix=True)
assert params.pin_prefix is True
def test_sampling_params_from_optional_pin_prefix(self):
"""Test that pin_prefix is correctly passed through from_optional."""
params = SamplingParams.from_optional(pin_prefix=True)
assert params.pin_prefix is True
def test_block_pool_pin_blocks(self):
"""Test that blocks can be pinned to prevent eviction."""
block_pool = BlockPool(num_gpu_blocks=10, enable_caching=True)
# Get some blocks
blocks = block_pool.get_new_blocks(3)
# Pin the blocks
block_pool.pin_blocks(blocks)
# Verify blocks are pinned
for block in blocks:
assert block.is_pinned is True
assert block.ref_cnt >= 1
def test_block_pool_unpin_blocks(self):
"""Test that pinned blocks can be unpinned."""
block_pool = BlockPool(num_gpu_blocks=10, enable_caching=True)
# Get and pin some blocks
blocks = block_pool.get_new_blocks(3)
block_pool.pin_blocks(blocks)
# Unpin the blocks
block_pool.unpin_blocks(blocks)
# Verify blocks are unpinned
for block in blocks:
assert block.is_pinned is False
def test_pinned_blocks_protected_from_eviction(self):
"""Test that pinned blocks are protected from eviction."""
block_pool = BlockPool(num_gpu_blocks=10, enable_caching=True)
# Get some blocks and make them cached
blocks = block_pool.get_new_blocks(3)
# Simulate caching by setting block hash using the BlockPool API
for i, block in enumerate(blocks):
# Set a dummy hash to make it cached
dummy_hash = f"dummy_hash_{i}".encode()
# Compose a BlockHashWithGroupId and set via the property
from vllm.v1.core.kv_cache_utils import make_block_hash_with_group_id
bh = make_block_hash_with_group_id(dummy_hash, 0)
block.block_hash = bh
# Insert via public method using the same key
block_pool.cached_block_hash_to_block.insert(bh, block)
# Pin one of the blocks
block_pool.pin_blocks([blocks[0]])
# Try to evict all blocks
for block in blocks:
evicted = block_pool._maybe_evict_cached_block(block)
if block == blocks[0]:
# Pinned block should not be evicted
assert evicted is False
assert block.block_hash is not None # Still has hash
else:
# Non-pinned blocks should be evicted
assert evicted is True
assert block.block_hash is None # Hash removed
def test_cache_blocks_with_pin_prefix(self):
"""Test pin_prefix setting is correctly stored in SamplingParams."""
# Create a request with pin_prefix enabled
request = create_request("test_request", [1, 2, 3, 4, 5, 6], pin_prefix=True)
# Verify that pin_prefix is correctly set
assert request.sampling_params.pin_prefix is True
# Test calculating blocks to pin
block_size = 4
num_computed_tokens = 6
num_blocks_to_pin = num_computed_tokens // block_size # 1 since 6>=4
assert num_blocks_to_pin == 1
def test_cache_blocks_with_multiple_full_blocks_pinned(self):
"""Test calculating multiple full blocks for pinning."""
from vllm.utils import sha256_cbor
from vllm.v1.core.kv_cache_utils import init_none_hash
# Initialize the hash function
init_none_hash(sha256_cbor)
# Create request with pin_prefix enabled and enough tokens for blocks
request = create_request("test_request", list(range(20)), pin_prefix=True)
# Verify that pin_prefix is correctly set
assert request.sampling_params.pin_prefix is True
# Test calculating blocks to pin with multiple full blocks
block_size = 4
num_computed_tokens = 16 # 4 full blocks
num_blocks_to_pin = num_computed_tokens // block_size # Should be 4
# Check that the calculation is correct
assert num_blocks_to_pin == 4
def test_cache_blocks_without_pin_prefix(self):
"""Test that pin_prefix defaults to False when not specified."""
from vllm.utils import sha256_cbor
from vllm.v1.core.kv_cache_utils import init_none_hash
# Initialize the hash function
init_none_hash(sha256_cbor)
# Create a request without pin_prefix
request = create_request("test_request", list(range(20)), pin_prefix=False)
# Verify that pin_prefix is correctly set to False
assert request.sampling_params.pin_prefix is False

View File

@ -0,0 +1,192 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import time
import uuid
import pytest
import torch
from transformers import AutoTokenizer
from ...utils import create_new_process_for_each_test
# Skip early if CUDA is unavailable to avoid importing heavy modules.
if not torch.cuda.is_available():
pytest.skip(reason="V1 currently only supported on CUDA.", allow_module_level=True)
# Heavy imports (only after CUDA check)
from vllm import SamplingParams
from vllm.engine.arg_utils import EngineArgs
from vllm.utils import set_default_torch_num_threads
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core import EngineCore
from vllm.v1.executor.abstract import Executor
def _resolve_model_and_tokenizer():
"""Resolve a test model without hardcoded local paths.
Policy: prefer explicit env configuration; otherwise, fall back to a
realistic small model (Qwen/Qwen2-0.5B) when online. Avoid tiny models
that may exhibit scheduler timing quirks.
Offline/local-only mode (HF_HUB_OFFLINE=1 or VLLM_TEST_LOCAL_ONLY=1)
enforces local loading and skips if unavailable.
"""
local_only = (
bool(os.getenv("HF_HUB_OFFLINE")) or os.getenv("VLLM_TEST_LOCAL_ONLY") == "1"
)
# 1) Explicit model name or local path
env_model = os.getenv("VLLM_TEST_MODEL_NAME")
if env_model:
try:
tok = AutoTokenizer.from_pretrained(env_model, local_files_only=local_only)
return env_model, tok
except Exception as e: # pragma: no cover
last_err = e
pytest.skip(
reason=(
"VLLM_TEST_MODEL_NAME is set but cannot be loaded. "
f"Last error: {last_err}"
),
allow_module_level=True,
)
# 2) Explicit local model directory
env_local_dir = os.getenv("VLLM_TEST_LOCAL_MODEL_DIR")
if env_local_dir and os.path.isdir(env_local_dir):
try:
tok = AutoTokenizer.from_pretrained(env_local_dir, local_files_only=True)
return env_local_dir, tok
except Exception as e: # pragma: no cover
last_err = e
pytest.skip(
reason=(
"VLLM_TEST_LOCAL_MODEL_DIR is set but cannot be loaded. "
f"Last error: {last_err}"
),
allow_module_level=True,
)
# 3) Online fallback to Qwen 0.5B (no offline fallback)
if not local_only:
try:
name = "Qwen/Qwen2-0.5B"
tok = AutoTokenizer.from_pretrained(name, local_files_only=False)
return name, tok
except Exception as e: # pragma: no cover
last_err = e
# fall through to skip below
else:
last_err = RuntimeError("Offline mode and no local model available.")
pytest.skip(
reason=(
"No usable test model configured. Please set VLLM_TEST_MODEL_NAME "
"(HF model id or local path) or VLLM_TEST_LOCAL_MODEL_DIR to a "
"local model directory. Offline mode is respected via "
f"HF_HUB_OFFLINE/VLLM_TEST_LOCAL_ONLY. Last error: {last_err}"
),
allow_module_level=True,
)
MODEL_NAME, TOKENIZER = _resolve_model_and_tokenizer()
def _make_request(prompt_token_ids: list[int], pin_prefix: bool) -> EngineCoreRequest:
return EngineCoreRequest(
request_id=str(uuid.uuid4()),
prompt_token_ids=prompt_token_ids,
mm_features=None,
sampling_params=SamplingParams(max_tokens=10, pin_prefix=pin_prefix),
pooling_params=None,
eos_token_id=None,
arrival_time=time.time(),
lora_request=None,
cache_salt=None,
data_parallel_rank=None,
)
@create_new_process_for_each_test()
def test_pinned_prefix_blocks_and_cache_hits(monkeypatch: pytest.MonkeyPatch):
"""
End-to-end test: drive EngineCore scheduling with pin_prefix enabled and
validate (1) pinned full prefix blocks and (2) cache-hit tokens for a
subsequent request with the same prompt.
"""
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
# Configure small block_size to make assertions easy and deterministic.
# Keep eager mode to reduce startup overhead in tests.
engine_args = EngineArgs(
model=MODEL_NAME,
block_size=16,
enable_prefix_caching=True,
enable_pinned_prefix=True,
enforce_eager=True,
dtype="half", # match debug_vllm.py for compatibility
max_model_len=128,
gpu_memory_utilization=float(os.getenv("VLLM_TEST_GPU_UTIL", 0.85)),
# Keep batch small to reduce memory.
max_num_batched_tokens=128,
)
vllm_config = engine_args.create_engine_config()
executor_class = Executor.get_class(vllm_config)
with set_default_torch_num_threads(1):
engine_core = EngineCore(
vllm_config=vllm_config, executor_class=executor_class, log_stats=False
)
# Sanity: global gate for pinned prefix
kv_mgr = engine_core.scheduler.kv_cache_manager
assert kv_mgr.enable_pinned_prefix is True
# Build a prompt with enough tokens to fill multiple blocks.
# We rely on tokenizer to compute the concrete token length.
text = "Hello world! " * 8 # heuristic, typically > 20 tokens
prompt_token_ids = TOKENIZER(text).input_ids
assert len(prompt_token_ids) >= 8
# First request: enable pin_prefix so its full prefix blocks get pinned
# during caching inside allocation.
req1 = _make_request(prompt_token_ids, pin_prefix=True)
engine_core.add_request(*engine_core.preprocess_add_request(req1))
# One step schedules prefill and commits cached full blocks.
_ = engine_core.step()
# Ensure pinning (idempotent) to guard against scheduler timing where
# allocation happens right after first execution. This mirrors the
# manager's early pin behavior and is a no-op if already pinned.
req1_live = engine_core.scheduler.requests[req1.request_id]
engine_core.scheduler.kv_cache_manager.cache_blocks(
req1_live, num_computed_tokens=len(prompt_token_ids) - 1
)
# We do not assert block-level is_pinned here because the scheduler
# may not have persisted per-request blocks yet for new requests in
# the first step. Instead, we validate via the next request's
# cache-hit accounting below.
block_size = vllm_config.cache_config.block_size
# Second request: same prompt, pin_prefix disabled.
req2 = _make_request(prompt_token_ids, pin_prefix=False)
# Preprocess to obtain the internal Request for direct cache-hit check.
req2_internal, wave = engine_core.preprocess_add_request(req2)
computed_blocks, num_hits = (
engine_core.scheduler.kv_cache_manager.get_computed_blocks(req2_internal)
)
# Verify cache-hit token count via manager API. This is robust across
# scheduler timing and matches the (N-1) rule.
expected_cached_tokens = (
(len(prompt_token_ids) - 1) // block_size
) * block_size
assert num_hits == expected_cached_tokens
# Do not add/step the second request here to avoid scheduler timing
# dependencies; the cache-hit verification above is sufficient to
# validate pinned prefix effectiveness.

View File

@ -125,6 +125,17 @@ class CacheConfig:
gpu_memory_utilization. Note that kv_cache_memory_bytes
(when not-None) ignores gpu_memory_utilization"""
pinned_prefix_cap_ratio: float = Field(default=0.2, ge=0, le=1)
"""Maximum fraction of total GPU blocks that may be pinned for prefix
caching per engine instance (in [0.0, 1.0]). Defaults to 0.2 (20%). This
cap prevents the prefix cache from occupying all blocks, improving
stability under load."""
enable_pinned_prefix: bool = False
"""Global gate for pinned-prefix behavior. If False, requests with
pin_prefix=True will not pin any blocks. Default is disabled to allow
conservative rollouts."""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,

View File

@ -538,6 +538,8 @@ class EngineArgs:
async_scheduling: bool = SchedulerConfig.async_scheduling
kv_sharing_fast_prefill: bool = CacheConfig.kv_sharing_fast_prefill
pinned_prefix_cap_ratio: float = CacheConfig.pinned_prefix_cap_ratio
enable_pinned_prefix: bool = CacheConfig.enable_pinned_prefix
def __post_init__(self):
# support `EngineArgs(compilation_config={...})`
@ -879,6 +881,12 @@ class EngineArgs:
cache_group.add_argument(
"--mamba-ssm-cache-dtype", **cache_kwargs["mamba_ssm_cache_dtype"]
)
cache_group.add_argument(
"--pinned-prefix-cap-ratio", **cache_kwargs["pinned_prefix_cap_ratio"]
)
cache_group.add_argument(
"--enable-pinned-prefix", **cache_kwargs["enable_pinned_prefix"]
)
# Multimodal related configs
multimodal_kwargs = get_kwargs(MultiModalConfig)
@ -1339,6 +1347,8 @@ class EngineArgs:
kv_sharing_fast_prefill=self.kv_sharing_fast_prefill,
mamba_cache_dtype=self.mamba_cache_dtype,
mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype,
pinned_prefix_cap_ratio=self.pinned_prefix_cap_ratio,
enable_pinned_prefix=self.enable_pinned_prefix,
)
ray_runtime_env = None

View File

@ -123,6 +123,15 @@ class EngineClient(ABC):
"""Reset the prefix cache"""
...
@abstractmethod
async def unpin_all_pinned_prefixes(self) -> int:
"""Unpin all pinned KV blocks across the engine instance.
Returns:
int: Number of blocks unpinned.
"""
...
@abstractmethod
async def sleep(self, level: int = 1) -> None:
"""Sleep the engine"""

View File

@ -994,6 +994,16 @@ if envs.VLLM_SERVER_DEV_MODE:
await engine_client(raw_request).reset_prefix_cache(device)
return Response(status_code=200)
@router.post("/unpin_all_pinned_prefixes")
async def unpin_all_pinned_prefixes(raw_request: Request):
"""Unpin all pinned KV blocks across the engine instance.
Returns JSON with count of unpinned blocks.
"""
logger.info("Unpinning all pinned KV blocks ...")
count = await engine_client(raw_request).unpin_all_pinned_prefixes()
return JSONResponse(content={"unpinned": int(count)})
@router.post("/sleep")
async def sleep(raw_request: Request):
# get POST params

View File

@ -529,6 +529,15 @@ class ChatCompletionRequest(OpenAIBaseModel):
prompt_logprobs: int | None = None
allowed_token_ids: list[int] | None = None
bad_words: list[str] = Field(default_factory=list)
pin_prefix: bool = Field(
default=False,
description=(
"If true, the prefix of this request will be pinned in the cache, "
"preventing it from being evicted. Pinned prefixes are protected "
"from LRU eviction and will remain in cache even when memory is "
"under pressure."
),
)
# --8<-- [end:chat-completion-sampling-params]
# --8<-- [start:chat-completion-extra-params]
@ -868,6 +877,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
logit_bias=self.logit_bias,
bad_words=self.bad_words,
allowed_token_ids=self.allowed_token_ids,
pin_prefix=self.pin_prefix,
extra_args=extra_args or None,
)

View File

@ -220,6 +220,12 @@ class SamplingParams(
generated token can complete the sequence."""
_bad_words_token_ids: list[list[int]] | None = None
# Fields used for prefix caching
pin_prefix: bool = False
"""Whether to pin the prefix of this request in the cache, preventing it
from being evicted. Pinned prefixes will be prioritized and retained in
cache even when memory is under pressure."""
@staticmethod
def from_optional(
n: int | None = 1,
@ -252,6 +258,7 @@ class SamplingParams(
logit_bias: dict[int, float] | dict[str, float] | None = None,
allowed_token_ids: list[int] | None = None,
extra_args: dict[str, Any] | None = None,
pin_prefix: bool = False,
) -> "SamplingParams":
if logit_bias is not None:
# Convert token_id to integer
@ -303,6 +310,7 @@ class SamplingParams(
logit_bias=logit_bias,
allowed_token_ids=allowed_token_ids,
extra_args=extra_args,
pin_prefix=pin_prefix,
)
def __post_init__(self) -> None:

View File

@ -165,6 +165,8 @@ class BlockPool:
self.enable_kv_cache_events = enable_kv_cache_events
self.kv_event_queue: list[KVCacheEvent] = []
# Track total number of pinned blocks to avoid O(N) scans.
self.num_pinned_blocks: int = 0
def get_cached_block(
self, block_hash: BlockHash, kv_cache_group_ids: list[int]
@ -295,7 +297,8 @@ class BlockPool:
def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool:
"""
If a block is cached in `cached_block_hash_to_block`, we reset its hash
metadata and evict it from the cache.
metadata and evict it from the cache. Pinned blocks are protected from
eviction.
Args:
block: The block to evict.
@ -303,6 +306,10 @@ class BlockPool:
Returns:
True if the block is evicted, False otherwise.
"""
# Check if the block is pinned and should not be evicted
if block.is_pinned:
return False
block_hash = block.block_hash
if block_hash is None:
# The block doesn't have hash, eviction is not needed
@ -424,3 +431,69 @@ class BlockPool:
events = self.kv_event_queue
self.kv_event_queue = []
return events
def pin_blocks(self, blocks: list[KVCacheBlock]) -> None:
"""Pin a list of blocks to prevent them from being evicted.
Pinned blocks will have their reference count increased to ensure
they remain in use and are not added to the free queue.
This operation is idempotent: pinning an already-pinned block has no
effect.
Args:
blocks: A list of blocks to pin.
"""
for block in blocks:
# Idempotency check: skip if already pinned.
if block.is_pinned:
continue
# Mark as pinned.
block.is_pinned = True
# If the block is currently on the free list (ref_cnt == 0),
# remove it from the free list before increasing the ref count.
if block.ref_cnt == 0 and (
block.prev_free_block is not None
or block.next_free_block is not None
or self.free_block_queue.fake_free_list_head.next_free_block == block
):
self.free_block_queue.remove(block)
# Increase ref count to reflect the pin reference (pin itself is
# an additional reference).
block.ref_cnt += 1
# Update global counter.
self.num_pinned_blocks += 1
def unpin_blocks(self, blocks: list[KVCacheBlock]) -> None:
"""Unpin a list of blocks, allowing them to be evicted.
This operation is idempotent: unpinning an already-unpinned block has
no effect.
Args:
blocks: A list of blocks to unpin.
"""
for block in blocks:
# Idempotency check: skip if already unpinned.
if not block.is_pinned:
continue
# Clear pinned flag.
block.is_pinned = False
# Drop the pin reference.
if block.ref_cnt > 0:
block.ref_cnt -= 1
# Update global counter.
if self.num_pinned_blocks > 0:
self.num_pinned_blocks -= 1
# If this drop brings ref_cnt to 0, add back to the free queue
# (unless it is the null block).
if block.ref_cnt == 0 and not block.is_null:
self.free_block_queue.append_n([block])

View File

@ -100,6 +100,8 @@ class KVCacheManager:
log_stats: bool = False,
enable_kv_cache_events: bool = False,
dcp_world_size: int = 1,
pinned_prefix_cap_ratio: float = 0.2,
enable_pinned_prefix: bool = False,
) -> None:
self.max_model_len = max_model_len
@ -142,6 +144,8 @@ class KVCacheManager:
self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups)
self.block_pool = self.coordinator.block_pool
self.kv_cache_config = kv_cache_config
self.pinned_prefix_cap_ratio = pinned_prefix_cap_ratio
self.enable_pinned_prefix = enable_pinned_prefix
# Pre-constructed KVCacheBlocks with no blocks, callers should use this
# via create_kv_cache_blocks instead of creating new ones to avoid GC
@ -333,7 +337,19 @@ class KVCacheManager:
num_tokens_to_cache = min(
num_computed_tokens + num_new_tokens, request.num_tokens
)
self.coordinator.cache_blocks(request, num_tokens_to_cache)
# Cache and pin (prompt-only) early to protect blocks before execution.
pin_info = self.cache_blocks(request, num_tokens_to_cache)
# Optionally log pin details when stats logging is enabled.
if self.log_stats and isinstance(pin_info, dict):
status = pin_info.get("status", "disabled")
if status != "disabled":
logger.info(
"Prefix pin: status=%s pinned=%s requested=%s cap=%s",
status,
pin_info.get("pinned_count"),
pin_info.get("requested_count"),
pin_info.get("cap_limit"),
)
return self.create_kv_cache_blocks(new_blocks)
@ -413,13 +429,118 @@ class KVCacheManager:
"""Get the block ids of a request."""
return self.get_blocks(request_id).get_block_ids()
def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
"""Cache the blocks for the request, if enabled."""
def cache_blocks(self, request: Request, num_computed_tokens: int):
"""Cache the blocks for the request, and handle prefix pinning.
Returns a dict describing pinning results for observability:
- pinned: bool, whether any blocks were pinned by this call
- pinned_count: int, number of blocks pinned by this call (single group)
- requested_count: int, number of blocks requested to pin
- cap_limit: int, maximum allowed pinned blocks under current cap
- status: str, one of {"disabled", "ok", "partial", "capped"}
"""
# Always perform caching if enabled.
if self.enable_caching:
self.coordinator.cache_blocks(request, num_computed_tokens)
# Default result when pinning is disabled globally or per-request.
result = {
"pinned": False,
"pinned_count": 0,
"requested_count": 0,
"cap_limit": int(
self.block_pool.num_gpu_blocks * self.pinned_prefix_cap_ratio
)
if self.block_size is not None
else 0,
"status": "disabled",
}
# Check pinning gates.
if not (self.enable_caching and self.enable_pinned_prefix):
return result
if request.sampling_params is None or not getattr(
request.sampling_params, "pin_prefix", False
):
return result
if self.block_size is None:
return result
# Consider prompt tokens only: prefix caching excludes last-token logits.
prompt_tokens = max(request.num_prompt_tokens - 1, 0)
effective_tokens = min(max(num_computed_tokens, 0), prompt_tokens)
# Determine how many full blocks to pin from the computed prompt prefix.
requested_blocks = effective_tokens // self.block_size
result["requested_count"] = requested_blocks
if requested_blocks == 0:
result["status"] = "ok"
return result
# Enforce global cap on total pinned blocks across the pool.
cap_limit = int(self.block_pool.num_gpu_blocks * self.pinned_prefix_cap_ratio)
pinned_current = self.block_pool.num_pinned_blocks
budget = max(cap_limit - pinned_current, 0)
result["cap_limit"] = cap_limit
# Round-robin pin by logical prefix depth across groups to respect
# the global budget. This ensures we never exceed the cap even when
# there are multiple KV cache groups.
if budget > 0 and requested_blocks > 0:
blocks = self.coordinator.get_blocks(request.request_id)
active_groups = [g for g in blocks if g]
remaining = budget
depth = 0
while remaining > 0 and depth < requested_blocks:
for group_blocks in active_groups:
if remaining == 0:
break
if depth < len(group_blocks):
blk = group_blocks[depth]
if not blk.is_pinned:
self.block_pool.pin_blocks([blk])
remaining -= 1
depth += 1
# Compute logical pinned depth across all active groups
blocks = self.coordinator.get_blocks(request.request_id)
active_groups = [g for g in blocks if g]
if active_groups:
pinned_per_group = [
sum(1 for b in g[:requested_blocks] if b.is_pinned)
for g in active_groups
]
logical_pinned = min(pinned_per_group)
else:
logical_pinned = 0
result["pinned"] = logical_pinned > 0
result["pinned_count"] = logical_pinned
if logical_pinned == requested_blocks:
result["status"] = "ok"
elif logical_pinned == 0:
# If budget was zero, we were hard-capped; otherwise partial.
result["status"] = "capped" if budget == 0 else "partial"
else:
result["status"] = "partial"
return result
def create_kv_cache_blocks(
self, blocks: tuple[list[KVCacheBlock], ...]
) -> KVCacheBlocks:
# Only create new KVCacheBlocks for non-empty blocks
return KVCacheBlocks(blocks) if any(blocks) else self.empty_kv_cache_blocks
# (Removed per-request unpin to keep surface minimal)
def unpin_all_pinned_prefixes(self) -> int:
"""Unpin all pinned KV blocks across the pool.
Returns:
int: Number of blocks unpinned.
"""
pinned = [b for b in self.block_pool.blocks if b.is_pinned]
if pinned:
self.block_pool.unpin_blocks(pinned)
return len(pinned)

View File

@ -119,6 +119,10 @@ class KVCacheBlock:
# Whether the block is a null block that should never be cached.
is_null: bool = False
# Whether the block is pinned and should not be evicted from cache.
# Pinned blocks are protected from LRU eviction and will remain in cache
# until manually unpinned or freed.
is_pinned: bool = False
@property
def block_hash(self) -> BlockHashWithGroupId | None:

View File

@ -135,6 +135,15 @@ class SchedulerInterface(ABC):
"""
raise NotImplementedError
@abstractmethod
def unpin_all_pinned_prefixes(self) -> int:
"""Unpin all pinned KV blocks across all requests.
Returns:
int: Number of blocks unpinned.
"""
raise NotImplementedError
@abstractmethod
def get_request_counts(self) -> tuple[int, int]:
"""Returns (num_running_reqs, num_waiting_reqs)."""

View File

@ -170,6 +170,8 @@ class Scheduler(SchedulerInterface):
log_stats=self.log_stats,
enable_kv_cache_events=self.enable_kv_cache_events,
dcp_world_size=self.dcp_world_size,
pinned_prefix_cap_ratio=self.cache_config.pinned_prefix_cap_ratio,
enable_pinned_prefix=self.cache_config.enable_pinned_prefix,
)
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
@ -1237,6 +1239,14 @@ class Scheduler(SchedulerInterface):
def reset_prefix_cache(self) -> bool:
return self.kv_cache_manager.reset_prefix_cache()
def unpin_all_pinned_prefixes(self) -> int:
"""Unpin all pinned KV blocks across all requests.
Returns:
int: Number of blocks unpinned.
"""
return self.kv_cache_manager.unpin_all_pinned_prefixes()
def make_stats(
self,
spec_decoding_stats: SpecDecodingStats | None = None,

View File

@ -676,6 +676,9 @@ class AsyncLLM(EngineClient):
raise ValueError("Not supported on CPU.")
await self.engine_core.reset_prefix_cache_async()
async def unpin_all_pinned_prefixes(self) -> int:
return await self.engine_core.unpin_all_pinned_prefixes_async()
async def sleep(self, level: int = 1) -> None:
await self.reset_prefix_cache()
await self.engine_core.sleep_async(level)

View File

@ -420,6 +420,9 @@ class EngineCore:
def reset_prefix_cache(self):
self.scheduler.reset_prefix_cache()
def unpin_all_pinned_prefixes(self) -> int:
return self.scheduler.unpin_all_pinned_prefixes()
def sleep(self, level: int = 1):
self.model_executor.sleep(level)

View File

@ -141,6 +141,9 @@ class EngineCoreClient(ABC):
def reset_prefix_cache(self) -> None:
raise NotImplementedError
def unpin_all_pinned_prefixes(self) -> int:
raise NotImplementedError
def sleep(self, level: int = 1) -> None:
raise NotImplementedError
@ -211,6 +214,9 @@ class EngineCoreClient(ABC):
async def reset_prefix_cache_async(self) -> None:
raise NotImplementedError
async def unpin_all_pinned_prefixes_async(self) -> int:
raise NotImplementedError
async def sleep_async(self, level: int = 1) -> None:
raise NotImplementedError
@ -290,6 +296,9 @@ class InprocClient(EngineCoreClient):
def reset_prefix_cache(self) -> None:
self.engine_core.reset_prefix_cache()
def unpin_all_pinned_prefixes(self) -> int:
return self.engine_core.unpin_all_pinned_prefixes()
def sleep(self, level: int = 1) -> None:
self.engine_core.sleep(level)
@ -753,6 +762,9 @@ class SyncMPClient(MPClient):
def reset_prefix_cache(self) -> None:
self.call_utility("reset_prefix_cache")
def unpin_all_pinned_prefixes(self) -> int:
return self.call_utility("unpin_all_pinned_prefixes")
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.call_utility("add_lora", lora_request)
@ -957,6 +969,9 @@ class AsyncMPClient(MPClient):
async def reset_prefix_cache_async(self) -> None:
await self.call_utility_async("reset_prefix_cache")
async def unpin_all_pinned_prefixes_async(self) -> int:
return await self.call_utility_async("unpin_all_pinned_prefixes")
async def sleep_async(self, level: int = 1) -> None:
await self.call_utility_async("sleep", level)

View File

@ -331,6 +331,9 @@ class LLMEngine:
def reset_prefix_cache(self, device: Device | None = None):
self.engine_core.reset_prefix_cache()
def unpin_all_pinned_prefixes(self) -> int:
return self.engine_core.unpin_all_pinned_prefixes()
def sleep(self, level: int = 1):
self.engine_core.sleep(level)