[Core] Support Local Chunked Attention for Hybrid KV Cache (#19351)

Signed-off-by: Lucia Fang <fanglu@fb.com>
Signed-off-by: Lu Fang <fanglu@meta.com>
Signed-off-by: Lu Fang <fanglu@fb.com>
Co-authored-by: Lu Fang <fanglu@meta.com>
This commit is contained in:
Lucia Fang 2025-07-19 11:48:38 +08:00 committed by GitHub
parent 466e878f2a
commit 9a9fda1423
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 351 additions and 19 deletions

View File

@ -1,13 +1,17 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random
import torch
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
KVCacheBlock)
from vllm.v1.core.single_type_kv_cache_manager import SlidingWindowManager
from vllm.v1.kv_cache_interface import SlidingWindowSpec
from vllm.v1.core.single_type_kv_cache_manager import (
ChunkedLocalAttentionManager, SlidingWindowManager)
from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
SlidingWindowSpec)
def get_sliding_window_manager(sliding_window_spec, block_pool):
@ -17,6 +21,80 @@ def get_sliding_window_manager(sliding_window_spec, block_pool):
kv_cache_group_id=0)
def get_chunked_local_attention_manager(chunked_local_attention_spec,
block_pool):
return ChunkedLocalAttentionManager(chunked_local_attention_spec,
block_pool,
caching_hash_fn=lambda x: x,
kv_cache_group_id=0)
def test_chunked_local_attention_possible_cached_prefix():
block_size = 2
chunked_local_attention_spec = ChunkedLocalAttentionSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
attention_chunk_size=4,
use_mla=False,
)
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
manager = get_chunked_local_attention_manager(chunked_local_attention_spec,
block_pool)
def run_one_case(block_is_cached, tail_token, expect_length):
block_hash_list = [
BlockHash(i, ()) for i in range(len(block_is_cached))
]
block_pool.cached_block_hash_to_block.clear()
# Mock the block pool with the cached blocks
for i, (block_hash,
is_cached) in enumerate(zip(block_hash_list, block_is_cached)):
if is_cached:
block_pool.cached_block_hash_to_block[BlockHashWithGroupId(
block_hash, 0)] = {
i: block_pool.blocks[i + 10],
}
computed_blocks = manager.find_longest_cache_hit(
block_hashes=block_hash_list,
max_length=len(block_hash_list) * block_size + tail_token,
kv_cache_group_ids=[0],
block_pool=block_pool,
kv_cache_spec=chunked_local_attention_spec,
use_eagle=False)[0]
assert len(computed_blocks) == expect_length
assert all(block == block_pool.null_block
for block in computed_blocks[:(expect_length - 1) // 2])
run_one_case([True], 0, 1)
run_one_case([True], 1, 1)
run_one_case([True, False], 0, 2)
run_one_case([True, False], 1, 2)
run_one_case([True, True], 0, 2)
run_one_case([True, True], 1, 2)
run_one_case([True, True, False], 0, 2)
run_one_case([True, True, False], 1, 2)
run_one_case([True, True, True], 0, 3)
run_one_case([True, True, True], 1, 3)
run_one_case([True, True, True, False], 0, 4)
run_one_case([True, True, True, False], 1, 4)
run_one_case([random.choice([True, False])] * 8 + [True], 1, 9)
run_one_case([random.choice([True, False])] * 8 + [False], 1, 8)
run_one_case([random.choice([True, False])] * 8 + [True, True], 1, 10)
run_one_case([random.choice([True, False])] * 8 + [True, False], 0, 10)
run_one_case([random.choice([True, False])] * 8 + [True, False], 1, 10)
run_one_case([random.choice([True, False])] * 8 + [False, True], 0, 10)
run_one_case([random.choice([True, False])] * 8 + [False, True], 1, 10)
run_one_case([random.choice([True, False])] * 8 + [False, False], 0, 10)
run_one_case([random.choice([True, False])] * 8 + [False, False], 1, 10)
def test_sliding_window_possible_cached_prefix():
block_size = 2
sliding_window_spec = SlidingWindowSpec(
@ -84,6 +162,58 @@ def test_sliding_window_possible_cached_prefix():
], 8)
def test_chunked_local_attention_remove_skipped_blocks():
attention_spec = ChunkedLocalAttentionSpec(
block_size=2,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
attention_chunk_size=4,
use_mla=False,
)
block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True)
manager = get_chunked_local_attention_manager(attention_spec, block_pool)
null_block_id = block_pool.null_block.block_id
def id_to_block_table(ids) -> list[KVCacheBlock]:
return [
KVCacheBlock(id_)
if id_ != null_block_id else block_pool.null_block for id_ in ids
]
def assert_block_id(block_table: list[KVCacheBlock], ids: list[int]):
for block, id_ in zip(block_table, ids):
if id_ == null_block_id:
assert block == block_pool.null_block
else:
assert block.block_id == id_
original_block_ids = [
1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010
]
block_table = id_to_block_table(original_block_ids)
manager.req_to_blocks["test"] = block_table
manager.remove_skipped_blocks("test", 0)
assert_block_id(block_table, original_block_ids)
# For 4th token (0-indexed), token 0-3 is out of the local attention window.
manager.remove_skipped_blocks("test", 4)
assert_block_id(block_table, [null_block_id] * 2)
# For 6th token (0-indexed), token 4 - 6 are in local attention window,
# token 0 - 3 are out, 2 blocks can be removed.
manager.remove_skipped_blocks("test", 6)
assert_block_id(block_table, [null_block_id] * 2 + original_block_ids[2:])
# For 12th token (0-indexed),
# token 0-11 are out, 6 block can be removed.
manager.remove_skipped_blocks("test", 12)
assert_block_id(block_table, [null_block_id] * 6)
def test_sliding_window_remove_skipped_blocks():
sliding_window_spec = SlidingWindowSpec(
block_size=2,
@ -172,3 +302,26 @@ def test_get_num_blocks_to_allocate():
cached_blocks_1) == 20
assert manager.get_num_blocks_to_allocate("2", 20 * block_size,
cached_blocks_2) == 15
def test_chunked_local_attention_get_num_blocks_to_allocate():
block_size = 2
attention_spec = ChunkedLocalAttentionSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
attention_chunk_size=4, # Placeholder value, not related to test result
use_mla=False,
)
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
manager = get_chunked_local_attention_manager(attention_spec, block_pool)
cached_blocks_1 = [KVCacheBlock(i + 1) for i in range(10)]
cached_blocks_2 = [block_pool.null_block for _ in range(5)
] + [KVCacheBlock(i + 1) for i in range(5)]
assert manager.get_num_blocks_to_allocate("1", 20 * block_size,
cached_blocks_1) == 20
assert manager.get_num_blocks_to_allocate("2", 20 * block_size,
cached_blocks_2) == 15

