[Core] Support Local Chunked Attention for Hybrid KV Cache (#19351)

Signed-off-by: Lucia Fang <fanglu@fb.com>
Signed-off-by: Lu Fang <fanglu@meta.com>
Signed-off-by: Lu Fang <fanglu@fb.com>
Co-authored-by: Lu Fang <fanglu@meta.com>
This commit is contained in:
Lucia Fang 2025-07-19 11:48:38 +08:00 committed by GitHub
parent 466e878f2a
commit 9a9fda1423
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 351 additions and 19 deletions

View File

@ -1,13 +1,17 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random
import torch import torch
from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId, from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
KVCacheBlock) KVCacheBlock)
from vllm.v1.core.single_type_kv_cache_manager import SlidingWindowManager from vllm.v1.core.single_type_kv_cache_manager import (
from vllm.v1.kv_cache_interface import SlidingWindowSpec ChunkedLocalAttentionManager, SlidingWindowManager)
from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
SlidingWindowSpec)
def get_sliding_window_manager(sliding_window_spec, block_pool): def get_sliding_window_manager(sliding_window_spec, block_pool):
@ -17,6 +21,80 @@ def get_sliding_window_manager(sliding_window_spec, block_pool):
kv_cache_group_id=0) kv_cache_group_id=0)
def get_chunked_local_attention_manager(chunked_local_attention_spec,
block_pool):
return ChunkedLocalAttentionManager(chunked_local_attention_spec,
block_pool,
caching_hash_fn=lambda x: x,
kv_cache_group_id=0)
def test_chunked_local_attention_possible_cached_prefix():
block_size = 2
chunked_local_attention_spec = ChunkedLocalAttentionSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
attention_chunk_size=4,
use_mla=False,
)
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
manager = get_chunked_local_attention_manager(chunked_local_attention_spec,
block_pool)
def run_one_case(block_is_cached, tail_token, expect_length):
block_hash_list = [
BlockHash(i, ()) for i in range(len(block_is_cached))
]
block_pool.cached_block_hash_to_block.clear()
# Mock the block pool with the cached blocks
for i, (block_hash,
is_cached) in enumerate(zip(block_hash_list, block_is_cached)):
if is_cached:
block_pool.cached_block_hash_to_block[BlockHashWithGroupId(
block_hash, 0)] = {
i: block_pool.blocks[i + 10],
}
computed_blocks = manager.find_longest_cache_hit(
block_hashes=block_hash_list,
max_length=len(block_hash_list) * block_size + tail_token,
kv_cache_group_ids=[0],
block_pool=block_pool,
kv_cache_spec=chunked_local_attention_spec,
use_eagle=False)[0]
assert len(computed_blocks) == expect_length
assert all(block == block_pool.null_block
for block in computed_blocks[:(expect_length - 1) // 2])
run_one_case([True], 0, 1)
run_one_case([True], 1, 1)
run_one_case([True, False], 0, 2)
run_one_case([True, False], 1, 2)
run_one_case([True, True], 0, 2)
run_one_case([True, True], 1, 2)
run_one_case([True, True, False], 0, 2)
run_one_case([True, True, False], 1, 2)
run_one_case([True, True, True], 0, 3)
run_one_case([True, True, True], 1, 3)
run_one_case([True, True, True, False], 0, 4)
run_one_case([True, True, True, False], 1, 4)
run_one_case([random.choice([True, False])] * 8 + [True], 1, 9)
run_one_case([random.choice([True, False])] * 8 + [False], 1, 8)
run_one_case([random.choice([True, False])] * 8 + [True, True], 1, 10)
run_one_case([random.choice([True, False])] * 8 + [True, False], 0, 10)
run_one_case([random.choice([True, False])] * 8 + [True, False], 1, 10)
run_one_case([random.choice([True, False])] * 8 + [False, True], 0, 10)
run_one_case([random.choice([True, False])] * 8 + [False, True], 1, 10)
run_one_case([random.choice([True, False])] * 8 + [False, False], 0, 10)
run_one_case([random.choice([True, False])] * 8 + [False, False], 1, 10)
def test_sliding_window_possible_cached_prefix(): def test_sliding_window_possible_cached_prefix():
block_size = 2 block_size = 2
sliding_window_spec = SlidingWindowSpec( sliding_window_spec = SlidingWindowSpec(
@ -84,6 +162,58 @@ def test_sliding_window_possible_cached_prefix():
], 8) ], 8)
def test_chunked_local_attention_remove_skipped_blocks():
attention_spec = ChunkedLocalAttentionSpec(
block_size=2,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
attention_chunk_size=4,
use_mla=False,
)
block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True)
manager = get_chunked_local_attention_manager(attention_spec, block_pool)
null_block_id = block_pool.null_block.block_id
def id_to_block_table(ids) -> list[KVCacheBlock]:
return [
KVCacheBlock(id_)
if id_ != null_block_id else block_pool.null_block for id_ in ids
]
def assert_block_id(block_table: list[KVCacheBlock], ids: list[int]):
for block, id_ in zip(block_table, ids):
if id_ == null_block_id:
assert block == block_pool.null_block
else:
assert block.block_id == id_
original_block_ids = [
1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010
]
block_table = id_to_block_table(original_block_ids)
manager.req_to_blocks["test"] = block_table
manager.remove_skipped_blocks("test", 0)
assert_block_id(block_table, original_block_ids)
# For 4th token (0-indexed), token 0-3 is out of the local attention window.
manager.remove_skipped_blocks("test", 4)
assert_block_id(block_table, [null_block_id] * 2)
# For 6th token (0-indexed), token 4 - 6 are in local attention window,
# token 0 - 3 are out, 2 blocks can be removed.
manager.remove_skipped_blocks("test", 6)
assert_block_id(block_table, [null_block_id] * 2 + original_block_ids[2:])
# For 12th token (0-indexed),
# token 0-11 are out, 6 block can be removed.
manager.remove_skipped_blocks("test", 12)
assert_block_id(block_table, [null_block_id] * 6)
def test_sliding_window_remove_skipped_blocks(): def test_sliding_window_remove_skipped_blocks():
sliding_window_spec = SlidingWindowSpec( sliding_window_spec = SlidingWindowSpec(
block_size=2, block_size=2,
@ -172,3 +302,26 @@ def test_get_num_blocks_to_allocate():
cached_blocks_1) == 20 cached_blocks_1) == 20
assert manager.get_num_blocks_to_allocate("2", 20 * block_size, assert manager.get_num_blocks_to_allocate("2", 20 * block_size,
cached_blocks_2) == 15 cached_blocks_2) == 15
def test_chunked_local_attention_get_num_blocks_to_allocate():
block_size = 2
attention_spec = ChunkedLocalAttentionSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
attention_chunk_size=4, # Placeholder value, not related to test result
use_mla=False,
)
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
manager = get_chunked_local_attention_manager(attention_spec, block_pool)
cached_blocks_1 = [KVCacheBlock(i + 1) for i in range(10)]
cached_blocks_2 = [block_pool.null_block for _ in range(5)
] + [KVCacheBlock(i + 1) for i in range(5)]
assert manager.get_num_blocks_to_allocate("1", 20 * block_size,
cached_blocks_1) == 20
assert manager.get_num_blocks_to_allocate("2", 20 * block_size,
cached_blocks_2) == 15

