[V1] Remove pre-allocation for KV cache (#16941)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-04-22 00:52:18 -07:00 committed by GitHub
parent 2689d5c027
commit c4ab9f3e71
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 61 additions and 141 deletions

View File

@ -496,8 +496,7 @@ def test_allocate_with_lookahead():
# Test case 1: Requires additional lookahead tokens
kv_cache_manager = KVCacheManager(kv_cache_config=config,
max_model_len=100,
num_preallocate_tokens=0)
max_model_len=100)
blocks = kv_cache_manager.allocate_slots(
request,
num_tokens=3,
@ -507,25 +506,19 @@ def test_allocate_with_lookahead():
# Test case 2: With precomputed blocks
kv_cache_manager = KVCacheManager(kv_cache_config=config,
max_model_len=100,
num_preallocate_tokens=4)
# num_preallocate_blocks = 4 // 4 - 2 // 4 = 1
max_model_len=100)
# required_blocks = ceil((3 + 2) /4) = 2
# total_blocks = 1 + 2 = 3
blocks = kv_cache_manager.allocate_slots(
request,
num_tokens=3,
num_lookahead_tokens=2,
)
assert len(blocks) == 3
assert len(blocks) == 2
# Test case 3: With precomputed blocks
# num_preallocate_blocks = 4 // 4 - 4 // 4 = 0
# required_blocks = ceil((3 + 4) / 4) = 2
# total_blocks = 0 + 2 = 2
kv_cache_manager = KVCacheManager(kv_cache_config=config,
max_model_len=100,
num_preallocate_tokens=4)
max_model_len=100)
blocks = kv_cache_manager.allocate_slots(
request,
num_tokens=3,

View File

@ -8,7 +8,7 @@ import torch
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.utils import cdiv, sha256
from vllm.utils import sha256
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock,
@ -61,7 +61,6 @@ def test_prefill(hash_algo):
max_model_len=8192,
enable_caching=True,
caching_hash_algo=hash_algo,
num_preallocate_tokens=16,
)
# choose the hash function according to the parameter
@ -80,7 +79,7 @@ def test_prefill(hash_algo):
assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55, computed_blocks)
assert [b.block_id for b in blocks] == [1, 2, 3, 4, 5]
assert [b.block_id for b in blocks] == [1, 2, 3, 4]
# Check full block metadata
parent_block_hash = None
@ -92,8 +91,8 @@ def test_prefill(hash_algo):
assert manager.block_pool.blocks[block_id].ref_cnt == 1
parent_block_hash = block_hash.hash_value
# Check partial/preallocated block metadata
for block_id in (4, 5):
# Check partial block metadata
for block_id in (4, ):
assert manager.block_pool.blocks[block_id].block_hash is None
assert manager.block_pool.blocks[block_id].ref_cnt == 1
@ -107,12 +106,12 @@ def test_prefill(hash_algo):
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks)
assert [b.block_id for b in blocks] == [6, 7]
assert [b.block_id for b in blocks] == [5]
for block in computed_blocks:
assert block.ref_cnt == 2
# At this point, we should have 3 free blocks left.
assert manager.block_pool.free_block_queue.num_free_blocks == 3
# At this point, we should have 5 free blocks left.
assert manager.block_pool.free_block_queue.num_free_blocks == 5
manager.free(req0)
manager.free(req1)
@ -120,14 +119,14 @@ def test_prefill(hash_algo):
# All blocks should be available.
assert manager.block_pool.free_block_queue.num_free_blocks == 10
# The order should be
# [unallocated (8, 9, 10)]
# [unique_req0 (5, 4)]
# [unique_req1 (7, 6)]
# [unallocated (6, 7, 8, 9, 10)]
# [unique_req0 (4)]
# [unique_req1 (5)]
# [common (3, 2, 1)]
assert [
b.block_id
for b in manager.block_pool.free_block_queue.get_all_free_blocks()
] == [8, 9, 10, 5, 4, 7, 6, 3, 2, 1]
] == [6, 7, 8, 9, 10, 4, 5, 3, 2, 1]
# Cache hit in the common prefix when the original block is already free.
# Incomplete 1 block (6 tokens)
@ -139,29 +138,29 @@ def test_prefill(hash_algo):
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks)
assert [b.block_id for b in blocks] == [8, 9]
assert [b.block_id for b in blocks] == [6]
# Although we only have 5 free blocks, we have 8 blocks in
# Although we only have 6 free blocks, we have 8 blocks in
# the free block queue due to lazy removal.
assert manager.block_pool.free_block_queue.num_free_blocks == 5
assert manager.block_pool.free_block_queue.num_free_blocks == 6
assert all([
b.ref_cnt == 0
for b in manager.block_pool.free_block_queue.get_all_free_blocks()
])
assert len([
b for b in manager.block_pool.free_block_queue.get_all_free_blocks()
]) == 5
]) == 6
manager.free(req2)
# Cache miss and eviction.
req3 = make_request("3", [99] * (16 * 9))
req3 = make_request("3", [99] * (16 * 10))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req3, 16 * 9, computed_blocks)
blocks = manager.allocate_slots(req3, 16 * 10, computed_blocks)
# This block ID order also checks the eviction order.
assert [b.block_id for b in blocks] == [10, 5, 4, 7, 6, 9, 8, 3, 2, 1]
assert [b.block_id for b in blocks] == [7, 8, 9, 10, 4, 5, 6, 3, 2, 1]
assert manager.block_pool.free_block_queue.num_free_blocks == 0
assert manager.block_pool.free_block_queue.free_list_head is None
assert manager.block_pool.free_block_queue.free_list_tail is None
@ -178,7 +177,6 @@ def test_prefill_plp():
make_kv_cache_config(16, 11),
max_model_len=8192,
enable_caching=True,
num_preallocate_tokens=16,
)
# the default hash function is hash
hash_fn = hash
@ -197,7 +195,7 @@ def test_prefill_plp():
assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55, computed_blocks)
assert [b.block_id for b in blocks] == [1, 2, 3, 4, 5]
assert [b.block_id for b in blocks] == [1, 2, 3, 4]
req0_block_hashes = [b.block_hash for b in blocks]
# Check full block metadata
@ -210,8 +208,8 @@ def test_prefill_plp():
assert manager.block_pool.blocks[block_id].ref_cnt == 1
parent_block_hash = block_hash.hash_value
# Check partial/preallocated block metadata
for block_id in (4, 5):
# Check partial block metadata
for block_id in (4, ):
assert manager.block_pool.blocks[block_id].block_hash is None
assert manager.block_pool.blocks[block_id].ref_cnt == 1
@ -226,12 +224,12 @@ def test_prefill_plp():
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks)
assert [b.block_id for b in blocks] == [6, 7]
assert [b.block_id for b in blocks] == [5]
for block in computed_blocks:
assert block.ref_cnt == 2
# At this point, we should have 3 free blocks left.
assert manager.block_pool.free_block_queue.num_free_blocks == 3
# At this point, we should have 5 free blocks left.
assert manager.block_pool.free_block_queue.num_free_blocks == 5
manager.free(req0)
manager.free(req1)
@ -239,14 +237,14 @@ def test_prefill_plp():
# All blocks should be available.
assert manager.block_pool.free_block_queue.num_free_blocks == 10
# The order should be
# [unallocated (8, 9, 10)]
# [unique_req0 (5, 4)]
# [unique_req1 (7, 6)]
# [unallocated (6, 7, 8, 9, 10)]
# [unique_req0 (4)]
# [unique_req1 (5)]
# [common (3, 2, 1)]
assert [
b.block_id
for b in manager.block_pool.free_block_queue.get_all_free_blocks()
] == [8, 9, 10, 5, 4, 7, 6, 3, 2, 1]
] == [6, 7, 8, 9, 10, 4, 5, 3, 2, 1]
# Request #2 is a prompt-logprobs request:
# NO cache hit in the common prefix; duplicates request #0 cached blocks
@ -262,7 +260,7 @@ def test_prefill_plp():
block_ids = [b.block_id for b in blocks]
# Duplicate cached blocks have different ids but same hashes vs request #0
assert [b.block_hash for b in blocks] == req0_block_hashes
assert block_ids != [1, 2, 3, 4, 5]
assert block_ids != [1, 2, 3, 4]
# Request #2 block hashes are valid since request #0 hashes are.
# Check block reference counts.
@ -277,7 +275,6 @@ def test_decode():
make_kv_cache_config(16, 11),
max_model_len=8192,
enable_caching=True,
num_preallocate_tokens=16,
)
# Complete 3 blocks (48 tokens)
@ -291,7 +288,7 @@ def test_decode():
assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55, computed_blocks)
assert [b.block_id for b in blocks] == [1, 2, 3, 4, 5]
assert [b.block_id for b in blocks] == [1, 2, 3, 4]
# Append slots without allocating a new block.
req0.num_computed_tokens = 55
@ -299,28 +296,18 @@ def test_decode():
req0.append_output_token_ids(8)
new_blocks = manager.allocate_slots(req0, 4)
assert new_blocks is not None and len(new_blocks) == 0
assert manager.req_to_blocks[req0.request_id][-2].block_hash is None
# Append slots without allocating a new block, but start using the
# preallocated block.
req0.num_computed_tokens = 59
# 6 tokens to fill the previous block, and 10 tokens to fill
# the preallocated block.
for _ in range(5 + 10):
req0.append_output_token_ids(7)
new_blocks = manager.allocate_slots(req0, 15)
assert new_blocks is not None and len(new_blocks) == 0
assert manager.req_to_blocks[req0.request_id][-2].block_hash is not None
assert manager.req_to_blocks[req0.request_id][-1].block_hash is None
# Append slots with allocating a new block.
req0.num_computed_tokens = 74
# 6 tokens to fill the previous block, and 10 tokens to fill
req0.num_computed_tokens = 59
# 9 tokens to fill the previous block, and 10 tokens to fill
# the preallocated block.
for _ in range(6 + 11):
req0.append_output_token_ids(12)
new_blocks = manager.allocate_slots(req0, 17)
# Plus one preallocated block.
assert new_blocks is not None and len(new_blocks) == 2
for _ in range(9 + 10):
req0.append_output_token_ids(7)
new_blocks = manager.allocate_slots(req0, 19)
assert new_blocks is not None and len(new_blocks) == 1
assert manager.req_to_blocks[req0.request_id][-2].block_hash is not None
assert manager.req_to_blocks[req0.request_id][-1].block_hash is None
def test_evict():
@ -328,7 +315,6 @@ def test_evict():
make_kv_cache_config(16, 11),
max_model_len=8192,
enable_caching=True,
num_preallocate_tokens=16,
)
last_token_id = 5 * 16 + 7
@ -337,7 +323,7 @@ def test_evict():
assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 5 * 16 + 7, computed_blocks)
assert len(blocks) == 7 # 5 full + 1 partial + 1 preallocated
assert len(blocks) == 6 # 5 full + 1 partial
# 3 blocks.
req1 = make_request("1", list(range(last_token_id,
@ -349,7 +335,8 @@ def test_evict():
assert len(blocks) == 3 # 3 full blocks
last_token_id += 3 * 16
assert manager.block_pool.free_block_queue.num_free_blocks == 0
# 10 - (6 + 3) == 1
assert manager.block_pool.free_block_queue.num_free_blocks == 1
manager.free(req0)
manager.free(req1)
@ -357,7 +344,7 @@ def test_evict():
assert [
b.block_id
for b in manager.block_pool.free_block_queue.get_all_free_blocks()
] == [7, 6, 5, 4, 3, 2, 1, 10, 9, 8]
] == [10, 6, 5, 4, 3, 2, 1, 9, 8, 7]
# Touch the first 2 blocks.
req2 = make_request("2", list(range(2 * 16 + 3)))
@ -365,8 +352,8 @@ def test_evict():
assert [b.block_id for b in computed_blocks] == [1, 2]
assert num_computed_tokens == 2 * 16
blocks = manager.allocate_slots(req2, 3, computed_blocks)
assert [b.block_id for b in blocks] == [7, 6]
assert manager.block_pool.free_block_queue.num_free_blocks == 6
assert [b.block_id for b in blocks] == [10]
assert manager.block_pool.free_block_queue.num_free_blocks == 7
def test_hash_block_correct_reuse():
@ -379,7 +366,6 @@ def test_hash_block_correct_reuse():
make_kv_cache_config(16, 2),
max_model_len=8192,
enable_caching=True,
num_preallocate_tokens=0,
)
# Allocate 1 block and cache it.
@ -416,7 +402,6 @@ def test_computed_blocks_not_evicted():
make_kv_cache_config(block_size, 3),
max_model_len=8192,
enable_caching=True,
num_preallocate_tokens=0,
)
# Allocate a block and cache it.
@ -465,7 +450,6 @@ def test_basic_prefix_caching_disabled():
make_kv_cache_config(block_size, 5),
max_model_len=8192,
enable_caching=False,
num_preallocate_tokens=0,
)
req1 = make_request("1", list(range(10))) # 2 blocks and some more
@ -496,40 +480,6 @@ def test_basic_prefix_caching_disabled():
assert not blocks
@pytest.mark.parametrize("num_preallocate_tokens", list(range(0, 8)))
@pytest.mark.parametrize("block_size", [4])
def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int):
"""
This tests that the preallocated blocks are correctly added.
"""
manager = KVCacheManager(
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
num_preallocate_tokens=num_preallocate_tokens,
)
num_preallocated_blocks = cdiv(num_preallocate_tokens, block_size)
req = make_request("0", list(range(block_size * 30)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert not computed_blocks
assert num_computed_tokens == 0
# Just ask for 1 block.
blocks = manager.allocate_slots(req, block_size, computed_blocks)
req.num_computed_tokens = block_size
assert len(blocks) == 1 + num_preallocated_blocks
# Assume all computed, only when num_preallocate_tokens > 0, we need to
# consume the previously preallocated blocks.
if num_preallocated_blocks > 0:
manager.allocate_slots(req, block_size * (len(blocks) - 1))
req.num_computed_tokens = block_size * len(blocks)
# Append 1 block.
blocks = manager.allocate_slots(req, block_size)
assert len(blocks) == 1 + num_preallocated_blocks
@pytest.mark.parametrize("hash_fn", [sha256, hash])
def test_cache_blocks(hash_fn):
"""
@ -588,7 +538,6 @@ def test_mm_prefix_caching():
make_kv_cache_config(16, 11),
max_model_len=8192,
enable_caching=True,
num_preallocate_tokens=16,
)
# Common prompt tokens (T is text tokens and P is image placeholder tokens)
@ -626,7 +575,7 @@ def test_mm_prefix_caching():
assert block_hashes[2].extra_keys == ("bbb", )
blocks = manager.allocate_slots(req0, 59, computed_blocks)
assert [b.block_id for b in blocks] == [1, 2, 3, 4, 5]
assert [b.block_id for b in blocks] == [1, 2, 3, 4]
req0.num_computed_tokens = 59
# Append slots without allocating a new block.
@ -667,7 +616,6 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
num_preallocate_tokens=0,
)
# Complete 3 blocks (48 tokens)
# | Common-0 | Common-1 | Common-2 | ... |
@ -721,7 +669,6 @@ def test_reset_prefix_cache():
make_kv_cache_config(16, 11),
max_model_len=8192,
enable_caching=True,
num_preallocate_tokens=0,
)
full_block_token_ids = [i for i in range(3) for _ in range(16)]

View File

@ -804,20 +804,17 @@ def _assert_right_kv_cache_manager(
"""Check whether KVCacheManager is correct after allocate."""
# Make sure the request stats are right.
EXPECTED_ACTUAL_BLOCKS = num_tokens // block_size
EXPECTED_TOTAL_BLOCKS = (EXPECTED_ACTUAL_BLOCKS +
scheduler.kv_cache_manager.num_preallocate_blocks)
EXPECTED_TOTAL_BLOCKS = num_tokens // block_size
for req_id in req_ids:
blocks = scheduler.kv_cache_manager.req_to_blocks[req_id]
hashes = scheduler.kv_cache_manager.req_to_block_hashes[req_id]
assert (scheduler.kv_cache_manager.num_cached_block[req_id] ==
EXPECTED_ACTUAL_BLOCKS)
EXPECTED_TOTAL_BLOCKS)
assert len(blocks) == EXPECTED_TOTAL_BLOCKS
assert len(hashes) == EXPECTED_ACTUAL_BLOCKS
assert len(hashes) == EXPECTED_TOTAL_BLOCKS
# Make sure we actually touched all the blocks.
BLOCKS_PER_REQ = (num_tokens / block_size +
scheduler.kv_cache_manager.num_preallocate_blocks)
BLOCKS_PER_REQ = num_tokens / block_size
assert (scheduler.kv_cache_manager.block_pool.get_num_free_blocks() ==
num_total_blocks - num_requests * BLOCKS_PER_REQ)
@ -1052,7 +1049,6 @@ def test_kv_connector_handles_preemption():
block_size=BLOCK_SIZE,
num_blocks=NUM_BLOCKS,
)
scheduler.kv_cache_manager.num_preallocate_blocks = 0
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE
scheduler.connector.get_num_new_matched_tokens = Mock(name="method")

View File

@ -25,7 +25,6 @@ class KVCacheManager:
max_model_len: int,
enable_caching: bool = True,
caching_hash_algo: str = "builtin",
num_preallocate_tokens: int = 64,
log_stats: bool = False,
) -> None:
assert len(kv_cache_config.kv_cache_groups) == 1, (
@ -42,22 +41,8 @@ class KVCacheManager:
self.log_stats = log_stats
# FIXME: make prefix cache stats conditional on log_stats
self.prefix_cache_stats = PrefixCacheStats() if log_stats else None
# NOTE(woosuk): To avoid frequent block allocation, we preallocate some
# blocks for each request. For example, when a request reaches the end
# of its block table, we preallocate N blocks in advance. This way, we
# reduce the overhead of updating free_block_ids and ref_cnts for each
# request every step (at the cost of some memory waste).
# NOTE(woosuk): This is different from the "lookahead" slots since this
# does not guarantee that the request always has N empty blocks. After
# the request gets N empty blocks, it starts to use the blocks without
# further allocation. When it uses up all the N empty blocks, it gets
# N new empty blocks.
self.num_preallocate_tokens = num_preallocate_tokens
self.num_preallocate_blocks = cdiv(num_preallocate_tokens,
self.block_size)
self.block_pool = BlockPool(self.num_gpu_blocks, enable_caching)
self.specialized_manager = get_specialized_manager(
kv_cache_spec=kv_cache_spec,
block_pool=self.block_pool,
@ -256,13 +241,9 @@ class KVCacheManager:
# No new block is needed.
new_blocks = []
else:
# Get new blocks from the free block pool considering
# preallocated blocks.
num_preallocate_blocks = max(
0, self.num_preallocate_blocks -
num_lookahead_tokens // self.block_size)
# Get new blocks from the free block pool.
num_new_blocks = min(
num_new_blocks + num_preallocate_blocks,
num_new_blocks,
self.block_pool.get_num_free_blocks(),
# Should not exceed the maximum number of blocks per request.
# This is especially because the block table has the shape

View File

@ -358,8 +358,11 @@ class Scheduler(SchedulerInterface):
new_encoder_budget = encoder_budget
new_blocks = self.kv_cache_manager.allocate_slots(
request, num_new_tokens + num_external_tokens,
computed_blocks)
request,
num_new_tokens + num_external_tokens,
computed_blocks,
num_lookahead_tokens=self.num_lookahead_tokens,
)
if new_blocks is None:
# The request cannot be scheduled.
break