mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 14:35:00 +08:00
[V1][Spec Decode] KV cache slots for eagle heads (#16370)
Signed-off-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
This commit is contained in:
parent
6c11ecf8d3
commit
f49e5aff11
@ -7,6 +7,7 @@ from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
|
|||||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.utils import GiB_bytes, sha256
|
from vllm.utils import GiB_bytes, sha256
|
||||||
|
from vllm.v1.core.kv_cache_manager import KVCacheManager
|
||||||
# disable yapf here as it formats differently than isort such that both fail
|
# disable yapf here as it formats differently than isort such that both fail
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
from vllm.v1.core.kv_cache_utils import (NONE_HASH, BlockHashType,
|
from vllm.v1.core.kv_cache_utils import (NONE_HASH, BlockHashType,
|
||||||
@ -48,6 +49,18 @@ def make_request(request_id,
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def new_kv_cache_spec(block_size=16,
|
||||||
|
num_kv_heads=2,
|
||||||
|
head_size=64,
|
||||||
|
dtype=torch.float32,
|
||||||
|
use_mla=False):
|
||||||
|
return FullAttentionSpec(block_size=block_size,
|
||||||
|
num_kv_heads=num_kv_heads,
|
||||||
|
head_size=head_size,
|
||||||
|
dtype=dtype,
|
||||||
|
use_mla=use_mla)
|
||||||
|
|
||||||
|
|
||||||
def test_none_hash():
|
def test_none_hash():
|
||||||
assert NONE_HASH is not None
|
assert NONE_HASH is not None
|
||||||
assert isinstance(NONE_HASH, int)
|
assert isinstance(NONE_HASH, int)
|
||||||
@ -327,18 +340,6 @@ def test_metrics():
|
|||||||
|
|
||||||
|
|
||||||
def test_unify_kv_cache_configs():
|
def test_unify_kv_cache_configs():
|
||||||
|
|
||||||
def new_kv_cache_spec(block_size=16,
|
|
||||||
num_kv_heads=2,
|
|
||||||
head_size=64,
|
|
||||||
dtype=torch.float32,
|
|
||||||
use_mla=False):
|
|
||||||
return FullAttentionSpec(block_size=block_size,
|
|
||||||
num_kv_heads=num_kv_heads,
|
|
||||||
head_size=head_size,
|
|
||||||
dtype=dtype,
|
|
||||||
use_mla=use_mla)
|
|
||||||
|
|
||||||
same_kv_cache_config = [
|
same_kv_cache_config = [
|
||||||
KVCacheConfig(
|
KVCacheConfig(
|
||||||
num_blocks=10,
|
num_blocks=10,
|
||||||
@ -470,3 +471,64 @@ def test_estimate_max_model_len(model_id, max_model_len,
|
|||||||
estimated_max_len = estimate_max_model_len(vllm_config, kv_cache_spec,
|
estimated_max_len = estimate_max_model_len(vllm_config, kv_cache_spec,
|
||||||
8 * GiB_bytes)
|
8 * GiB_bytes)
|
||||||
assert estimated_max_len == want_estimated_max_len
|
assert estimated_max_len == want_estimated_max_len
|
||||||
|
|
||||||
|
|
||||||
|
def test_allocate_with_lookahead():
|
||||||
|
"""Verify that lookahead tokens correctly affect block allocation"""
|
||||||
|
block_size = 4
|
||||||
|
config = KVCacheConfig(
|
||||||
|
num_blocks=10,
|
||||||
|
tensors={
|
||||||
|
"layer1": KVCacheTensor(100),
|
||||||
|
},
|
||||||
|
kv_cache_groups=[
|
||||||
|
KVCacheGroupSpec(["layer1"],
|
||||||
|
new_kv_cache_spec(block_size=block_size)),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
request = make_request(
|
||||||
|
request_id=0,
|
||||||
|
prompt_token_ids=[],
|
||||||
|
mm_positions=None,
|
||||||
|
mm_hashes=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test case 1: Requires additional lookahead tokens
|
||||||
|
kv_cache_manager = KVCacheManager(kv_cache_config=config,
|
||||||
|
max_model_len=100,
|
||||||
|
num_preallocate_tokens=0)
|
||||||
|
blocks = kv_cache_manager.allocate_slots(
|
||||||
|
request,
|
||||||
|
num_tokens=3,
|
||||||
|
num_lookahead_tokens=2, # Total required: 3+2=5 tokens
|
||||||
|
)
|
||||||
|
assert len(blocks) == 2 # ceil(5/4)=2 blocks
|
||||||
|
|
||||||
|
# Test case 2: With precomputed blocks
|
||||||
|
kv_cache_manager = KVCacheManager(kv_cache_config=config,
|
||||||
|
max_model_len=100,
|
||||||
|
num_preallocate_tokens=4)
|
||||||
|
# num_preallocate_blocks = 4 // 4 - 2 // 4 = 1
|
||||||
|
# required_blocks = ceil((3 + 2) /4) = 2
|
||||||
|
# total_blocks = 1 + 2 = 3
|
||||||
|
blocks = kv_cache_manager.allocate_slots(
|
||||||
|
request,
|
||||||
|
num_tokens=3,
|
||||||
|
num_lookahead_tokens=2,
|
||||||
|
)
|
||||||
|
assert len(blocks) == 3
|
||||||
|
|
||||||
|
# Test case 3: With precomputed blocks
|
||||||
|
# num_preallocate_blocks = 4 // 4 - 4 // 4 = 0
|
||||||
|
# required_blocks = ceil((3 + 4) / 4) = 2
|
||||||
|
# total_blocks = 0 + 2 = 2
|
||||||
|
kv_cache_manager = KVCacheManager(kv_cache_config=config,
|
||||||
|
max_model_len=100,
|
||||||
|
num_preallocate_tokens=4)
|
||||||
|
blocks = kv_cache_manager.allocate_slots(
|
||||||
|
request,
|
||||||
|
num_tokens=3,
|
||||||
|
num_lookahead_tokens=4,
|
||||||
|
)
|
||||||
|
assert len(blocks) == 2
|
||||||
|
|||||||
@ -164,7 +164,8 @@ class KVCacheManager:
|
|||||||
self,
|
self,
|
||||||
request: Request,
|
request: Request,
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
new_computed_blocks: Optional[list[KVCacheBlock]] = None
|
new_computed_blocks: Optional[list[KVCacheBlock]] = None,
|
||||||
|
num_lookahead_tokens: int = 0,
|
||||||
) -> Optional[list[KVCacheBlock]]:
|
) -> Optional[list[KVCacheBlock]]:
|
||||||
"""Add slots for a request with new tokens to append.
|
"""Add slots for a request with new tokens to append.
|
||||||
|
|
||||||
@ -174,6 +175,9 @@ class KVCacheManager:
|
|||||||
not include the tokens that have already been computed.
|
not include the tokens that have already been computed.
|
||||||
new_computed_blocks: A list of new computed blocks just hitting the
|
new_computed_blocks: A list of new computed blocks just hitting the
|
||||||
prefix caching.
|
prefix caching.
|
||||||
|
num_lookahead_tokens: The number of speculative tokens to allocate.
|
||||||
|
This is used by spec decode proposers with kv-cache such
|
||||||
|
as eagle.
|
||||||
|
|
||||||
Blocks layout:
|
Blocks layout:
|
||||||
-----------------------------------------------------------------------
|
-----------------------------------------------------------------------
|
||||||
@ -211,8 +215,9 @@ class KVCacheManager:
|
|||||||
# the new prefix caching hits
|
# the new prefix caching hits
|
||||||
num_computed_tokens = (request.num_computed_tokens +
|
num_computed_tokens = (request.num_computed_tokens +
|
||||||
len(new_computed_blocks) * self.block_size)
|
len(new_computed_blocks) * self.block_size)
|
||||||
num_required_blocks = cdiv(num_computed_tokens + num_tokens,
|
num_required_blocks = cdiv(
|
||||||
self.block_size)
|
num_computed_tokens + num_tokens + num_lookahead_tokens,
|
||||||
|
self.block_size)
|
||||||
num_new_blocks = (num_required_blocks - len(req_blocks) -
|
num_new_blocks = (num_required_blocks - len(req_blocks) -
|
||||||
len(new_computed_blocks))
|
len(new_computed_blocks))
|
||||||
|
|
||||||
@ -246,8 +251,11 @@ class KVCacheManager:
|
|||||||
else:
|
else:
|
||||||
# Get new blocks from the free block pool considering
|
# Get new blocks from the free block pool considering
|
||||||
# preallocated blocks.
|
# preallocated blocks.
|
||||||
|
num_preallocate_blocks = max(
|
||||||
|
0, self.num_preallocate_blocks -
|
||||||
|
num_lookahead_tokens // self.block_size)
|
||||||
num_new_blocks = min(
|
num_new_blocks = min(
|
||||||
num_new_blocks + self.num_preallocate_blocks,
|
num_new_blocks + num_preallocate_blocks,
|
||||||
self.block_pool.get_num_free_blocks(),
|
self.block_pool.get_num_free_blocks(),
|
||||||
# Should not exceed the maximum number of blocks per request.
|
# Should not exceed the maximum number of blocks per request.
|
||||||
# This is especially because the block table has the shape
|
# This is especially because the block table has the shape
|
||||||
|
|||||||
@ -7,7 +7,8 @@ from collections import deque
|
|||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from vllm.config import CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig
|
from vllm.config import (CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig,
|
||||||
|
SpeculativeConfig)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||||
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
|
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
|
||||||
@ -39,6 +40,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
lora_config: Optional[LoRAConfig],
|
lora_config: Optional[LoRAConfig],
|
||||||
kv_cache_config: KVCacheConfig,
|
kv_cache_config: KVCacheConfig,
|
||||||
structured_output_manager: StructuredOutputManager,
|
structured_output_manager: StructuredOutputManager,
|
||||||
|
speculative_config: SpeculativeConfig = None,
|
||||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||||
include_finished_set: bool = False,
|
include_finished_set: bool = False,
|
||||||
log_stats: bool = False,
|
log_stats: bool = False,
|
||||||
@ -112,6 +114,11 @@ class Scheduler(SchedulerInterface):
|
|||||||
self.encoder_cache_manager = EncoderCacheManager(
|
self.encoder_cache_manager = EncoderCacheManager(
|
||||||
cache_size=encoder_cache_size)
|
cache_size=encoder_cache_size)
|
||||||
|
|
||||||
|
self.num_lookahead_tokens = 0
|
||||||
|
if speculative_config and speculative_config.method == "eagle":
|
||||||
|
self.num_lookahead_tokens = \
|
||||||
|
speculative_config.num_speculative_tokens
|
||||||
|
|
||||||
def schedule(self) -> SchedulerOutput:
|
def schedule(self) -> SchedulerOutput:
|
||||||
# NOTE(woosuk) on the scheduling algorithm:
|
# NOTE(woosuk) on the scheduling algorithm:
|
||||||
# There's no "decoding phase" nor "prefill phase" in the scheduler.
|
# There's no "decoding phase" nor "prefill phase" in the scheduler.
|
||||||
@ -188,7 +195,9 @@ class Scheduler(SchedulerInterface):
|
|||||||
|
|
||||||
while True:
|
while True:
|
||||||
new_blocks = self.kv_cache_manager.allocate_slots(
|
new_blocks = self.kv_cache_manager.allocate_slots(
|
||||||
request, num_new_tokens)
|
request,
|
||||||
|
num_new_tokens,
|
||||||
|
num_lookahead_tokens=self.num_lookahead_tokens)
|
||||||
if new_blocks is None:
|
if new_blocks is None:
|
||||||
# The request cannot be scheduled.
|
# The request cannot be scheduled.
|
||||||
# Preempt the lowest-priority request.
|
# Preempt the lowest-priority request.
|
||||||
|
|||||||
@ -98,6 +98,7 @@ class EngineCore:
|
|||||||
cache_config=vllm_config.cache_config,
|
cache_config=vllm_config.cache_config,
|
||||||
lora_config=vllm_config.lora_config,
|
lora_config=vllm_config.lora_config,
|
||||||
kv_cache_config=kv_cache_config,
|
kv_cache_config=kv_cache_config,
|
||||||
|
speculative_config=vllm_config.speculative_config,
|
||||||
structured_output_manager=self.structured_output_manager,
|
structured_output_manager=self.structured_output_manager,
|
||||||
include_finished_set=vllm_config.parallel_config.data_parallel_size
|
include_finished_set=vllm_config.parallel_config.data_parallel_size
|
||||||
> 1,
|
> 1,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user