View File

@ -172,6 +172,7 @@ class Attention(nn.Module):
kv_sharing_target_layer_name, **extra_impl_args) kv_sharing_target_layer_name, **extra_impl_args)
self.backend = backend_name_to_enum(attn_backend.get_name()) self.backend = backend_name_to_enum(attn_backend.get_name())
self.dtype = dtype self.dtype = dtype
self.use_irope = extra_impl_args.get("use_irope", False)
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
# torch.compile works by registering the attention as one giant # torch.compile works by registering the attention as one giant

View File

@ -4722,6 +4722,13 @@ class VllmConfig:
if self.kv_events_config is not None: if self.kv_events_config is not None:
# Hybrid KV cache manager is not compatible with KV events. # Hybrid KV cache manager is not compatible with KV events.
self.scheduler_config.disable_hybrid_kv_cache_manager = True self.scheduler_config.disable_hybrid_kv_cache_manager = True
if self.model_config is not None and \
self.model_config.attention_chunk_size is not None and \
self.speculative_config is not None and \
self.speculative_config.use_eagle():
# Hybrid KV cache manager is not yet supported with chunked
# local attention + eagle.
self.scheduler_config.disable_hybrid_kv_cache_manager = True
def update_sizes_for_sequence_parallelism(self, def update_sizes_for_sequence_parallelism(self,
possible_sizes: list) -> list: possible_sizes: list) -> list:

