mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:06:06 +08:00
[Bugfix] get_num_blocks_to_allocate with null_block (#19031)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
parent
135cf55cd1
commit
b5fd9506c1
@ -144,3 +144,26 @@ def test_sliding_window_remove_skipped_blocks():
|
||||
# of removed blocks should be [1003, 1002].
|
||||
manager.remove_skipped_blocks("test", 11)
|
||||
assert_block_id(block_table, [null_block_id] * 4 + original_block_ids[4:])
|
||||
|
||||
|
||||
def test_get_num_blocks_to_allocate():
|
||||
block_size = 2
|
||||
sliding_window_spec = SlidingWindowSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
sliding_window=4, # Placeholder value, not related to test result
|
||||
use_mla=False,
|
||||
)
|
||||
|
||||
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
|
||||
manager = get_sliding_window_manager(sliding_window_spec, block_pool)
|
||||
cached_blocks_1 = [KVCacheBlock(i + 1) for i in range(10)]
|
||||
cached_blocks_2 = [block_pool.null_block for _ in range(5)
|
||||
] + [KVCacheBlock(i + 1) for i in range(5)]
|
||||
|
||||
assert manager.get_num_blocks_to_allocate("1", 20 * block_size,
|
||||
cached_blocks_1) == 20
|
||||
assert manager.get_num_blocks_to_allocate("2", 20 * block_size,
|
||||
cached_blocks_2) == 15
|
||||
|
||||
@ -63,6 +63,7 @@ class BlockPool:
|
||||
# The ref_cnt of null_block is not maintained, needs special care to
|
||||
# avoid freeing it.
|
||||
self.null_block = self.free_block_queue.popleft()
|
||||
self.null_block.is_null = True
|
||||
|
||||
self.enable_kv_cache_events = enable_kv_cache_events
|
||||
self.kv_event_queue: list[KVCacheEvent] = []
|
||||
@ -252,7 +253,7 @@ class BlockPool:
|
||||
for block in blocks:
|
||||
# ref_cnt=0 means this block is in the free list (i.e. eviction
|
||||
# candidate), so remove it.
|
||||
if block.ref_cnt == 0 and block != self.null_block:
|
||||
if block.ref_cnt == 0 and not block.is_null:
|
||||
self.free_block_queue.remove(block)
|
||||
block.incr_ref()
|
||||
|
||||
@ -267,7 +268,7 @@ class BlockPool:
|
||||
for block in ordered_blocks:
|
||||
block.decr_ref()
|
||||
# null_block should not be added to the free list.
|
||||
if block.ref_cnt == 0 and block != self.null_block:
|
||||
if block.ref_cnt == 0 and not block.is_null:
|
||||
self.free_block_queue.append(block)
|
||||
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
|
||||
@ -125,6 +125,9 @@ class KVCacheBlock:
|
||||
prev_free_block: Optional["KVCacheBlock"] = None
|
||||
next_free_block: Optional["KVCacheBlock"] = None
|
||||
|
||||
# Whether the block is a null block that should never be cached.
|
||||
is_null: bool = False
|
||||
|
||||
def incr_ref(self):
|
||||
self.ref_cnt += 1
|
||||
|
||||
|
||||
@ -83,8 +83,9 @@ class SingleTypeKVCacheManager(ABC):
|
||||
# free queue and ref_cnt == 0), it will be changed from a free block
|
||||
# to a computed block when the request is allocated, so we also count
|
||||
# it as needed to be allocated.
|
||||
num_evictable_computed_blocks = sum(blk.ref_cnt == 0
|
||||
for blk in new_computed_blocks)
|
||||
num_evictable_computed_blocks = sum(
|
||||
blk.ref_cnt == 0 and not blk.is_null
|
||||
for blk in new_computed_blocks)
|
||||
return ((num_new_blocks + num_evictable_computed_blocks) *
|
||||
self.num_kv_cache_groups)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user