[Attention] Support multiple attention metadata builders per kv_cache_spec + proper local attention no hybrid kv cache fix (#21588)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson 2025-08-06 21:40:52 -04:00 committed by GitHub
parent f825c6bd22
commit 1dc8a70b6d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 369 additions and 213 deletions

View File

@ -313,7 +313,8 @@ def test_propose(num_speculative_tokens, backend):
# Mock runner for attention metadata building # Mock runner for attention metadata building
proposer.runner = mock.MagicMock() proposer.runner = mock.MagicMock()
proposer.runner.attn_metadata_builders = [attn_metadata_builder] proposer.runner.attn_groups.append([mock.MagicMock()])
proposer.runner.attn_groups[0][0].metadata_builder = attn_metadata_builder
result = proposer.propose(target_token_ids=target_token_ids, result = proposer.propose(target_token_ids=target_token_ids,
target_positions=target_positions, target_positions=target_positions,

View File

@ -417,12 +417,12 @@ def test_kv_cache_stride_order(monkeypatch, model_runner):
return rnd_stride return rnd_stride
# Patch the attention backend class and re-trigger the KV cache creation. # Patch the attention backend class and re-trigger the KV cache creation.
for attn_backend in model_runner.attn_backends: for attn_group in model_runner._attn_group_iterator():
attn_backend = attn_group.backend
monkeypatch.setattr(attn_backend, "get_kv_cache_stride_order", monkeypatch.setattr(attn_backend, "get_kv_cache_stride_order",
rnd_stride_order) rnd_stride_order)
model_runner.attn_backends = [] model_runner.attn_groups = []
model_runner.attn_metadata_builders = []
model_runner.initialize_kv_cache(model_runner.kv_cache_config) model_runner.initialize_kv_cache(model_runner.kv_cache_config)
# Shape is unchanged, but layout may differ # Shape is unchanged, but layout may differ

View File

@ -106,6 +106,10 @@ class AttentionBackend(ABC):
block_size: int, num_seqs: int, num_queries: int) -> None: block_size: int, num_seqs: int, num_queries: int) -> None:
raise NotImplementedError raise NotImplementedError
@classmethod
def full_cls_name(cls) -> tuple[str, str]:
return (cls.__module__, cls.__qualname__)
@dataclass @dataclass
class AttentionMetadata: class AttentionMetadata:

View File

@ -9,6 +9,7 @@ import torch.nn.functional as F
import vllm.envs as envs import vllm.envs as envs
from vllm.attention import AttentionType from vllm.attention import AttentionType
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.selector import backend_name_to_enum, get_attn_backend from vllm.attention.selector import backend_name_to_enum, get_attn_backend
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
from vllm.config import CacheConfig, get_current_vllm_config from vllm.config import CacheConfig, get_current_vllm_config
@ -80,6 +81,7 @@ class Attention(nn.Module):
prefix: str = "", prefix: str = "",
attn_type: str = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None, kv_sharing_target_layer_name: Optional[str] = None,
attn_backend: Optional[type[AttentionBackend]] = None,
**extra_impl_args, **extra_impl_args,
) -> None: ) -> None:
""" """
@ -137,15 +139,6 @@ class Attention(nn.Module):
self.num_kv_heads = num_kv_heads self.num_kv_heads = num_kv_heads
self.sliding_window = sliding_window self.sliding_window = sliding_window
# For v1 we have backend agnostic iRoPE (local chunked attention)
# we have to store the flag on the layer so gpu model runner can
# set KVSpec appropriately (and pop it so it doesnt get passed to
# the backends)
if envs.VLLM_USE_V1:
self.use_irope = extra_impl_args.pop("use_irope", False)
else:
self.use_irope = extra_impl_args.get("use_irope", False)
quant_method = quant_config.get_quant_method( quant_method = quant_config.get_quant_method(
self, prefix=prefix) if quant_config else None self, prefix=prefix) if quant_config else None
if quant_method is not None and not isinstance( if quant_method is not None and not isinstance(
@ -166,18 +159,22 @@ class Attention(nn.Module):
# During model initialization, the default dtype is set as the model # During model initialization, the default dtype is set as the model
# weight and activation dtype. # weight and activation dtype.
dtype = torch.get_default_dtype() dtype = torch.get_default_dtype()
attn_backend = get_attn_backend(head_size, if attn_backend is None:
dtype, self.attn_backend = get_attn_backend(head_size,
kv_cache_dtype, dtype,
block_size, kv_cache_dtype,
is_attention_free, block_size,
use_mla=use_mla) is_attention_free,
impl_cls = attn_backend.get_impl_cls() use_mla=use_mla)
else:
self.attn_backend = attn_backend
impl_cls = self.attn_backend.get_impl_cls()
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype, alibi_slopes, sliding_window, kv_cache_dtype,
logits_soft_cap, attn_type, logits_soft_cap, attn_type,
kv_sharing_target_layer_name, **extra_impl_args) kv_sharing_target_layer_name, **extra_impl_args)
self.backend = backend_name_to_enum(attn_backend.get_name()) self.backend = backend_name_to_enum(self.attn_backend.get_name())
self.dtype = dtype self.dtype = dtype
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
@ -187,7 +184,7 @@ class Attention(nn.Module):
self.use_direct_call = not current_platform.is_cuda_alike( self.use_direct_call = not current_platform.is_cuda_alike(
) and not current_platform.is_cpu() ) and not current_platform.is_cpu()
self.use_output = attn_backend.accept_output_buffer self.use_output = self.attn_backend.accept_output_buffer
compilation_config = get_current_vllm_config().compilation_config compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context: if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}") raise ValueError(f"Duplicate layer name: {prefix}")
@ -309,6 +306,9 @@ class Attention(nn.Module):
if hasattr(self.impl, "process_weights_after_loading"): if hasattr(self.impl, "process_weights_after_loading"):
self.impl.process_weights_after_loading(act_dtype) self.impl.process_weights_after_loading(act_dtype)
def get_attn_backend(self) -> type[AttentionBackend]:
return self.attn_backend
class MultiHeadAttention(nn.Module): class MultiHeadAttention(nn.Module):
"""Multi-headed attention without any cache, used for ViT.""" """Multi-headed attention without any cache, used for ViT."""