View File

@ -538,6 +538,7 @@ def use_cascade_attention(
num_kv_heads: int, num_kv_heads: int,
use_alibi: bool, use_alibi: bool,
use_sliding_window: bool, use_sliding_window: bool,
use_local_attention: bool,
num_sms: int, num_sms: int,
) -> bool: ) -> bool:
"""Decide whether to use cascade attention. """Decide whether to use cascade attention.
@ -553,7 +554,7 @@ def use_cascade_attention(
if common_prefix_len < 256: if common_prefix_len < 256:
return False return False
# Cascade attention is currently not supported with these variants. # Cascade attention is currently not supported with these variants.
if use_alibi or use_sliding_window: if use_alibi or use_sliding_window or use_local_attention:
return False return False
# Too few queries. Probably not worth using cascade attention. # Too few queries. Probably not worth using cascade attention.
# We use an arbitrary threshold of 8 queries. TODO: Tune this threshold. # We use an arbitrary threshold of 8 queries. TODO: Tune this threshold.

View File

@ -120,6 +120,7 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
num_kv_heads: int, num_kv_heads: int,
use_alibi: bool, use_alibi: bool,
use_sliding_window: bool, use_sliding_window: bool,
use_local_attention: bool,
num_sms: int, num_sms: int,
) -> bool: ) -> bool:
return False return False

View File

