mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:04:53 +08:00
[Core] Support Local Chunked Attention for Hybrid KV Cache (#19351)
Signed-off-by: Lucia Fang <fanglu@fb.com> Signed-off-by: Lu Fang <fanglu@meta.com> Signed-off-by: Lu Fang <fanglu@fb.com> Co-authored-by: Lu Fang <fanglu@meta.com>
This commit is contained in:
parent
466e878f2a
commit
9a9fda1423
@ -1,13 +1,17 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import random
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
|
||||
KVCacheBlock)
|
||||
from vllm.v1.core.single_type_kv_cache_manager import SlidingWindowManager
|
||||
from vllm.v1.kv_cache_interface import SlidingWindowSpec
|
||||
from vllm.v1.core.single_type_kv_cache_manager import (
|
||||
ChunkedLocalAttentionManager, SlidingWindowManager)
|
||||
from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
|
||||
SlidingWindowSpec)
|
||||
|
||||
|
||||
def get_sliding_window_manager(sliding_window_spec, block_pool):
|
||||
@ -17,6 +21,80 @@ def get_sliding_window_manager(sliding_window_spec, block_pool):
|
||||
kv_cache_group_id=0)
|
||||
|
||||
|
||||
def get_chunked_local_attention_manager(chunked_local_attention_spec,
|
||||
block_pool):
|
||||
return ChunkedLocalAttentionManager(chunked_local_attention_spec,
|
||||
block_pool,
|
||||
caching_hash_fn=lambda x: x,
|
||||
kv_cache_group_id=0)
|
||||
|
||||
|
||||
def test_chunked_local_attention_possible_cached_prefix():
|
||||
block_size = 2
|
||||
chunked_local_attention_spec = ChunkedLocalAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
attention_chunk_size=4,
|
||||
use_mla=False,
|
||||
)
|
||||
|
||||
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
|
||||
manager = get_chunked_local_attention_manager(chunked_local_attention_spec,
|
||||
block_pool)
|
||||
|
||||
def run_one_case(block_is_cached, tail_token, expect_length):
|
||||
block_hash_list = [
|
||||
BlockHash(i, ()) for i in range(len(block_is_cached))
|
||||
]
|
||||
|
||||
block_pool.cached_block_hash_to_block.clear()
|
||||
|
||||
# Mock the block pool with the cached blocks
|
||||
for i, (block_hash,
|
||||
is_cached) in enumerate(zip(block_hash_list, block_is_cached)):
|
||||
if is_cached:
|
||||
block_pool.cached_block_hash_to_block[BlockHashWithGroupId(
|
||||
block_hash, 0)] = {
|
||||
i: block_pool.blocks[i + 10],
|
||||
}
|
||||
|
||||
computed_blocks = manager.find_longest_cache_hit(
|
||||
block_hashes=block_hash_list,
|
||||
max_length=len(block_hash_list) * block_size + tail_token,
|
||||
kv_cache_group_ids=[0],
|
||||
block_pool=block_pool,
|
||||
kv_cache_spec=chunked_local_attention_spec,
|
||||
use_eagle=False)[0]
|
||||
assert len(computed_blocks) == expect_length
|
||||
|
||||
assert all(block == block_pool.null_block
|
||||
for block in computed_blocks[:(expect_length - 1) // 2])
|
||||
|
||||
run_one_case([True], 0, 1)
|
||||
run_one_case([True], 1, 1)
|
||||
run_one_case([True, False], 0, 2)
|
||||
run_one_case([True, False], 1, 2)
|
||||
run_one_case([True, True], 0, 2)
|
||||
run_one_case([True, True], 1, 2)
|
||||
run_one_case([True, True, False], 0, 2)
|
||||
run_one_case([True, True, False], 1, 2)
|
||||
run_one_case([True, True, True], 0, 3)
|
||||
run_one_case([True, True, True], 1, 3)
|
||||
run_one_case([True, True, True, False], 0, 4)
|
||||
run_one_case([True, True, True, False], 1, 4)
|
||||
run_one_case([random.choice([True, False])] * 8 + [True], 1, 9)
|
||||
run_one_case([random.choice([True, False])] * 8 + [False], 1, 8)
|
||||
run_one_case([random.choice([True, False])] * 8 + [True, True], 1, 10)
|
||||
run_one_case([random.choice([True, False])] * 8 + [True, False], 0, 10)
|
||||
run_one_case([random.choice([True, False])] * 8 + [True, False], 1, 10)
|
||||
run_one_case([random.choice([True, False])] * 8 + [False, True], 0, 10)
|
||||
run_one_case([random.choice([True, False])] * 8 + [False, True], 1, 10)
|
||||
run_one_case([random.choice([True, False])] * 8 + [False, False], 0, 10)
|
||||
run_one_case([random.choice([True, False])] * 8 + [False, False], 1, 10)
|
||||
|
||||
|
||||
def test_sliding_window_possible_cached_prefix():
|
||||
block_size = 2
|
||||
sliding_window_spec = SlidingWindowSpec(
|
||||
@ -84,6 +162,58 @@ def test_sliding_window_possible_cached_prefix():
|
||||
], 8)
|
||||
|
||||
|
||||
def test_chunked_local_attention_remove_skipped_blocks():
|
||||
attention_spec = ChunkedLocalAttentionSpec(
|
||||
block_size=2,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
attention_chunk_size=4,
|
||||
use_mla=False,
|
||||
)
|
||||
|
||||
block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True)
|
||||
|
||||
manager = get_chunked_local_attention_manager(attention_spec, block_pool)
|
||||
|
||||
null_block_id = block_pool.null_block.block_id
|
||||
|
||||
def id_to_block_table(ids) -> list[KVCacheBlock]:
|
||||
return [
|
||||
KVCacheBlock(id_)
|
||||
if id_ != null_block_id else block_pool.null_block for id_ in ids
|
||||
]
|
||||
|
||||
def assert_block_id(block_table: list[KVCacheBlock], ids: list[int]):
|
||||
for block, id_ in zip(block_table, ids):
|
||||
if id_ == null_block_id:
|
||||
assert block == block_pool.null_block
|
||||
else:
|
||||
assert block.block_id == id_
|
||||
|
||||
original_block_ids = [
|
||||
1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010
|
||||
]
|
||||
block_table = id_to_block_table(original_block_ids)
|
||||
manager.req_to_blocks["test"] = block_table
|
||||
|
||||
manager.remove_skipped_blocks("test", 0)
|
||||
assert_block_id(block_table, original_block_ids)
|
||||
|
||||
# For 4th token (0-indexed), token 0-3 is out of the local attention window.
|
||||
manager.remove_skipped_blocks("test", 4)
|
||||
assert_block_id(block_table, [null_block_id] * 2)
|
||||
|
||||
# For 6th token (0-indexed), token 4 - 6 are in local attention window,
|
||||
# token 0 - 3 are out, 2 blocks can be removed.
|
||||
manager.remove_skipped_blocks("test", 6)
|
||||
assert_block_id(block_table, [null_block_id] * 2 + original_block_ids[2:])
|
||||
# For 12th token (0-indexed),
|
||||
# token 0-11 are out, 6 block can be removed.
|
||||
manager.remove_skipped_blocks("test", 12)
|
||||
assert_block_id(block_table, [null_block_id] * 6)
|
||||
|
||||
|
||||
def test_sliding_window_remove_skipped_blocks():
|
||||
sliding_window_spec = SlidingWindowSpec(
|
||||
block_size=2,
|
||||
@ -172,3 +302,26 @@ def test_get_num_blocks_to_allocate():
|
||||
cached_blocks_1) == 20
|
||||
assert manager.get_num_blocks_to_allocate("2", 20 * block_size,
|
||||
cached_blocks_2) == 15
|
||||
|
||||
|
||||
def test_chunked_local_attention_get_num_blocks_to_allocate():
|
||||
block_size = 2
|
||||
attention_spec = ChunkedLocalAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
attention_chunk_size=4, # Placeholder value, not related to test result
|
||||
use_mla=False,
|
||||
)
|
||||
|
||||
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
|
||||
manager = get_chunked_local_attention_manager(attention_spec, block_pool)
|
||||
cached_blocks_1 = [KVCacheBlock(i + 1) for i in range(10)]
|
||||
cached_blocks_2 = [block_pool.null_block for _ in range(5)
|
||||
] + [KVCacheBlock(i + 1) for i in range(5)]
|
||||
|
||||
assert manager.get_num_blocks_to_allocate("1", 20 * block_size,
|
||||
cached_blocks_1) == 20
|
||||
assert manager.get_num_blocks_to_allocate("2", 20 * block_size,
|
||||
cached_blocks_2) == 15
|
||||
|
||||
@ -172,6 +172,7 @@ class Attention(nn.Module):
|
||||
kv_sharing_target_layer_name, **extra_impl_args)
|
||||
self.backend = backend_name_to_enum(attn_backend.get_name())
|
||||
self.dtype = dtype
|
||||
self.use_irope = extra_impl_args.get("use_irope", False)
|
||||
|
||||
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
|
||||
# torch.compile works by registering the attention as one giant
|
||||
|
||||
@ -4722,6 +4722,13 @@ class VllmConfig:
|
||||
if self.kv_events_config is not None:
|
||||
# Hybrid KV cache manager is not compatible with KV events.
|
||||
self.scheduler_config.disable_hybrid_kv_cache_manager = True
|
||||
if self.model_config is not None and \
|
||||
self.model_config.attention_chunk_size is not None and \
|
||||
self.speculative_config is not None and \
|
||||
self.speculative_config.use_eagle():
|
||||
# Hybrid KV cache manager is not yet supported with chunked
|
||||
# local attention + eagle.
|
||||
self.scheduler_config.disable_hybrid_kv_cache_manager = True
|
||||
|
||||
def update_sizes_for_sequence_parallelism(self,
|
||||
possible_sizes: list) -> list:
|
||||
|
||||
@ -538,6 +538,7 @@ def use_cascade_attention(
|
||||
num_kv_heads: int,
|
||||
use_alibi: bool,
|
||||
use_sliding_window: bool,
|
||||
use_local_attention: bool,
|
||||
num_sms: int,
|
||||
) -> bool:
|
||||
"""Decide whether to use cascade attention.
|
||||
@ -553,7 +554,7 @@ def use_cascade_attention(
|
||||
if common_prefix_len < 256:
|
||||
return False
|
||||
# Cascade attention is currently not supported with these variants.
|
||||
if use_alibi or use_sliding_window:
|
||||
if use_alibi or use_sliding_window or use_local_attention:
|
||||
return False
|
||||
# Too few queries. Probably not worth using cascade attention.
|
||||
# We use an arbitrary threshold of 8 queries. TODO: Tune this threshold.
|
||||
|
||||
@ -120,6 +120,7 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
|
||||
num_kv_heads: int,
|
||||
use_alibi: bool,
|
||||
use_sliding_window: bool,
|
||||
use_local_attention: bool,
|
||||
num_sms: int,
|
||||
) -> bool:
|
||||
return False
|
||||
|
||||
@ -11,7 +11,8 @@ from typing import Any, Callable, NamedTuple, Optional
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import GiB_bytes, cdiv, sha256_cbor_64bit
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
|
||||
FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec, KVCacheSpec,
|
||||
KVCacheTensor, SlidingWindowSpec)
|
||||
from vllm.v1.metrics.stats import PrefixCacheStats
|
||||
@ -976,7 +977,11 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
|
||||
isinstance(spec, FullAttentionSpec) for spec in kv_cache_spec.values())
|
||||
has_sliding_window = any(
|
||||
isinstance(spec, SlidingWindowSpec) for spec in kv_cache_spec.values())
|
||||
if has_full_attention and has_sliding_window:
|
||||
has_chunked_local_attention = any(
|
||||
isinstance(spec, ChunkedLocalAttentionSpec)
|
||||
for spec in kv_cache_spec.values())
|
||||
if has_full_attention and (has_sliding_window
|
||||
or has_chunked_local_attention):
|
||||
for layer_name, spec in kv_cache_spec.items():
|
||||
if isinstance(spec, SlidingWindowSpec):
|
||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||
@ -987,6 +992,15 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
|
||||
use_mla=spec.use_mla,
|
||||
sliding_window=spec.sliding_window,
|
||||
)
|
||||
elif isinstance(spec, ChunkedLocalAttentionSpec):
|
||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||
block_size=spec.block_size,
|
||||
num_kv_heads=spec.num_kv_heads,
|
||||
head_size=spec.head_size,
|
||||
dtype=spec.dtype,
|
||||
use_mla=spec.use_mla,
|
||||
attention_chunk_size=spec.attention_chunk_size,
|
||||
)
|
||||
|
||||
if is_hybrid(kv_cache_spec):
|
||||
raise ValueError("Hybrid KV cache manager is disabled but failed to "
|
||||
@ -1010,7 +1024,6 @@ def get_kv_cache_config(
|
||||
The generated KVCacheConfigs
|
||||
"""
|
||||
check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory)
|
||||
|
||||
if vllm_config.scheduler_config.disable_hybrid_kv_cache_manager:
|
||||
unify_hybrid_kv_cache_specs(kv_cache_spec)
|
||||
|
||||
|
||||
@ -394,6 +394,129 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
|
||||
return 0
|
||||
|
||||
|
||||
class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
|
||||
|
||||
def __init__(self, kv_cache_spec: ChunkedLocalAttentionSpec,
|
||||
block_pool: BlockPool, **kwargs) -> None:
|
||||
super().__init__(kv_cache_spec, block_pool, **kwargs)
|
||||
self.attention_chunk_size = kv_cache_spec.attention_chunk_size
|
||||
self._null_block = block_pool.null_block
|
||||
|
||||
@classmethod
|
||||
def find_longest_cache_hit(
|
||||
cls,
|
||||
block_hashes: list[BlockHash],
|
||||
max_length: int,
|
||||
kv_cache_group_ids: list[int],
|
||||
block_pool: BlockPool,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
"""
|
||||
For chunked local attention, we need to find the longest cache hit
|
||||
prefix of the blocks that is not longer than `max_length`. The prefix
|
||||
should be a common prefix hit for all the kv cache groups in
|
||||
`kv_cache_group_ids`. If no cache hit is found, return an empty list.
|
||||
note we mark as computed if the whole block is outside of the local
|
||||
window, and set the block as null. Examples:
|
||||
|
||||
1. Attention chunk size of 8, block size of 4, max length of 15
|
||||
for next token at 15th (zero-indexed), 8th - 14th tokens are in
|
||||
the window(needs lookup), 0th - 7th are not in the window,
|
||||
so they are already marked as computed. We check the complete
|
||||
block3 (8th - 11th tokens), Assume block 3 is hit, we will return
|
||||
[null, null, block 3], otherwise, we return [null, null]
|
||||
|
||||
2. Attention chunk size of 8, block size of 4, max length of 16
|
||||
for next token at 16th (zero-indexed), 0th - 15th tokens are not
|
||||
in the window, so they are already marked as computed.
|
||||
we return 4 blocks[null, null, null, null]
|
||||
|
||||
Args:
|
||||
block_hashes: The block hashes of the request.
|
||||
max_length: The maximum length of the cache hit prefix.
|
||||
kv_cache_group_ids: The ids of the kv cache groups.
|
||||
block_pool: The block pool.
|
||||
kv_cache_spec: The kv cache spec.
|
||||
use_eagle: Whether to use eagle.
|
||||
|
||||
Returns:
|
||||
A list of cached blocks
|
||||
"""
|
||||
assert isinstance(kv_cache_spec, ChunkedLocalAttentionSpec), (
|
||||
"ChunkedLocalAttentionManager can only be used for " +
|
||||
"chunked local attention groups")
|
||||
assert use_eagle is False, ("Hybrid KV cache is not supported for " +
|
||||
"eagle + chunked local attention.")
|
||||
max_num_blocks = max_length // kv_cache_spec.block_size
|
||||
if max_length > 0:
|
||||
local_attention_start_idx = (max_length //
|
||||
kv_cache_spec.attention_chunk_size *
|
||||
kv_cache_spec.attention_chunk_size)
|
||||
else:
|
||||
local_attention_start_idx = 0
|
||||
# we marked blocks out of window as computed
|
||||
# with null blocks, and blocks inside window based on cache lookup
|
||||
# result [null] [null] ... [null] [hit block 1 (1st block contain
|
||||
# last window)] [hit block 2] ... [hit block x]
|
||||
local_attention_start_block_idx = (local_attention_start_idx //
|
||||
kv_cache_spec.block_size)
|
||||
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
|
||||
[block_pool.null_block] * local_attention_start_block_idx
|
||||
for _ in range(len(kv_cache_group_ids)))
|
||||
for i in range(local_attention_start_block_idx, max_num_blocks):
|
||||
block_hash = block_hashes[i]
|
||||
if cached_block := block_pool.get_cached_block(
|
||||
block_hash, kv_cache_group_ids):
|
||||
for computed, cached in zip(computed_blocks, cached_block):
|
||||
computed.append(cached)
|
||||
else:
|
||||
break
|
||||
return computed_blocks
|
||||
|
||||
def remove_skipped_blocks(self, request_id: str,
|
||||
num_computed_tokens: int) -> None:
|
||||
# Remove the blocks that are no longer be in the chunked attention
|
||||
# window and skipped during the attention computation.
|
||||
|
||||
# [chunk 0][chunk 1]local_attention_start_idx ... current
|
||||
# we computed previous number of chunks to get the idx of
|
||||
# current chunk window starting offset,
|
||||
# e.g. for computed 1024 tokens, the 1024th token (0 indexed)
|
||||
# is in the second chunk, there are 1 prev chunk, the start idx
|
||||
# is 1024. for 1023, it will be 0.
|
||||
num_cached_block = self.num_cached_block.get(request_id, 0)
|
||||
local_attention_start_idx = (
|
||||
num_computed_tokens
|
||||
) // self.attention_chunk_size * self.attention_chunk_size
|
||||
first_useful_block_idx = local_attention_start_idx // self.block_size
|
||||
if num_cached_block > 0:
|
||||
# Make sure we don't delete the last cached block
|
||||
first_useful_block_idx = min(first_useful_block_idx,
|
||||
num_cached_block - 1)
|
||||
# if block size = 128, 0 -> block 0, 1024 (= 128 * 8) ->
|
||||
# block 8, 372 (= 128 * 2 + 116) -> block 2
|
||||
blocks = self.req_to_blocks[request_id]
|
||||
removed_blocks: list[KVCacheBlock] = []
|
||||
# we need to keep the last block to get the previous hash key
|
||||
for i in range(first_useful_block_idx - 1, -1, -1):
|
||||
if blocks[i] == self._null_block:
|
||||
# If the block is already a null block, the blocks before it
|
||||
# should also have been set to null blocks by the previous calls
|
||||
# to this function.
|
||||
break
|
||||
removed_blocks.append(blocks[i])
|
||||
blocks[i] = self._null_block
|
||||
self.block_pool.free_blocks(removed_blocks)
|
||||
|
||||
def get_num_common_prefix_blocks(self, request_id: str,
|
||||
num_running_requests: int) -> int:
|
||||
"""
|
||||
cascade attention is not supported by chunked local attention.
|
||||
"""
|
||||
return 0
|
||||
|
||||
|
||||
class MambaManager(SingleTypeKVCacheManager):
|
||||
|
||||
@classmethod
|
||||
@ -435,8 +558,8 @@ class MambaManager(SingleTypeKVCacheManager):
|
||||
|
||||
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
|
||||
FullAttentionSpec: FullAttentionManager,
|
||||
ChunkedLocalAttentionSpec: FullAttentionManager,
|
||||
SlidingWindowSpec: SlidingWindowManager,
|
||||
ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager,
|
||||
MambaSpec: MambaManager,
|
||||
}
|
||||
|
||||
|
||||
@ -87,6 +87,7 @@ class AttentionSpec(KVCacheSpec):
|
||||
@dataclass
|
||||
class FullAttentionSpec(AttentionSpec):
|
||||
sliding_window: Optional[int] = None
|
||||
attention_chunk_size: Optional[int] = None
|
||||
"""
|
||||
When hybrid allocator is disabled and the model contains both full
|
||||
attention layers and sliding window attention layers, sliding
|
||||
@ -105,6 +106,17 @@ class FullAttentionSpec(AttentionSpec):
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
|
||||
|
||||
@classmethod
|
||||
def merge_window_sizes(cls, window_sizes: set[int]) -> Optional[int]:
|
||||
if len(window_sizes) == 0:
|
||||
return None
|
||||
elif len(window_sizes) == 1:
|
||||
return window_sizes.pop()
|
||||
else:
|
||||
raise ValueError(
|
||||
"All attention layers in the same KV cache group must have the "
|
||||
"same window size.")
|
||||
|
||||
@classmethod
|
||||
def merge(cls, specs: list[Self]) -> Self:
|
||||
"""
|
||||
@ -114,14 +126,17 @@ class FullAttentionSpec(AttentionSpec):
|
||||
merged_spec = super().merge(specs)
|
||||
sliding_window = set(spec.sliding_window for spec in specs
|
||||
if spec.sliding_window is not None)
|
||||
if len(sliding_window) == 0:
|
||||
merged_spec.sliding_window = None
|
||||
elif len(sliding_window) == 1:
|
||||
merged_spec.sliding_window = sliding_window.pop()
|
||||
else:
|
||||
raise ValueError(
|
||||
"All sliding window layers in the same KV cache group "
|
||||
"must have the same window size.")
|
||||
attention_chunk_size = set(spec.attention_chunk_size for spec in specs
|
||||
if spec.attention_chunk_size is not None)
|
||||
|
||||
merged_spec.sliding_window = cls.merge_window_sizes(sliding_window)
|
||||
merged_spec.attention_chunk_size = (
|
||||
cls.merge_window_sizes(attention_chunk_size))
|
||||
assert (
|
||||
(merged_spec.sliding_window is not None) +
|
||||
(merged_spec.attention_chunk_size is not None) <= 1
|
||||
), ("Model with both sliding window layers and chunked local attention "
|
||||
"layers is not supported.")
|
||||
return merged_spec
|
||||
|
||||
|
||||
@ -129,16 +144,26 @@ class FullAttentionSpec(AttentionSpec):
|
||||
class ChunkedLocalAttentionSpec(AttentionSpec):
|
||||
attention_chunk_size: int
|
||||
|
||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
|
||||
|
||||
@property
|
||||
def type_id(self) -> str:
|
||||
return (
|
||||
f"local_attention_{self.attention_chunk_size}_{self.block_size}_{self.page_size_bytes}"
|
||||
) # noqa
|
||||
|
||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
max_num_batched_tokens = (
|
||||
vllm_config.scheduler_config.max_num_batched_tokens)
|
||||
|
||||
# During chunked prefill, we allocate KV cache for at most
|
||||
# `self.attention_chunk_size` computed tokens plus the newly scheduled
|
||||
# tokens. And we won't allocate KV cache for more than `max_model_len`
|
||||
# tokens.
|
||||
num_tokens = min(self.attention_chunk_size + max_num_batched_tokens,
|
||||
max_model_len)
|
||||
|
||||
return cdiv(num_tokens, self.block_size) * self.page_size_bytes
|
||||
|
||||
|
||||
@dataclass
|
||||
class SlidingWindowSpec(AttentionSpec):
|
||||
|
||||
@ -862,6 +862,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
use_sliding_window = (isinstance(kv_cache_spec, SlidingWindowSpec) or
|
||||
(isinstance(kv_cache_spec, FullAttentionSpec)
|
||||
and kv_cache_spec.sliding_window is not None))
|
||||
use_local_attention = (
|
||||
isinstance(kv_cache_spec, ChunkedLocalAttentionSpec)
|
||||
or (isinstance(kv_cache_spec, FullAttentionSpec)
|
||||
and kv_cache_spec.attention_chunk_size is not None))
|
||||
assert isinstance(kv_cache_spec, AttentionSpec)
|
||||
use_cascade = attn_metadata_builder.use_cascade_attention(
|
||||
common_prefix_len=common_prefix_len,
|
||||
@ -870,6 +874,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_kv_heads=kv_cache_spec.num_kv_heads,
|
||||
use_alibi=self.use_alibi,
|
||||
use_sliding_window=use_sliding_window,
|
||||
use_local_attention=use_local_attention,
|
||||
num_sms=self.num_sms,
|
||||
)
|
||||
return common_prefix_len if use_cascade else 0
|
||||
@ -2672,6 +2677,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
dtype=self.kv_cache_dtype,
|
||||
sliding_window=attn_module.sliding_window,
|
||||
use_mla=use_mla)
|
||||
assert not use_local_attention, (
|
||||
"attention module can not be with ",
|
||||
"both local attention and sliding window")
|
||||
elif use_local_attention:
|
||||
kv_cache_spec[layer_name] = (ChunkedLocalAttentionSpec(
|
||||
block_size=block_size,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user