View File

@ -0,0 +1,88 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from typing import List, Optional
import torch
from vllm import envs
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig, QuantizationConfig
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata, make_local_attention_virtual_batches,
subclass_attention_backend, subclass_attention_metadata_builder)
from ..layer import Attention
@functools.lru_cache
def create_chunked_local_attention_backend(
underlying_attn_backend: AttentionBackend,
attention_chunk_size: int,
block_size: int,
) -> type[AttentionBackend]:
prefix = f"ChunkedLocalAttention_{attention_chunk_size}_{block_size}_"
def build_preprocess_fn(cm: CommonAttentionMetadata):
return make_local_attention_virtual_batches(attention_chunk_size, cm,
block_size)
# Dynamically create a new attention backend that wraps the
# underlying attention backend but applies
# `make_local_attention_virtual_batches` before calling `build(...)`
builder_cls = subclass_attention_metadata_builder(
name_prefix=prefix,
builder_cls=underlying_attn_backend.get_builder_cls(),
build_preprocess_fn=build_preprocess_fn)
attn_backend = subclass_attention_backend(
name_prefix=prefix,
attention_backend_cls=underlying_attn_backend,
builder_cls=builder_cls)
return attn_backend
class ChunkedLocalAttention(Attention):
def __init__(self,
num_heads: int,
head_size: int,
scale: float,
attention_chunk_size: int,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
kv_sharing_target_layer_name: Optional[str] = None,
prefix: str = ""):
dtype = torch.get_default_dtype()
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
else:
kv_cache_dtype = "auto"
block_size = 16
if envs.VLLM_USE_V1:
underlying_attn_backend = get_attn_backend(head_size, dtype,
kv_cache_dtype,
block_size)
attn_backend = create_chunked_local_attention_backend(
underlying_attn_backend, attention_chunk_size, block_size)
else:
# in v0 the local attention is handled inside the backends
attn_backend = None
super().__init__(
num_heads=num_heads,
head_size=head_size,
scale=scale,
num_kv_heads=num_kv_heads,
alibi_slopes=alibi_slopes,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
attn_backend=attn_backend)

View File

