[Attention] Support multiple attention metadata builders per kv_cache_spec + proper local attention no hybrid kv cache fix (#21588)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson 2025-08-06 21:40:52 -04:00 committed by GitHub
parent f825c6bd22
commit 1dc8a70b6d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 369 additions and 213 deletions

View File

@ -313,7 +313,8 @@ def test_propose(num_speculative_tokens, backend):
# Mock runner for attention metadata building
proposer.runner = mock.MagicMock()
proposer.runner.attn_metadata_builders = [attn_metadata_builder]
proposer.runner.attn_groups.append([mock.MagicMock()])
proposer.runner.attn_groups[0][0].metadata_builder = attn_metadata_builder
result = proposer.propose(target_token_ids=target_token_ids,
target_positions=target_positions,

View File

@ -417,12 +417,12 @@ def test_kv_cache_stride_order(monkeypatch, model_runner):
return rnd_stride
# Patch the attention backend class and re-trigger the KV cache creation.
for attn_backend in model_runner.attn_backends:
for attn_group in model_runner._attn_group_iterator():
attn_backend = attn_group.backend
monkeypatch.setattr(attn_backend, "get_kv_cache_stride_order",
rnd_stride_order)
model_runner.attn_backends = []
model_runner.attn_metadata_builders = []
model_runner.attn_groups = []
model_runner.initialize_kv_cache(model_runner.kv_cache_config)
# Shape is unchanged, but layout may differ

View File

@ -106,6 +106,10 @@ class AttentionBackend(ABC):
block_size: int, num_seqs: int, num_queries: int) -> None:
raise NotImplementedError
@classmethod
def full_cls_name(cls) -> tuple[str, str]:
return (cls.__module__, cls.__qualname__)
@dataclass
class AttentionMetadata:

View File

@ -9,6 +9,7 @@ import torch.nn.functional as F
import vllm.envs as envs
from vllm.attention import AttentionType
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
from vllm.config import CacheConfig, get_current_vllm_config
@ -80,6 +81,7 @@ class Attention(nn.Module):
prefix: str = "",
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
attn_backend: Optional[type[AttentionBackend]] = None,
**extra_impl_args,
) -> None:
"""
@ -137,15 +139,6 @@ class Attention(nn.Module):
self.num_kv_heads = num_kv_heads
self.sliding_window = sliding_window
# For v1 we have backend agnostic iRoPE (local chunked attention)
# we have to store the flag on the layer so gpu model runner can
# set KVSpec appropriately (and pop it so it doesnt get passed to
# the backends)
if envs.VLLM_USE_V1:
self.use_irope = extra_impl_args.pop("use_irope", False)
else:
self.use_irope = extra_impl_args.get("use_irope", False)
quant_method = quant_config.get_quant_method(
self, prefix=prefix) if quant_config else None
if quant_method is not None and not isinstance(
@ -166,18 +159,22 @@ class Attention(nn.Module):
# During model initialization, the default dtype is set as the model
# weight and activation dtype.
dtype = torch.get_default_dtype()
attn_backend = get_attn_backend(head_size,
dtype,
kv_cache_dtype,
block_size,
is_attention_free,
use_mla=use_mla)
impl_cls = attn_backend.get_impl_cls()
if attn_backend is None:
self.attn_backend = get_attn_backend(head_size,
dtype,
kv_cache_dtype,
block_size,
is_attention_free,
use_mla=use_mla)
else:
self.attn_backend = attn_backend
impl_cls = self.attn_backend.get_impl_cls()
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
logits_soft_cap, attn_type,
kv_sharing_target_layer_name, **extra_impl_args)
self.backend = backend_name_to_enum(attn_backend.get_name())
self.backend = backend_name_to_enum(self.attn_backend.get_name())
self.dtype = dtype
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
@ -187,7 +184,7 @@ class Attention(nn.Module):
self.use_direct_call = not current_platform.is_cuda_alike(
) and not current_platform.is_cpu()
self.use_output = attn_backend.accept_output_buffer
self.use_output = self.attn_backend.accept_output_buffer
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
@ -309,6 +306,9 @@ class Attention(nn.Module):
if hasattr(self.impl, "process_weights_after_loading"):
self.impl.process_weights_after_loading(act_dtype)
def get_attn_backend(self) -> type[AttentionBackend]:
return self.attn_backend
class MultiHeadAttention(nn.Module):
"""Multi-headed attention without any cache, used for ViT."""

View File

@ -0,0 +1,88 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from typing import List, Optional
import torch
from vllm import envs
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig, QuantizationConfig
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata, make_local_attention_virtual_batches,
subclass_attention_backend, subclass_attention_metadata_builder)
from ..layer import Attention
@functools.lru_cache
def create_chunked_local_attention_backend(
underlying_attn_backend: AttentionBackend,
attention_chunk_size: int,
block_size: int,
) -> type[AttentionBackend]:
prefix = f"ChunkedLocalAttention_{attention_chunk_size}_{block_size}_"
def build_preprocess_fn(cm: CommonAttentionMetadata):
return make_local_attention_virtual_batches(attention_chunk_size, cm,
block_size)
# Dynamically create a new attention backend that wraps the
# underlying attention backend but applies
# `make_local_attention_virtual_batches` before calling `build(...)`
builder_cls = subclass_attention_metadata_builder(
name_prefix=prefix,
builder_cls=underlying_attn_backend.get_builder_cls(),
build_preprocess_fn=build_preprocess_fn)
attn_backend = subclass_attention_backend(
name_prefix=prefix,
attention_backend_cls=underlying_attn_backend,
builder_cls=builder_cls)
return attn_backend
class ChunkedLocalAttention(Attention):
def __init__(self,
num_heads: int,
head_size: int,
scale: float,
attention_chunk_size: int,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
kv_sharing_target_layer_name: Optional[str] = None,
prefix: str = ""):
dtype = torch.get_default_dtype()
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
else:
kv_cache_dtype = "auto"
block_size = 16
if envs.VLLM_USE_V1:
underlying_attn_backend = get_attn_backend(head_size, dtype,
kv_cache_dtype,
block_size)
attn_backend = create_chunked_local_attention_backend(
underlying_attn_backend, attention_chunk_size, block_size)
else:
# in v0 the local attention is handled inside the backends
attn_backend = None
super().__init__(
num_heads=num_heads,
head_size=head_size,
scale=scale,
num_kv_heads=num_kv_heads,
alibi_slopes=alibi_slopes,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
attn_backend=attn_backend)

View File

@ -142,7 +142,7 @@ def get_attn_backend(
dtype: torch.dtype,
kv_cache_dtype: Optional[str],
block_size: int,
is_attention_free: bool,
is_attention_free: bool = False,
use_mla: bool = False,
) -> type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it."""

