[Core] Introduce popleft_n and append_n in FreeKVCacheBlockQueue to further optimize block_pool (#21222)

Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com>
This commit is contained in:
Jialin Ouyang 2025-07-22 06:17:47 -07:00 committed by GitHub
parent 10904e6d75
commit ed25054577
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 182 additions and 19 deletions

View File

@ -184,6 +184,111 @@ def test_free_kv_cache_block_queue_operations():
assert str(e.value) == "No free blocks available"
def test_free_kv_cache_block_queue_append_n():
# Create an empty FreeKVCacheBlockQueue with these blocks
queue = FreeKVCacheBlockQueue([])
blocks = [KVCacheBlock(block_id=i) for i in range(6)]
# Append 0 block
# fake_head->fake_tail
queue.append_n([])
assert queue.num_free_blocks == 0
assert (queue.fake_free_list_head.next_free_block
is queue.fake_free_list_tail)
assert (queue.fake_free_list_tail.prev_free_block
is queue.fake_free_list_head)
# Append 1 block
# fake_head->b0->fake_tail
queue.append_n(blocks[0:1])
assert queue.num_free_blocks == 1
assert queue.fake_free_list_head.next_free_block is blocks[0]
assert blocks[0].prev_free_block is queue.fake_free_list_head
assert blocks[0].next_free_block is queue.fake_free_list_tail
assert queue.fake_free_list_tail.prev_free_block is blocks[0]
# Append 2 blocks
# fake_head->b0->b4->b5->fake_tail
queue.append_n(blocks[4:6])
assert queue.num_free_blocks == 3
assert queue.fake_free_list_head.next_free_block is blocks[0]
assert blocks[0].prev_free_block is queue.fake_free_list_head
assert blocks[0].next_free_block is blocks[4]
assert blocks[4].prev_free_block is blocks[0]
assert blocks[4].next_free_block is blocks[5]
assert blocks[5].prev_free_block is blocks[4]
assert blocks[5].next_free_block is queue.fake_free_list_tail
assert queue.fake_free_list_tail.prev_free_block is blocks[5]
# Append 3 blocks
# fake_head->b0->b4->b5->b1->b2->b3->fake_tail
queue.append_n(blocks[1:4])
assert queue.num_free_blocks == 6
assert queue.fake_free_list_head.next_free_block is blocks[0]
assert blocks[0].prev_free_block is queue.fake_free_list_head
assert blocks[0].next_free_block is blocks[4]
assert blocks[4].prev_free_block is blocks[0]
assert blocks[4].next_free_block is blocks[5]
assert blocks[5].prev_free_block is blocks[4]
assert blocks[5].next_free_block is blocks[1]
assert blocks[1].prev_free_block is blocks[5]
assert blocks[1].next_free_block is blocks[2]
assert blocks[2].prev_free_block is blocks[1]
assert blocks[2].next_free_block is blocks[3]
assert blocks[3].prev_free_block is blocks[2]
assert blocks[3].next_free_block is queue.fake_free_list_tail
assert queue.fake_free_list_tail.prev_free_block is blocks[3]
def test_free_kv_cache_block_queue_popleft_n():
blocks = [KVCacheBlock(block_id=i) for i in range(6)]
# Create a empty FreeKVCacheBlockQueue with these blocks
queue = FreeKVCacheBlockQueue(
[blocks[1], blocks[3], blocks[5], blocks[4], blocks[0], blocks[2]])
assert queue.num_free_blocks == 6
assert queue.fake_free_list_head.next_free_block is blocks[1]
assert blocks[1].prev_free_block is queue.fake_free_list_head
assert blocks[1].next_free_block is blocks[3]
assert blocks[3].prev_free_block is blocks[1]
assert blocks[3].next_free_block is blocks[5]
assert blocks[5].prev_free_block is blocks[3]
assert blocks[5].next_free_block is blocks[4]
assert blocks[4].prev_free_block is blocks[5]
assert blocks[4].next_free_block is blocks[0]
assert blocks[0].prev_free_block is blocks[4]
assert blocks[0].next_free_block is blocks[2]
assert blocks[2].prev_free_block is blocks[0]
assert blocks[2].next_free_block is queue.fake_free_list_tail
assert queue.fake_free_list_tail.prev_free_block is blocks[2]
# Pop 0 block
# fake_head->b1->b3->b5->b4->b0->b2->fake_tail
assert len(queue.popleft_n(0)) == 0
# Pop 1 block
# fake_head->b3->b5->b4->b0->b2->fake_tail
result_blocks = queue.popleft_n(1)
assert len(result_blocks) == 1
assert result_blocks[0] is blocks[1]
for block in result_blocks:
assert block.prev_free_block is None
assert block.next_free_block is None
# Pop 2 blocks
# fake_head->b4->b0->b2->fake_tail
result_blocks = queue.popleft_n(2)
assert len(result_blocks) == 2
assert result_blocks[0] is blocks[3]
assert result_blocks[1] is blocks[5]
for block in result_blocks:
assert block.prev_free_block is None
assert block.next_free_block is None
# Pop 3 blocks
# fake_head->fake_tail
result_blocks = queue.popleft_n(3)
assert len(result_blocks) == 3
assert result_blocks[0] is blocks[4]
assert result_blocks[1] is blocks[0]
assert result_blocks[2] is blocks[2]
for block in result_blocks:
assert block.prev_free_block is None
assert block.next_free_block is None
def test_free_kv_cache_block_queue_get_all_free_blocks():
# Create a list of KVCacheBlock objects
blocks = [KVCacheBlock(block_id=i) for i in range(5)]

View File

@ -214,21 +214,18 @@ class BlockPool:
raise ValueError(
f"Cannot get {num_blocks} free blocks from the pool")
ret: list[KVCacheBlock] = []
idx = 0
while idx < num_blocks:
# First allocate blocks.
curr_block = self.free_block_queue.popleft()
assert curr_block.ref_cnt == 0
# If the block is cached, evict it.
if self.enable_caching:
self._maybe_evict_cached_block(curr_block)
curr_block.incr_ref()
ret.append(curr_block)
idx += 1
ret: list[KVCacheBlock] = self.free_block_queue.popleft_n(num_blocks)
# In order to only iterate the list once, we duplicated code a bit
if self.enable_caching:
for block in ret:
self._maybe_evict_cached_block(block)
assert block.ref_cnt == 0
block.ref_cnt += 1
else:
for block in ret:
assert block.ref_cnt == 0
block.ref_cnt += 1
return ret
def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool:
@ -289,11 +286,14 @@ class BlockPool:
ordered_blocks: A list of blocks to free ordered by their eviction
priority.
"""
for block in ordered_blocks:
block.decr_ref()
# null_block should not be added to the free list.
if block.ref_cnt == 0 and not block.is_null:
self.free_block_queue.append(block)
# Materialize the iterable to allow multiple passes.
blocks_list = list(ordered_blocks)
for block in blocks_list:
block.ref_cnt -= 1
self.free_block_queue.append_n([
block for block in blocks_list
if block.ref_cnt == 0 and not block.is_null
])
def reset_prefix_cache(self) -> bool:
"""Reset prefix cache. This function may be used in RLHF

View File

@ -154,6 +154,8 @@ class KVCacheBlock:
# Whether the block is a null block that should never be cached.
is_null: bool = False
# TODO(Jialin): For performance, let callers handle ref_cnt bumps to
# avoid function calls.
def incr_ref(self):
self.ref_cnt += 1
@ -273,6 +275,39 @@ class FreeKVCacheBlockQueue:
self.num_free_blocks -= 1
return first_block
def popleft_n(self, n: int) -> list[KVCacheBlock]:
"""Pop the first n free blocks and reduce num_free_blocks by n.
Args:
n: The number of blocks to pop.
Returns:
A list of n free blocks.
"""
if n == 0:
return []
assert self.num_free_blocks >= n
self.num_free_blocks -= n
curr_block = self.fake_free_list_head.next_free_block
# Pop n blocks from the head of the list
ret = []
for _ in range(n):
assert curr_block is not None
ret.append(curr_block)
last_block = curr_block
curr_block = curr_block.next_free_block
# Reset prev_free_block and next_free_block of all popped blocks
last_block.prev_free_block = None
last_block.next_free_block = None
if curr_block is not None:
# The queue is not empty, connect the fake head to
# the new first block.
self.fake_free_list_head.next_free_block = curr_block
curr_block.prev_free_block = self.fake_free_list_head
return ret
def remove(self, block: KVCacheBlock) -> None:
"""Remove a block in the free list and reduce num_free_blocks by 1.
@ -315,6 +350,29 @@ class FreeKVCacheBlockQueue:
self.num_free_blocks += 1
def append_n(self, blocks: list[KVCacheBlock]) -> None:
"""Put a list of blocks back into the free list
Args:
blocks: The blocks to append.
"""
if len(blocks) == 0:
return
self.num_free_blocks += len(blocks)
last_block = self.fake_free_list_tail.prev_free_block
assert last_block is not None, (
"prev_free_block of fake_free_list_tail should always exist")
# Add inter-connections between consecutive blocks
for block in blocks:
block.prev_free_block = last_block
last_block.next_free_block = block
last_block = block
# Connect the last block of <blocks> to the fake tail
last_block.next_free_block = self.fake_free_list_tail
self.fake_free_list_tail.prev_free_block = last_block
def get_all_free_blocks(self) -> list[KVCacheBlock]:
"""Get all free blocks in the free list. Mainly used for testing.