mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:34:57 +08:00
[V1][Spec Decode] Make eagle compatible with prefix caching. (#17137)
Signed-off-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
This commit is contained in:
parent
4213475ec7
commit
20e489eaa1
@ -719,3 +719,60 @@ def test_prefix_cache_stats_disabled():
|
||||
|
||||
# Ensure prefix_cache_stats remains None
|
||||
assert manager.prefix_cache_stats is None
|
||||
|
||||
|
||||
def test_eagle_enabled_removes_last_block():
|
||||
"""Verify Eagle does NOT remove blocks when request
|
||||
length is divisible by block size."""
|
||||
block_size = 16
|
||||
manager = KVCacheManager(
|
||||
make_kv_cache_config(block_size, num_blocks=10),
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
use_eagle=True,
|
||||
)
|
||||
|
||||
# Request with 3 full blocks (48 tokens)
|
||||
token_ids = [0] * (3 * block_size)
|
||||
req = make_request("divisible_request", token_ids)
|
||||
|
||||
# Prime the cache
|
||||
computed_blocks, _ = manager.get_computed_blocks(req)
|
||||
manager.allocate_slots(req, len(token_ids), computed_blocks)
|
||||
manager.free(req)
|
||||
|
||||
# New request with same tokens + Eagle enabled
|
||||
req_eagle = make_request("eagle_divisible", token_ids)
|
||||
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
|
||||
|
||||
# Should retain 2 blocks:
|
||||
# 1. Original 3 blocks → pop last hash → 2 matched blocks
|
||||
# 2. last_block_hash is not None → Eagle pop is not SKIPPED
|
||||
assert len(computed_blocks) == 1
|
||||
assert num_tokens == 1 * block_size # 32 tokens
|
||||
|
||||
|
||||
def test_eagle_with_partial_blocks():
|
||||
"""Test Eagle behavior with requests containing partial blocks."""
|
||||
block_size = 16
|
||||
manager = KVCacheManager(
|
||||
make_kv_cache_config(block_size, num_blocks=10),
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
use_eagle=True,
|
||||
)
|
||||
# 2 full blocks + 5 tokens (non-divisible length)
|
||||
token_ids = [0] * (2 * block_size + 5)
|
||||
req = make_request("partial_block_test", token_ids)
|
||||
|
||||
# Prime the cache
|
||||
computed_blocks, _ = manager.get_computed_blocks(req)
|
||||
manager.allocate_slots(req, len(token_ids), computed_blocks)
|
||||
manager.free(req)
|
||||
|
||||
# New request with Eagle enabled
|
||||
req_eagle = make_request("partial_eagle", token_ids)
|
||||
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
|
||||
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
|
||||
assert len(computed_blocks) == 1
|
||||
assert num_tokens == 1 * block_size
|
||||
|
||||
@ -44,7 +44,6 @@ def test_prompts():
|
||||
|
||||
@pytest.fixture
|
||||
def sampling_config():
|
||||
# Only support greedy for now
|
||||
return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False)
|
||||
|
||||
|
||||
|
||||
@ -25,6 +25,7 @@ class KVCacheManager:
|
||||
max_model_len: int,
|
||||
enable_caching: bool = True,
|
||||
caching_hash_algo: str = "builtin",
|
||||
use_eagle: bool = False,
|
||||
log_stats: bool = False,
|
||||
) -> None:
|
||||
assert len(kv_cache_config.kv_cache_groups) == 1, (
|
||||
@ -38,6 +39,7 @@ class KVCacheManager:
|
||||
|
||||
self.enable_caching = enable_caching
|
||||
self.caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash
|
||||
self.use_eagle = use_eagle
|
||||
self.log_stats = log_stats
|
||||
# FIXME: make prefix cache stats conditional on log_stats
|
||||
self.prefix_cache_stats = PrefixCacheStats() if log_stats else None
|
||||
@ -134,6 +136,14 @@ class KVCacheManager:
|
||||
|
||||
computed_blocks = (
|
||||
self.specialized_manager.find_longest_cache_hit(block_hashes))
|
||||
|
||||
if self.use_eagle and len(computed_blocks) > 0:
|
||||
# Drop the last matched block if (1) eagle is enabled and
|
||||
# (2) there is a cache hit.
|
||||
# This is to recompute the last block to get the required
|
||||
# hidden states for eagle drafting head.
|
||||
computed_blocks.pop()
|
||||
|
||||
if self.log_stats:
|
||||
assert self.prefix_cache_stats is not None
|
||||
self.prefix_cache_stats.queries += len(block_hashes)
|
||||
|
||||
@ -74,13 +74,6 @@ class Scheduler(SchedulerInterface):
|
||||
num_gpu_blocks = self.cache_config.num_gpu_blocks
|
||||
assert num_gpu_blocks is not None and num_gpu_blocks > 0
|
||||
|
||||
# Create the KV cache manager.
|
||||
self.kv_cache_manager = KVCacheManager(
|
||||
kv_cache_config=kv_cache_config,
|
||||
max_model_len=self.max_model_len,
|
||||
enable_caching=self.cache_config.enable_prefix_caching,
|
||||
caching_hash_algo=self.cache_config.prefix_caching_hash_algo,
|
||||
log_stats=self.log_stats)
|
||||
self.block_size = self.cache_config.block_size
|
||||
|
||||
# req_id -> Request
|
||||
@ -123,12 +116,24 @@ class Scheduler(SchedulerInterface):
|
||||
cache_size=encoder_cache_size)
|
||||
|
||||
speculative_config = vllm_config.speculative_config
|
||||
|
||||
self.use_eagle = False
|
||||
self.num_spec_tokens = self.num_lookahead_tokens = 0
|
||||
if speculative_config:
|
||||
self.num_spec_tokens = speculative_config.num_speculative_tokens
|
||||
if speculative_config.use_eagle():
|
||||
self.use_eagle = True
|
||||
self.num_lookahead_tokens = self.num_spec_tokens
|
||||
|
||||
# Create the KV cache manager.
|
||||
self.kv_cache_manager = KVCacheManager(
|
||||
kv_cache_config=kv_cache_config,
|
||||
max_model_len=self.max_model_len,
|
||||
enable_caching=self.cache_config.enable_prefix_caching,
|
||||
caching_hash_algo=self.cache_config.prefix_caching_hash_algo,
|
||||
use_eagle=self.use_eagle,
|
||||
log_stats=self.log_stats)
|
||||
|
||||
def schedule(self) -> SchedulerOutput:
|
||||
# NOTE(woosuk) on the scheduling algorithm:
|
||||
# There's no "decoding phase" nor "prefill phase" in the scheduler.
|
||||
@ -317,7 +322,8 @@ class Scheduler(SchedulerInterface):
|
||||
|
||||
# Get already-cached tokens.
|
||||
computed_blocks, num_computed_tokens = \
|
||||
self.kv_cache_manager.get_computed_blocks(request)
|
||||
self.kv_cache_manager.get_computed_blocks(
|
||||
request)
|
||||
|
||||
# Get externally-cached tokens if using a KVConnector.
|
||||
num_external_tokens = (
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user