vllm/tests/v1/core/test_pinned_prefix.py
dongbo910220 b0cde8866e feat(v1): Implement pinned prefix caching with global unpin API
Core Features:
- Add pin_prefix parameter to SamplingParams for per-request prefix pinning
- Implement pinned prefix caching in V1 engine KVCacheManager
- Add pinned_prefix_cap_ratio (default 0.2) to control memory usage
- Add enable_pinned_prefix global gate for conservative rollouts
- Protect pinned blocks from LRU eviction in BlockPool

Bug Fixes:
- Fix multi-group budget bug with round-robin pinning strategy
- Ensure global cap is never exceeded even with multiple KV cache groups
- Use logical pinned depth (min across groups) for accurate reporting

Management APIs:
- Add HTTP endpoint POST /unpin_all_pinned_prefixes for memory reclamation
- Implement complete call chain: API -> AsyncLLM -> EngineCore -> Scheduler -> KVCacheManager
- Remove per-request unpin to keep API surface minimal

Code Quality:
- Replace manual @field_validator with Field(ge=0, le=1) for cleaner validation
- Add comprehensive test coverage (unit + integration + E2E)
- Add test_multi_group_prefix_pinning_respects_global_cap() for multi-group validation
- Add test_unpin_all_pinned_prefixes_clears_pool() for unpin API validation

Resolves: #23083
Signed-off-by: dongbo910220 <1275604947@qq.com>
2025-10-17 19:38:21 +08:00