View File

@ -172,6 +172,7 @@ class Attention(nn.Module):
kv_sharing_target_layer_name, **extra_impl_args)
self.backend = backend_name_to_enum(attn_backend.get_name())
self.dtype = dtype
self.use_irope = extra_impl_args.get("use_irope", False)
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
# torch.compile works by registering the attention as one giant

View File

@ -4722,6 +4722,13 @@ class VllmConfig:
if self.kv_events_config is not None:
# Hybrid KV cache manager is not compatible with KV events.
self.scheduler_config.disable_hybrid_kv_cache_manager = True
if self.model_config is not None and \
self.model_config.attention_chunk_size is not None and \
self.speculative_config is not None and \
self.speculative_config.use_eagle():
# Hybrid KV cache manager is not yet supported with chunked
# local attention + eagle.
self.scheduler_config.disable_hybrid_kv_cache_manager = True
def update_sizes_for_sequence_parallelism(self,
possible_sizes: list) -> list:

View File

@ -538,6 +538,7 @@ def use_cascade_attention(
num_kv_heads: int,
use_alibi: bool,
use_sliding_window: bool,
use_local_attention: bool,
num_sms: int,
) -> bool:
"""Decide whether to use cascade attention.
@ -553,7 +554,7 @@ def use_cascade_attention(
if common_prefix_len < 256:
return False
# Cascade attention is currently not supported with these variants.
if use_alibi or use_sliding_window:
if use_alibi or use_sliding_window or use_local_attention:
return False
# Too few queries. Probably not worth using cascade attention.
# We use an arbitrary threshold of 8 queries. TODO: Tune this threshold.

View File

@ -120,6 +120,7 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
num_kv_heads: int,
use_alibi: bool,
use_sliding_window: bool,
use_local_attention: bool,
num_sms: int,
) -> bool:
return False

View File

@ -11,7 +11,8 @@ from typing import Any, Callable, NamedTuple, Optional
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils import GiB_bytes, cdiv, sha256_cbor_64bit
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheSpec,
KVCacheTensor, SlidingWindowSpec)
from vllm.v1.metrics.stats import PrefixCacheStats
@ -976,7 +977,11 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
isinstance(spec, FullAttentionSpec) for spec in kv_cache_spec.values())
has_sliding_window = any(
isinstance(spec, SlidingWindowSpec) for spec in kv_cache_spec.values())
if has_full_attention and has_sliding_window:
has_chunked_local_attention = any(
isinstance(spec, ChunkedLocalAttentionSpec)
for spec in kv_cache_spec.values())
if has_full_attention and (has_sliding_window
or has_chunked_local_attention):
for layer_name, spec in kv_cache_spec.items():
if isinstance(spec, SlidingWindowSpec):
kv_cache_spec[layer_name] = FullAttentionSpec(
@ -987,6 +992,15 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
use_mla=spec.use_mla,
sliding_window=spec.sliding_window,
)
elif isinstance(spec, ChunkedLocalAttentionSpec):
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=spec.block_size,
num_kv_heads=spec.num_kv_heads,
head_size=spec.head_size,
dtype=spec.dtype,
use_mla=spec.use_mla,
attention_chunk_size=spec.attention_chunk_size,
)
if is_hybrid(kv_cache_spec):
raise ValueError("Hybrid KV cache manager is disabled but failed to "
@ -1010,7 +1024,6 @@ def get_kv_cache_config(
The generated KVCacheConfigs
"""
check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory)
if vllm_config.scheduler_config.disable_hybrid_kv_cache_manager:
unify_hybrid_kv_cache_specs(kv_cache_spec)

