[Attention] Cache attention metadata builds across hybrid KV-cache groups (#29627)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: Stanislaw Wozniak <stw@zurich.ibm.com>
This commit is contained in:
Lucas Wilkinson 2025-12-16 17:10:16 -05:00 committed by GitHub
parent 254a7f8fd6
commit 9fec0e13d5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 105 additions and 13 deletions

View File

@ -172,7 +172,7 @@ def test_local_attention_virtual_batches(test_data: LocalAttentionTestData):
)
# Call the function
result = make_local_attention_virtual_batches(
result, _ = make_local_attention_virtual_batches(
attn_chunk_size, common_attn_metadata, block_size
)

View File

@ -4,7 +4,7 @@ import functools
import torch
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig
@ -51,11 +51,19 @@ def create_chunked_local_attention_backend(
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> AttentionMetadata:
common_attn_metadata = make_local_attention_virtual_batches(
):
cm, make_virtual_batches_block_table = make_local_attention_virtual_batches(
attention_chunk_size, common_attn_metadata, block_size
)
return super().build(common_prefix_len, common_attn_metadata, fast_build)
metadata = super().build(common_prefix_len, cm, fast_build)
metadata.make_virtual_batches_block_table = make_virtual_batches_block_table
return metadata
def update_block_table(
self, metadata, blk_table: torch.Tensor, slot_mapping: torch.Tensor
):
blk_table = metadata.make_virtual_batches_block_table(blk_table)
return super().update_block_table(metadata, blk_table, slot_mapping)
attn_backend = subclass_attention_backend(
name_prefix=prefix,

View File

@ -207,7 +207,7 @@ if TYPE_CHECKING:
VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL: bool = False
VLLM_ENABLE_CUDAGRAPH_GC: bool = False
VLLM_LOOPBACK_IP: str = ""
VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False
VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = True
VLLM_ENABLE_RESPONSES_API_STORE: bool = False
VLLM_USE_TRTLLM_ATTENTION: str | None = None
VLLM_NVFP4_GEMM_BACKEND: str | None = None
@ -1430,7 +1430,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# kv-cache memory usage and enable longer contexts)
# TODO(lucas): Remove this flag once latency regression is resolved.
"VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE": lambda: bool(
int(os.getenv("VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE", "0"))
int(os.getenv("VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE", "1"))
),
# Enables support for the "store" option in the OpenAI Responses API.
# When set to 1, vLLM's OpenAI server will retain the input and output

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with FlashAttention."""
import copy
from dataclasses import dataclass
from typing import ClassVar
@ -250,6 +251,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
if get_flash_attn_version() == 3
else AttentionCGSupport.UNIFORM_BATCH
)
supports_update_block_table: bool = True
def __init__(
self,
@ -493,6 +495,17 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
)
return attn_metadata
def update_block_table(
self,
metadata: FlashAttentionMetadata,
blk_table: torch.Tensor,
slot_mapping: torch.Tensor,
) -> FlashAttentionMetadata:
new_metadata = copy.copy(metadata)
new_metadata.block_table = blk_table
new_metadata.slot_mapping = slot_mapping
return new_metadata
def use_cascade_attention(self, *args, **kwargs) -> bool:
return use_cascade_attention(*args, **kwargs)

View File

@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import itertools
from dataclasses import dataclass
@ -134,6 +135,8 @@ class Mamba2AttentionMetadata:
class Mamba2AttentionMetadataBuilder(
BaseMambaAttentionMetadataBuilder[Mamba2AttentionMetadata]
):
supports_update_block_table: bool = True
def __init__(
self,
kv_cache_spec: AttentionSpec,
@ -346,3 +349,27 @@ class Mamba2AttentionMetadataBuilder(
num_computed_tokens_p=num_computed_tokens_p,
)
return attn_metadata
def update_block_table(
self,
metadata: Mamba2AttentionMetadata,
blk_table: torch.Tensor,
slot_mapping: torch.Tensor,
) -> Mamba2AttentionMetadata:
new_metadata = copy.copy(metadata)
prefix_caching = self.vllm_config.cache_config.enable_prefix_caching
state_indices_t = blk_table if prefix_caching else blk_table[:, 0]
num_reqs = blk_table.shape[0]
# For CUDA graphs, copy to persistent buffer
if (
metadata.num_prefills == 0
and num_reqs <= self.decode_cudagraph_max_bs
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
):
persistent_state_indices_t = self.state_indices_tensor[:num_reqs]
persistent_state_indices_t.copy_(state_indices_t, non_blocking=True)
state_indices_t = persistent_state_indices_t
new_metadata.state_indices_tensor = state_indices_t
return new_metadata

View File

@ -4,6 +4,7 @@ import abc
import enum
import functools
from abc import abstractmethod
from collections.abc import Callable
from dataclasses import dataclass, field, fields, make_dataclass
from typing import (
TYPE_CHECKING,
@ -317,6 +318,9 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
# If not, set this to None. Otherwise set it to the query
# length that will be pulled into the front of the batch.
reorder_batch_threshold: int | None = None
# Does this backend/builder support updating the block table in existing
# metadata
supports_update_block_table: bool = False
@abstractmethod
def __init__(
@ -387,6 +391,21 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
"""
raise NotImplementedError
def update_block_table(
self,
metadata: M,
blk_table: torch.Tensor,
slot_mapping: torch.Tensor,
) -> M:
"""
Update the block table for the attention metadata.
Faster when theres multiple kv-cache groups that create virtually the
same metadata but just with different block tables.
Only needs to be implemented if supports_update_block_table is True.
"""
raise NotImplementedError
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata
) -> M:
@ -603,7 +622,7 @@ def make_local_attention_virtual_batches(
attn_chunk_size: int,
common_attn_metadata: CommonAttentionMetadata,
block_size: int = 0,
) -> CommonAttentionMetadata:
) -> tuple[CommonAttentionMetadata, Callable[[torch.Tensor], torch.Tensor]]:
query_start_loc_np = common_attn_metadata.query_start_loc_cpu.numpy()
seq_lens_np = common_attn_metadata.seq_lens_cpu.numpy()
block_table = common_attn_metadata.block_table_tensor
@ -715,9 +734,12 @@ def make_local_attention_virtual_batches(
# tensor first, which recovers perf.
batch_indices_torch = torch.from_numpy(batch_indices)
block_indices_torch = torch.from_numpy(block_indices)
block_table_local = block_table[batch_indices_torch, block_indices_torch].view(
virtual_batches, -1
)
# Save as a lambda so we can return this for update_block_table
make_block_table = lambda block_table: block_table[
batch_indices_torch, block_indices_torch
].view(virtual_batches, -1)
block_table_local = make_block_table(block_table)
query_start_loc_cpu = torch.from_numpy(cu_seqlens_q_local)
seq_lens_cpu = torch.from_numpy(seqlens_k_local)
@ -736,7 +758,7 @@ def make_local_attention_virtual_batches(
causal=True,
_seq_lens_cpu=seq_lens_cpu,
_num_computed_tokens_cpu=torch.from_numpy(num_computed_tokens_local),
)
), make_block_table
def make_kv_sharing_fast_prefill_common_attn_metadata(

View File

@ -1630,6 +1630,15 @@ class GPUModelRunner(
logits_indices
)
# Cache attention metadata builds across hybrid KV-cache groups
# The only thing that changes between different hybrid KV-cache groups when the
# same metadata builder and KVCacheSpec is the same is the block table, so we
# can cache the attention metadata builds and just update the block table using
# `builder.update_block_table` if the builder supports it.
cached_attn_metadata: dict[
tuple[KVCacheSpec, type[AttentionMetadataBuilder]], AttentionMetadata
] = {}
def _build_attn_group_metadata(
kv_cache_gid: int,
attn_gid: int,
@ -1637,13 +1646,15 @@ class GPUModelRunner(
ubid: int | None = None,
) -> None:
attn_group = self.attn_groups[kv_cache_gid][attn_gid]
builder = attn_group.get_metadata_builder(ubid or 0)
cache_key = (kv_cache_groups[kv_cache_gid].kv_cache_spec, type(builder))
cascade_attn_prefix_len = (
cascade_attn_prefix_lens[kv_cache_gid][attn_gid]
if cascade_attn_prefix_lens
else 0
)
builder = attn_group.get_metadata_builder(ubid or 0)
extra_attn_metadata_args = {}
if use_spec_decode and isinstance(builder, GDNAttentionMetadataBuilder):
assert ubid is None, "UBatching not supported with GDN yet"
@ -1658,12 +1669,23 @@ class GPUModelRunner(
attn_metadata_i = builder.build_for_cudagraph_capture(
common_attn_metadata
)
elif (
cache_key in cached_attn_metadata
and builder.supports_update_block_table
):
attn_metadata_i = builder.update_block_table(
cached_attn_metadata[cache_key],
common_attn_metadata.block_table_tensor,
common_attn_metadata.slot_mapping,
)
else:
attn_metadata_i = builder.build(
common_prefix_len=cascade_attn_prefix_len,
common_attn_metadata=common_attn_metadata,
**extra_attn_metadata_args,
)
if builder.supports_update_block_table:
cached_attn_metadata[cache_key] = attn_metadata_i
if ubid is None:
assert isinstance(attn_metadata, dict)