[V1][Spec Decode] Make eagle compatible with prefix caching. (#17137)

Signed-off-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
This commit is contained in:
Lily Liu 2025-04-27 09:29:43 -07:00 committed by GitHub
parent 4213475ec7
commit 20e489eaa1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 81 additions and 9 deletions

View File

@ -719,3 +719,60 @@ def test_prefix_cache_stats_disabled():
# Ensure prefix_cache_stats remains None
assert manager.prefix_cache_stats is None
def test_eagle_enabled_removes_last_block():
"""Verify Eagle does NOT remove blocks when request
length is divisible by block size."""
block_size = 16
manager = KVCacheManager(
make_kv_cache_config(block_size, num_blocks=10),
max_model_len=8192,
enable_caching=True,
use_eagle=True,
)
# Request with 3 full blocks (48 tokens)
token_ids = [0] * (3 * block_size)
req = make_request("divisible_request", token_ids)
# Prime the cache
computed_blocks, _ = manager.get_computed_blocks(req)
manager.allocate_slots(req, len(token_ids), computed_blocks)
manager.free(req)
# New request with same tokens + Eagle enabled
req_eagle = make_request("eagle_divisible", token_ids)
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
# Should retain 2 blocks:
# 1. Original 3 blocks → pop last hash → 2 matched blocks
# 2. last_block_hash is not None → Eagle pop is not SKIPPED
assert len(computed_blocks) == 1
assert num_tokens == 1 * block_size # 32 tokens
def test_eagle_with_partial_blocks():
"""Test Eagle behavior with requests containing partial blocks."""
block_size = 16
manager = KVCacheManager(
make_kv_cache_config(block_size, num_blocks=10),
max_model_len=8192,
enable_caching=True,
use_eagle=True,
)
# 2 full blocks + 5 tokens (non-divisible length)
token_ids = [0] * (2 * block_size + 5)
req = make_request("partial_block_test", token_ids)
# Prime the cache
computed_blocks, _ = manager.get_computed_blocks(req)
manager.allocate_slots(req, len(token_ids), computed_blocks)
manager.free(req)
# New request with Eagle enabled
req_eagle = make_request("partial_eagle", token_ids)
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
assert len(computed_blocks) == 1
assert num_tokens == 1 * block_size

View File

@ -44,7 +44,6 @@ def test_prompts():
@pytest.fixture
def sampling_config():
# Only support greedy for now
return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False)

View File

@ -25,6 +25,7 @@ class KVCacheManager:
max_model_len: int,
enable_caching: bool = True,
caching_hash_algo: str = "builtin",
use_eagle: bool = False,
log_stats: bool = False,
) -> None:
assert len(kv_cache_config.kv_cache_groups) == 1, (
@ -38,6 +39,7 @@ class KVCacheManager:
self.enable_caching = enable_caching
self.caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash
self.use_eagle = use_eagle
self.log_stats = log_stats
# FIXME: make prefix cache stats conditional on log_stats
self.prefix_cache_stats = PrefixCacheStats() if log_stats else None
@ -134,6 +136,14 @@ class KVCacheManager:
computed_blocks = (
self.specialized_manager.find_longest_cache_hit(block_hashes))
if self.use_eagle and len(computed_blocks) > 0:
# Drop the last matched block if (1) eagle is enabled and
# (2) there is a cache hit.
# This is to recompute the last block to get the required
# hidden states for eagle drafting head.
computed_blocks.pop()
if self.log_stats:
assert self.prefix_cache_stats is not None
self.prefix_cache_stats.queries += len(block_hashes)

View File

@ -74,13 +74,6 @@ class Scheduler(SchedulerInterface):
num_gpu_blocks = self.cache_config.num_gpu_blocks
assert num_gpu_blocks is not None and num_gpu_blocks > 0
# Create the KV cache manager.
self.kv_cache_manager = KVCacheManager(
kv_cache_config=kv_cache_config,
max_model_len=self.max_model_len,
enable_caching=self.cache_config.enable_prefix_caching,
caching_hash_algo=self.cache_config.prefix_caching_hash_algo,
log_stats=self.log_stats)
self.block_size = self.cache_config.block_size
# req_id -> Request
@ -123,12 +116,24 @@ class Scheduler(SchedulerInterface):
cache_size=encoder_cache_size)
speculative_config = vllm_config.speculative_config
self.use_eagle = False
self.num_spec_tokens = self.num_lookahead_tokens = 0
if speculative_config:
self.num_spec_tokens = speculative_config.num_speculative_tokens
if speculative_config.use_eagle():
self.use_eagle = True
self.num_lookahead_tokens = self.num_spec_tokens
# Create the KV cache manager.
self.kv_cache_manager = KVCacheManager(
kv_cache_config=kv_cache_config,
max_model_len=self.max_model_len,
enable_caching=self.cache_config.enable_prefix_caching,
caching_hash_algo=self.cache_config.prefix_caching_hash_algo,
use_eagle=self.use_eagle,
log_stats=self.log_stats)
def schedule(self) -> SchedulerOutput:
# NOTE(woosuk) on the scheduling algorithm:
# There's no "decoding phase" nor "prefill phase" in the scheduler.
@ -317,7 +322,8 @@ class Scheduler(SchedulerInterface):
# Get already-cached tokens.
computed_blocks, num_computed_tokens = \
self.kv_cache_manager.get_computed_blocks(request)
self.kv_cache_manager.get_computed_blocks(
request)
# Get externally-cached tokens if using a KVConnector.
num_external_tokens = (