View File

@ -394,6 +394,129 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
return 0
class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
def __init__(self, kv_cache_spec: ChunkedLocalAttentionSpec,
block_pool: BlockPool, **kwargs) -> None:
super().__init__(kv_cache_spec, block_pool, **kwargs)
self.attention_chunk_size = kv_cache_spec.attention_chunk_size
self._null_block = block_pool.null_block
@classmethod
def find_longest_cache_hit(
cls,
block_hashes: list[BlockHash],
max_length: int,
kv_cache_group_ids: list[int],
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
) -> tuple[list[KVCacheBlock], ...]:
"""
For chunked local attention, we need to find the longest cache hit
prefix of the blocks that is not longer than `max_length`. The prefix
should be a common prefix hit for all the kv cache groups in
`kv_cache_group_ids`. If no cache hit is found, return an empty list.
note we mark as computed if the whole block is outside of the local
window, and set the block as null. Examples:
1. Attention chunk size of 8, block size of 4, max length of 15
for next token at 15th (zero-indexed), 8th - 14th tokens are in
the window(needs lookup), 0th - 7th are not in the window,
so they are already marked as computed. We check the complete
block3 (8th - 11th tokens), Assume block 3 is hit, we will return
[null, null, block 3], otherwise, we return [null, null]
2. Attention chunk size of 8, block size of 4, max length of 16
for next token at 16th (zero-indexed), 0th - 15th tokens are not
in the window, so they are already marked as computed.
we return 4 blocks[null, null, null, null]
Args:
block_hashes: The block hashes of the request.
max_length: The maximum length of the cache hit prefix.
kv_cache_group_ids: The ids of the kv cache groups.
block_pool: The block pool.
kv_cache_spec: The kv cache spec.
use_eagle: Whether to use eagle.
Returns:
A list of cached blocks
"""
assert isinstance(kv_cache_spec, ChunkedLocalAttentionSpec), (
"ChunkedLocalAttentionManager can only be used for " +
"chunked local attention groups")
assert use_eagle is False, ("Hybrid KV cache is not supported for " +
"eagle + chunked local attention.")
max_num_blocks = max_length // kv_cache_spec.block_size
if max_length > 0:
local_attention_start_idx = (max_length //
kv_cache_spec.attention_chunk_size *
kv_cache_spec.attention_chunk_size)
else:
local_attention_start_idx = 0
# we marked blocks out of window as computed
# with null blocks, and blocks inside window based on cache lookup
# result [null] [null] ... [null] [hit block 1 (1st block contain
# last window)] [hit block 2] ... [hit block x]
local_attention_start_block_idx = (local_attention_start_idx //
kv_cache_spec.block_size)
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
[block_pool.null_block] * local_attention_start_block_idx
for _ in range(len(kv_cache_group_ids)))
for i in range(local_attention_start_block_idx, max_num_blocks):
block_hash = block_hashes[i]
if cached_block := block_pool.get_cached_block(
block_hash, kv_cache_group_ids):
for computed, cached in zip(computed_blocks, cached_block):
computed.append(cached)
else:
break
return computed_blocks
def remove_skipped_blocks(self, request_id: str,
num_computed_tokens: int) -> None:
# Remove the blocks that are no longer be in the chunked attention
# window and skipped during the attention computation.
# [chunk 0][chunk 1]local_attention_start_idx ... current
# we computed previous number of chunks to get the idx of
# current chunk window starting offset,
# e.g. for computed 1024 tokens, the 1024th token (0 indexed)
# is in the second chunk, there are 1 prev chunk, the start idx
# is 1024. for 1023, it will be 0.
num_cached_block = self.num_cached_block.get(request_id, 0)
local_attention_start_idx = (
num_computed_tokens
) // self.attention_chunk_size * self.attention_chunk_size
first_useful_block_idx = local_attention_start_idx // self.block_size
if num_cached_block > 0:
# Make sure we don't delete the last cached block
first_useful_block_idx = min(first_useful_block_idx,
num_cached_block - 1)
# if block size = 128, 0 -> block 0, 1024 (= 128 * 8) ->
# block 8, 372 (= 128 * 2 + 116) -> block 2
blocks = self.req_to_blocks[request_id]
removed_blocks: list[KVCacheBlock] = []
# we need to keep the last block to get the previous hash key
for i in range(first_useful_block_idx - 1, -1, -1):
if blocks[i] == self._null_block:
# If the block is already a null block, the blocks before it
# should also have been set to null blocks by the previous calls
# to this function.
break
removed_blocks.append(blocks[i])
blocks[i] = self._null_block
self.block_pool.free_blocks(removed_blocks)
def get_num_common_prefix_blocks(self, request_id: str,
num_running_requests: int) -> int:
"""
cascade attention is not supported by chunked local attention.
"""
return 0
class MambaManager(SingleTypeKVCacheManager):
@classmethod
@ -435,8 +558,8 @@ class MambaManager(SingleTypeKVCacheManager):
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
FullAttentionSpec: FullAttentionManager,
ChunkedLocalAttentionSpec: FullAttentionManager,
SlidingWindowSpec: SlidingWindowManager,
ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager,
MambaSpec: MambaManager,
}