405 lines
15 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Pinned prefix caching: concurrency, cap, and status behaviors."""
import pytest # noqa: F401
from vllm.sampling_params import SamplingParams
from vllm.utils import sha256_cbor
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_manager import KVCacheManager
from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash
from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
KVCacheConfig,
KVCacheGroupSpec,
SlidingWindowSpec,
)
from vllm.v1.request import Request
def _make_manager(
blocks: int,
block_size: int,
cap_ratio: float = 0.2,
enable_pin: bool = True,
num_groups: int = 1,
) -> KVCacheManager:
import torch
groups = []
for gi in range(num_groups):
groups.append(
KVCacheGroupSpec(
layer_names=[f"layer_{gi}"],
kv_cache_spec=FullAttentionSpec(
block_size=block_size,
num_kv_heads=2,
head_size=8,
dtype=torch.float16,
),
)
)
cfg = KVCacheConfig(kv_cache_groups=groups, kv_cache_tensors=[], num_blocks=blocks)
init_none_hash(sha256_cbor)
return KVCacheManager(
cfg,
max_model_len=1024,
enable_caching=True,
pinned_prefix_cap_ratio=cap_ratio,
enable_pinned_prefix=enable_pin,
)
def _make_request(
req_id: str, tokens: list[int], block_size: int, pin: bool
) -> Request:
sp = SamplingParams(max_tokens=4, pin_prefix=pin)
hasher = get_request_block_hasher(block_size, sha256_cbor)
return Request(req_id, tokens, sp, None, None, block_hasher=hasher)
def test_multi_group_prefix_pinning_respects_global_cap():
"""Multi-group pinning must not exceed global budget.
Create 2 groups, requested logical depth=3, but cap only allows ~4 pins.
Round-robin should pin depth 0 and 1 across both groups (total 4), and
report logical pinned depth=2 (partial).
"""
import torch
block_size = 4
# Build a hybrid config: one full-attn group + one sliding-window group
groups = [
KVCacheGroupSpec(
layer_names=["layer_fa"],
kv_cache_spec=FullAttentionSpec(
block_size=block_size, num_kv_heads=2, head_size=8, dtype=torch.float16
),
),
KVCacheGroupSpec(
layer_names=["layer_sw"],
kv_cache_spec=SlidingWindowSpec(
block_size=block_size,
num_kv_heads=2,
head_size=8,
dtype=torch.float16,
sliding_window=8,
),
),
]
cfg = KVCacheConfig(kv_cache_groups=groups, kv_cache_tensors=[], num_blocks=20)
init_none_hash(sha256_cbor)
kv = KVCacheManager(
cfg,
max_model_len=1024,
enable_caching=True,
pinned_prefix_cap_ratio=0.2,
enable_pinned_prefix=True,
)
req = _make_request("mg", list(range(20)), block_size, pin=True)
kv.allocate_slots(req, num_new_tokens=len(req.all_token_ids))
num_computed = len(req.all_token_ids) - 1 # exclude last token for logits
result = kv.cache_blocks(req, num_computed_tokens=num_computed)
cap_limit = int(kv.block_pool.num_gpu_blocks * kv.pinned_prefix_cap_ratio)
assert result["cap_limit"] == cap_limit
# Check BlockPool global counter does not exceed cap
assert kv.block_pool.num_pinned_blocks <= cap_limit
# With 2 groups and cap ~ 4, expect logical pinned depth == 2 (partial)
assert result["pinned_count"] <= result["requested_count"]
assert result["status"] in {"ok", "partial", "capped"}
assert result["status"] == "partial"
# Ensure each group's first two blocks are pinned
blocks = kv.coordinator.get_blocks(req.request_id)
for group_blocks in blocks:
if not group_blocks:
continue
assert all(b.is_pinned for b in group_blocks[:2])
# (Per-request unpin method removed to keep surface minimal.)
def test_unpin_all_pinned_prefixes_clears_pool():
"""Global unpin clears all pinned blocks regardless of request id."""
block_size = 4
kv = _make_manager(
blocks=24, block_size=block_size, cap_ratio=0.5, enable_pin=True, num_groups=1
)
req = _make_request("unp_all", list(range(12)), block_size, pin=True)
kv.allocate_slots(req, num_new_tokens=len(req.all_token_ids))
kv.cache_blocks(req, num_computed_tokens=len(req.all_token_ids) - 1)
assert kv.block_pool.num_pinned_blocks > 0
unpinned = kv.unpin_all_pinned_prefixes()
assert unpinned >= 1
assert kv.block_pool.num_pinned_blocks == 0
def test_concurrent_prefix_sharing_and_pinned_eviction_protection():
"""Two requests share pinned prefix; evictions avoided for pins."""
block_size = 4
kv = _make_manager(blocks=24, block_size=block_size)
# Prompt spans 3 full blocks (12 tokens).
prompt = list(range(12))
# r1: enable pin_prefix so its full-prefix blocks get pinned.
r1 = _make_request("r1", prompt, block_size, pin=True)
computed_r1, hits_r1 = kv.get_computed_blocks(r1)
assert hits_r1 == 0
assert all(len(g) == 0 for g in computed_r1.blocks)
kv.allocate_slots(r1, num_new_tokens=len(prompt))
kv.cache_blocks(r1, num_computed_tokens=len(prompt) - 1)
num_pinned_blocks = (len(prompt) - 1) // block_size
r1_blocks = kv.coordinator.get_blocks(r1.request_id)[0]
assert len(r1_blocks) >= num_pinned_blocks
pinned_prefix = r1_blocks[:num_pinned_blocks]
for blk in pinned_prefix:
assert blk.is_pinned is True
# r2: same prompt; should share the cached prefix blocks.
r2 = _make_request("r2", prompt, block_size, pin=False)
computed_r2, hits_r2 = kv.get_computed_blocks(r2)
assert hits_r2 == num_pinned_blocks * block_size
assert computed_r2.blocks[0] == pinned_prefix
# Simulate scheduler touching for r2.
kv.block_pool.touch(computed_r2.blocks)
for blk in pinned_prefix:
assert blk.ref_cnt >= 2
# Pinned blocks should be protected from eviction.
pool = kv.block_pool
for blk in pinned_prefix:
evicted = pool._maybe_evict_cached_block(blk)
assert evicted is False
assert blk.block_hash is not None
# Verify the block remains in the cached map
assert pool.cached_block_hash_to_block.get_one_block(blk.block_hash) is not None
def test_pinned_prefix_cap_and_return_fields():
"""Verify cap is enforced and return dict contains expected fields."""
block_size = 4
kv = _make_manager(blocks=11, block_size=block_size, cap_ratio=0.2, enable_pin=True)
req = _make_request("r", list(range(40)), block_size, pin=True)
kv.allocate_slots(req, num_new_tokens=len(req.all_token_ids))
result = kv.cache_blocks(req, num_computed_tokens=len(req.all_token_ids) - 1)
assert set(result.keys()) == {
"pinned",
"pinned_count",
"requested_count",
"cap_limit",
"status",
}
assert result["cap_limit"] == int(
kv.block_pool.num_gpu_blocks * kv.pinned_prefix_cap_ratio
)
assert result["pinned_count"] <= result["cap_limit"]
assert kv.block_pool.num_pinned_blocks == sum(
1 for b in kv.block_pool.blocks if b.is_pinned
)
assert result["status"] in {"ok", "partial", "capped"}
def test_pinned_prefix_statuses():
"""Cover disabled / ok / capped cases for status field."""
block_size = 4
# disabled: global gate off
kv = _make_manager(
blocks=11, block_size=block_size, cap_ratio=0.2, enable_pin=False
)
req = _make_request("r0", list(range(32)), block_size, pin=True)
kv.allocate_slots(req, num_new_tokens=len(req.all_token_ids))
result = kv.cache_blocks(req, num_computed_tokens=len(req.all_token_ids) - 1)
assert result["status"] == "disabled"
assert result["pinned"] is False
# ok: cap large enough, all requested pinned
kv = _make_manager(blocks=11, block_size=block_size, cap_ratio=1.0, enable_pin=True)
req = _make_request("r1", list(range(16)), block_size, pin=True)
kv.allocate_slots(req, num_new_tokens=len(req.all_token_ids))
result = kv.cache_blocks(req, num_computed_tokens=len(req.all_token_ids) - 1)
assert result["status"] == "ok"
assert result["pinned"] is True
assert result["pinned_count"] == result["requested_count"]
# capped: cap=0, requested>0 but pinned_count==0
kv = _make_manager(blocks=11, block_size=block_size, cap_ratio=0.0, enable_pin=True)
req = _make_request("r2", list(range(20)), block_size, pin=True)
kv.allocate_slots(req, num_new_tokens=len(req.all_token_ids))
result = kv.cache_blocks(req, num_computed_tokens=len(req.all_token_ids) - 1)
assert result["requested_count"] > 0
assert result["pinned_count"] == 0
assert result["status"] == "capped"
# -----------------------------------------------------------------------------
# Additional tests merged from test_pinned_prefix_caching.py
# -----------------------------------------------------------------------------
def create_request(
request_id: str, prompt_token_ids: list[int], pin_prefix: bool = False
) -> Request:
"""Helper function to create a request with optional prefix pinning."""
sampling_params = SamplingParams(max_tokens=10, pin_prefix=pin_prefix)
block_hasher = get_request_block_hasher(4, sha256_cbor)
return Request(
request_id=request_id,
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params,
pooling_params=None,
eos_token_id=None,
block_hasher=block_hasher,
)
class TestPinnedPrefixCaching:
"""Test cases for pinned prefix caching functionality (unit-level)."""
def test_sampling_params_pin_prefix_default(self):
"""Test that pin_prefix defaults to False in SamplingParams."""
params = SamplingParams()
assert params.pin_prefix is False
def test_sampling_params_pin_prefix_enabled(self):
"""Test that pin_prefix can be set to True in SamplingParams."""
params = SamplingParams(pin_prefix=True)
assert params.pin_prefix is True
def test_sampling_params_from_optional_pin_prefix(self):
"""Test that pin_prefix is correctly passed through from_optional."""
params = SamplingParams.from_optional(pin_prefix=True)
assert params.pin_prefix is True
def test_block_pool_pin_blocks(self):
"""Test that blocks can be pinned to prevent eviction."""
block_pool = BlockPool(num_gpu_blocks=10, enable_caching=True)
# Get some blocks
blocks = block_pool.get_new_blocks(3)
# Pin the blocks
block_pool.pin_blocks(blocks)
# Verify blocks are pinned
for block in blocks:
assert block.is_pinned is True
assert block.ref_cnt >= 1
def test_block_pool_unpin_blocks(self):
"""Test that pinned blocks can be unpinned."""
block_pool = BlockPool(num_gpu_blocks=10, enable_caching=True)
# Get and pin some blocks
blocks = block_pool.get_new_blocks(3)
block_pool.pin_blocks(blocks)
# Unpin the blocks
block_pool.unpin_blocks(blocks)
# Verify blocks are unpinned
for block in blocks:
assert block.is_pinned is False
def test_pinned_blocks_protected_from_eviction(self):
"""Test that pinned blocks are protected from eviction."""
block_pool = BlockPool(num_gpu_blocks=10, enable_caching=True)
# Get some blocks and make them cached
blocks = block_pool.get_new_blocks(3)
# Simulate caching by setting block hash using the BlockPool API
for i, block in enumerate(blocks):
# Set a dummy hash to make it cached
dummy_hash = f"dummy_hash_{i}".encode()
# Compose a BlockHashWithGroupId and set via the property
from vllm.v1.core.kv_cache_utils import make_block_hash_with_group_id
bh = make_block_hash_with_group_id(dummy_hash, 0)
block.block_hash = bh
# Insert via public method using the same key
block_pool.cached_block_hash_to_block.insert(bh, block)
# Pin one of the blocks
block_pool.pin_blocks([blocks[0]])
# Try to evict all blocks
for block in blocks:
evicted = block_pool._maybe_evict_cached_block(block)
if block == blocks[0]:
# Pinned block should not be evicted
assert evicted is False
assert block.block_hash is not None # Still has hash
else:
# Non-pinned blocks should be evicted
assert evicted is True
assert block.block_hash is None # Hash removed
def test_cache_blocks_with_pin_prefix(self):
"""Test pin_prefix setting is correctly stored in SamplingParams."""
# Create a request with pin_prefix enabled
request = create_request("test_request", [1, 2, 3, 4, 5, 6], pin_prefix=True)
# Verify that pin_prefix is correctly set
assert request.sampling_params.pin_prefix is True
# Test calculating blocks to pin
block_size = 4
num_computed_tokens = 6
num_blocks_to_pin = num_computed_tokens // block_size # 1 since 6>=4
assert num_blocks_to_pin == 1
def test_cache_blocks_with_multiple_full_blocks_pinned(self):
"""Test calculating multiple full blocks for pinning."""
from vllm.utils import sha256_cbor
from vllm.v1.core.kv_cache_utils import init_none_hash
# Initialize the hash function
init_none_hash(sha256_cbor)
# Create request with pin_prefix enabled and enough tokens for blocks
request = create_request("test_request", list(range(20)), pin_prefix=True)
# Verify that pin_prefix is correctly set
assert request.sampling_params.pin_prefix is True
# Test calculating blocks to pin with multiple full blocks
block_size = 4
num_computed_tokens = 16 # 4 full blocks
num_blocks_to_pin = num_computed_tokens // block_size # Should be 4
# Check that the calculation is correct
assert num_blocks_to_pin == 4
def test_cache_blocks_without_pin_prefix(self):
"""Test that pin_prefix defaults to False when not specified."""
from vllm.utils import sha256_cbor
from vllm.v1.core.kv_cache_utils import init_none_hash
# Initialize the hash function
init_none_hash(sha256_cbor)
# Create a request without pin_prefix
request = create_request("test_request", list(range(20)), pin_prefix=False)
# Verify that pin_prefix is correctly set to False
assert request.sampling_params.pin_prefix is False