mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 05:04:58 +08:00
[v1][Spec Decode] Make sliding window compatible with eagle prefix caching (#17398)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
parent
42d9a2c4c7
commit
81ecf425f0
@ -15,7 +15,7 @@ from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
|
|||||||
from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock,
|
from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock,
|
||||||
hash_block_tokens)
|
hash_block_tokens)
|
||||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||||
KVCacheGroupSpec)
|
KVCacheGroupSpec, SlidingWindowSpec)
|
||||||
|
|
||||||
|
|
||||||
def make_request(request_id,
|
def make_request(request_id,
|
||||||
@ -863,11 +863,11 @@ def test_eagle_enabled_removes_last_block():
|
|||||||
req_eagle = make_request("eagle_divisible", token_ids)
|
req_eagle = make_request("eagle_divisible", token_ids)
|
||||||
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
|
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
|
||||||
|
|
||||||
# Should retain 2 blocks:
|
# Should retain 1 block:
|
||||||
# 1. Original 3 blocks → pop last hash → 2 matched blocks
|
# 1. Original 3 blocks → pop last hash → 2 matched blocks
|
||||||
# 2. last_block_hash is not None → Eagle pop is not SKIPPED
|
# 2. drop last matched block → 1 remaining block
|
||||||
assert len(computed_blocks) == 1
|
assert len(computed_blocks) == 1
|
||||||
assert num_tokens == 1 * block_size # 32 tokens
|
assert num_tokens == 1 * block_size # 16 tokens
|
||||||
|
|
||||||
|
|
||||||
def test_eagle_with_partial_blocks():
|
def test_eagle_with_partial_blocks():
|
||||||
@ -894,3 +894,59 @@ def test_eagle_with_partial_blocks():
|
|||||||
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
|
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
|
||||||
assert len(computed_blocks) == 1
|
assert len(computed_blocks) == 1
|
||||||
assert num_tokens == 1 * block_size
|
assert num_tokens == 1 * block_size
|
||||||
|
|
||||||
|
|
||||||
|
def test_eagle_with_sliding_window():
|
||||||
|
"""Test Eagle behavior with sliding window."""
|
||||||
|
block_size = 16
|
||||||
|
sliding_window_spec = SlidingWindowSpec(
|
||||||
|
block_size=block_size,
|
||||||
|
num_kv_heads=1,
|
||||||
|
head_size=1,
|
||||||
|
dtype=torch.float32,
|
||||||
|
sliding_window=block_size,
|
||||||
|
use_mla=False,
|
||||||
|
)
|
||||||
|
manager = KVCacheManager(
|
||||||
|
KVCacheConfig(
|
||||||
|
num_blocks=10,
|
||||||
|
tensors={},
|
||||||
|
kv_cache_groups=[KVCacheGroupSpec(['layer'], sliding_window_spec)],
|
||||||
|
),
|
||||||
|
max_model_len=8192,
|
||||||
|
enable_caching=True,
|
||||||
|
use_eagle=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2 full blocks + 5 tokens (non-divisible length)
|
||||||
|
token_ids = [0] * (2 * block_size + 5)
|
||||||
|
req = make_request("partial_block_test", token_ids)
|
||||||
|
|
||||||
|
# Prime the cache
|
||||||
|
computed_blocks, _ = manager.get_computed_blocks(req)
|
||||||
|
manager.allocate_slots(req, len(token_ids), computed_blocks)
|
||||||
|
# record the block hash of the first block in the request for later use
|
||||||
|
block_hash_first_block = manager.req_to_block_hashes[req.request_id][0]
|
||||||
|
assert block_hash_first_block is not None
|
||||||
|
manager.free(req)
|
||||||
|
|
||||||
|
# New request with Eagle enabled
|
||||||
|
req_eagle = make_request("partial_eagle", token_ids)
|
||||||
|
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
|
||||||
|
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
|
||||||
|
assert len(computed_blocks) == 1
|
||||||
|
assert num_tokens == 1 * block_size
|
||||||
|
|
||||||
|
# Evict the first block in the request
|
||||||
|
assert manager.block_pool.get_cached_block(
|
||||||
|
block_hash_first_block) is not None
|
||||||
|
manager.block_pool.cached_block_hash_to_block.pop(block_hash_first_block)
|
||||||
|
|
||||||
|
# New request
|
||||||
|
req_after_evict = make_request("partial_eagle_after_evict", token_ids)
|
||||||
|
computed_blocks, num_tokens = manager.get_computed_blocks(req_after_evict)
|
||||||
|
# Cache miss. The only hit prefix is [NULL_BLOCK, BLOCK_2] if eagle is
|
||||||
|
# not considered. But after dropping the last matched block due to eagle,
|
||||||
|
# there will be no matched prefix.
|
||||||
|
assert len(computed_blocks) == 0
|
||||||
|
assert num_tokens == 0
|
||||||
|
|||||||
@ -19,7 +19,9 @@ def test_sliding_window_possible_cached_prefix():
|
|||||||
)
|
)
|
||||||
|
|
||||||
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
|
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
|
||||||
manager = SlidingWindowManager(sliding_window_spec, block_pool)
|
manager = SlidingWindowManager(sliding_window_spec,
|
||||||
|
block_pool,
|
||||||
|
use_eagle=False)
|
||||||
|
|
||||||
def run_one_case(block_is_cached, expect_length):
|
def run_one_case(block_is_cached, expect_length):
|
||||||
block_hash_list = [
|
block_hash_list = [
|
||||||
@ -79,7 +81,9 @@ def test_sliding_window_remove_skipped_blocks():
|
|||||||
|
|
||||||
block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True)
|
block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True)
|
||||||
|
|
||||||
manager = SlidingWindowManager(sliding_window_spec, block_pool)
|
manager = SlidingWindowManager(sliding_window_spec,
|
||||||
|
block_pool,
|
||||||
|
use_eagle=False)
|
||||||
|
|
||||||
null_block_id = block_pool.null_block.block_id
|
null_block_id = block_pool.null_block.block_id
|
||||||
|
|
||||||
|
|||||||
@ -52,6 +52,7 @@ class KVCacheManager:
|
|||||||
self.specialized_manager = get_specialized_manager(
|
self.specialized_manager = get_specialized_manager(
|
||||||
kv_cache_spec=kv_cache_spec,
|
kv_cache_spec=kv_cache_spec,
|
||||||
block_pool=self.block_pool,
|
block_pool=self.block_pool,
|
||||||
|
use_eagle=self.use_eagle,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mapping from request ID to blocks to track the blocks allocated
|
# Mapping from request ID to blocks to track the blocks allocated
|
||||||
@ -141,13 +142,6 @@ class KVCacheManager:
|
|||||||
computed_blocks = (
|
computed_blocks = (
|
||||||
self.specialized_manager.find_longest_cache_hit(block_hashes))
|
self.specialized_manager.find_longest_cache_hit(block_hashes))
|
||||||
|
|
||||||
if self.use_eagle and len(computed_blocks) > 0:
|
|
||||||
# Drop the last matched block if (1) eagle is enabled and
|
|
||||||
# (2) there is a cache hit.
|
|
||||||
# This is to recompute the last block to get the required
|
|
||||||
# hidden states for eagle drafting head.
|
|
||||||
computed_blocks.pop()
|
|
||||||
|
|
||||||
if self.log_stats:
|
if self.log_stats:
|
||||||
assert self.prefix_cache_stats is not None
|
assert self.prefix_cache_stats is not None
|
||||||
self.prefix_cache_stats.queries += len(block_hashes)
|
self.prefix_cache_stats.queries += len(block_hashes)
|
||||||
|
|||||||
@ -18,6 +18,7 @@ class SpecializedManager(ABC):
|
|||||||
self,
|
self,
|
||||||
kv_cache_spec: KVCacheSpec,
|
kv_cache_spec: KVCacheSpec,
|
||||||
block_pool: BlockPool,
|
block_pool: BlockPool,
|
||||||
|
use_eagle: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Initializes the SpecializedManager.
|
Initializes the SpecializedManager.
|
||||||
@ -30,12 +31,17 @@ class SpecializedManager(ABC):
|
|||||||
self.kv_cache_spec = kv_cache_spec
|
self.kv_cache_spec = kv_cache_spec
|
||||||
self.block_pool = block_pool
|
self.block_pool = block_pool
|
||||||
|
|
||||||
|
# Needs special handling for find_longest_cache_hit if eagle is enabled
|
||||||
|
self.use_eagle = use_eagle
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def find_longest_cache_hit(
|
def find_longest_cache_hit(
|
||||||
self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]:
|
self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]:
|
||||||
"""
|
"""
|
||||||
Get the longest cache hit prefix of the blocks. If no cache hit is
|
Get the longest cache hit prefix of the blocks. If no cache hit is
|
||||||
found, return an empty list.
|
found, return an empty list. if eagle is enabled, drop the last matched
|
||||||
|
block to force recompute the last block to get the required hidden
|
||||||
|
states for eagle drafting head.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
block_hashes: The block hashes of the request.
|
block_hashes: The block hashes of the request.
|
||||||
@ -79,6 +85,8 @@ class FullAttentionManager(SpecializedManager):
|
|||||||
computed_blocks.append(cached_block)
|
computed_blocks.append(cached_block)
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
|
if self.use_eagle and len(computed_blocks) > 0:
|
||||||
|
computed_blocks.pop()
|
||||||
return computed_blocks
|
return computed_blocks
|
||||||
|
|
||||||
def remove_skipped_blocks(self, blocks: list[KVCacheBlock],
|
def remove_skipped_blocks(self, blocks: list[KVCacheBlock],
|
||||||
@ -89,14 +97,20 @@ class FullAttentionManager(SpecializedManager):
|
|||||||
|
|
||||||
class SlidingWindowManager(SpecializedManager):
|
class SlidingWindowManager(SpecializedManager):
|
||||||
|
|
||||||
def __init__(self, kv_cache_spec: SlidingWindowSpec,
|
def __init__(self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool,
|
||||||
block_pool: BlockPool):
|
use_eagle: bool):
|
||||||
super().__init__(kv_cache_spec, block_pool)
|
super().__init__(kv_cache_spec, block_pool, use_eagle)
|
||||||
self.sliding_window = kv_cache_spec.sliding_window
|
self.sliding_window = kv_cache_spec.sliding_window
|
||||||
# The number of contiguous blocks needed for prefix cache hit.
|
# The number of contiguous blocks needed for prefix cache hit.
|
||||||
# -1 since the input token itself is also included in the window
|
# -1 since the input token itself is also included in the window
|
||||||
self.sliding_window_contiguous_blocks = cdiv(
|
self.sliding_window_contiguous_blocks = cdiv(
|
||||||
(kv_cache_spec.sliding_window - 1), self.block_size)
|
(kv_cache_spec.sliding_window - 1), self.block_size)
|
||||||
|
if self.use_eagle:
|
||||||
|
# Need to drop the last matched block if eagle is enabled. For
|
||||||
|
# sliding window layer, we achieve this by increasing the number of
|
||||||
|
# contiguous blocks needed for prefix cache hit by one and dropping
|
||||||
|
# the last matched block.
|
||||||
|
self.sliding_window_contiguous_blocks += 1
|
||||||
self._null_block = block_pool.null_block
|
self._null_block = block_pool.null_block
|
||||||
|
|
||||||
def find_longest_cache_hit(
|
def find_longest_cache_hit(
|
||||||
@ -109,6 +123,7 @@ class SlidingWindowManager(SpecializedManager):
|
|||||||
computed_blocks = [self._null_block] * len(block_hashes)
|
computed_blocks = [self._null_block] * len(block_hashes)
|
||||||
num_contiguous_blocks = 0
|
num_contiguous_blocks = 0
|
||||||
|
|
||||||
|
match_found = False
|
||||||
# Search from right to left and early stop when a match is found.
|
# Search from right to left and early stop when a match is found.
|
||||||
for i in range(len(block_hashes) - 1, -1, -1):
|
for i in range(len(block_hashes) - 1, -1, -1):
|
||||||
if cached_block := self.block_pool.get_cached_block(
|
if cached_block := self.block_pool.get_cached_block(
|
||||||
@ -121,12 +136,16 @@ class SlidingWindowManager(SpecializedManager):
|
|||||||
# E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3]
|
# E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3]
|
||||||
# when sliding_window_contiguous_blocks=2.
|
# when sliding_window_contiguous_blocks=2.
|
||||||
del computed_blocks[i + num_contiguous_blocks:]
|
del computed_blocks[i + num_contiguous_blocks:]
|
||||||
return computed_blocks
|
match_found = True
|
||||||
|
break
|
||||||
else:
|
else:
|
||||||
num_contiguous_blocks = 0
|
num_contiguous_blocks = 0
|
||||||
# The first `num_contiguous_blocks` is a cache hit even if
|
if not match_found:
|
||||||
# `num_contiguous_blocks < sliding_window_contiguous_blocks`.
|
# The first `num_contiguous_blocks` is a cache hit even if
|
||||||
del computed_blocks[num_contiguous_blocks:]
|
# `num_contiguous_blocks < sliding_window_contiguous_blocks`.
|
||||||
|
del computed_blocks[num_contiguous_blocks:]
|
||||||
|
if self.use_eagle and len(computed_blocks) > 0:
|
||||||
|
computed_blocks.pop()
|
||||||
return computed_blocks
|
return computed_blocks
|
||||||
|
|
||||||
def remove_skipped_blocks(self, blocks: list[KVCacheBlock],
|
def remove_skipped_blocks(self, blocks: list[KVCacheBlock],
|
||||||
@ -155,7 +174,7 @@ spec_manager_map: dict[type[KVCacheSpec], type[SpecializedManager]] = {
|
|||||||
|
|
||||||
|
|
||||||
def get_specialized_manager(kv_cache_spec: KVCacheSpec,
|
def get_specialized_manager(kv_cache_spec: KVCacheSpec,
|
||||||
block_pool: BlockPool) -> SpecializedManager:
|
**kwargs) -> SpecializedManager:
|
||||||
manager_class = spec_manager_map[type(kv_cache_spec)]
|
manager_class = spec_manager_map[type(kv_cache_spec)]
|
||||||
manager = manager_class(kv_cache_spec, block_pool)
|
manager = manager_class(kv_cache_spec, **kwargs)
|
||||||
return manager
|
return manager
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user