View File

@ -25,6 +25,7 @@ from torch import nn
from transformers import Llama4TextConfig
from vllm.attention import Attention
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
@ -194,17 +195,18 @@ class Llama4Attention(nn.Module):
is_neox_style=is_neox_style,
) if not self.nope else None
self.attn = Attention(
attn_cls = Attention if self.nope else ChunkedLocalAttention
self.attn = attn_cls(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
per_layer_sliding_window=None,
use_irope=not self.nope,
prefix=f"{prefix}.attn",
)
**({
"attention_chunk_size": config.attention_chunk_size
} if not self.nope else {}))
def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
floor = torch.floor((positions + 1.0) / self.floor_scale)

View File

@ -5,12 +5,12 @@ import enum
import functools
from abc import abstractmethod
from dataclasses import dataclass, make_dataclass
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Generic, Optional,
TypeVar)
import numpy as np
import torch
from vllm.attention.layer import Attention
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.utils import cdiv
@ -20,6 +20,8 @@ if TYPE_CHECKING:
from vllm.v1.worker.gpu_input_batch import InputBatch
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.layer import Attention
from vllm.distributed.kv_transfer.kv_connector.utils import (
get_kv_connector_cache_layout)
from vllm.logger import init_logger
@ -532,6 +534,48 @@ def make_local_attention_virtual_batches(
)
def subclass_attention_metadata_builder(
name_prefix: str,
builder_cls: type[AttentionMetadataBuilder[M]],
build_preprocess_fn: Callable[[CommonAttentionMetadata],
CommonAttentionMetadata],
) -> type[AttentionMetadataBuilder[M]]:
"""
Return a new subclass of `builder_cls` whose .build(...) method
first calls build_preprocess_fn(common_attn_metadata) on the metadata.
"""
name: str = name_prefix + builder_cls.__name__ # type: ignore
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False):
return builder_cls.build(self, common_prefix_len,
build_preprocess_fn(common_attn_metadata),
fast_build)
Wrapped = type(
name,
(builder_cls, ), # inherit from the original
{
"build": build,
})
return Wrapped # type: ignore
def subclass_attention_backend(
name_prefix: str, attention_backend_cls: type[AttentionBackend],
builder_cls: type[AttentionMetadataBuilder[M]]
) -> type[AttentionBackend]:
"""
Return a new subclass where `get_builder_cls` returns `builder_cls`.
"""
name: str = name_prefix + attention_backend_cls.__name__ # type: ignore
return type(name, (attention_backend_cls, ),
{"get_builder_cls": lambda: builder_cls})
def split_decodes_and_prefills(
common_attn_metadata: CommonAttentionMetadata,
decode_threshold: int = 1,

View File

@ -158,9 +158,9 @@ class EagleProposer:
assert self.runner is not None
# FIXME: need to consider multiple kv_cache_groups
attn_metadata = self.runner.attn_metadata_builders[
0].build_for_drafting(common_attn_metadata=common_attn_metadata,
draft_index=0)
attn_metadata = self.runner.attn_groups[0][0].metadata_builder\
.build_for_drafting(common_attn_metadata=common_attn_metadata,
draft_index=0)
# At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata.
@ -349,7 +349,8 @@ class EagleProposer:
hidden_states: torch.Tensor,
common_attn_metadata: CommonAttentionMetadata,
) -> list[torch.Tensor]:
tree_attn_metadata_builder = self.runner.attn_metadata_builders[0]
tree_attn_metadata_builder = \
self.runner.attn_groups[0][0].metadata_builder
assert isinstance(tree_attn_metadata_builder,
TreeAttentionMetadataBuilder)