@ -11,7 +11,8 @@ from typing import Any, Callable, NamedTuple, Optional
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import GiB_bytes, cdiv, sha256_cbor_64bit from vllm.utils import GiB_bytes, cdiv, sha256_cbor_64bit
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheSpec, KVCacheGroupSpec, KVCacheSpec,
KVCacheTensor, SlidingWindowSpec) KVCacheTensor, SlidingWindowSpec)
from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.metrics.stats import PrefixCacheStats
@ -976,7 +977,11 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
isinstance(spec, FullAttentionSpec) for spec in kv_cache_spec.values()) isinstance(spec, FullAttentionSpec) for spec in kv_cache_spec.values())
has_sliding_window = any( has_sliding_window = any(
isinstance(spec, SlidingWindowSpec) for spec in kv_cache_spec.values()) isinstance(spec, SlidingWindowSpec) for spec in kv_cache_spec.values())
if has_full_attention and has_sliding_window: has_chunked_local_attention = any(
isinstance(spec, ChunkedLocalAttentionSpec)
for spec in kv_cache_spec.values())
if has_full_attention and (has_sliding_window
or has_chunked_local_attention):
for layer_name, spec in kv_cache_spec.items(): for layer_name, spec in kv_cache_spec.items():
if isinstance(spec, SlidingWindowSpec): if isinstance(spec, SlidingWindowSpec):
kv_cache_spec[layer_name] = FullAttentionSpec( kv_cache_spec[layer_name] = FullAttentionSpec(
@ -987,6 +992,15 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
use_mla=spec.use_mla, use_mla=spec.use_mla,
sliding_window=spec.sliding_window, sliding_window=spec.sliding_window,
) )
elif isinstance(spec, ChunkedLocalAttentionSpec):
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=spec.block_size,
num_kv_heads=spec.num_kv_heads,
head_size=spec.head_size,
dtype=spec.dtype,
use_mla=spec.use_mla,
attention_chunk_size=spec.attention_chunk_size,
)
if is_hybrid(kv_cache_spec): if is_hybrid(kv_cache_spec):
raise ValueError("Hybrid KV cache manager is disabled but failed to " raise ValueError("Hybrid KV cache manager is disabled but failed to "
@ -1010,7 +1024,6 @@ def get_kv_cache_config(
The generated KVCacheConfigs The generated KVCacheConfigs
""" """
check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory) check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory)
if vllm_config.scheduler_config.disable_hybrid_kv_cache_manager: if vllm_config.scheduler_config.disable_hybrid_kv_cache_manager:
unify_hybrid_kv_cache_specs(kv_cache_spec) unify_hybrid_kv_cache_specs(kv_cache_spec)

View File

@ -394,6 +394,129 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
return 0 return 0
class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
def __init__(self, kv_cache_spec: ChunkedLocalAttentionSpec,
block_pool: BlockPool, **kwargs) -> None:
super().__init__(kv_cache_spec, block_pool, **kwargs)
self.attention_chunk_size = kv_cache_spec.attention_chunk_size
self._null_block = block_pool.null_block
@classmethod
def find_longest_cache_hit(
cls,
block_hashes: list[BlockHash],
max_length: int,
kv_cache_group_ids: list[int],
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
) -> tuple[list[KVCacheBlock], ...]:
"""
For chunked local attention, we need to find the longest cache hit
prefix of the blocks that is not longer than `max_length`. The prefix
should be a common prefix hit for all the kv cache groups in
`kv_cache_group_ids`. If no cache hit is found, return an empty list.
note we mark as computed if the whole block is outside of the local
window, and set the block as null. Examples:
1. Attention chunk size of 8, block size of 4, max length of 15
for next token at 15th (zero-indexed), 8th - 14th tokens are in
the window(needs lookup), 0th - 7th are not in the window,
so they are already marked as computed. We check the complete
block3 (8th - 11th tokens), Assume block 3 is hit, we will return
[null, null, block 3], otherwise, we return [null, null]
2. Attention chunk size of 8, block size of 4, max length of 16
for next token at 16th (zero-indexed), 0th - 15th tokens are not
in the window, so they are already marked as computed.
we return 4 blocks[null, null, null, null]
Args:
block_hashes: The block hashes of the request.
max_length: The maximum length of the cache hit prefix.
kv_cache_group_ids: The ids of the kv cache groups.
block_pool: The block pool.
kv_cache_spec: The kv cache spec.
use_eagle: Whether to use eagle.
Returns:
A list of cached blocks
"""
assert isinstance(kv_cache_spec, ChunkedLocalAttentionSpec), (
"ChunkedLocalAttentionManager can only be used for " +
"chunked local attention groups")
assert use_eagle is False, ("Hybrid KV cache is not supported for " +
"eagle + chunked local attention.")
max_num_blocks = max_length // kv_cache_spec.block_size
if max_length > 0:
local_attention_start_idx = (max_length //
kv_cache_spec.attention_chunk_size *
kv_cache_spec.attention_chunk_size)
else:
local_attention_start_idx = 0
# we marked blocks out of window as computed
# with null blocks, and blocks inside window based on cache lookup
# result [null] [null] ... [null] [hit block 1 (1st block contain
# last window)] [hit block 2] ... [hit block x]
local_attention_start_block_idx = (local_attention_start_idx //
kv_cache_spec.block_size)
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
[block_pool.null_block] * local_attention_start_block_idx
for _ in range(len(kv_cache_group_ids)))
for i in range(local_attention_start_block_idx, max_num_blocks):
block_hash = block_hashes[i]
if cached_block := block_pool.get_cached_block(
block_hash, kv_cache_group_ids):
for computed, cached in zip(computed_blocks, cached_block):
computed.append(cached)
else:
break
return computed_blocks
def remove_skipped_blocks(self, request_id: str,
num_computed_tokens: int) -> None:
# Remove the blocks that are no longer be in the chunked attention
# window and skipped during the attention computation.
# [chunk 0][chunk 1]local_attention_start_idx ... current
# we computed previous number of chunks to get the idx of
# current chunk window starting offset,
# e.g. for computed 1024 tokens, the 1024th token (0 indexed)
# is in the second chunk, there are 1 prev chunk, the start idx
# is 1024. for 1023, it will be 0.
num_cached_block = self.num_cached_block.get(request_id, 0)
local_attention_start_idx = (
num_computed_tokens
) // self.attention_chunk_size * self.attention_chunk_size
first_useful_block_idx = local_attention_start_idx // self.block_size
if num_cached_block > 0:
# Make sure we don't delete the last cached block
first_useful_block_idx = min(first_useful_block_idx,
num_cached_block - 1)
# if block size = 128, 0 -> block 0, 1024 (= 128 * 8) ->
# block 8, 372 (= 128 * 2 + 116) -> block 2
blocks = self.req_to_blocks[request_id]
removed_blocks: list[KVCacheBlock] = []
# we need to keep the last block to get the previous hash key
for i in range(first_useful_block_idx - 1, -1, -1):
if blocks[i] == self._null_block:
# If the block is already a null block, the blocks before it
# should also have been set to null blocks by the previous calls
# to this function.
break
removed_blocks.append(blocks[i])
blocks[i] = self._null_block
self.block_pool.free_blocks(removed_blocks)
def get_num_common_prefix_blocks(self, request_id: str,
num_running_requests: int) -> int:
"""
cascade attention is not supported by chunked local attention.
"""
return 0
class MambaManager(SingleTypeKVCacheManager): class MambaManager(SingleTypeKVCacheManager):
@classmethod @classmethod
@ -435,8 +558,8 @@ class MambaManager(SingleTypeKVCacheManager):
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = { spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
FullAttentionSpec: FullAttentionManager, FullAttentionSpec: FullAttentionManager,
ChunkedLocalAttentionSpec: FullAttentionManager,
SlidingWindowSpec: SlidingWindowManager, SlidingWindowSpec: SlidingWindowManager,
ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager,
MambaSpec: MambaManager, MambaSpec: MambaManager,
} }

