mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 00:34:58 +08:00
Core Features: - Add pin_prefix parameter to SamplingParams for per-request prefix pinning - Implement pinned prefix caching in V1 engine KVCacheManager - Add pinned_prefix_cap_ratio (default 0.2) to control memory usage - Add enable_pinned_prefix global gate for conservative rollouts - Protect pinned blocks from LRU eviction in BlockPool Bug Fixes: - Fix multi-group budget bug with round-robin pinning strategy - Ensure global cap is never exceeded even with multiple KV cache groups - Use logical pinned depth (min across groups) for accurate reporting Management APIs: - Add HTTP endpoint POST /unpin_all_pinned_prefixes for memory reclamation - Implement complete call chain: API -> AsyncLLM -> EngineCore -> Scheduler -> KVCacheManager - Remove per-request unpin to keep API surface minimal Code Quality: - Replace manual @field_validator with Field(ge=0, le=1) for cleaner validation - Add comprehensive test coverage (unit + integration + E2E) - Add test_multi_group_prefix_pinning_respects_global_cap() for multi-group validation - Add test_unpin_all_pinned_prefixes_clears_pool() for unpin API validation Resolves: #23083 Signed-off-by: dongbo910220 <1275604947@qq.com>
405 lines
15 KiB
Python
405 lines
15 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""Pinned prefix caching: concurrency, cap, and status behaviors."""
|
|
|
|
import pytest # noqa: F401
|
|
|
|
from vllm.sampling_params import SamplingParams
|
|
from vllm.utils import sha256_cbor
|
|
from vllm.v1.core.block_pool import BlockPool
|
|
from vllm.v1.core.kv_cache_manager import KVCacheManager
|
|
from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash
|
|
from vllm.v1.kv_cache_interface import (
|
|
FullAttentionSpec,
|
|
KVCacheConfig,
|
|
KVCacheGroupSpec,
|
|
SlidingWindowSpec,
|
|
)
|
|
from vllm.v1.request import Request
|
|
|
|
|
|
def _make_manager(
|
|
blocks: int,
|
|
block_size: int,
|
|
cap_ratio: float = 0.2,
|
|
enable_pin: bool = True,
|
|
num_groups: int = 1,
|
|
) -> KVCacheManager:
|
|
import torch
|
|
|
|
groups = []
|
|
for gi in range(num_groups):
|
|
groups.append(
|
|
KVCacheGroupSpec(
|
|
layer_names=[f"layer_{gi}"],
|
|
kv_cache_spec=FullAttentionSpec(
|
|
block_size=block_size,
|
|
num_kv_heads=2,
|
|
head_size=8,
|
|
dtype=torch.float16,
|
|
),
|
|
)
|
|
)
|
|
cfg = KVCacheConfig(kv_cache_groups=groups, kv_cache_tensors=[], num_blocks=blocks)
|
|
init_none_hash(sha256_cbor)
|
|
return KVCacheManager(
|
|
cfg,
|
|
max_model_len=1024,
|
|
enable_caching=True,
|
|
pinned_prefix_cap_ratio=cap_ratio,
|
|
enable_pinned_prefix=enable_pin,
|
|
)
|
|
|
|
|
|
def _make_request(
|
|
req_id: str, tokens: list[int], block_size: int, pin: bool
|
|
) -> Request:
|
|
sp = SamplingParams(max_tokens=4, pin_prefix=pin)
|
|
hasher = get_request_block_hasher(block_size, sha256_cbor)
|
|
return Request(req_id, tokens, sp, None, None, block_hasher=hasher)
|
|
|
|
|
|
def test_multi_group_prefix_pinning_respects_global_cap():
|
|
"""Multi-group pinning must not exceed global budget.
|
|
|
|
Create 2 groups, requested logical depth=3, but cap only allows ~4 pins.
|
|
Round-robin should pin depth 0 and 1 across both groups (total 4), and
|
|
report logical pinned depth=2 (partial).
|
|
"""
|
|
import torch
|
|
|
|
block_size = 4
|
|
# Build a hybrid config: one full-attn group + one sliding-window group
|
|
groups = [
|
|
KVCacheGroupSpec(
|
|
layer_names=["layer_fa"],
|
|
kv_cache_spec=FullAttentionSpec(
|
|
block_size=block_size, num_kv_heads=2, head_size=8, dtype=torch.float16
|
|
),
|
|
),
|
|
KVCacheGroupSpec(
|
|
layer_names=["layer_sw"],
|
|
kv_cache_spec=SlidingWindowSpec(
|
|
block_size=block_size,
|
|
num_kv_heads=2,
|
|
head_size=8,
|
|
dtype=torch.float16,
|
|
sliding_window=8,
|
|
),
|
|
),
|
|
]
|
|
cfg = KVCacheConfig(kv_cache_groups=groups, kv_cache_tensors=[], num_blocks=20)
|
|
init_none_hash(sha256_cbor)
|
|
kv = KVCacheManager(
|
|
cfg,
|
|
max_model_len=1024,
|
|
enable_caching=True,
|
|
pinned_prefix_cap_ratio=0.2,
|
|
enable_pinned_prefix=True,
|
|
)
|
|
req = _make_request("mg", list(range(20)), block_size, pin=True)
|
|
|
|
kv.allocate_slots(req, num_new_tokens=len(req.all_token_ids))
|
|
num_computed = len(req.all_token_ids) - 1 # exclude last token for logits
|
|
result = kv.cache_blocks(req, num_computed_tokens=num_computed)
|
|
|
|
cap_limit = int(kv.block_pool.num_gpu_blocks * kv.pinned_prefix_cap_ratio)
|
|
assert result["cap_limit"] == cap_limit
|
|
|
|
# Check BlockPool global counter does not exceed cap
|
|
assert kv.block_pool.num_pinned_blocks <= cap_limit
|
|
|
|
# With 2 groups and cap ~ 4, expect logical pinned depth == 2 (partial)
|
|
assert result["pinned_count"] <= result["requested_count"]
|
|
assert result["status"] in {"ok", "partial", "capped"}
|
|
assert result["status"] == "partial"
|
|
|
|
# Ensure each group's first two blocks are pinned
|
|
blocks = kv.coordinator.get_blocks(req.request_id)
|
|
for group_blocks in blocks:
|
|
if not group_blocks:
|
|
continue
|
|
assert all(b.is_pinned for b in group_blocks[:2])
|
|
|
|
|
|
# (Per-request unpin method removed to keep surface minimal.)
|
|
|
|
|
|
def test_unpin_all_pinned_prefixes_clears_pool():
|
|
"""Global unpin clears all pinned blocks regardless of request id."""
|
|
block_size = 4
|
|
kv = _make_manager(
|
|
blocks=24, block_size=block_size, cap_ratio=0.5, enable_pin=True, num_groups=1
|
|
)
|
|
req = _make_request("unp_all", list(range(12)), block_size, pin=True)
|
|
kv.allocate_slots(req, num_new_tokens=len(req.all_token_ids))
|
|
kv.cache_blocks(req, num_computed_tokens=len(req.all_token_ids) - 1)
|
|
|
|
assert kv.block_pool.num_pinned_blocks > 0
|
|
unpinned = kv.unpin_all_pinned_prefixes()
|
|
assert unpinned >= 1
|
|
assert kv.block_pool.num_pinned_blocks == 0
|
|
|
|
|
|
def test_concurrent_prefix_sharing_and_pinned_eviction_protection():
|
|
"""Two requests share pinned prefix; evictions avoided for pins."""
|
|
block_size = 4
|
|
kv = _make_manager(blocks=24, block_size=block_size)
|
|
|
|
# Prompt spans 3 full blocks (12 tokens).
|
|
prompt = list(range(12))
|
|
|
|
# r1: enable pin_prefix so its full-prefix blocks get pinned.
|
|
r1 = _make_request("r1", prompt, block_size, pin=True)
|
|
computed_r1, hits_r1 = kv.get_computed_blocks(r1)
|
|
assert hits_r1 == 0
|
|
assert all(len(g) == 0 for g in computed_r1.blocks)
|
|
|
|
kv.allocate_slots(r1, num_new_tokens=len(prompt))
|
|
kv.cache_blocks(r1, num_computed_tokens=len(prompt) - 1)
|
|
|
|
num_pinned_blocks = (len(prompt) - 1) // block_size
|
|
r1_blocks = kv.coordinator.get_blocks(r1.request_id)[0]
|
|
assert len(r1_blocks) >= num_pinned_blocks
|
|
pinned_prefix = r1_blocks[:num_pinned_blocks]
|
|
for blk in pinned_prefix:
|
|
assert blk.is_pinned is True
|
|
|
|
# r2: same prompt; should share the cached prefix blocks.
|
|
r2 = _make_request("r2", prompt, block_size, pin=False)
|
|
computed_r2, hits_r2 = kv.get_computed_blocks(r2)
|
|
assert hits_r2 == num_pinned_blocks * block_size
|
|
assert computed_r2.blocks[0] == pinned_prefix
|
|
|
|
# Simulate scheduler touching for r2.
|
|
kv.block_pool.touch(computed_r2.blocks)
|
|
for blk in pinned_prefix:
|
|
assert blk.ref_cnt >= 2
|
|
|
|
# Pinned blocks should be protected from eviction.
|
|
pool = kv.block_pool
|
|
for blk in pinned_prefix:
|
|
evicted = pool._maybe_evict_cached_block(blk)
|
|
assert evicted is False
|
|
assert blk.block_hash is not None
|
|
# Verify the block remains in the cached map
|
|
assert pool.cached_block_hash_to_block.get_one_block(blk.block_hash) is not None
|
|
|
|
|
|
def test_pinned_prefix_cap_and_return_fields():
|
|
"""Verify cap is enforced and return dict contains expected fields."""
|
|
block_size = 4
|
|
kv = _make_manager(blocks=11, block_size=block_size, cap_ratio=0.2, enable_pin=True)
|
|
req = _make_request("r", list(range(40)), block_size, pin=True)
|
|
|
|
kv.allocate_slots(req, num_new_tokens=len(req.all_token_ids))
|
|
result = kv.cache_blocks(req, num_computed_tokens=len(req.all_token_ids) - 1)
|
|
|
|
assert set(result.keys()) == {
|
|
"pinned",
|
|
"pinned_count",
|
|
"requested_count",
|
|
"cap_limit",
|
|
"status",
|
|
}
|
|
assert result["cap_limit"] == int(
|
|
kv.block_pool.num_gpu_blocks * kv.pinned_prefix_cap_ratio
|
|
)
|
|
assert result["pinned_count"] <= result["cap_limit"]
|
|
assert kv.block_pool.num_pinned_blocks == sum(
|
|
1 for b in kv.block_pool.blocks if b.is_pinned
|
|
)
|
|
assert result["status"] in {"ok", "partial", "capped"}
|
|
|
|
|
|
def test_pinned_prefix_statuses():
|
|
"""Cover disabled / ok / capped cases for status field."""
|
|
block_size = 4
|
|
|
|
# disabled: global gate off
|
|
kv = _make_manager(
|
|
blocks=11, block_size=block_size, cap_ratio=0.2, enable_pin=False
|
|
)
|
|
req = _make_request("r0", list(range(32)), block_size, pin=True)
|
|
kv.allocate_slots(req, num_new_tokens=len(req.all_token_ids))
|
|
result = kv.cache_blocks(req, num_computed_tokens=len(req.all_token_ids) - 1)
|
|
assert result["status"] == "disabled"
|
|
assert result["pinned"] is False
|
|
|
|
# ok: cap large enough, all requested pinned
|
|
kv = _make_manager(blocks=11, block_size=block_size, cap_ratio=1.0, enable_pin=True)
|
|
req = _make_request("r1", list(range(16)), block_size, pin=True)
|
|
kv.allocate_slots(req, num_new_tokens=len(req.all_token_ids))
|
|
result = kv.cache_blocks(req, num_computed_tokens=len(req.all_token_ids) - 1)
|
|
assert result["status"] == "ok"
|
|
assert result["pinned"] is True
|
|
assert result["pinned_count"] == result["requested_count"]
|
|
|
|
# capped: cap=0, requested>0 but pinned_count==0
|
|
kv = _make_manager(blocks=11, block_size=block_size, cap_ratio=0.0, enable_pin=True)
|
|
req = _make_request("r2", list(range(20)), block_size, pin=True)
|
|
kv.allocate_slots(req, num_new_tokens=len(req.all_token_ids))
|
|
result = kv.cache_blocks(req, num_computed_tokens=len(req.all_token_ids) - 1)
|
|
assert result["requested_count"] > 0
|
|
assert result["pinned_count"] == 0
|
|
assert result["status"] == "capped"
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Additional tests merged from test_pinned_prefix_caching.py
|
|
# -----------------------------------------------------------------------------
|
|
|
|
|
|
def create_request(
|
|
request_id: str, prompt_token_ids: list[int], pin_prefix: bool = False
|
|
) -> Request:
|
|
"""Helper function to create a request with optional prefix pinning."""
|
|
sampling_params = SamplingParams(max_tokens=10, pin_prefix=pin_prefix)
|
|
block_hasher = get_request_block_hasher(4, sha256_cbor)
|
|
|
|
return Request(
|
|
request_id=request_id,
|
|
prompt_token_ids=prompt_token_ids,
|
|
sampling_params=sampling_params,
|
|
pooling_params=None,
|
|
eos_token_id=None,
|
|
block_hasher=block_hasher,
|
|
)
|
|
|
|
|
|
class TestPinnedPrefixCaching:
|
|
"""Test cases for pinned prefix caching functionality (unit-level)."""
|
|
|
|
def test_sampling_params_pin_prefix_default(self):
|
|
"""Test that pin_prefix defaults to False in SamplingParams."""
|
|
params = SamplingParams()
|
|
assert params.pin_prefix is False
|
|
|
|
def test_sampling_params_pin_prefix_enabled(self):
|
|
"""Test that pin_prefix can be set to True in SamplingParams."""
|
|
params = SamplingParams(pin_prefix=True)
|
|
assert params.pin_prefix is True
|
|
|
|
def test_sampling_params_from_optional_pin_prefix(self):
|
|
"""Test that pin_prefix is correctly passed through from_optional."""
|
|
params = SamplingParams.from_optional(pin_prefix=True)
|
|
assert params.pin_prefix is True
|
|
|
|
def test_block_pool_pin_blocks(self):
|
|
"""Test that blocks can be pinned to prevent eviction."""
|
|
|
|
block_pool = BlockPool(num_gpu_blocks=10, enable_caching=True)
|
|
|
|
# Get some blocks
|
|
blocks = block_pool.get_new_blocks(3)
|
|
|
|
# Pin the blocks
|
|
block_pool.pin_blocks(blocks)
|
|
|
|
# Verify blocks are pinned
|
|
for block in blocks:
|
|
assert block.is_pinned is True
|
|
assert block.ref_cnt >= 1
|
|
|
|
def test_block_pool_unpin_blocks(self):
|
|
"""Test that pinned blocks can be unpinned."""
|
|
|
|
block_pool = BlockPool(num_gpu_blocks=10, enable_caching=True)
|
|
|
|
# Get and pin some blocks
|
|
blocks = block_pool.get_new_blocks(3)
|
|
block_pool.pin_blocks(blocks)
|
|
|
|
# Unpin the blocks
|
|
block_pool.unpin_blocks(blocks)
|
|
|
|
# Verify blocks are unpinned
|
|
for block in blocks:
|
|
assert block.is_pinned is False
|
|
|
|
def test_pinned_blocks_protected_from_eviction(self):
|
|
"""Test that pinned blocks are protected from eviction."""
|
|
|
|
block_pool = BlockPool(num_gpu_blocks=10, enable_caching=True)
|
|
|
|
# Get some blocks and make them cached
|
|
blocks = block_pool.get_new_blocks(3)
|
|
|
|
# Simulate caching by setting block hash using the BlockPool API
|
|
for i, block in enumerate(blocks):
|
|
# Set a dummy hash to make it cached
|
|
dummy_hash = f"dummy_hash_{i}".encode()
|
|
# Compose a BlockHashWithGroupId and set via the property
|
|
from vllm.v1.core.kv_cache_utils import make_block_hash_with_group_id
|
|
|
|
bh = make_block_hash_with_group_id(dummy_hash, 0)
|
|
block.block_hash = bh
|
|
# Insert via public method using the same key
|
|
block_pool.cached_block_hash_to_block.insert(bh, block)
|
|
|
|
# Pin one of the blocks
|
|
block_pool.pin_blocks([blocks[0]])
|
|
|
|
# Try to evict all blocks
|
|
for block in blocks:
|
|
evicted = block_pool._maybe_evict_cached_block(block)
|
|
if block == blocks[0]:
|
|
# Pinned block should not be evicted
|
|
assert evicted is False
|
|
assert block.block_hash is not None # Still has hash
|
|
else:
|
|
# Non-pinned blocks should be evicted
|
|
assert evicted is True
|
|
assert block.block_hash is None # Hash removed
|
|
|
|
def test_cache_blocks_with_pin_prefix(self):
|
|
"""Test pin_prefix setting is correctly stored in SamplingParams."""
|
|
# Create a request with pin_prefix enabled
|
|
request = create_request("test_request", [1, 2, 3, 4, 5, 6], pin_prefix=True)
|
|
|
|
# Verify that pin_prefix is correctly set
|
|
assert request.sampling_params.pin_prefix is True
|
|
|
|
# Test calculating blocks to pin
|
|
block_size = 4
|
|
num_computed_tokens = 6
|
|
num_blocks_to_pin = num_computed_tokens // block_size # 1 since 6>=4
|
|
|
|
assert num_blocks_to_pin == 1
|
|
|
|
def test_cache_blocks_with_multiple_full_blocks_pinned(self):
|
|
"""Test calculating multiple full blocks for pinning."""
|
|
from vllm.utils import sha256_cbor
|
|
from vllm.v1.core.kv_cache_utils import init_none_hash
|
|
|
|
# Initialize the hash function
|
|
init_none_hash(sha256_cbor)
|
|
|
|
# Create request with pin_prefix enabled and enough tokens for blocks
|
|
request = create_request("test_request", list(range(20)), pin_prefix=True)
|
|
|
|
# Verify that pin_prefix is correctly set
|
|
assert request.sampling_params.pin_prefix is True
|
|
|
|
# Test calculating blocks to pin with multiple full blocks
|
|
block_size = 4
|
|
num_computed_tokens = 16 # 4 full blocks
|
|
num_blocks_to_pin = num_computed_tokens // block_size # Should be 4
|
|
|
|
# Check that the calculation is correct
|
|
assert num_blocks_to_pin == 4
|
|
|
|
def test_cache_blocks_without_pin_prefix(self):
|
|
"""Test that pin_prefix defaults to False when not specified."""
|
|
from vllm.utils import sha256_cbor
|
|
from vllm.v1.core.kv_cache_utils import init_none_hash
|
|
|
|
# Initialize the hash function
|
|
init_none_hash(sha256_cbor)
|
|
|
|
# Create a request without pin_prefix
|
|
request = create_request("test_request", list(range(20)), pin_prefix=False)
|
|
|
|
# Verify that pin_prefix is correctly set to False
|
|
assert request.sampling_params.pin_prefix is False
|