View File

@ -53,11 +53,11 @@ class CPUModelRunner(GPUModelRunner):
raise ValueError("Multiple KVCacheGroups is not"
"currently supported with CPU model runner.")
assert type(
self.attn_metadata_builders[0]) is TorchSDPAMetadataBuilderV1
assert type(self.attn_groups[0]
[0].metadata_builder) is TorchSDPAMetadataBuilderV1
self.attn_metadata_builders[0].reorder_batch(self.input_batch,
scheduler_output)
self.attn_groups[0][0].metadata_builder.reorder_batch(
self.input_batch, scheduler_output)
def _postprocess_tenosrs(self) -> None:
# Note: replace device tensors with cpu tensors

View File

@ -3,7 +3,10 @@
import dataclasses
import gc
import itertools
import time
from collections import defaultdict
from collections.abc import Iterator
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Optional, Union, cast
@ -14,9 +17,9 @@ import torch.nn as nn
from tqdm import tqdm
import vllm.envs as envs
from vllm.attention import AttentionType, get_attn_backend
from vllm.attention import Attention, AttentionType
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.layer import Attention
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
from vllm.compilation.counter import compilation_counter
from vllm.config import (CompilationLevel, VllmConfig,
get_layers_from_vllm_config, update_config)
@ -50,7 +53,6 @@ from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend
from vllm.v1.attention.backends.utils import (
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
make_kv_sharing_fast_prefill_attention_metadata,
make_local_attention_virtual_batches,
reorder_batch_to_split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import (AttentionSpec,
ChunkedLocalAttentionSpec,
@ -73,8 +75,8 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import (
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from ..sample.logits_processor import LogitsProcessorManager
from .utils import (MultiModalBudget, bind_kv_cache, gather_mm_placeholders,
initialize_kv_cache_for_kv_sharing,
from .utils import (AttentionGroup, MultiModalBudget, bind_kv_cache,
gather_mm_placeholders, initialize_kv_cache_for_kv_sharing,
sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
if TYPE_CHECKING:
@ -162,8 +164,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# self.model: nn.Module # Set after load_model
# Initialize in initialize_kv_cache
self.kv_caches: list[torch.Tensor] = []
self.attn_metadata_builders: list[AttentionMetadataBuilder] = []
self.attn_backends: list[type[AttentionBackend]] = []
# indexes: [kv_cache_group_id][attn_group]
self.attn_groups: list[list[AttentionGroup]] = []
# self.kv_cache_config: KVCacheConfig
# req_id -> (input_id -> encoder_output)
@ -830,81 +832,51 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
spec_decode_common_attn_metadata is None:
spec_decode_common_attn_metadata = common_attn_metadata
if isinstance(kv_cache_group_spec.kv_cache_spec,
ChunkedLocalAttentionSpec):
common_attn_metadata = make_local_attention_virtual_batches(
kv_cache_group_spec.kv_cache_spec.attention_chunk_size,
common_attn_metadata, self.cache_config.block_size)
for attn_group in self.attn_groups[kv_cache_group_id]:
# Prepare for cascade attention if enabled & beneficial.
common_prefix_len = 0
builder = attn_group.metadata_builder
if self.cascade_attn_enabled:
common_prefix_len = self._compute_cascade_attn_prefix_len(
num_scheduled_tokens,
scheduler_output.
num_common_prefix_blocks[kv_cache_group_id],
kv_cache_group_spec.kv_cache_spec,
builder,
)
# Prepare for cascade attention if enabled & beneficial.
common_prefix_len = 0
builder = self.attn_metadata_builders[kv_cache_group_id]
if self.cascade_attn_enabled:
common_prefix_len = self._compute_cascade_attn_prefix_len(
num_scheduled_tokens,
scheduler_output.
num_common_prefix_blocks[kv_cache_group_id],
kv_cache_group_spec.kv_cache_spec,
builder,
)
attn_metadata_i = (builder.build(
common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata,
))
fast_prefill_metadata = attn_metadata_i
if (self.cache_config.kv_sharing_fast_prefill
and self.kv_sharing_fast_prefill_eligible_layers):
# Dynamically create a a dataclass type that inherits
# from attention metadata type but includes additional
# fields logits_indices_padded and num_logits_indices
# which are required for prefill truncation
fast_prefill_metadata_type = (
make_kv_sharing_fast_prefill_attention_metadata(
metadata_cls=type(attn_metadata_i), ))
fast_prefill_metadata = fast_prefill_metadata_type(
**dataclasses.asdict(attn_metadata_i),
logits_indices_padded=logits_indices_padded,
num_logits_indices=logits_indices.size(0),
)
for layer_name in kv_cache_group_spec.layer_names:
if (self.cache_config.kv_sharing_fast_prefill and layer_name
in self.kv_sharing_fast_prefill_eligible_layers):
attn_metadata[layer_name] = fast_prefill_metadata
continue
attn_metadata[layer_name] = attn_metadata_i
# Hack for now to fix chunked local attention + no hybrid kv cache
# manager we can remove this once
# https://github.com/vllm-project/vllm/pull/21588
# is merged (i.e. properly handle different attention backends for
# the same kv_cache_spec)
if self.attention_chunk_size is not None \
and self.scheduler_config.disable_hybrid_kv_cache_manager:
if not hasattr(self, "local_attention_layers"):
self.local_attention_layers = []
attn_layers = get_layers_from_vllm_config(
self.vllm_config, Attention)
for layer_name, attn_module in attn_layers.items():
if attn_module.use_irope:
self.local_attention_layers.append(layer_name)
local_attn_metadata_i = (builder.build(
common_prefix_len=0,
common_attn_metadata=make_local_attention_virtual_batches(
self.attention_chunk_size, common_attn_metadata,
self.cache_config.block_size),
attn_metadata_i = (builder.build(
common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata,
))
for layer_name in self.local_attention_layers:
attn_metadata[layer_name] = local_attn_metadata_i
fast_prefill_metadata = attn_metadata_i
if (self.cache_config.kv_sharing_fast_prefill
and self.kv_sharing_fast_prefill_eligible_layers):
# Dynamically create a a dataclass type that inherits
# from attention metadata type but includes additional
# fields logits_indices_padded and num_logits_indices
# which are required for prefill truncation
fast_prefill_metadata_type = (
make_kv_sharing_fast_prefill_attention_metadata(
metadata_cls=type(attn_metadata_i), ))
fast_prefill_metadata = fast_prefill_metadata_type(
**dataclasses.asdict(attn_metadata_i),
logits_indices_padded=logits_indices_padded,
num_logits_indices=logits_indices.size(0),
)
for layer_name in attn_group.layer_names:
if (self.cache_config.kv_sharing_fast_prefill
and layer_name
in self.kv_sharing_fast_prefill_eligible_layers):
attn_metadata[layer_name] = fast_prefill_metadata
continue
attn_metadata[layer_name] = attn_metadata_i
attention_cuda_graphs = all(
b.can_run_in_cudagraph(common_attn_metadata)
for b in self.attn_metadata_builders)
g.metadata_builder.can_run_in_cudagraph(common_attn_metadata)
for g in self._attn_group_iterator())
# Hot-Swap lora model
if self.lora_config:
@ -2229,11 +2201,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
block_table[kv_cache_group_id].slot_mapping[:num_tokens],
causal=True)
attn_metadata_i = self.attn_metadata_builders[
kv_cache_group_id].build_for_cudagraph_capture(
common_attn_metadata)
for layer_name in kv_cache_group_spec.layer_names:
attn_metadata[layer_name] = attn_metadata_i
for attn_group in self.attn_groups[kv_cache_group_id]:
attn_metadata_i = attn_group.metadata_builder\
.build_for_cudagraph_capture(common_attn_metadata)
for layer_name in kv_cache_group_spec.layer_names:
attn_metadata[layer_name] = attn_metadata_i
with self.maybe_dummy_run_with_lora(self.lora_config,
num_scheduled_tokens):
@ -2565,88 +2537,100 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
elapsed_time, cuda_graph_size / (1 << 30))
def _initialize_single_attn_backend(
self, kv_cache_spec: KVCacheSpec, layer_names: list[str]
) -> tuple[AttentionBackend, AttentionMetadataBuilder]:
if isinstance(kv_cache_spec, AttentionSpec):
attn_backend_i = get_attn_backend(
kv_cache_spec.head_size,
self.dtype,
kv_cache_spec.dtype,
kv_cache_spec.block_size,
self.model_config.is_attention_free,
use_mla=kv_cache_spec.use_mla,
)
if attn_backend_i is None:
error_msg = (f"Error with get_attn_backend: "
f"{kv_cache_spec.head_size=}, "
f"{self.dtype=}, {kv_cache_spec.dtype=}, "
f"{kv_cache_spec.block_size=}, "
f"{self.model_config.is_attention_free=}, "
f"{kv_cache_spec.use_mla=}")
logger.error(error_msg)
raise NotImplementedError(
"Non-Attention backend is not supported by V1 "
"GPUModelRunner.")
elif isinstance(kv_cache_spec, MambaSpec):
attn_backend_i = get_mamba_attn_backend(kv_cache_spec.mamba_type)
else:
raise ValueError(
f"Unknown KV cache spec type: {type(kv_cache_spec)}")
attn_metadata_builder_i = attn_backend_i.get_builder_cls()(
kv_cache_spec,
layer_names,
self.vllm_config,
self.device,
)
if self.full_cuda_graph:
if attn_metadata_builder_i.attn_cudagraph_support == \
AttentionCGSupport.NEVER:
raise ValueError(f"Full CUDAGraph not supported for "
f"{attn_backend_i.__name__}. Turn off "
f"CompilationConfig.full_cuda_graph or use a "
f" different attention backend.")
if attn_metadata_builder_i.attn_cudagraph_support == \
AttentionCGSupport.PURE_DECODE_ONLY:
# Limit the max cudagraph size to the max number of
# sequences for pure decode only cudagraph backend,
# whose max_query_len is 1.
self.cudagraph_batch_sizes = [
size for size in self.cudagraph_batch_sizes
if size <= self.scheduler_config.max_num_seqs
]
return attn_backend_i, attn_metadata_builder_i
def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
"""
Initialize the attention backends and attention metadata builders.
"""
assert len(self.attn_backends) == 0 and len(
self.attn_metadata_builders
) == 0, "Attention backends are already initialized"
for i, kv_cache_group_spec in enumerate(
kv_cache_config.kv_cache_groups):
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
assert len(self.attn_groups) == 0, \
"Attention backends are already initialized"
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
attn_backend_i, attn_metadata_builder_i = (
self._initialize_single_attn_backend(
kv_cache_spec, kv_cache_group_spec.layer_names))
self.attn_backends.append(attn_backend_i)
self.attn_metadata_builders.append(attn_metadata_builder_i)
def get_attn_backends_for_layers(
layer_names: list[str]
) -> dict[type[AttentionBackend], list[str]]:
attn_backends = {}
attn_backend_layers = defaultdict(list)
# Dedupe based on full class name; this is a bit safer than using
# using the class itself as the key because when we create dynamic
# attention backend subclasses (e.g. ChunkedLocalAttention) unless
# they are cached correctly, there will be different objects per
# layer.
for layer_name in layer_names:
attn_backend = attn_layers[layer_name].get_attn_backend()
key = attn_backend.full_cls_name()
attn_backends[key] = attn_backend
attn_backend_layers[key].append(layer_name)
return {
attn_backends[k]: v
for k, v in attn_backend_layers.items()
}
def create_attn_groups(
attn_backends_map: dict[AttentionBackend, list[str]],
kv_cache_spec: KVCacheSpec,
) -> list[AttentionGroup]:
attn_groups: list[AttentionGroup] = []
for attn_backend, layer_names in attn_backends_map.items():
attn_metadata_builder_i = attn_backend.get_builder_cls()(
kv_cache_spec,
layer_names,
self.vllm_config,
self.device,
)
attn_group = AttentionGroup(attn_backend,
attn_metadata_builder_i,
layer_names)
attn_groups.append(attn_group)
if self.full_cuda_graph:
if attn_metadata_builder_i.attn_cudagraph_support == \
AttentionCGSupport.NEVER:
raise ValueError(
f"Full CUDAGraph not supported for "
f"{attn_backend.__name__}. Turn off "
f"CompilationConfig.full_cuda_graph or use a "
f" different attention backend.")
if attn_metadata_builder_i.attn_cudagraph_support == \
AttentionCGSupport.PURE_DECODE_ONLY:
# Limit the max cudagraph size to the max number of
# sequences for pure decode only cudagraph backend,
# whose max_query_len is 1.
self.cudagraph_batch_sizes = [
size for size in self.cudagraph_batch_sizes
if size <= self.scheduler_config.max_num_seqs
]
return attn_groups
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
if isinstance(kv_cache_spec, AttentionSpec):
attn_backends = get_attn_backends_for_layers(
kv_cache_group_spec.layer_names)
# TODO(lucas): move `get_mamba_attn_backend` into the mamba
# layers like above
elif isinstance(kv_cache_spec, MambaSpec):
attn_backends = {
get_mamba_attn_backend(kv_cache_spec.mamba_type):
kv_cache_group_spec.layer_names
}
else:
raise ValueError(
f"Unknown KV cache spec type: {type(kv_cache_spec)}")
self.attn_groups.append(
create_attn_groups(attn_backends, kv_cache_spec))
# Calculate reorder batch threshold (if neeeded)
self.calculate_reorder_batch_threshold()
if len(self.attn_backends) > 0:
if len(self.attn_groups) > 0:
return
# Check if model is encoder-only
block_size = self.vllm_config.cache_config.block_size
use_mla = self.vllm_config.model_config.use_mla
attn_specs = list[AttentionSpec]()
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
for attn_module in attn_layers.values():
if attn_module.attn_type == AttentionType.ENCODER_ONLY:
@ -2666,11 +2650,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
assert len(attn_specs) == len(attn_layers), \
"All or none of the layers are expected to be encoder-only"
attn_backend, attn_metadata_builder = (
self._initialize_single_attn_backend(attn_specs[0],
attn_layers.keys()))
self.attn_backends.append(attn_backend)
self.attn_metadata_builders.append(attn_metadata_builder)
attn_backends = get_attn_backends_for_layers(attn_layers.keys())
self.attn_groups.append(
create_attn_groups(attn_backends, attn_specs[0]))
self.is_encoder_only_model = True
def calculate_reorder_batch_threshold(self) -> None:
@ -2678,7 +2661,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
Check that if any backends reorder batches; that the reordering
is compatible (e.g., decode threshold is the same)
"""
for attn_metadata_builder_i in self.attn_metadata_builders:
for group in self._attn_group_iterator():
attn_metadata_builder_i = group.metadata_builder
# check that if any backends reorder batches; that the reordering
# is compatible (e.g., decode threshold is the same)
reorder_batch_threshold_i = (
@ -2752,6 +2737,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
)), "Some layers are not correctly initialized"
return kv_cache_raw_tensors
def _attn_group_iterator(self) -> Iterator[AttentionGroup]:
return itertools.chain.from_iterable(self.attn_groups)
def _kv_cache_spec_attn_group_iterator(
self) -> Iterator[tuple[KVCacheSpec, AttentionGroup]]:
if not self.kv_cache_config.kv_cache_groups:
return
for kv_cache_spec_id, attn_groups in enumerate(self.attn_groups):
for attn_group in attn_groups:
yield self.kv_cache_config.kv_cache_groups[
kv_cache_spec_id].kv_cache_spec, attn_group
def _reshape_kv_cache_tensors(
self,
kv_cache_config: KVCacheConfig,
@ -2770,23 +2767,22 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
"""
kv_caches: dict[str, torch.Tensor] = {}
has_attn, has_mamba = False, False
for i, kv_cache_group_spec in enumerate(
kv_cache_config.kv_cache_groups):
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
for layer_name in kv_cache_group_spec.layer_names:
for kv_cache_spec, group in self._kv_cache_spec_attn_group_iterator():
attn_backend = group.backend
for layer_name in group.layer_names:
raw_tensor = kv_cache_raw_tensors[layer_name]
assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
num_blocks = (raw_tensor.numel() //
kv_cache_spec.page_size_bytes)
if isinstance(kv_cache_spec, AttentionSpec):
has_attn = True
kv_cache_shape = self.attn_backends[i].get_kv_cache_shape(
kv_cache_shape = attn_backend.get_kv_cache_shape(
num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
dtype = kv_cache_spec.dtype
try:
kv_cache_stride_order = self.attn_backends[
i].get_kv_cache_stride_order()
kv_cache_stride_order = \
attn_backend.get_kv_cache_stride_order()
assert len(kv_cache_stride_order) == len(
kv_cache_shape)
except (AttributeError, NotImplementedError):
@ -2850,15 +2846,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
kv_cache_raw_tensors: The KV cache buffer of each layer.
"""
for i, kv_cache_group_spec in enumerate(
kv_cache_config.kv_cache_groups):
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
for layer_name in kv_cache_group_spec.layer_names:
for kv_cache_spec, group in self._kv_cache_spec_attn_group_iterator():
for layer_name in group.layer_names:
raw_tensor = kv_cache_raw_tensors[layer_name]
num_blocks = (raw_tensor.numel() //
kv_cache_spec.page_size_bytes)
if isinstance(kv_cache_spec, AttentionSpec):
kv_cache_shape = self.attn_backends[i].get_kv_cache_shape(
kv_cache_shape = group.backend.get_kv_cache_shape(
num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
if kv_cache_shape[0] != num_blocks or kv_cache_shape[
@ -2893,6 +2888,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.shared_kv_cache_layers,
kv_cache_config.kv_cache_groups,
kv_caches,
self.attn_groups,
)
attn_layers = get_layers_from_vllm_config(self.vllm_config,
Attention)
@ -2958,9 +2954,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
continue
# TODO: Support other attention modules, e.g., cross-attention
# TODO(lucas): move the attention specs into the model layers like
# the attention backends
if attn_module.attn_type == AttentionType.DECODER:
use_local_attention = (self.attention_chunk_size is not None
and attn_module.use_irope)
if attn_module.sliding_window is not None:
kv_cache_spec[layer_name] = SlidingWindowSpec(
block_size=block_size,
@ -2969,10 +2965,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
dtype=self.kv_cache_dtype,
sliding_window=attn_module.sliding_window,
use_mla=use_mla)
assert not use_local_attention, (
"attention module can not be with ",
"both local attention and sliding window")
elif use_local_attention:
elif self.attention_chunk_size is not None \
and isinstance(attn_module, ChunkedLocalAttention):
kv_cache_spec[layer_name] = ChunkedLocalAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
@ -3043,7 +3037,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Use the first attention metadata builder
# to create encoder attention metadata
builder = self.attn_metadata_builders[0]
builder = self.attn_groups[0][0].metadata_builder
dummy_block_table = torch.zeros((num_reqs, 1),
dtype=torch.int32,

View File

@ -15,8 +15,9 @@ import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr
import vllm.envs as envs
from vllm.attention import Attention
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import (ParallelConfig, VllmConfig,
get_layers_from_vllm_config, update_config)
@ -518,7 +519,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
continue
if attn_module.attn_type == AttentionType.DECODER:
if attn_module.use_irope:
if isinstance(attn_module, ChunkedLocalAttention):
logger.warning_once(
"Using irope in Pallas is not supported yet, it "
"will fall back to global attention for long context.")

View File

@ -1,14 +1,17 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import ModelConfig, SchedulerConfig
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.utils import extract_layer_index
from vllm.multimodal.registry import MultiModalRegistry
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import KVCacheGroupSpec
@ -122,6 +125,13 @@ class MultiModalBudget:
return max_items_per_prompt, max_items_per_batch
@dataclass
class AttentionGroup:
backend: type[AttentionBackend]
metadata_builder: AttentionMetadataBuilder
layer_names: list[str]
def sanity_check_mm_encoder_outputs(
mm_embeddings: MultiModalEmbeddings,
expected_num_items: int,
@ -196,6 +206,8 @@ def initialize_kv_cache_for_kv_sharing(
shared_kv_cache_layers: dict[str, str],
kv_cache_groups: list[KVCacheGroupSpec],
kv_caches: dict[str, torch.Tensor],
# Optional for now to avoid breaking TPU
attn_groups: Optional[list[list[AttentionGroup]]] = None,
) -> None:
"""
Sets up KV cache sharing by reusing the allocated KV caches in `kv_caches`
@ -225,6 +237,15 @@ def initialize_kv_cache_for_kv_sharing(
group_idx = layer_to_kv_cache_group_idx[target_layer_name]
kv_cache_groups[group_idx].layer_names.append(layer_name)
if attn_groups is not None:
assert len(attn_groups[group_idx]) == 1, (
"Only one attention group per KV cache group is supported "
"for KV-cache sharing for now.")
# TODO(lucas): I think in the future the layers that re-use a
# KV cache will be in a different attention group so we can
# remove this code from here.
attn_groups[group_idx][0].layer_names.append(layer_name)
def bind_kv_cache(
kv_caches: dict[str, torch.Tensor],