@ -142,7 +142,7 @@ def get_attn_backend(
dtype: torch.dtype, dtype: torch.dtype,
kv_cache_dtype: Optional[str], kv_cache_dtype: Optional[str],
block_size: int, block_size: int,
is_attention_free: bool, is_attention_free: bool = False,
use_mla: bool = False, use_mla: bool = False,
) -> type[AttentionBackend]: ) -> type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it.""" """Selects which attention backend to use and lazily imports it."""

View File

@ -25,6 +25,7 @@ from torch import nn
from transformers import Llama4TextConfig from transformers import Llama4TextConfig
from vllm.attention import Attention from vllm.attention import Attention
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
@ -194,17 +195,18 @@ class Llama4Attention(nn.Module):
is_neox_style=is_neox_style, is_neox_style=is_neox_style,
) if not self.nope else None ) if not self.nope else None
self.attn = Attention( attn_cls = Attention if self.nope else ChunkedLocalAttention
self.attn = attn_cls(
self.num_heads, self.num_heads,
self.head_dim, self.head_dim,
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
per_layer_sliding_window=None,
use_irope=not self.nope,
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
) **({
"attention_chunk_size": config.attention_chunk_size
} if not self.nope else {}))
def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor: def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
floor = torch.floor((positions + 1.0) / self.floor_scale) floor = torch.floor((positions + 1.0) / self.floor_scale)

View File

@ -5,12 +5,12 @@ import enum
import functools import functools
from abc import abstractmethod from abc import abstractmethod
from dataclasses import dataclass, make_dataclass from dataclasses import dataclass, make_dataclass
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Generic, Optional,
TypeVar)
import numpy as np import numpy as np
import torch import torch
from vllm.attention.layer import Attention
from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.utils import cdiv from vllm.utils import cdiv
@ -20,6 +20,8 @@ if TYPE_CHECKING:
from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.gpu_input_batch import InputBatch
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.layer import Attention
from vllm.distributed.kv_transfer.kv_connector.utils import ( from vllm.distributed.kv_transfer.kv_connector.utils import (
get_kv_connector_cache_layout) get_kv_connector_cache_layout)
from vllm.logger import init_logger from vllm.logger import init_logger
@ -532,6 +534,48 @@ def make_local_attention_virtual_batches(
) )
def subclass_attention_metadata_builder(
name_prefix: str,
builder_cls: type[AttentionMetadataBuilder[M]],
build_preprocess_fn: Callable[[CommonAttentionMetadata],
CommonAttentionMetadata],
) -> type[AttentionMetadataBuilder[M]]:
"""
Return a new subclass of `builder_cls` whose .build(...) method
first calls build_preprocess_fn(common_attn_metadata) on the metadata.
"""
name: str = name_prefix + builder_cls.__name__ # type: ignore
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False):
return builder_cls.build(self, common_prefix_len,
build_preprocess_fn(common_attn_metadata),
fast_build)
Wrapped = type(
name,
(builder_cls, ), # inherit from the original
{
"build": build,
})
return Wrapped # type: ignore
def subclass_attention_backend(
name_prefix: str, attention_backend_cls: type[AttentionBackend],
builder_cls: type[AttentionMetadataBuilder[M]]
) -> type[AttentionBackend]:
"""
Return a new subclass where `get_builder_cls` returns `builder_cls`.
"""
name: str = name_prefix + attention_backend_cls.__name__ # type: ignore
return type(name, (attention_backend_cls, ),
{"get_builder_cls": lambda: builder_cls})
def split_decodes_and_prefills( def split_decodes_and_prefills(
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
decode_threshold: int = 1, decode_threshold: int = 1,

View File

@ -158,9 +158,9 @@ class EagleProposer:
assert self.runner is not None assert self.runner is not None
# FIXME: need to consider multiple kv_cache_groups # FIXME: need to consider multiple kv_cache_groups
attn_metadata = self.runner.attn_metadata_builders[ attn_metadata = self.runner.attn_groups[0][0].metadata_builder\
0].build_for_drafting(common_attn_metadata=common_attn_metadata, .build_for_drafting(common_attn_metadata=common_attn_metadata,
draft_index=0) draft_index=0)
# At this moment, we assume all eagle layers belong to the same KV # At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata. # cache group, thus using the same attention metadata.
@ -349,7 +349,8 @@ class EagleProposer:
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
) -> list[torch.Tensor]: ) -> list[torch.Tensor]:
tree_attn_metadata_builder = self.runner.attn_metadata_builders[0] tree_attn_metadata_builder = \
self.runner.attn_groups[0][0].metadata_builder
assert isinstance(tree_attn_metadata_builder, assert isinstance(tree_attn_metadata_builder,
TreeAttentionMetadataBuilder) TreeAttentionMetadataBuilder)

View File

@ -53,11 +53,11 @@ class CPUModelRunner(GPUModelRunner):
raise ValueError("Multiple KVCacheGroups is not" raise ValueError("Multiple KVCacheGroups is not"
"currently supported with CPU model runner.") "currently supported with CPU model runner.")
assert type( assert type(self.attn_groups[0]
self.attn_metadata_builders[0]) is TorchSDPAMetadataBuilderV1 [0].metadata_builder) is TorchSDPAMetadataBuilderV1
self.attn_metadata_builders[0].reorder_batch(self.input_batch, self.attn_groups[0][0].metadata_builder.reorder_batch(
scheduler_output) self.input_batch, scheduler_output)
def _postprocess_tenosrs(self) -> None: def _postprocess_tenosrs(self) -> None:
# Note: replace device tensors with cpu tensors # Note: replace device tensors with cpu tensors

View File

@ -3,7 +3,10 @@
import dataclasses import dataclasses
import gc import gc
import itertools
import time import time
from collections import defaultdict
from collections.abc import Iterator
from contextlib import contextmanager from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Optional, Union, cast from typing import TYPE_CHECKING, Any, Optional, Union, cast
@ -14,9 +17,9 @@ import torch.nn as nn
from tqdm import tqdm from tqdm import tqdm
import vllm.envs as envs import vllm.envs as envs
from vllm.attention import AttentionType, get_attn_backend from vllm.attention import Attention, AttentionType
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.layer import Attention from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
from vllm.config import (CompilationLevel, VllmConfig, from vllm.config import (CompilationLevel, VllmConfig,
get_layers_from_vllm_config, update_config) get_layers_from_vllm_config, update_config)
@ -50,7 +53,6 @@ from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
make_kv_sharing_fast_prefill_attention_metadata, make_kv_sharing_fast_prefill_attention_metadata,
make_local_attention_virtual_batches,
reorder_batch_to_split_decodes_and_prefills) reorder_batch_to_split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import (AttentionSpec, from vllm.v1.kv_cache_interface import (AttentionSpec,
ChunkedLocalAttentionSpec, ChunkedLocalAttentionSpec,
@ -73,8 +75,8 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import (
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from ..sample.logits_processor import LogitsProcessorManager from ..sample.logits_processor import LogitsProcessorManager
from .utils import (MultiModalBudget, bind_kv_cache, gather_mm_placeholders, from .utils import (AttentionGroup, MultiModalBudget, bind_kv_cache,
initialize_kv_cache_for_kv_sharing, gather_mm_placeholders, initialize_kv_cache_for_kv_sharing,
sanity_check_mm_encoder_outputs, scatter_mm_placeholders) sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
if TYPE_CHECKING: if TYPE_CHECKING:
@ -162,8 +164,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# self.model: nn.Module # Set after load_model # self.model: nn.Module # Set after load_model
# Initialize in initialize_kv_cache # Initialize in initialize_kv_cache
self.kv_caches: list[torch.Tensor] = [] self.kv_caches: list[torch.Tensor] = []
self.attn_metadata_builders: list[AttentionMetadataBuilder] = [] # indexes: [kv_cache_group_id][attn_group]
self.attn_backends: list[type[AttentionBackend]] = [] self.attn_groups: list[list[AttentionGroup]] = []
# self.kv_cache_config: KVCacheConfig # self.kv_cache_config: KVCacheConfig
# req_id -> (input_id -> encoder_output) # req_id -> (input_id -> encoder_output)
@ -830,81 +832,51 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
spec_decode_common_attn_metadata is None: spec_decode_common_attn_metadata is None:
spec_decode_common_attn_metadata = common_attn_metadata spec_decode_common_attn_metadata = common_attn_metadata
if isinstance(kv_cache_group_spec.kv_cache_spec, for attn_group in self.attn_groups[kv_cache_group_id]:
ChunkedLocalAttentionSpec): # Prepare for cascade attention if enabled & beneficial.
common_attn_metadata = make_local_attention_virtual_batches( common_prefix_len = 0
kv_cache_group_spec.kv_cache_spec.attention_chunk_size, builder = attn_group.metadata_builder
common_attn_metadata, self.cache_config.block_size) if self.cascade_attn_enabled:
common_prefix_len = self._compute_cascade_attn_prefix_len(
num_scheduled_tokens,
scheduler_output.
num_common_prefix_blocks[kv_cache_group_id],
kv_cache_group_spec.kv_cache_spec,
builder,
)
# Prepare for cascade attention if enabled & beneficial. attn_metadata_i = (builder.build(
common_prefix_len = 0 common_prefix_len=common_prefix_len,
builder = self.attn_metadata_builders[kv_cache_group_id] common_attn_metadata=common_attn_metadata,
if self.cascade_attn_enabled:
common_prefix_len = self._compute_cascade_attn_prefix_len(
num_scheduled_tokens,
scheduler_output.
num_common_prefix_blocks[kv_cache_group_id],
kv_cache_group_spec.kv_cache_spec,
builder,
)
attn_metadata_i = (builder.build(
common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata,
))
fast_prefill_metadata = attn_metadata_i
if (self.cache_config.kv_sharing_fast_prefill
and self.kv_sharing_fast_prefill_eligible_layers):
# Dynamically create a a dataclass type that inherits
# from attention metadata type but includes additional
# fields logits_indices_padded and num_logits_indices
# which are required for prefill truncation
fast_prefill_metadata_type = (
make_kv_sharing_fast_prefill_attention_metadata(
metadata_cls=type(attn_metadata_i), ))
fast_prefill_metadata = fast_prefill_metadata_type(
**dataclasses.asdict(attn_metadata_i),
logits_indices_padded=logits_indices_padded,
num_logits_indices=logits_indices.size(0),
)
for layer_name in kv_cache_group_spec.layer_names:
if (self.cache_config.kv_sharing_fast_prefill and layer_name
in self.kv_sharing_fast_prefill_eligible_layers):
attn_metadata[layer_name] = fast_prefill_metadata
continue
attn_metadata[layer_name] = attn_metadata_i
# Hack for now to fix chunked local attention + no hybrid kv cache
# manager we can remove this once
# https://github.com/vllm-project/vllm/pull/21588
# is merged (i.e. properly handle different attention backends for
# the same kv_cache_spec)
if self.attention_chunk_size is not None \
and self.scheduler_config.disable_hybrid_kv_cache_manager:
if not hasattr(self, "local_attention_layers"):
self.local_attention_layers = []
attn_layers = get_layers_from_vllm_config(
self.vllm_config, Attention)
for layer_name, attn_module in attn_layers.items():
if attn_module.use_irope:
self.local_attention_layers.append(layer_name)
local_attn_metadata_i = (builder.build(
common_prefix_len=0,
common_attn_metadata=make_local_attention_virtual_batches(
self.attention_chunk_size, common_attn_metadata,
self.cache_config.block_size),
)) ))
for layer_name in self.local_attention_layers: fast_prefill_metadata = attn_metadata_i
attn_metadata[layer_name] = local_attn_metadata_i if (self.cache_config.kv_sharing_fast_prefill
and self.kv_sharing_fast_prefill_eligible_layers):
# Dynamically create a a dataclass type that inherits
# from attention metadata type but includes additional
# fields logits_indices_padded and num_logits_indices
# which are required for prefill truncation
fast_prefill_metadata_type = (
make_kv_sharing_fast_prefill_attention_metadata(
metadata_cls=type(attn_metadata_i), ))
fast_prefill_metadata = fast_prefill_metadata_type(
**dataclasses.asdict(attn_metadata_i),
logits_indices_padded=logits_indices_padded,
num_logits_indices=logits_indices.size(0),
)
for layer_name in attn_group.layer_names:
if (self.cache_config.kv_sharing_fast_prefill
and layer_name
in self.kv_sharing_fast_prefill_eligible_layers):
attn_metadata[layer_name] = fast_prefill_metadata
continue
attn_metadata[layer_name] = attn_metadata_i
attention_cuda_graphs = all( attention_cuda_graphs = all(
b.can_run_in_cudagraph(common_attn_metadata) g.metadata_builder.can_run_in_cudagraph(common_attn_metadata)
for b in self.attn_metadata_builders) for g in self._attn_group_iterator())
# Hot-Swap lora model # Hot-Swap lora model
if self.lora_config: if self.lora_config:
@ -2229,11 +2201,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
block_table[kv_cache_group_id].slot_mapping[:num_tokens], block_table[kv_cache_group_id].slot_mapping[:num_tokens],
causal=True) causal=True)
attn_metadata_i = self.attn_metadata_builders[ for attn_group in self.attn_groups[kv_cache_group_id]:
kv_cache_group_id].build_for_cudagraph_capture( attn_metadata_i = attn_group.metadata_builder\
common_attn_metadata) .build_for_cudagraph_capture(common_attn_metadata)
for layer_name in kv_cache_group_spec.layer_names: for layer_name in kv_cache_group_spec.layer_names:
attn_metadata[layer_name] = attn_metadata_i attn_metadata[layer_name] = attn_metadata_i
with self.maybe_dummy_run_with_lora(self.lora_config, with self.maybe_dummy_run_with_lora(self.lora_config,
num_scheduled_tokens): num_scheduled_tokens):
@ -2565,88 +2537,100 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
elapsed_time, cuda_graph_size / (1 << 30)) elapsed_time, cuda_graph_size / (1 << 30))
def _initialize_single_attn_backend(
self, kv_cache_spec: KVCacheSpec, layer_names: list[str]
) -> tuple[AttentionBackend, AttentionMetadataBuilder]:
if isinstance(kv_cache_spec, AttentionSpec):
attn_backend_i = get_attn_backend(
kv_cache_spec.head_size,
self.dtype,
kv_cache_spec.dtype,
kv_cache_spec.block_size,
self.model_config.is_attention_free,
use_mla=kv_cache_spec.use_mla,
)
if attn_backend_i is None:
error_msg = (f"Error with get_attn_backend: "
f"{kv_cache_spec.head_size=}, "
f"{self.dtype=}, {kv_cache_spec.dtype=}, "
f"{kv_cache_spec.block_size=}, "
f"{self.model_config.is_attention_free=}, "
f"{kv_cache_spec.use_mla=}")
logger.error(error_msg)
raise NotImplementedError(
"Non-Attention backend is not supported by V1 "
"GPUModelRunner.")
elif isinstance(kv_cache_spec, MambaSpec):
attn_backend_i = get_mamba_attn_backend(kv_cache_spec.mamba_type)
else:
raise ValueError(
f"Unknown KV cache spec type: {type(kv_cache_spec)}")
attn_metadata_builder_i = attn_backend_i.get_builder_cls()(
kv_cache_spec,
layer_names,
self.vllm_config,
self.device,
)
if self.full_cuda_graph:
if attn_metadata_builder_i.attn_cudagraph_support == \
AttentionCGSupport.NEVER:
raise ValueError(f"Full CUDAGraph not supported for "
f"{attn_backend_i.__name__}. Turn off "
f"CompilationConfig.full_cuda_graph or use a "
f" different attention backend.")
if attn_metadata_builder_i.attn_cudagraph_support == \
AttentionCGSupport.PURE_DECODE_ONLY:
# Limit the max cudagraph size to the max number of
# sequences for pure decode only cudagraph backend,
# whose max_query_len is 1.
self.cudagraph_batch_sizes = [
size for size in self.cudagraph_batch_sizes
if size <= self.scheduler_config.max_num_seqs
]
return attn_backend_i, attn_metadata_builder_i
def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
""" """
Initialize the attention backends and attention metadata builders. Initialize the attention backends and attention metadata builders.
""" """
assert len(self.attn_backends) == 0 and len( assert len(self.attn_groups) == 0, \
self.attn_metadata_builders "Attention backends are already initialized"
) == 0, "Attention backends are already initialized" attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
for i, kv_cache_group_spec in enumerate(
kv_cache_config.kv_cache_groups):
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
attn_backend_i, attn_metadata_builder_i = ( def get_attn_backends_for_layers(
self._initialize_single_attn_backend( layer_names: list[str]
kv_cache_spec, kv_cache_group_spec.layer_names)) ) -> dict[type[AttentionBackend], list[str]]:
self.attn_backends.append(attn_backend_i) attn_backends = {}
self.attn_metadata_builders.append(attn_metadata_builder_i) attn_backend_layers = defaultdict(list)
# Dedupe based on full class name; this is a bit safer than using
# using the class itself as the key because when we create dynamic
# attention backend subclasses (e.g. ChunkedLocalAttention) unless
# they are cached correctly, there will be different objects per
# layer.
for layer_name in layer_names:
attn_backend = attn_layers[layer_name].get_attn_backend()
key = attn_backend.full_cls_name()
attn_backends[key] = attn_backend
attn_backend_layers[key].append(layer_name)
return {
attn_backends[k]: v
for k, v in attn_backend_layers.items()
}
def create_attn_groups(
attn_backends_map: dict[AttentionBackend, list[str]],
kv_cache_spec: KVCacheSpec,
) -> list[AttentionGroup]:
attn_groups: list[AttentionGroup] = []
for attn_backend, layer_names in attn_backends_map.items():
attn_metadata_builder_i = attn_backend.get_builder_cls()(
kv_cache_spec,
layer_names,
self.vllm_config,
self.device,
)
attn_group = AttentionGroup(attn_backend,
attn_metadata_builder_i,
layer_names)
attn_groups.append(attn_group)
if self.full_cuda_graph:
if attn_metadata_builder_i.attn_cudagraph_support == \
AttentionCGSupport.NEVER:
raise ValueError(
f"Full CUDAGraph not supported for "
f"{attn_backend.__name__}. Turn off "
f"CompilationConfig.full_cuda_graph or use a "
f" different attention backend.")
if attn_metadata_builder_i.attn_cudagraph_support == \
AttentionCGSupport.PURE_DECODE_ONLY:
# Limit the max cudagraph size to the max number of
# sequences for pure decode only cudagraph backend,
# whose max_query_len is 1.
self.cudagraph_batch_sizes = [
size for size in self.cudagraph_batch_sizes
if size <= self.scheduler_config.max_num_seqs
]
return attn_groups
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
if isinstance(kv_cache_spec, AttentionSpec):
attn_backends = get_attn_backends_for_layers(
kv_cache_group_spec.layer_names)
# TODO(lucas): move `get_mamba_attn_backend` into the mamba
# layers like above
elif isinstance(kv_cache_spec, MambaSpec):
attn_backends = {
get_mamba_attn_backend(kv_cache_spec.mamba_type):
kv_cache_group_spec.layer_names
}
else:
raise ValueError(
f"Unknown KV cache spec type: {type(kv_cache_spec)}")
self.attn_groups.append(
create_attn_groups(attn_backends, kv_cache_spec))
# Calculate reorder batch threshold (if neeeded) # Calculate reorder batch threshold (if neeeded)
self.calculate_reorder_batch_threshold() self.calculate_reorder_batch_threshold()
if len(self.attn_backends) > 0: if len(self.attn_groups) > 0:
return return
# Check if model is encoder-only # Check if model is encoder-only
block_size = self.vllm_config.cache_config.block_size block_size = self.vllm_config.cache_config.block_size
use_mla = self.vllm_config.model_config.use_mla use_mla = self.vllm_config.model_config.use_mla
attn_specs = list[AttentionSpec]() attn_specs = list[AttentionSpec]()
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
for attn_module in attn_layers.values(): for attn_module in attn_layers.values():
if attn_module.attn_type == AttentionType.ENCODER_ONLY: if attn_module.attn_type == AttentionType.ENCODER_ONLY:
@ -2666,11 +2650,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
assert len(attn_specs) == len(attn_layers), \ assert len(attn_specs) == len(attn_layers), \
"All or none of the layers are expected to be encoder-only" "All or none of the layers are expected to be encoder-only"
attn_backend, attn_metadata_builder = ( attn_backends = get_attn_backends_for_layers(attn_layers.keys())
self._initialize_single_attn_backend(attn_specs[0],
attn_layers.keys())) self.attn_groups.append(
self.attn_backends.append(attn_backend) create_attn_groups(attn_backends, attn_specs[0]))
self.attn_metadata_builders.append(attn_metadata_builder)
self.is_encoder_only_model = True self.is_encoder_only_model = True
def calculate_reorder_batch_threshold(self) -> None: def calculate_reorder_batch_threshold(self) -> None:
@ -2678,7 +2661,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
Check that if any backends reorder batches; that the reordering Check that if any backends reorder batches; that the reordering
is compatible (e.g., decode threshold is the same) is compatible (e.g., decode threshold is the same)
""" """
for attn_metadata_builder_i in self.attn_metadata_builders: for group in self._attn_group_iterator():
attn_metadata_builder_i = group.metadata_builder
# check that if any backends reorder batches; that the reordering # check that if any backends reorder batches; that the reordering
# is compatible (e.g., decode threshold is the same) # is compatible (e.g., decode threshold is the same)
reorder_batch_threshold_i = ( reorder_batch_threshold_i = (
@ -2752,6 +2737,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
)), "Some layers are not correctly initialized" )), "Some layers are not correctly initialized"
return kv_cache_raw_tensors return kv_cache_raw_tensors
def _attn_group_iterator(self) -> Iterator[AttentionGroup]:
return itertools.chain.from_iterable(self.attn_groups)
def _kv_cache_spec_attn_group_iterator(
self) -> Iterator[tuple[KVCacheSpec, AttentionGroup]]:
if not self.kv_cache_config.kv_cache_groups:
return
for kv_cache_spec_id, attn_groups in enumerate(self.attn_groups):
for attn_group in attn_groups:
yield self.kv_cache_config.kv_cache_groups[
kv_cache_spec_id].kv_cache_spec, attn_group
def _reshape_kv_cache_tensors( def _reshape_kv_cache_tensors(
self, self,
kv_cache_config: KVCacheConfig, kv_cache_config: KVCacheConfig,
@ -2770,23 +2767,22 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
""" """
kv_caches: dict[str, torch.Tensor] = {} kv_caches: dict[str, torch.Tensor] = {}
has_attn, has_mamba = False, False has_attn, has_mamba = False, False
for i, kv_cache_group_spec in enumerate( for kv_cache_spec, group in self._kv_cache_spec_attn_group_iterator():
kv_cache_config.kv_cache_groups): attn_backend = group.backend
kv_cache_spec = kv_cache_group_spec.kv_cache_spec for layer_name in group.layer_names:
for layer_name in kv_cache_group_spec.layer_names:
raw_tensor = kv_cache_raw_tensors[layer_name] raw_tensor = kv_cache_raw_tensors[layer_name]
assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0 assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
num_blocks = (raw_tensor.numel() // num_blocks = (raw_tensor.numel() //
kv_cache_spec.page_size_bytes) kv_cache_spec.page_size_bytes)
if isinstance(kv_cache_spec, AttentionSpec): if isinstance(kv_cache_spec, AttentionSpec):
has_attn = True has_attn = True
kv_cache_shape = self.attn_backends[i].get_kv_cache_shape( kv_cache_shape = attn_backend.get_kv_cache_shape(
num_blocks, kv_cache_spec.block_size, num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
dtype = kv_cache_spec.dtype dtype = kv_cache_spec.dtype
try: try:
kv_cache_stride_order = self.attn_backends[ kv_cache_stride_order = \
i].get_kv_cache_stride_order() attn_backend.get_kv_cache_stride_order()
assert len(kv_cache_stride_order) == len( assert len(kv_cache_stride_order) == len(
kv_cache_shape) kv_cache_shape)
except (AttributeError, NotImplementedError): except (AttributeError, NotImplementedError):
@ -2850,15 +2846,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
kv_cache_raw_tensors: The KV cache buffer of each layer. kv_cache_raw_tensors: The KV cache buffer of each layer.
""" """
for i, kv_cache_group_spec in enumerate( for kv_cache_spec, group in self._kv_cache_spec_attn_group_iterator():
kv_cache_config.kv_cache_groups): for layer_name in group.layer_names:
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
for layer_name in kv_cache_group_spec.layer_names:
raw_tensor = kv_cache_raw_tensors[layer_name] raw_tensor = kv_cache_raw_tensors[layer_name]
num_blocks = (raw_tensor.numel() // num_blocks = (raw_tensor.numel() //
kv_cache_spec.page_size_bytes) kv_cache_spec.page_size_bytes)
if isinstance(kv_cache_spec, AttentionSpec): if isinstance(kv_cache_spec, AttentionSpec):
kv_cache_shape = self.attn_backends[i].get_kv_cache_shape(
kv_cache_shape = group.backend.get_kv_cache_shape(
num_blocks, kv_cache_spec.block_size, num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
if kv_cache_shape[0] != num_blocks or kv_cache_shape[ if kv_cache_shape[0] != num_blocks or kv_cache_shape[
@ -2893,6 +2888,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.shared_kv_cache_layers, self.shared_kv_cache_layers,
kv_cache_config.kv_cache_groups, kv_cache_config.kv_cache_groups,
kv_caches, kv_caches,
self.attn_groups,
) )
attn_layers = get_layers_from_vllm_config(self.vllm_config, attn_layers = get_layers_from_vllm_config(self.vllm_config,
Attention) Attention)
@ -2958,9 +2954,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
continue continue
# TODO: Support other attention modules, e.g., cross-attention # TODO: Support other attention modules, e.g., cross-attention
# TODO(lucas): move the attention specs into the model layers like
# the attention backends
if attn_module.attn_type == AttentionType.DECODER: if attn_module.attn_type == AttentionType.DECODER:
use_local_attention = (self.attention_chunk_size is not None
and attn_module.use_irope)
if attn_module.sliding_window is not None: if attn_module.sliding_window is not None:
kv_cache_spec[layer_name] = SlidingWindowSpec( kv_cache_spec[layer_name] = SlidingWindowSpec(
block_size=block_size, block_size=block_size,
@ -2969,10 +2965,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
dtype=self.kv_cache_dtype, dtype=self.kv_cache_dtype,
sliding_window=attn_module.sliding_window, sliding_window=attn_module.sliding_window,
use_mla=use_mla) use_mla=use_mla)
assert not use_local_attention, ( elif self.attention_chunk_size is not None \
"attention module can not be with ", and isinstance(attn_module, ChunkedLocalAttention):
"both local attention and sliding window")
elif use_local_attention:
kv_cache_spec[layer_name] = ChunkedLocalAttentionSpec( kv_cache_spec[layer_name] = ChunkedLocalAttentionSpec(
block_size=block_size, block_size=block_size,
num_kv_heads=attn_module.num_kv_heads, num_kv_heads=attn_module.num_kv_heads,
@ -3043,7 +3037,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Use the first attention metadata builder # Use the first attention metadata builder
# to create encoder attention metadata # to create encoder attention metadata
builder = self.attn_metadata_builders[0] builder = self.attn_groups[0][0].metadata_builder
dummy_block_table = torch.zeros((num_reqs, 1), dummy_block_table = torch.zeros((num_reqs, 1),
dtype=torch.int32, dtype=torch.int32,

View File

@ -15,8 +15,9 @@ import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr import torch_xla.runtime as xr
import vllm.envs as envs import vllm.envs as envs
from vllm.attention import Attention
from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import (ParallelConfig, VllmConfig, from vllm.config import (ParallelConfig, VllmConfig,
get_layers_from_vllm_config, update_config) get_layers_from_vllm_config, update_config)
@ -518,7 +519,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
continue continue
if attn_module.attn_type == AttentionType.DECODER: if attn_module.attn_type == AttentionType.DECODER:
if attn_module.use_irope: if isinstance(attn_module, ChunkedLocalAttention):
logger.warning_once( logger.warning_once(
"Using irope in Pallas is not supported yet, it " "Using irope in Pallas is not supported yet, it "
"will fall back to global attention for long context.") "will fall back to global attention for long context.")

View File

@ -1,14 +1,17 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
import torch import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import ModelConfig, SchedulerConfig from vllm.config import ModelConfig, SchedulerConfig
from vllm.model_executor.models.interfaces import MultiModalEmbeddings from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.utils import extract_layer_index from vllm.model_executor.models.utils import extract_layer_index
from vllm.multimodal.registry import MultiModalRegistry from vllm.multimodal.registry import MultiModalRegistry
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import KVCacheGroupSpec from vllm.v1.kv_cache_interface import KVCacheGroupSpec
@ -122,6 +125,13 @@ class MultiModalBudget:
return max_items_per_prompt, max_items_per_batch return max_items_per_prompt, max_items_per_batch
@dataclass
class AttentionGroup:
backend: type[AttentionBackend]
metadata_builder: AttentionMetadataBuilder
layer_names: list[str]
def sanity_check_mm_encoder_outputs( def sanity_check_mm_encoder_outputs(
mm_embeddings: MultiModalEmbeddings, mm_embeddings: MultiModalEmbeddings,
expected_num_items: int, expected_num_items: int,
@ -196,6 +206,8 @@ def initialize_kv_cache_for_kv_sharing(
shared_kv_cache_layers: dict[str, str], shared_kv_cache_layers: dict[str, str],
kv_cache_groups: list[KVCacheGroupSpec], kv_cache_groups: list[KVCacheGroupSpec],
kv_caches: dict[str, torch.Tensor], kv_caches: dict[str, torch.Tensor],
# Optional for now to avoid breaking TPU
attn_groups: Optional[list[list[AttentionGroup]]] = None,
) -> None: ) -> None:
""" """
Sets up KV cache sharing by reusing the allocated KV caches in `kv_caches` Sets up KV cache sharing by reusing the allocated KV caches in `kv_caches`
@ -225,6 +237,15 @@ def initialize_kv_cache_for_kv_sharing(
group_idx = layer_to_kv_cache_group_idx[target_layer_name] group_idx = layer_to_kv_cache_group_idx[target_layer_name]
kv_cache_groups[group_idx].layer_names.append(layer_name) kv_cache_groups[group_idx].layer_names.append(layer_name)
if attn_groups is not None:
assert len(attn_groups[group_idx]) == 1, (
"Only one attention group per KV cache group is supported "
"for KV-cache sharing for now.")
# TODO(lucas): I think in the future the layers that re-use a
# KV cache will be in a different attention group so we can
# remove this code from here.
attn_groups[group_idx][0].layer_names.append(layer_name)
def bind_kv_cache( def bind_kv_cache(
kv_caches: dict[str, torch.Tensor], kv_caches: dict[str, torch.Tensor],