View File

@ -87,6 +87,7 @@ class AttentionSpec(KVCacheSpec):
@dataclass
class FullAttentionSpec(AttentionSpec):
sliding_window: Optional[int] = None
attention_chunk_size: Optional[int] = None
"""
When hybrid allocator is disabled and the model contains both full
attention layers and sliding window attention layers, sliding
@ -105,6 +106,17 @@ class FullAttentionSpec(AttentionSpec):
max_model_len = vllm_config.model_config.max_model_len
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
@classmethod
def merge_window_sizes(cls, window_sizes: set[int]) -> Optional[int]:
if len(window_sizes) == 0:
return None
elif len(window_sizes) == 1:
return window_sizes.pop()
else:
raise ValueError(
"All attention layers in the same KV cache group must have the "
"same window size.")
@classmethod
def merge(cls, specs: list[Self]) -> Self:
"""
@ -114,14 +126,17 @@ class FullAttentionSpec(AttentionSpec):
merged_spec = super().merge(specs)
sliding_window = set(spec.sliding_window for spec in specs
if spec.sliding_window is not None)
if len(sliding_window) == 0:
merged_spec.sliding_window = None
elif len(sliding_window) == 1:
merged_spec.sliding_window = sliding_window.pop()
else:
raise ValueError(
"All sliding window layers in the same KV cache group "
"must have the same window size.")
attention_chunk_size = set(spec.attention_chunk_size for spec in specs
if spec.attention_chunk_size is not None)
merged_spec.sliding_window = cls.merge_window_sizes(sliding_window)
merged_spec.attention_chunk_size = (
cls.merge_window_sizes(attention_chunk_size))
assert (
(merged_spec.sliding_window is not None) +
(merged_spec.attention_chunk_size is not None) <= 1
), ("Model with both sliding window layers and chunked local attention "
"layers is not supported.")
return merged_spec
@ -129,16 +144,26 @@ class FullAttentionSpec(AttentionSpec):
class ChunkedLocalAttentionSpec(AttentionSpec):
attention_chunk_size: int
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
max_model_len = vllm_config.model_config.max_model_len
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
@property
def type_id(self) -> str:
return (
f"local_attention_{self.attention_chunk_size}_{self.block_size}_{self.page_size_bytes}"
) # noqa
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
max_model_len = vllm_config.model_config.max_model_len
max_num_batched_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens)
# During chunked prefill, we allocate KV cache for at most
# `self.attention_chunk_size` computed tokens plus the newly scheduled
# tokens. And we won't allocate KV cache for more than `max_model_len`
# tokens.
num_tokens = min(self.attention_chunk_size + max_num_batched_tokens,
max_model_len)
return cdiv(num_tokens, self.block_size) * self.page_size_bytes
@dataclass
class SlidingWindowSpec(AttentionSpec):