View File

@ -87,6 +87,7 @@ class AttentionSpec(KVCacheSpec):
@dataclass @dataclass
class FullAttentionSpec(AttentionSpec): class FullAttentionSpec(AttentionSpec):
sliding_window: Optional[int] = None sliding_window: Optional[int] = None
attention_chunk_size: Optional[int] = None
""" """
When hybrid allocator is disabled and the model contains both full When hybrid allocator is disabled and the model contains both full
attention layers and sliding window attention layers, sliding attention layers and sliding window attention layers, sliding
@ -105,6 +106,17 @@ class FullAttentionSpec(AttentionSpec):
max_model_len = vllm_config.model_config.max_model_len max_model_len = vllm_config.model_config.max_model_len
return cdiv(max_model_len, self.block_size) * self.page_size_bytes return cdiv(max_model_len, self.block_size) * self.page_size_bytes
@classmethod
def merge_window_sizes(cls, window_sizes: set[int]) -> Optional[int]:
if len(window_sizes) == 0:
return None
elif len(window_sizes) == 1:
return window_sizes.pop()
else:
raise ValueError(
"All attention layers in the same KV cache group must have the "
"same window size.")
@classmethod @classmethod
def merge(cls, specs: list[Self]) -> Self: def merge(cls, specs: list[Self]) -> Self:
""" """
@ -114,14 +126,17 @@ class FullAttentionSpec(AttentionSpec):
merged_spec = super().merge(specs) merged_spec = super().merge(specs)
sliding_window = set(spec.sliding_window for spec in specs sliding_window = set(spec.sliding_window for spec in specs
if spec.sliding_window is not None) if spec.sliding_window is not None)
if len(sliding_window) == 0: attention_chunk_size = set(spec.attention_chunk_size for spec in specs
merged_spec.sliding_window = None if spec.attention_chunk_size is not None)
elif len(sliding_window) == 1:
merged_spec.sliding_window = sliding_window.pop() merged_spec.sliding_window = cls.merge_window_sizes(sliding_window)
else: merged_spec.attention_chunk_size = (
raise ValueError( cls.merge_window_sizes(attention_chunk_size))
"All sliding window layers in the same KV cache group " assert (
"must have the same window size.") (merged_spec.sliding_window is not None) +
(merged_spec.attention_chunk_size is not None) <= 1
), ("Model with both sliding window layers and chunked local attention "
"layers is not supported.")
return merged_spec return merged_spec
@ -129,16 +144,26 @@ class FullAttentionSpec(AttentionSpec):
class ChunkedLocalAttentionSpec(AttentionSpec): class ChunkedLocalAttentionSpec(AttentionSpec):
attention_chunk_size: int attention_chunk_size: int
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
max_model_len = vllm_config.model_config.max_model_len
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
@property @property
def type_id(self) -> str: def type_id(self) -> str:
return ( return (
f"local_attention_{self.attention_chunk_size}_{self.block_size}_{self.page_size_bytes}" f"local_attention_{self.attention_chunk_size}_{self.block_size}_{self.page_size_bytes}"
) # noqa ) # noqa
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
max_model_len = vllm_config.model_config.max_model_len
max_num_batched_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens)
# During chunked prefill, we allocate KV cache for at most
# `self.attention_chunk_size` computed tokens plus the newly scheduled
# tokens. And we won't allocate KV cache for more than `max_model_len`
# tokens.
num_tokens = min(self.attention_chunk_size + max_num_batched_tokens,
max_model_len)
return cdiv(num_tokens, self.block_size) * self.page_size_bytes
@dataclass @dataclass
class SlidingWindowSpec(AttentionSpec): class SlidingWindowSpec(AttentionSpec):

