[v1][Spec Decode] Make sliding window compatible with eagle prefix caching (#17398)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang 2025-05-01 02:25:53 +08:00 committed by GitHub
parent 42d9a2c4c7
commit 81ecf425f0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 96 additions and 23 deletions

View File

@ -15,7 +15,7 @@ from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock,
hash_block_tokens)
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec)
KVCacheGroupSpec, SlidingWindowSpec)
def make_request(request_id,
@ -863,11 +863,11 @@ def test_eagle_enabled_removes_last_block():
req_eagle = make_request("eagle_divisible", token_ids)
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
# Should retain 2 blocks:
# Should retain 1 block:
# 1. Original 3 blocks → pop last hash → 2 matched blocks
# 2. last_block_hash is not None → Eagle pop is not SKIPPED
# 2. drop last matched block → 1 remaining block
assert len(computed_blocks) == 1
assert num_tokens == 1 * block_size # 32 tokens
assert num_tokens == 1 * block_size # 16 tokens
def test_eagle_with_partial_blocks():
@ -894,3 +894,59 @@ def test_eagle_with_partial_blocks():
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
assert len(computed_blocks) == 1
assert num_tokens == 1 * block_size
def test_eagle_with_sliding_window():
"""Test Eagle behavior with sliding window."""
block_size = 16
sliding_window_spec = SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=block_size,
use_mla=False,
)
manager = KVCacheManager(
KVCacheConfig(
num_blocks=10,
tensors={},
kv_cache_groups=[KVCacheGroupSpec(['layer'], sliding_window_spec)],
),
max_model_len=8192,
enable_caching=True,
use_eagle=True,
)
# 2 full blocks + 5 tokens (non-divisible length)
token_ids = [0] * (2 * block_size + 5)
req = make_request("partial_block_test", token_ids)
# Prime the cache
computed_blocks, _ = manager.get_computed_blocks(req)
manager.allocate_slots(req, len(token_ids), computed_blocks)
# record the block hash of the first block in the request for later use
block_hash_first_block = manager.req_to_block_hashes[req.request_id][0]
assert block_hash_first_block is not None
manager.free(req)
# New request with Eagle enabled
req_eagle = make_request("partial_eagle", token_ids)
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
assert len(computed_blocks) == 1
assert num_tokens == 1 * block_size
# Evict the first block in the request
assert manager.block_pool.get_cached_block(
block_hash_first_block) is not None
manager.block_pool.cached_block_hash_to_block.pop(block_hash_first_block)
# New request
req_after_evict = make_request("partial_eagle_after_evict", token_ids)
computed_blocks, num_tokens = manager.get_computed_blocks(req_after_evict)
# Cache miss. The only hit prefix is [NULL_BLOCK, BLOCK_2] if eagle is
# not considered. But after dropping the last matched block due to eagle,
# there will be no matched prefix.
assert len(computed_blocks) == 0
assert num_tokens == 0

View File

@ -19,7 +19,9 @@ def test_sliding_window_possible_cached_prefix():
)
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
manager = SlidingWindowManager(sliding_window_spec, block_pool)
manager = SlidingWindowManager(sliding_window_spec,
block_pool,
use_eagle=False)
def run_one_case(block_is_cached, expect_length):
block_hash_list = [
@ -79,7 +81,9 @@ def test_sliding_window_remove_skipped_blocks():
block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True)
manager = SlidingWindowManager(sliding_window_spec, block_pool)
manager = SlidingWindowManager(sliding_window_spec,
block_pool,
use_eagle=False)
null_block_id = block_pool.null_block.block_id

View File

@ -52,6 +52,7 @@ class KVCacheManager:
self.specialized_manager = get_specialized_manager(
kv_cache_spec=kv_cache_spec,
block_pool=self.block_pool,
use_eagle=self.use_eagle,
)
# Mapping from request ID to blocks to track the blocks allocated
@ -141,13 +142,6 @@ class KVCacheManager:
computed_blocks = (
self.specialized_manager.find_longest_cache_hit(block_hashes))
if self.use_eagle and len(computed_blocks) > 0:
# Drop the last matched block if (1) eagle is enabled and
# (2) there is a cache hit.
# This is to recompute the last block to get the required
# hidden states for eagle drafting head.
computed_blocks.pop()
if self.log_stats:
assert self.prefix_cache_stats is not None
self.prefix_cache_stats.queries += len(block_hashes)

