mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-24 00:24:30 +08:00
[Core] Introduce popleft_n and append_n in FreeKVCacheBlockQueue to further optimize block_pool (#21222)
Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com>
This commit is contained in:
parent
10904e6d75
commit
ed25054577
@ -184,6 +184,111 @@ def test_free_kv_cache_block_queue_operations():
|
||||
assert str(e.value) == "No free blocks available"
|
||||
|
||||
|
||||
def test_free_kv_cache_block_queue_append_n():
|
||||
# Create an empty FreeKVCacheBlockQueue with these blocks
|
||||
queue = FreeKVCacheBlockQueue([])
|
||||
blocks = [KVCacheBlock(block_id=i) for i in range(6)]
|
||||
# Append 0 block
|
||||
# fake_head->fake_tail
|
||||
queue.append_n([])
|
||||
assert queue.num_free_blocks == 0
|
||||
assert (queue.fake_free_list_head.next_free_block
|
||||
is queue.fake_free_list_tail)
|
||||
assert (queue.fake_free_list_tail.prev_free_block
|
||||
is queue.fake_free_list_head)
|
||||
# Append 1 block
|
||||
# fake_head->b0->fake_tail
|
||||
queue.append_n(blocks[0:1])
|
||||
assert queue.num_free_blocks == 1
|
||||
assert queue.fake_free_list_head.next_free_block is blocks[0]
|
||||
assert blocks[0].prev_free_block is queue.fake_free_list_head
|
||||
assert blocks[0].next_free_block is queue.fake_free_list_tail
|
||||
assert queue.fake_free_list_tail.prev_free_block is blocks[0]
|
||||
# Append 2 blocks
|
||||
# fake_head->b0->b4->b5->fake_tail
|
||||
queue.append_n(blocks[4:6])
|
||||
assert queue.num_free_blocks == 3
|
||||
assert queue.fake_free_list_head.next_free_block is blocks[0]
|
||||
assert blocks[0].prev_free_block is queue.fake_free_list_head
|
||||
assert blocks[0].next_free_block is blocks[4]
|
||||
assert blocks[4].prev_free_block is blocks[0]
|
||||
assert blocks[4].next_free_block is blocks[5]
|
||||
assert blocks[5].prev_free_block is blocks[4]
|
||||
assert blocks[5].next_free_block is queue.fake_free_list_tail
|
||||
assert queue.fake_free_list_tail.prev_free_block is blocks[5]
|
||||
# Append 3 blocks
|
||||
# fake_head->b0->b4->b5->b1->b2->b3->fake_tail
|
||||
queue.append_n(blocks[1:4])
|
||||
assert queue.num_free_blocks == 6
|
||||
assert queue.fake_free_list_head.next_free_block is blocks[0]
|
||||
assert blocks[0].prev_free_block is queue.fake_free_list_head
|
||||
assert blocks[0].next_free_block is blocks[4]
|
||||
assert blocks[4].prev_free_block is blocks[0]
|
||||
assert blocks[4].next_free_block is blocks[5]
|
||||
assert blocks[5].prev_free_block is blocks[4]
|
||||
assert blocks[5].next_free_block is blocks[1]
|
||||
assert blocks[1].prev_free_block is blocks[5]
|
||||
assert blocks[1].next_free_block is blocks[2]
|
||||
assert blocks[2].prev_free_block is blocks[1]
|
||||
assert blocks[2].next_free_block is blocks[3]
|
||||
assert blocks[3].prev_free_block is blocks[2]
|
||||
assert blocks[3].next_free_block is queue.fake_free_list_tail
|
||||
assert queue.fake_free_list_tail.prev_free_block is blocks[3]
|
||||
|
||||
|
||||
def test_free_kv_cache_block_queue_popleft_n():
|
||||
blocks = [KVCacheBlock(block_id=i) for i in range(6)]
|
||||
# Create a empty FreeKVCacheBlockQueue with these blocks
|
||||
queue = FreeKVCacheBlockQueue(
|
||||
[blocks[1], blocks[3], blocks[5], blocks[4], blocks[0], blocks[2]])
|
||||
assert queue.num_free_blocks == 6
|
||||
assert queue.fake_free_list_head.next_free_block is blocks[1]
|
||||
assert blocks[1].prev_free_block is queue.fake_free_list_head
|
||||
assert blocks[1].next_free_block is blocks[3]
|
||||
assert blocks[3].prev_free_block is blocks[1]
|
||||
assert blocks[3].next_free_block is blocks[5]
|
||||
assert blocks[5].prev_free_block is blocks[3]
|
||||
assert blocks[5].next_free_block is blocks[4]
|
||||
assert blocks[4].prev_free_block is blocks[5]
|
||||
assert blocks[4].next_free_block is blocks[0]
|
||||
assert blocks[0].prev_free_block is blocks[4]
|
||||
assert blocks[0].next_free_block is blocks[2]
|
||||
assert blocks[2].prev_free_block is blocks[0]
|
||||
assert blocks[2].next_free_block is queue.fake_free_list_tail
|
||||
assert queue.fake_free_list_tail.prev_free_block is blocks[2]
|
||||
|
||||
# Pop 0 block
|
||||
# fake_head->b1->b3->b5->b4->b0->b2->fake_tail
|
||||
assert len(queue.popleft_n(0)) == 0
|
||||
# Pop 1 block
|
||||
# fake_head->b3->b5->b4->b0->b2->fake_tail
|
||||
result_blocks = queue.popleft_n(1)
|
||||
assert len(result_blocks) == 1
|
||||
assert result_blocks[0] is blocks[1]
|
||||
for block in result_blocks:
|
||||
assert block.prev_free_block is None
|
||||
assert block.next_free_block is None
|
||||
# Pop 2 blocks
|
||||
# fake_head->b4->b0->b2->fake_tail
|
||||
result_blocks = queue.popleft_n(2)
|
||||
assert len(result_blocks) == 2
|
||||
assert result_blocks[0] is blocks[3]
|
||||
assert result_blocks[1] is blocks[5]
|
||||
for block in result_blocks:
|
||||
assert block.prev_free_block is None
|
||||
assert block.next_free_block is None
|
||||
# Pop 3 blocks
|
||||
# fake_head->fake_tail
|
||||
result_blocks = queue.popleft_n(3)
|
||||
assert len(result_blocks) == 3
|
||||
assert result_blocks[0] is blocks[4]
|
||||
assert result_blocks[1] is blocks[0]
|
||||
assert result_blocks[2] is blocks[2]
|
||||
for block in result_blocks:
|
||||
assert block.prev_free_block is None
|
||||
assert block.next_free_block is None
|
||||
|
||||
|
||||
def test_free_kv_cache_block_queue_get_all_free_blocks():
|
||||
# Create a list of KVCacheBlock objects
|
||||
blocks = [KVCacheBlock(block_id=i) for i in range(5)]
|
||||
|
||||
@ -214,21 +214,18 @@ class BlockPool:
|
||||
raise ValueError(
|
||||
f"Cannot get {num_blocks} free blocks from the pool")
|
||||
|
||||
ret: list[KVCacheBlock] = []
|
||||
idx = 0
|
||||
while idx < num_blocks:
|
||||
# First allocate blocks.
|
||||
curr_block = self.free_block_queue.popleft()
|
||||
assert curr_block.ref_cnt == 0
|
||||
|
||||
# If the block is cached, evict it.
|
||||
if self.enable_caching:
|
||||
self._maybe_evict_cached_block(curr_block)
|
||||
|
||||
curr_block.incr_ref()
|
||||
ret.append(curr_block)
|
||||
idx += 1
|
||||
ret: list[KVCacheBlock] = self.free_block_queue.popleft_n(num_blocks)
|
||||
|
||||
# In order to only iterate the list once, we duplicated code a bit
|
||||
if self.enable_caching:
|
||||
for block in ret:
|
||||
self._maybe_evict_cached_block(block)
|
||||
assert block.ref_cnt == 0
|
||||
block.ref_cnt += 1
|
||||
else:
|
||||
for block in ret:
|
||||
assert block.ref_cnt == 0
|
||||
block.ref_cnt += 1
|
||||
return ret
|
||||
|
||||
def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool:
|
||||
@ -289,11 +286,14 @@ class BlockPool:
|
||||
ordered_blocks: A list of blocks to free ordered by their eviction
|
||||
priority.
|
||||
"""
|
||||
for block in ordered_blocks:
|
||||
block.decr_ref()
|
||||
# null_block should not be added to the free list.
|
||||
if block.ref_cnt == 0 and not block.is_null:
|
||||
self.free_block_queue.append(block)
|
||||
# Materialize the iterable to allow multiple passes.
|
||||
blocks_list = list(ordered_blocks)
|
||||
for block in blocks_list:
|
||||
block.ref_cnt -= 1
|
||||
self.free_block_queue.append_n([
|
||||
block for block in blocks_list
|
||||
if block.ref_cnt == 0 and not block.is_null
|
||||
])
|
||||
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
"""Reset prefix cache. This function may be used in RLHF
|
||||
|
||||
@ -154,6 +154,8 @@ class KVCacheBlock:
|
||||
# Whether the block is a null block that should never be cached.
|
||||
is_null: bool = False
|
||||
|
||||
# TODO(Jialin): For performance, let callers handle ref_cnt bumps to
|
||||
# avoid function calls.
|
||||
def incr_ref(self):
|
||||
self.ref_cnt += 1
|
||||
|
||||
@ -273,6 +275,39 @@ class FreeKVCacheBlockQueue:
|
||||
self.num_free_blocks -= 1
|
||||
return first_block
|
||||
|
||||
def popleft_n(self, n: int) -> list[KVCacheBlock]:
|
||||
"""Pop the first n free blocks and reduce num_free_blocks by n.
|
||||
|
||||
Args:
|
||||
n: The number of blocks to pop.
|
||||
|
||||
Returns:
|
||||
A list of n free blocks.
|
||||
"""
|
||||
if n == 0:
|
||||
return []
|
||||
assert self.num_free_blocks >= n
|
||||
self.num_free_blocks -= n
|
||||
|
||||
curr_block = self.fake_free_list_head.next_free_block
|
||||
# Pop n blocks from the head of the list
|
||||
ret = []
|
||||
for _ in range(n):
|
||||
assert curr_block is not None
|
||||
ret.append(curr_block)
|
||||
last_block = curr_block
|
||||
curr_block = curr_block.next_free_block
|
||||
# Reset prev_free_block and next_free_block of all popped blocks
|
||||
last_block.prev_free_block = None
|
||||
last_block.next_free_block = None
|
||||
|
||||
if curr_block is not None:
|
||||
# The queue is not empty, connect the fake head to
|
||||
# the new first block.
|
||||
self.fake_free_list_head.next_free_block = curr_block
|
||||
curr_block.prev_free_block = self.fake_free_list_head
|
||||
return ret
|
||||
|
||||
def remove(self, block: KVCacheBlock) -> None:
|
||||
"""Remove a block in the free list and reduce num_free_blocks by 1.
|
||||
|
||||
@ -315,6 +350,29 @@ class FreeKVCacheBlockQueue:
|
||||
|
||||
self.num_free_blocks += 1
|
||||
|
||||
def append_n(self, blocks: list[KVCacheBlock]) -> None:
|
||||
"""Put a list of blocks back into the free list
|
||||
|
||||
Args:
|
||||
blocks: The blocks to append.
|
||||
"""
|
||||
if len(blocks) == 0:
|
||||
return
|
||||
self.num_free_blocks += len(blocks)
|
||||
|
||||
last_block = self.fake_free_list_tail.prev_free_block
|
||||
assert last_block is not None, (
|
||||
"prev_free_block of fake_free_list_tail should always exist")
|
||||
# Add inter-connections between consecutive blocks
|
||||
for block in blocks:
|
||||
block.prev_free_block = last_block
|
||||
last_block.next_free_block = block
|
||||
last_block = block
|
||||
|
||||
# Connect the last block of <blocks> to the fake tail
|
||||
last_block.next_free_block = self.fake_free_list_tail
|
||||
self.fake_free_list_tail.prev_free_block = last_block
|
||||
|
||||
def get_all_free_blocks(self) -> list[KVCacheBlock]:
|
||||
"""Get all free blocks in the free list. Mainly used for testing.
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user