View File

@ -862,6 +862,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
use_sliding_window = (isinstance(kv_cache_spec, SlidingWindowSpec) or use_sliding_window = (isinstance(kv_cache_spec, SlidingWindowSpec) or
(isinstance(kv_cache_spec, FullAttentionSpec) (isinstance(kv_cache_spec, FullAttentionSpec)
and kv_cache_spec.sliding_window is not None)) and kv_cache_spec.sliding_window is not None))
use_local_attention = (
isinstance(kv_cache_spec, ChunkedLocalAttentionSpec)
or (isinstance(kv_cache_spec, FullAttentionSpec)
and kv_cache_spec.attention_chunk_size is not None))
assert isinstance(kv_cache_spec, AttentionSpec) assert isinstance(kv_cache_spec, AttentionSpec)
use_cascade = attn_metadata_builder.use_cascade_attention( use_cascade = attn_metadata_builder.use_cascade_attention(
common_prefix_len=common_prefix_len, common_prefix_len=common_prefix_len,
@ -870,6 +874,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_kv_heads=kv_cache_spec.num_kv_heads, num_kv_heads=kv_cache_spec.num_kv_heads,
use_alibi=self.use_alibi, use_alibi=self.use_alibi,
use_sliding_window=use_sliding_window, use_sliding_window=use_sliding_window,
use_local_attention=use_local_attention,
num_sms=self.num_sms, num_sms=self.num_sms,
) )
return common_prefix_len if use_cascade else 0 return common_prefix_len if use_cascade else 0
@ -2672,6 +2677,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
dtype=self.kv_cache_dtype, dtype=self.kv_cache_dtype,
sliding_window=attn_module.sliding_window, sliding_window=attn_module.sliding_window,
use_mla=use_mla) use_mla=use_mla)
assert not use_local_attention, (
"attention module can not be with ",
"both local attention and sliding window")
elif use_local_attention: elif use_local_attention:
kv_cache_spec[layer_name] = (ChunkedLocalAttentionSpec( kv_cache_spec[layer_name] = (ChunkedLocalAttentionSpec(
block_size=block_size, block_size=block_size,