mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-19 21:14:47 +08:00
[Attention] Cache attention metadata builds across hybrid KV-cache groups (#29627)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Stanislaw Wozniak <stw@zurich.ibm.com>
This commit is contained in:
parent
254a7f8fd6
commit
9fec0e13d5
@ -172,7 +172,7 @@ def test_local_attention_virtual_batches(test_data: LocalAttentionTestData):
|
||||
)
|
||||
|
||||
# Call the function
|
||||
result = make_local_attention_virtual_batches(
|
||||
result, _ = make_local_attention_virtual_batches(
|
||||
attn_chunk_size, common_attn_metadata, block_size
|
||||
)
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@ import functools
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
from vllm.config import CacheConfig
|
||||
@ -51,11 +51,19 @@ def create_chunked_local_attention_backend(
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> AttentionMetadata:
|
||||
common_attn_metadata = make_local_attention_virtual_batches(
|
||||
):
|
||||
cm, make_virtual_batches_block_table = make_local_attention_virtual_batches(
|
||||
attention_chunk_size, common_attn_metadata, block_size
|
||||
)
|
||||
return super().build(common_prefix_len, common_attn_metadata, fast_build)
|
||||
metadata = super().build(common_prefix_len, cm, fast_build)
|
||||
metadata.make_virtual_batches_block_table = make_virtual_batches_block_table
|
||||
return metadata
|
||||
|
||||
def update_block_table(
|
||||
self, metadata, blk_table: torch.Tensor, slot_mapping: torch.Tensor
|
||||
):
|
||||
blk_table = metadata.make_virtual_batches_block_table(blk_table)
|
||||
return super().update_block_table(metadata, blk_table, slot_mapping)
|
||||
|
||||
attn_backend = subclass_attention_backend(
|
||||
name_prefix=prefix,
|
||||
|
||||
@ -207,7 +207,7 @@ if TYPE_CHECKING:
|
||||
VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL: bool = False
|
||||
VLLM_ENABLE_CUDAGRAPH_GC: bool = False
|
||||
VLLM_LOOPBACK_IP: str = ""
|
||||
VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False
|
||||
VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = True
|
||||
VLLM_ENABLE_RESPONSES_API_STORE: bool = False
|
||||
VLLM_USE_TRTLLM_ATTENTION: str | None = None
|
||||
VLLM_NVFP4_GEMM_BACKEND: str | None = None
|
||||
@ -1430,7 +1430,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# kv-cache memory usage and enable longer contexts)
|
||||
# TODO(lucas): Remove this flag once latency regression is resolved.
|
||||
"VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE": lambda: bool(
|
||||
int(os.getenv("VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE", "0"))
|
||||
int(os.getenv("VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE", "1"))
|
||||
),
|
||||
# Enables support for the "store" option in the OpenAI Responses API.
|
||||
# When set to 1, vLLM's OpenAI server will retain the input and output
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer with FlashAttention."""
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
@ -250,6 +251,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
if get_flash_attn_version() == 3
|
||||
else AttentionCGSupport.UNIFORM_BATCH
|
||||
)
|
||||
supports_update_block_table: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -493,6 +495,17 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
def update_block_table(
|
||||
self,
|
||||
metadata: FlashAttentionMetadata,
|
||||
blk_table: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
) -> FlashAttentionMetadata:
|
||||
new_metadata = copy.copy(metadata)
|
||||
new_metadata.block_table = blk_table
|
||||
new_metadata.slot_mapping = slot_mapping
|
||||
return new_metadata
|
||||
|
||||
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
||||
return use_cascade_attention(*args, **kwargs)
|
||||
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
import itertools
|
||||
from dataclasses import dataclass
|
||||
|
||||
@ -134,6 +135,8 @@ class Mamba2AttentionMetadata:
|
||||
class Mamba2AttentionMetadataBuilder(
|
||||
BaseMambaAttentionMetadataBuilder[Mamba2AttentionMetadata]
|
||||
):
|
||||
supports_update_block_table: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
@ -346,3 +349,27 @@ class Mamba2AttentionMetadataBuilder(
|
||||
num_computed_tokens_p=num_computed_tokens_p,
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
def update_block_table(
|
||||
self,
|
||||
metadata: Mamba2AttentionMetadata,
|
||||
blk_table: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
) -> Mamba2AttentionMetadata:
|
||||
new_metadata = copy.copy(metadata)
|
||||
prefix_caching = self.vllm_config.cache_config.enable_prefix_caching
|
||||
state_indices_t = blk_table if prefix_caching else blk_table[:, 0]
|
||||
num_reqs = blk_table.shape[0]
|
||||
|
||||
# For CUDA graphs, copy to persistent buffer
|
||||
if (
|
||||
metadata.num_prefills == 0
|
||||
and num_reqs <= self.decode_cudagraph_max_bs
|
||||
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
):
|
||||
persistent_state_indices_t = self.state_indices_tensor[:num_reqs]
|
||||
persistent_state_indices_t.copy_(state_indices_t, non_blocking=True)
|
||||
state_indices_t = persistent_state_indices_t
|
||||
|
||||
new_metadata.state_indices_tensor = state_indices_t
|
||||
return new_metadata
|
||||
|
||||
@ -4,6 +4,7 @@ import abc
|
||||
import enum
|
||||
import functools
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field, fields, make_dataclass
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@ -317,6 +318,9 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
|
||||
# If not, set this to None. Otherwise set it to the query
|
||||
# length that will be pulled into the front of the batch.
|
||||
reorder_batch_threshold: int | None = None
|
||||
# Does this backend/builder support updating the block table in existing
|
||||
# metadata
|
||||
supports_update_block_table: bool = False
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
@ -387,6 +391,21 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def update_block_table(
|
||||
self,
|
||||
metadata: M,
|
||||
blk_table: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
) -> M:
|
||||
"""
|
||||
Update the block table for the attention metadata.
|
||||
Faster when theres multiple kv-cache groups that create virtually the
|
||||
same metadata but just with different block tables.
|
||||
|
||||
Only needs to be implemented if supports_update_block_table is True.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: CommonAttentionMetadata
|
||||
) -> M:
|
||||
@ -603,7 +622,7 @@ def make_local_attention_virtual_batches(
|
||||
attn_chunk_size: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
block_size: int = 0,
|
||||
) -> CommonAttentionMetadata:
|
||||
) -> tuple[CommonAttentionMetadata, Callable[[torch.Tensor], torch.Tensor]]:
|
||||
query_start_loc_np = common_attn_metadata.query_start_loc_cpu.numpy()
|
||||
seq_lens_np = common_attn_metadata.seq_lens_cpu.numpy()
|
||||
block_table = common_attn_metadata.block_table_tensor
|
||||
@ -715,9 +734,12 @@ def make_local_attention_virtual_batches(
|
||||
# tensor first, which recovers perf.
|
||||
batch_indices_torch = torch.from_numpy(batch_indices)
|
||||
block_indices_torch = torch.from_numpy(block_indices)
|
||||
block_table_local = block_table[batch_indices_torch, block_indices_torch].view(
|
||||
virtual_batches, -1
|
||||
)
|
||||
|
||||
# Save as a lambda so we can return this for update_block_table
|
||||
make_block_table = lambda block_table: block_table[
|
||||
batch_indices_torch, block_indices_torch
|
||||
].view(virtual_batches, -1)
|
||||
block_table_local = make_block_table(block_table)
|
||||
|
||||
query_start_loc_cpu = torch.from_numpy(cu_seqlens_q_local)
|
||||
seq_lens_cpu = torch.from_numpy(seqlens_k_local)
|
||||
@ -736,7 +758,7 @@ def make_local_attention_virtual_batches(
|
||||
causal=True,
|
||||
_seq_lens_cpu=seq_lens_cpu,
|
||||
_num_computed_tokens_cpu=torch.from_numpy(num_computed_tokens_local),
|
||||
)
|
||||
), make_block_table
|
||||
|
||||
|
||||
def make_kv_sharing_fast_prefill_common_attn_metadata(
|
||||
|
||||
@ -1630,6 +1630,15 @@ class GPUModelRunner(
|
||||
logits_indices
|
||||
)
|
||||
|
||||
# Cache attention metadata builds across hybrid KV-cache groups
|
||||
# The only thing that changes between different hybrid KV-cache groups when the
|
||||
# same metadata builder and KVCacheSpec is the same is the block table, so we
|
||||
# can cache the attention metadata builds and just update the block table using
|
||||
# `builder.update_block_table` if the builder supports it.
|
||||
cached_attn_metadata: dict[
|
||||
tuple[KVCacheSpec, type[AttentionMetadataBuilder]], AttentionMetadata
|
||||
] = {}
|
||||
|
||||
def _build_attn_group_metadata(
|
||||
kv_cache_gid: int,
|
||||
attn_gid: int,
|
||||
@ -1637,13 +1646,15 @@ class GPUModelRunner(
|
||||
ubid: int | None = None,
|
||||
) -> None:
|
||||
attn_group = self.attn_groups[kv_cache_gid][attn_gid]
|
||||
builder = attn_group.get_metadata_builder(ubid or 0)
|
||||
cache_key = (kv_cache_groups[kv_cache_gid].kv_cache_spec, type(builder))
|
||||
|
||||
cascade_attn_prefix_len = (
|
||||
cascade_attn_prefix_lens[kv_cache_gid][attn_gid]
|
||||
if cascade_attn_prefix_lens
|
||||
else 0
|
||||
)
|
||||
|
||||
builder = attn_group.get_metadata_builder(ubid or 0)
|
||||
extra_attn_metadata_args = {}
|
||||
if use_spec_decode and isinstance(builder, GDNAttentionMetadataBuilder):
|
||||
assert ubid is None, "UBatching not supported with GDN yet"
|
||||
@ -1658,12 +1669,23 @@ class GPUModelRunner(
|
||||
attn_metadata_i = builder.build_for_cudagraph_capture(
|
||||
common_attn_metadata
|
||||
)
|
||||
elif (
|
||||
cache_key in cached_attn_metadata
|
||||
and builder.supports_update_block_table
|
||||
):
|
||||
attn_metadata_i = builder.update_block_table(
|
||||
cached_attn_metadata[cache_key],
|
||||
common_attn_metadata.block_table_tensor,
|
||||
common_attn_metadata.slot_mapping,
|
||||
)
|
||||
else:
|
||||
attn_metadata_i = builder.build(
|
||||
common_prefix_len=cascade_attn_prefix_len,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
**extra_attn_metadata_args,
|
||||
)
|
||||
if builder.supports_update_block_table:
|
||||
cached_attn_metadata[cache_key] = attn_metadata_i
|
||||
|
||||
if ubid is None:
|
||||
assert isinstance(attn_metadata, dict)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user