View File

@ -18,6 +18,7 @@ class SpecializedManager(ABC):
self,
kv_cache_spec: KVCacheSpec,
block_pool: BlockPool,
use_eagle: bool,
) -> None:
"""
Initializes the SpecializedManager.
@ -30,12 +31,17 @@ class SpecializedManager(ABC):
self.kv_cache_spec = kv_cache_spec
self.block_pool = block_pool
# Needs special handling for find_longest_cache_hit if eagle is enabled
self.use_eagle = use_eagle
@abstractmethod
def find_longest_cache_hit(
self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]:
"""
Get the longest cache hit prefix of the blocks. If no cache hit is
found, return an empty list.
found, return an empty list. if eagle is enabled, drop the last matched
block to force recompute the last block to get the required hidden
states for eagle drafting head.
Args:
block_hashes: The block hashes of the request.
@ -79,6 +85,8 @@ class FullAttentionManager(SpecializedManager):
computed_blocks.append(cached_block)
else:
break
if self.use_eagle and len(computed_blocks) > 0:
computed_blocks.pop()
return computed_blocks
def remove_skipped_blocks(self, blocks: list[KVCacheBlock],
@ -89,14 +97,20 @@ class FullAttentionManager(SpecializedManager):
class SlidingWindowManager(SpecializedManager):
def __init__(self, kv_cache_spec: SlidingWindowSpec,
block_pool: BlockPool):
super().__init__(kv_cache_spec, block_pool)
def __init__(self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool,
use_eagle: bool):
super().__init__(kv_cache_spec, block_pool, use_eagle)
self.sliding_window = kv_cache_spec.sliding_window
# The number of contiguous blocks needed for prefix cache hit.
# -1 since the input token itself is also included in the window
self.sliding_window_contiguous_blocks = cdiv(
(kv_cache_spec.sliding_window - 1), self.block_size)
if self.use_eagle:
# Need to drop the last matched block if eagle is enabled. For
# sliding window layer, we achieve this by increasing the number of
# contiguous blocks needed for prefix cache hit by one and dropping
# the last matched block.
self.sliding_window_contiguous_blocks += 1
self._null_block = block_pool.null_block
def find_longest_cache_hit(
@ -109,6 +123,7 @@ class SlidingWindowManager(SpecializedManager):
computed_blocks = [self._null_block] * len(block_hashes)
num_contiguous_blocks = 0
match_found = False
# Search from right to left and early stop when a match is found.
for i in range(len(block_hashes) - 1, -1, -1):
if cached_block := self.block_pool.get_cached_block(
@ -121,12 +136,16 @@ class SlidingWindowManager(SpecializedManager):
# E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3]
# when sliding_window_contiguous_blocks=2.
del computed_blocks[i + num_contiguous_blocks:]
return computed_blocks
match_found = True
break
else:
num_contiguous_blocks = 0
# The first `num_contiguous_blocks` is a cache hit even if
# `num_contiguous_blocks < sliding_window_contiguous_blocks`.
del computed_blocks[num_contiguous_blocks:]
if not match_found:
# The first `num_contiguous_blocks` is a cache hit even if
# `num_contiguous_blocks < sliding_window_contiguous_blocks`.
del computed_blocks[num_contiguous_blocks:]
if self.use_eagle and len(computed_blocks) > 0:
computed_blocks.pop()
return computed_blocks
def remove_skipped_blocks(self, blocks: list[KVCacheBlock],
@ -155,7 +174,7 @@ spec_manager_map: dict[type[KVCacheSpec], type[SpecializedManager]] = {
def get_specialized_manager(kv_cache_spec: KVCacheSpec,
block_pool: BlockPool) -> SpecializedManager:
**kwargs) -> SpecializedManager:
manager_class = spec_manager_map[type(kv_cache_spec)]
manager = manager_class(kv_cache_spec, block_pool)
manager = manager_class(kv_cache_spec, **kwargs)
return manager