mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 17:16:20 +08:00
[V1] Add kv cache utils tests. (#11513)
Signed-off-by: xcnick <xcnick0412@gmail.com>
This commit is contained in:
parent
fbf2564554
commit
d91457d529
241
tests/v1/core/test_kv_cache_utils.py
Normal file
241
tests/v1/core/test_kv_cache_utils.py
Normal file
@ -0,0 +1,241 @@
|
||||
import pytest
|
||||
|
||||
from vllm.inputs import token_inputs
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
|
||||
KVCacheBlock,
|
||||
generate_block_hash_extra_keys,
|
||||
hash_block_tokens,
|
||||
hash_request_tokens)
|
||||
from vllm.v1.request import Request
|
||||
|
||||
|
||||
def make_request(request_id,
|
||||
prompt_token_ids,
|
||||
mm_positions=None,
|
||||
mm_hashes=None):
|
||||
return Request(
|
||||
request_id=request_id,
|
||||
inputs=token_inputs(
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
multi_modal_placeholders={"image": mm_positions}
|
||||
if mm_positions else None,
|
||||
multi_modal_hashes=mm_hashes,
|
||||
),
|
||||
sampling_params=SamplingParams(max_tokens=17),
|
||||
eos_token_id=100,
|
||||
arrival_time=0,
|
||||
lora_request=None,
|
||||
)
|
||||
|
||||
|
||||
def test_kv_cache_block():
|
||||
# Test KVCacheBlock initialization
|
||||
block = KVCacheBlock(block_id=0)
|
||||
assert block.block_id == 0
|
||||
assert block.ref_cnt == 0
|
||||
assert block.block_hash is None
|
||||
|
||||
# Test reference count manipulation
|
||||
block.incr_ref()
|
||||
assert block.ref_cnt == 1
|
||||
block.decr_ref()
|
||||
assert block.ref_cnt == 0
|
||||
|
||||
# Test block hash setting and resetting
|
||||
block_hash = BlockHashType(hash_value=123, token_ids=(1, 2, 3))
|
||||
block.block_hash = block_hash
|
||||
assert block.block_hash == block_hash
|
||||
|
||||
block.reset_hash()
|
||||
assert block.block_hash is None
|
||||
|
||||
|
||||
def test_free_kv_cache_block_queue_initialization():
|
||||
# Test with a single block
|
||||
block = KVCacheBlock(block_id=0)
|
||||
queue = FreeKVCacheBlockQueue([block])
|
||||
assert queue.num_free_blocks == 1
|
||||
assert queue.free_list_head == block
|
||||
assert queue.free_list_tail == block
|
||||
|
||||
|
||||
def test_free_kv_cache_block_queue_operations():
|
||||
# Create a list of KVCacheBlock objects
|
||||
blocks = [KVCacheBlock(block_id=i) for i in range(5)]
|
||||
|
||||
# Create a FreeKVCacheBlockQueue with these blocks
|
||||
queue = FreeKVCacheBlockQueue(blocks)
|
||||
|
||||
# Check initial state
|
||||
assert queue.num_free_blocks == 5
|
||||
assert queue.free_list_head == blocks[0]
|
||||
assert queue.free_list_tail == blocks[4]
|
||||
|
||||
# Pop the first block
|
||||
block1 = queue.popleft()
|
||||
assert block1 == blocks[0]
|
||||
assert queue.num_free_blocks == 4
|
||||
assert queue.free_list_head == blocks[1]
|
||||
assert queue.free_list_tail == blocks[4]
|
||||
|
||||
# Remove a block from the middle
|
||||
block_to_remove = blocks[2]
|
||||
queue.remove(block_to_remove)
|
||||
assert queue.num_free_blocks == 3
|
||||
assert blocks[1].next_free_block == blocks[3]
|
||||
assert blocks[3].prev_free_block == blocks[1]
|
||||
|
||||
# Append a block back
|
||||
queue.append(block_to_remove)
|
||||
assert queue.num_free_blocks == 4
|
||||
assert queue.free_list_tail == block_to_remove
|
||||
assert block_to_remove.prev_free_block == blocks[4]
|
||||
assert block_to_remove.next_free_block is None
|
||||
|
||||
# Pop blocks until empty
|
||||
for _ in range(4):
|
||||
queue.popleft()
|
||||
assert queue.num_free_blocks == 0
|
||||
assert queue.free_list_head is None
|
||||
assert queue.free_list_tail is None
|
||||
|
||||
# Attempt to pop from an empty queue
|
||||
with pytest.raises(ValueError) as e:
|
||||
queue.popleft()
|
||||
assert str(e.value) == "No free blocks available"
|
||||
|
||||
|
||||
def test_free_kv_cache_block_queue_get_all_free_blocks():
|
||||
# Create a list of KVCacheBlock objects
|
||||
blocks = [KVCacheBlock(block_id=i) for i in range(5)]
|
||||
|
||||
# Create a FreeKVCacheBlockQueue with these blocks
|
||||
queue = FreeKVCacheBlockQueue(blocks)
|
||||
|
||||
# Check all blocks are correctly retrieved
|
||||
assert queue.get_all_free_blocks() == blocks
|
||||
|
||||
# Pop a block and check again
|
||||
queue.popleft()
|
||||
assert queue.get_all_free_blocks() == blocks[1:]
|
||||
|
||||
# Remove a block and check again
|
||||
block_to_remove = blocks[2]
|
||||
queue.remove(block_to_remove)
|
||||
assert queue.get_all_free_blocks() == blocks[1:2] + blocks[3:]
|
||||
|
||||
# Append a block back and check again
|
||||
queue.append(block_to_remove)
|
||||
assert queue.get_all_free_blocks() == \
|
||||
blocks[1:2] + blocks[3:] + [block_to_remove]
|
||||
|
||||
|
||||
def test_generate_block_hash_extra_keys():
|
||||
request = make_request(
|
||||
request_id=0,
|
||||
prompt_token_ids=[_ for _ in range(20)],
|
||||
mm_positions=[{
|
||||
"offset": 0,
|
||||
"length": 5
|
||||
}, {
|
||||
"offset": 10,
|
||||
"length": 5
|
||||
}],
|
||||
mm_hashes=["hash1", "hash2"],
|
||||
)
|
||||
|
||||
# Test with no extra keys
|
||||
extra_keys, next_mm_idx = generate_block_hash_extra_keys(request, 0, 5, 0)
|
||||
assert extra_keys == (("hash1", 0), )
|
||||
assert next_mm_idx == 1
|
||||
|
||||
# Test with partial overlap
|
||||
extra_keys, next_mm_idx = generate_block_hash_extra_keys(request, 3, 8, 0)
|
||||
assert extra_keys == (("hash1", 3), )
|
||||
assert next_mm_idx == 1
|
||||
|
||||
# Test with no overlap
|
||||
extra_keys, next_mm_idx = generate_block_hash_extra_keys(request, 6, 10, 0)
|
||||
assert extra_keys == ()
|
||||
assert next_mm_idx == 1
|
||||
|
||||
# Test with multiple extra keys
|
||||
extra_keys, next_mm_idx = generate_block_hash_extra_keys(request, 0, 15, 0)
|
||||
assert extra_keys == (("hash1", 0), ("hash2", 0))
|
||||
assert next_mm_idx == 2
|
||||
|
||||
|
||||
def test_generate_block_hash_extra_keys_no_mm_inputs():
|
||||
request = make_request(
|
||||
request_id=0,
|
||||
prompt_token_ids=[_ for _ in range(6)],
|
||||
mm_positions=None,
|
||||
mm_hashes=None,
|
||||
)
|
||||
|
||||
extra_keys, next_mm_idx = generate_block_hash_extra_keys(request, 0, 5, 0)
|
||||
assert extra_keys is None
|
||||
assert next_mm_idx == 0
|
||||
|
||||
|
||||
def test_hash_block_tokens():
|
||||
parent_block_hash = 123
|
||||
curr_block_token_ids = (1, 2, 3)
|
||||
extra_keys = ("key1", "key2")
|
||||
|
||||
block_hash = hash_block_tokens(parent_block_hash, curr_block_token_ids,
|
||||
extra_keys)
|
||||
assert isinstance(block_hash, BlockHashType)
|
||||
assert block_hash.hash_value == hash(
|
||||
(parent_block_hash, *curr_block_token_ids))
|
||||
assert block_hash.token_ids == curr_block_token_ids
|
||||
assert block_hash.extra_keys == extra_keys
|
||||
|
||||
|
||||
def test_hash_request_tokens():
|
||||
request = make_request(
|
||||
request_id=0,
|
||||
prompt_token_ids=[_ for _ in range(6)],
|
||||
mm_positions=[{
|
||||
"offset": 0,
|
||||
"length": 3
|
||||
}, {
|
||||
"offset": 3,
|
||||
"length": 3
|
||||
}],
|
||||
mm_hashes=["hash1", "hash2"],
|
||||
)
|
||||
|
||||
block_size = 3
|
||||
block_hashes = hash_request_tokens(block_size, request)
|
||||
|
||||
assert len(block_hashes) == 2
|
||||
assert isinstance(block_hashes[0], BlockHashType)
|
||||
assert isinstance(block_hashes[1], BlockHashType)
|
||||
|
||||
# Check the first block
|
||||
assert block_hashes[0].token_ids == (0, 1, 2)
|
||||
assert block_hashes[0].extra_keys == (("hash1", 0), )
|
||||
|
||||
# Check the second block
|
||||
assert block_hashes[1].token_ids == (3, 4, 5)
|
||||
assert block_hashes[1].extra_keys == (("hash2", 0), )
|
||||
|
||||
|
||||
def test_hash_request_tokens_no_mm_inputs():
|
||||
request = make_request(
|
||||
request_id=0,
|
||||
prompt_token_ids=[_ for _ in range(6)],
|
||||
mm_positions=None,
|
||||
mm_hashes=None,
|
||||
)
|
||||
|
||||
block_size = 3
|
||||
block_hashes = hash_request_tokens(block_size, request)
|
||||
|
||||
assert len(block_hashes) == 2
|
||||
assert block_hashes[0].token_ids == (0, 1, 2)
|
||||
assert block_hashes[0].extra_keys is None
|
||||
assert block_hashes[1].token_ids == (3, 4, 5)
|
||||
assert block_hashes[1].extra_keys is None
|
||||
Loading…
x
Reference in New Issue
Block a user