View File

@ -862,6 +862,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
use_sliding_window = (isinstance(kv_cache_spec, SlidingWindowSpec) or
(isinstance(kv_cache_spec, FullAttentionSpec)
and kv_cache_spec.sliding_window is not None))
use_local_attention = (
isinstance(kv_cache_spec, ChunkedLocalAttentionSpec)
or (isinstance(kv_cache_spec, FullAttentionSpec)
and kv_cache_spec.attention_chunk_size is not None))
assert isinstance(kv_cache_spec, AttentionSpec)
use_cascade = attn_metadata_builder.use_cascade_attention(
common_prefix_len=common_prefix_len,
@ -870,6 +874,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_kv_heads=kv_cache_spec.num_kv_heads,
use_alibi=self.use_alibi,
use_sliding_window=use_sliding_window,
use_local_attention=use_local_attention,
num_sms=self.num_sms,
)
return common_prefix_len if use_cascade else 0
@ -2672,6 +2677,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
dtype=self.kv_cache_dtype,
sliding_window=attn_module.sliding_window,
use_mla=use_mla)
assert not use_local_attention, (
"attention module can not be with ",
"both local attention and sliding window")
elif use_local_attention:
kv_cache_spec[layer_name] = (ChunkedLocalAttentionSpec(
block_size=block_size,