mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:26:12 +08:00
[Attention] Support multiple attention metadata builders per kv_cache_spec + proper local attention no hybrid kv cache fix (#21588)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
parent
f825c6bd22
commit
1dc8a70b6d
@ -313,7 +313,8 @@ def test_propose(num_speculative_tokens, backend):
|
||||
|
||||
# Mock runner for attention metadata building
|
||||
proposer.runner = mock.MagicMock()
|
||||
proposer.runner.attn_metadata_builders = [attn_metadata_builder]
|
||||
proposer.runner.attn_groups.append([mock.MagicMock()])
|
||||
proposer.runner.attn_groups[0][0].metadata_builder = attn_metadata_builder
|
||||
|
||||
result = proposer.propose(target_token_ids=target_token_ids,
|
||||
target_positions=target_positions,
|
||||
|
||||
@ -417,12 +417,12 @@ def test_kv_cache_stride_order(monkeypatch, model_runner):
|
||||
return rnd_stride
|
||||
|
||||
# Patch the attention backend class and re-trigger the KV cache creation.
|
||||
for attn_backend in model_runner.attn_backends:
|
||||
for attn_group in model_runner._attn_group_iterator():
|
||||
attn_backend = attn_group.backend
|
||||
monkeypatch.setattr(attn_backend, "get_kv_cache_stride_order",
|
||||
rnd_stride_order)
|
||||
|
||||
model_runner.attn_backends = []
|
||||
model_runner.attn_metadata_builders = []
|
||||
model_runner.attn_groups = []
|
||||
model_runner.initialize_kv_cache(model_runner.kv_cache_config)
|
||||
|
||||
# Shape is unchanged, but layout may differ
|
||||
|
||||
@ -106,6 +106,10 @@ class AttentionBackend(ABC):
|
||||
block_size: int, num_seqs: int, num_queries: int) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def full_cls_name(cls) -> tuple[str, str]:
|
||||
return (cls.__module__, cls.__qualname__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AttentionMetadata:
|
||||
|
||||
@ -9,6 +9,7 @@ import torch.nn.functional as F
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention import AttentionType
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
|
||||
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
|
||||
from vllm.config import CacheConfig, get_current_vllm_config
|
||||
@ -80,6 +81,7 @@ class Attention(nn.Module):
|
||||
prefix: str = "",
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
attn_backend: Optional[type[AttentionBackend]] = None,
|
||||
**extra_impl_args,
|
||||
) -> None:
|
||||
"""
|
||||
@ -137,15 +139,6 @@ class Attention(nn.Module):
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.sliding_window = sliding_window
|
||||
|
||||
# For v1 we have backend agnostic iRoPE (local chunked attention)
|
||||
# we have to store the flag on the layer so gpu model runner can
|
||||
# set KVSpec appropriately (and pop it so it doesnt get passed to
|
||||
# the backends)
|
||||
if envs.VLLM_USE_V1:
|
||||
self.use_irope = extra_impl_args.pop("use_irope", False)
|
||||
else:
|
||||
self.use_irope = extra_impl_args.get("use_irope", False)
|
||||
|
||||
quant_method = quant_config.get_quant_method(
|
||||
self, prefix=prefix) if quant_config else None
|
||||
if quant_method is not None and not isinstance(
|
||||
@ -166,18 +159,22 @@ class Attention(nn.Module):
|
||||
# During model initialization, the default dtype is set as the model
|
||||
# weight and activation dtype.
|
||||
dtype = torch.get_default_dtype()
|
||||
attn_backend = get_attn_backend(head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
is_attention_free,
|
||||
use_mla=use_mla)
|
||||
impl_cls = attn_backend.get_impl_cls()
|
||||
if attn_backend is None:
|
||||
self.attn_backend = get_attn_backend(head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
is_attention_free,
|
||||
use_mla=use_mla)
|
||||
else:
|
||||
self.attn_backend = attn_backend
|
||||
|
||||
impl_cls = self.attn_backend.get_impl_cls()
|
||||
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **extra_impl_args)
|
||||
self.backend = backend_name_to_enum(attn_backend.get_name())
|
||||
self.backend = backend_name_to_enum(self.attn_backend.get_name())
|
||||
self.dtype = dtype
|
||||
|
||||
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
|
||||
@ -187,7 +184,7 @@ class Attention(nn.Module):
|
||||
self.use_direct_call = not current_platform.is_cuda_alike(
|
||||
) and not current_platform.is_cpu()
|
||||
|
||||
self.use_output = attn_backend.accept_output_buffer
|
||||
self.use_output = self.attn_backend.accept_output_buffer
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
@ -309,6 +306,9 @@ class Attention(nn.Module):
|
||||
if hasattr(self.impl, "process_weights_after_loading"):
|
||||
self.impl.process_weights_after_loading(act_dtype)
|
||||
|
||||
def get_attn_backend(self) -> type[AttentionBackend]:
|
||||
return self.attn_backend
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
"""Multi-headed attention without any cache, used for ViT."""
|
||||
|
||||
88
vllm/attention/layers/chunked_local_attention.py
Normal file
88
vllm/attention/layers/chunked_local_attention.py
Normal file
@ -0,0 +1,88 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
from vllm.config import CacheConfig, QuantizationConfig
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
CommonAttentionMetadata, make_local_attention_virtual_batches,
|
||||
subclass_attention_backend, subclass_attention_metadata_builder)
|
||||
|
||||
from ..layer import Attention
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def create_chunked_local_attention_backend(
|
||||
underlying_attn_backend: AttentionBackend,
|
||||
attention_chunk_size: int,
|
||||
block_size: int,
|
||||
) -> type[AttentionBackend]:
|
||||
prefix = f"ChunkedLocalAttention_{attention_chunk_size}_{block_size}_"
|
||||
|
||||
def build_preprocess_fn(cm: CommonAttentionMetadata):
|
||||
return make_local_attention_virtual_batches(attention_chunk_size, cm,
|
||||
block_size)
|
||||
|
||||
# Dynamically create a new attention backend that wraps the
|
||||
# underlying attention backend but applies
|
||||
# `make_local_attention_virtual_batches` before calling `build(...)`
|
||||
builder_cls = subclass_attention_metadata_builder(
|
||||
name_prefix=prefix,
|
||||
builder_cls=underlying_attn_backend.get_builder_cls(),
|
||||
build_preprocess_fn=build_preprocess_fn)
|
||||
attn_backend = subclass_attention_backend(
|
||||
name_prefix=prefix,
|
||||
attention_backend_cls=underlying_attn_backend,
|
||||
builder_cls=builder_cls)
|
||||
|
||||
return attn_backend
|
||||
|
||||
|
||||
class ChunkedLocalAttention(Attention):
|
||||
|
||||
def __init__(self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
attention_chunk_size: int,
|
||||
num_kv_heads: Optional[int] = None,
|
||||
alibi_slopes: Optional[List[float]] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
prefix: str = ""):
|
||||
dtype = torch.get_default_dtype()
|
||||
if cache_config is not None:
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
block_size = cache_config.block_size
|
||||
else:
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = 16
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
underlying_attn_backend = get_attn_backend(head_size, dtype,
|
||||
kv_cache_dtype,
|
||||
block_size)
|
||||
|
||||
attn_backend = create_chunked_local_attention_backend(
|
||||
underlying_attn_backend, attention_chunk_size, block_size)
|
||||
else:
|
||||
# in v0 the local attention is handled inside the backends
|
||||
attn_backend = None
|
||||
|
||||
super().__init__(
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
num_kv_heads=num_kv_heads,
|
||||
alibi_slopes=alibi_slopes,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
|
||||
attn_backend=attn_backend)
|
||||
@ -142,7 +142,7 @@ def get_attn_backend(
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: Optional[str],
|
||||
block_size: int,
|
||||
is_attention_free: bool,
|
||||
is_attention_free: bool = False,
|
||||
use_mla: bool = False,
|
||||
) -> type[AttentionBackend]:
|
||||
"""Selects which attention backend to use and lazily imports it."""
|
||||
|
||||
@ -25,6 +25,7 @@ from torch import nn
|
||||
from transformers import Llama4TextConfig
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
@ -194,17 +195,18 @@ class Llama4Attention(nn.Module):
|
||||
is_neox_style=is_neox_style,
|
||||
) if not self.nope else None
|
||||
|
||||
self.attn = Attention(
|
||||
attn_cls = Attention if self.nope else ChunkedLocalAttention
|
||||
self.attn = attn_cls(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
per_layer_sliding_window=None,
|
||||
use_irope=not self.nope,
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
**({
|
||||
"attention_chunk_size": config.attention_chunk_size
|
||||
} if not self.nope else {}))
|
||||
|
||||
def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
|
||||
floor = torch.floor((positions + 1.0) / self.floor_scale)
|
||||
|
||||
@ -5,12 +5,12 @@ import enum
|
||||
import functools
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass, make_dataclass
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar
|
||||
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Generic, Optional,
|
||||
TypeVar)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.utils import cdiv
|
||||
|
||||
@ -20,6 +20,8 @@ if TYPE_CHECKING:
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import (
|
||||
get_kv_connector_cache_layout)
|
||||
from vllm.logger import init_logger
|
||||
@ -532,6 +534,48 @@ def make_local_attention_virtual_batches(
|
||||
)
|
||||
|
||||
|
||||
def subclass_attention_metadata_builder(
|
||||
name_prefix: str,
|
||||
builder_cls: type[AttentionMetadataBuilder[M]],
|
||||
build_preprocess_fn: Callable[[CommonAttentionMetadata],
|
||||
CommonAttentionMetadata],
|
||||
) -> type[AttentionMetadataBuilder[M]]:
|
||||
"""
|
||||
Return a new subclass of `builder_cls` whose .build(...) method
|
||||
first calls build_preprocess_fn(common_attn_metadata) on the metadata.
|
||||
"""
|
||||
name: str = name_prefix + builder_cls.__name__ # type: ignore
|
||||
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False):
|
||||
return builder_cls.build(self, common_prefix_len,
|
||||
build_preprocess_fn(common_attn_metadata),
|
||||
fast_build)
|
||||
|
||||
Wrapped = type(
|
||||
name,
|
||||
(builder_cls, ), # inherit from the original
|
||||
{
|
||||
"build": build,
|
||||
})
|
||||
return Wrapped # type: ignore
|
||||
|
||||
|
||||
def subclass_attention_backend(
|
||||
name_prefix: str, attention_backend_cls: type[AttentionBackend],
|
||||
builder_cls: type[AttentionMetadataBuilder[M]]
|
||||
) -> type[AttentionBackend]:
|
||||
"""
|
||||
Return a new subclass where `get_builder_cls` returns `builder_cls`.
|
||||
"""
|
||||
name: str = name_prefix + attention_backend_cls.__name__ # type: ignore
|
||||
|
||||
return type(name, (attention_backend_cls, ),
|
||||
{"get_builder_cls": lambda: builder_cls})
|
||||
|
||||
|
||||
def split_decodes_and_prefills(
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
decode_threshold: int = 1,
|
||||
|
||||
@ -158,9 +158,9 @@ class EagleProposer:
|
||||
assert self.runner is not None
|
||||
|
||||
# FIXME: need to consider multiple kv_cache_groups
|
||||
attn_metadata = self.runner.attn_metadata_builders[
|
||||
0].build_for_drafting(common_attn_metadata=common_attn_metadata,
|
||||
draft_index=0)
|
||||
attn_metadata = self.runner.attn_groups[0][0].metadata_builder\
|
||||
.build_for_drafting(common_attn_metadata=common_attn_metadata,
|
||||
draft_index=0)
|
||||
|
||||
# At this moment, we assume all eagle layers belong to the same KV
|
||||
# cache group, thus using the same attention metadata.
|
||||
@ -349,7 +349,8 @@ class EagleProposer:
|
||||
hidden_states: torch.Tensor,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
) -> list[torch.Tensor]:
|
||||
tree_attn_metadata_builder = self.runner.attn_metadata_builders[0]
|
||||
tree_attn_metadata_builder = \
|
||||
self.runner.attn_groups[0][0].metadata_builder
|
||||
assert isinstance(tree_attn_metadata_builder,
|
||||
TreeAttentionMetadataBuilder)
|
||||
|
||||
|
||||
@ -53,11 +53,11 @@ class CPUModelRunner(GPUModelRunner):
|
||||
raise ValueError("Multiple KVCacheGroups is not"
|
||||
"currently supported with CPU model runner.")
|
||||
|
||||
assert type(
|
||||
self.attn_metadata_builders[0]) is TorchSDPAMetadataBuilderV1
|
||||
assert type(self.attn_groups[0]
|
||||
[0].metadata_builder) is TorchSDPAMetadataBuilderV1
|
||||
|
||||
self.attn_metadata_builders[0].reorder_batch(self.input_batch,
|
||||
scheduler_output)
|
||||
self.attn_groups[0][0].metadata_builder.reorder_batch(
|
||||
self.input_batch, scheduler_output)
|
||||
|
||||
def _postprocess_tenosrs(self) -> None:
|
||||
# Note: replace device tensors with cpu tensors
|
||||
|
||||
@ -3,7 +3,10 @@
|
||||
|
||||
import dataclasses
|
||||
import gc
|
||||
import itertools
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||
|
||||
@ -14,9 +17,9 @@ import torch.nn as nn
|
||||
from tqdm import tqdm
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention import AttentionType, get_attn_backend
|
||||
from vllm.attention import Attention, AttentionType
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.config import (CompilationLevel, VllmConfig,
|
||||
get_layers_from_vllm_config, update_config)
|
||||
@ -50,7 +53,6 @@ from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
|
||||
make_kv_sharing_fast_prefill_attention_metadata,
|
||||
make_local_attention_virtual_batches,
|
||||
reorder_batch_to_split_decodes_and_prefills)
|
||||
from vllm.v1.kv_cache_interface import (AttentionSpec,
|
||||
ChunkedLocalAttentionSpec,
|
||||
@ -73,8 +75,8 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import (
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
|
||||
from ..sample.logits_processor import LogitsProcessorManager
|
||||
from .utils import (MultiModalBudget, bind_kv_cache, gather_mm_placeholders,
|
||||
initialize_kv_cache_for_kv_sharing,
|
||||
from .utils import (AttentionGroup, MultiModalBudget, bind_kv_cache,
|
||||
gather_mm_placeholders, initialize_kv_cache_for_kv_sharing,
|
||||
sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -162,8 +164,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# self.model: nn.Module # Set after load_model
|
||||
# Initialize in initialize_kv_cache
|
||||
self.kv_caches: list[torch.Tensor] = []
|
||||
self.attn_metadata_builders: list[AttentionMetadataBuilder] = []
|
||||
self.attn_backends: list[type[AttentionBackend]] = []
|
||||
# indexes: [kv_cache_group_id][attn_group]
|
||||
self.attn_groups: list[list[AttentionGroup]] = []
|
||||
# self.kv_cache_config: KVCacheConfig
|
||||
|
||||
# req_id -> (input_id -> encoder_output)
|
||||
@ -830,81 +832,51 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
spec_decode_common_attn_metadata is None:
|
||||
spec_decode_common_attn_metadata = common_attn_metadata
|
||||
|
||||
if isinstance(kv_cache_group_spec.kv_cache_spec,
|
||||
ChunkedLocalAttentionSpec):
|
||||
common_attn_metadata = make_local_attention_virtual_batches(
|
||||
kv_cache_group_spec.kv_cache_spec.attention_chunk_size,
|
||||
common_attn_metadata, self.cache_config.block_size)
|
||||
for attn_group in self.attn_groups[kv_cache_group_id]:
|
||||
# Prepare for cascade attention if enabled & beneficial.
|
||||
common_prefix_len = 0
|
||||
builder = attn_group.metadata_builder
|
||||
if self.cascade_attn_enabled:
|
||||
common_prefix_len = self._compute_cascade_attn_prefix_len(
|
||||
num_scheduled_tokens,
|
||||
scheduler_output.
|
||||
num_common_prefix_blocks[kv_cache_group_id],
|
||||
kv_cache_group_spec.kv_cache_spec,
|
||||
builder,
|
||||
)
|
||||
|
||||
# Prepare for cascade attention if enabled & beneficial.
|
||||
common_prefix_len = 0
|
||||
builder = self.attn_metadata_builders[kv_cache_group_id]
|
||||
if self.cascade_attn_enabled:
|
||||
common_prefix_len = self._compute_cascade_attn_prefix_len(
|
||||
num_scheduled_tokens,
|
||||
scheduler_output.
|
||||
num_common_prefix_blocks[kv_cache_group_id],
|
||||
kv_cache_group_spec.kv_cache_spec,
|
||||
builder,
|
||||
)
|
||||
|
||||
attn_metadata_i = (builder.build(
|
||||
common_prefix_len=common_prefix_len,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
))
|
||||
|
||||
fast_prefill_metadata = attn_metadata_i
|
||||
if (self.cache_config.kv_sharing_fast_prefill
|
||||
and self.kv_sharing_fast_prefill_eligible_layers):
|
||||
# Dynamically create a a dataclass type that inherits
|
||||
# from attention metadata type but includes additional
|
||||
# fields logits_indices_padded and num_logits_indices
|
||||
# which are required for prefill truncation
|
||||
fast_prefill_metadata_type = (
|
||||
make_kv_sharing_fast_prefill_attention_metadata(
|
||||
metadata_cls=type(attn_metadata_i), ))
|
||||
fast_prefill_metadata = fast_prefill_metadata_type(
|
||||
**dataclasses.asdict(attn_metadata_i),
|
||||
logits_indices_padded=logits_indices_padded,
|
||||
num_logits_indices=logits_indices.size(0),
|
||||
)
|
||||
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
if (self.cache_config.kv_sharing_fast_prefill and layer_name
|
||||
in self.kv_sharing_fast_prefill_eligible_layers):
|
||||
attn_metadata[layer_name] = fast_prefill_metadata
|
||||
continue
|
||||
|
||||
attn_metadata[layer_name] = attn_metadata_i
|
||||
|
||||
# Hack for now to fix chunked local attention + no hybrid kv cache
|
||||
# manager we can remove this once
|
||||
# https://github.com/vllm-project/vllm/pull/21588
|
||||
# is merged (i.e. properly handle different attention backends for
|
||||
# the same kv_cache_spec)
|
||||
if self.attention_chunk_size is not None \
|
||||
and self.scheduler_config.disable_hybrid_kv_cache_manager:
|
||||
if not hasattr(self, "local_attention_layers"):
|
||||
self.local_attention_layers = []
|
||||
attn_layers = get_layers_from_vllm_config(
|
||||
self.vllm_config, Attention)
|
||||
for layer_name, attn_module in attn_layers.items():
|
||||
if attn_module.use_irope:
|
||||
self.local_attention_layers.append(layer_name)
|
||||
|
||||
local_attn_metadata_i = (builder.build(
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=make_local_attention_virtual_batches(
|
||||
self.attention_chunk_size, common_attn_metadata,
|
||||
self.cache_config.block_size),
|
||||
attn_metadata_i = (builder.build(
|
||||
common_prefix_len=common_prefix_len,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
))
|
||||
|
||||
for layer_name in self.local_attention_layers:
|
||||
attn_metadata[layer_name] = local_attn_metadata_i
|
||||
fast_prefill_metadata = attn_metadata_i
|
||||
if (self.cache_config.kv_sharing_fast_prefill
|
||||
and self.kv_sharing_fast_prefill_eligible_layers):
|
||||
# Dynamically create a a dataclass type that inherits
|
||||
# from attention metadata type but includes additional
|
||||
# fields logits_indices_padded and num_logits_indices
|
||||
# which are required for prefill truncation
|
||||
fast_prefill_metadata_type = (
|
||||
make_kv_sharing_fast_prefill_attention_metadata(
|
||||
metadata_cls=type(attn_metadata_i), ))
|
||||
fast_prefill_metadata = fast_prefill_metadata_type(
|
||||
**dataclasses.asdict(attn_metadata_i),
|
||||
logits_indices_padded=logits_indices_padded,
|
||||
num_logits_indices=logits_indices.size(0),
|
||||
)
|
||||
|
||||
for layer_name in attn_group.layer_names:
|
||||
if (self.cache_config.kv_sharing_fast_prefill
|
||||
and layer_name
|
||||
in self.kv_sharing_fast_prefill_eligible_layers):
|
||||
attn_metadata[layer_name] = fast_prefill_metadata
|
||||
continue
|
||||
attn_metadata[layer_name] = attn_metadata_i
|
||||
|
||||
attention_cuda_graphs = all(
|
||||
b.can_run_in_cudagraph(common_attn_metadata)
|
||||
for b in self.attn_metadata_builders)
|
||||
g.metadata_builder.can_run_in_cudagraph(common_attn_metadata)
|
||||
for g in self._attn_group_iterator())
|
||||
|
||||
# Hot-Swap lora model
|
||||
if self.lora_config:
|
||||
@ -2229,11 +2201,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
block_table[kv_cache_group_id].slot_mapping[:num_tokens],
|
||||
causal=True)
|
||||
|
||||
attn_metadata_i = self.attn_metadata_builders[
|
||||
kv_cache_group_id].build_for_cudagraph_capture(
|
||||
common_attn_metadata)
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
attn_metadata[layer_name] = attn_metadata_i
|
||||
for attn_group in self.attn_groups[kv_cache_group_id]:
|
||||
attn_metadata_i = attn_group.metadata_builder\
|
||||
.build_for_cudagraph_capture(common_attn_metadata)
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
attn_metadata[layer_name] = attn_metadata_i
|
||||
|
||||
with self.maybe_dummy_run_with_lora(self.lora_config,
|
||||
num_scheduled_tokens):
|
||||
@ -2565,88 +2537,100 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
|
||||
elapsed_time, cuda_graph_size / (1 << 30))
|
||||
|
||||
def _initialize_single_attn_backend(
|
||||
self, kv_cache_spec: KVCacheSpec, layer_names: list[str]
|
||||
) -> tuple[AttentionBackend, AttentionMetadataBuilder]:
|
||||
if isinstance(kv_cache_spec, AttentionSpec):
|
||||
attn_backend_i = get_attn_backend(
|
||||
kv_cache_spec.head_size,
|
||||
self.dtype,
|
||||
kv_cache_spec.dtype,
|
||||
kv_cache_spec.block_size,
|
||||
self.model_config.is_attention_free,
|
||||
use_mla=kv_cache_spec.use_mla,
|
||||
)
|
||||
if attn_backend_i is None:
|
||||
error_msg = (f"Error with get_attn_backend: "
|
||||
f"{kv_cache_spec.head_size=}, "
|
||||
f"{self.dtype=}, {kv_cache_spec.dtype=}, "
|
||||
f"{kv_cache_spec.block_size=}, "
|
||||
f"{self.model_config.is_attention_free=}, "
|
||||
f"{kv_cache_spec.use_mla=}")
|
||||
logger.error(error_msg)
|
||||
raise NotImplementedError(
|
||||
"Non-Attention backend is not supported by V1 "
|
||||
"GPUModelRunner.")
|
||||
elif isinstance(kv_cache_spec, MambaSpec):
|
||||
attn_backend_i = get_mamba_attn_backend(kv_cache_spec.mamba_type)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown KV cache spec type: {type(kv_cache_spec)}")
|
||||
|
||||
attn_metadata_builder_i = attn_backend_i.get_builder_cls()(
|
||||
kv_cache_spec,
|
||||
layer_names,
|
||||
self.vllm_config,
|
||||
self.device,
|
||||
)
|
||||
|
||||
if self.full_cuda_graph:
|
||||
if attn_metadata_builder_i.attn_cudagraph_support == \
|
||||
AttentionCGSupport.NEVER:
|
||||
raise ValueError(f"Full CUDAGraph not supported for "
|
||||
f"{attn_backend_i.__name__}. Turn off "
|
||||
f"CompilationConfig.full_cuda_graph or use a "
|
||||
f" different attention backend.")
|
||||
if attn_metadata_builder_i.attn_cudagraph_support == \
|
||||
AttentionCGSupport.PURE_DECODE_ONLY:
|
||||
# Limit the max cudagraph size to the max number of
|
||||
# sequences for pure decode only cudagraph backend,
|
||||
# whose max_query_len is 1.
|
||||
self.cudagraph_batch_sizes = [
|
||||
size for size in self.cudagraph_batch_sizes
|
||||
if size <= self.scheduler_config.max_num_seqs
|
||||
]
|
||||
return attn_backend_i, attn_metadata_builder_i
|
||||
|
||||
def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
Initialize the attention backends and attention metadata builders.
|
||||
"""
|
||||
assert len(self.attn_backends) == 0 and len(
|
||||
self.attn_metadata_builders
|
||||
) == 0, "Attention backends are already initialized"
|
||||
for i, kv_cache_group_spec in enumerate(
|
||||
kv_cache_config.kv_cache_groups):
|
||||
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
|
||||
assert len(self.attn_groups) == 0, \
|
||||
"Attention backends are already initialized"
|
||||
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
|
||||
|
||||
attn_backend_i, attn_metadata_builder_i = (
|
||||
self._initialize_single_attn_backend(
|
||||
kv_cache_spec, kv_cache_group_spec.layer_names))
|
||||
self.attn_backends.append(attn_backend_i)
|
||||
self.attn_metadata_builders.append(attn_metadata_builder_i)
|
||||
def get_attn_backends_for_layers(
|
||||
layer_names: list[str]
|
||||
) -> dict[type[AttentionBackend], list[str]]:
|
||||
attn_backends = {}
|
||||
attn_backend_layers = defaultdict(list)
|
||||
# Dedupe based on full class name; this is a bit safer than using
|
||||
# using the class itself as the key because when we create dynamic
|
||||
# attention backend subclasses (e.g. ChunkedLocalAttention) unless
|
||||
# they are cached correctly, there will be different objects per
|
||||
# layer.
|
||||
for layer_name in layer_names:
|
||||
attn_backend = attn_layers[layer_name].get_attn_backend()
|
||||
key = attn_backend.full_cls_name()
|
||||
attn_backends[key] = attn_backend
|
||||
attn_backend_layers[key].append(layer_name)
|
||||
return {
|
||||
attn_backends[k]: v
|
||||
for k, v in attn_backend_layers.items()
|
||||
}
|
||||
|
||||
def create_attn_groups(
|
||||
attn_backends_map: dict[AttentionBackend, list[str]],
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
) -> list[AttentionGroup]:
|
||||
attn_groups: list[AttentionGroup] = []
|
||||
for attn_backend, layer_names in attn_backends_map.items():
|
||||
attn_metadata_builder_i = attn_backend.get_builder_cls()(
|
||||
kv_cache_spec,
|
||||
layer_names,
|
||||
self.vllm_config,
|
||||
self.device,
|
||||
)
|
||||
attn_group = AttentionGroup(attn_backend,
|
||||
attn_metadata_builder_i,
|
||||
layer_names)
|
||||
attn_groups.append(attn_group)
|
||||
|
||||
if self.full_cuda_graph:
|
||||
if attn_metadata_builder_i.attn_cudagraph_support == \
|
||||
AttentionCGSupport.NEVER:
|
||||
raise ValueError(
|
||||
f"Full CUDAGraph not supported for "
|
||||
f"{attn_backend.__name__}. Turn off "
|
||||
f"CompilationConfig.full_cuda_graph or use a "
|
||||
f" different attention backend.")
|
||||
if attn_metadata_builder_i.attn_cudagraph_support == \
|
||||
AttentionCGSupport.PURE_DECODE_ONLY:
|
||||
# Limit the max cudagraph size to the max number of
|
||||
# sequences for pure decode only cudagraph backend,
|
||||
# whose max_query_len is 1.
|
||||
self.cudagraph_batch_sizes = [
|
||||
size for size in self.cudagraph_batch_sizes
|
||||
if size <= self.scheduler_config.max_num_seqs
|
||||
]
|
||||
|
||||
return attn_groups
|
||||
|
||||
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
|
||||
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
|
||||
if isinstance(kv_cache_spec, AttentionSpec):
|
||||
attn_backends = get_attn_backends_for_layers(
|
||||
kv_cache_group_spec.layer_names)
|
||||
# TODO(lucas): move `get_mamba_attn_backend` into the mamba
|
||||
# layers like above
|
||||
elif isinstance(kv_cache_spec, MambaSpec):
|
||||
attn_backends = {
|
||||
get_mamba_attn_backend(kv_cache_spec.mamba_type):
|
||||
kv_cache_group_spec.layer_names
|
||||
}
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown KV cache spec type: {type(kv_cache_spec)}")
|
||||
|
||||
self.attn_groups.append(
|
||||
create_attn_groups(attn_backends, kv_cache_spec))
|
||||
|
||||
# Calculate reorder batch threshold (if neeeded)
|
||||
self.calculate_reorder_batch_threshold()
|
||||
|
||||
if len(self.attn_backends) > 0:
|
||||
if len(self.attn_groups) > 0:
|
||||
return
|
||||
|
||||
# Check if model is encoder-only
|
||||
block_size = self.vllm_config.cache_config.block_size
|
||||
use_mla = self.vllm_config.model_config.use_mla
|
||||
attn_specs = list[AttentionSpec]()
|
||||
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
|
||||
for attn_module in attn_layers.values():
|
||||
|
||||
if attn_module.attn_type == AttentionType.ENCODER_ONLY:
|
||||
@ -2666,11 +2650,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
assert len(attn_specs) == len(attn_layers), \
|
||||
"All or none of the layers are expected to be encoder-only"
|
||||
|
||||
attn_backend, attn_metadata_builder = (
|
||||
self._initialize_single_attn_backend(attn_specs[0],
|
||||
attn_layers.keys()))
|
||||
self.attn_backends.append(attn_backend)
|
||||
self.attn_metadata_builders.append(attn_metadata_builder)
|
||||
attn_backends = get_attn_backends_for_layers(attn_layers.keys())
|
||||
|
||||
self.attn_groups.append(
|
||||
create_attn_groups(attn_backends, attn_specs[0]))
|
||||
self.is_encoder_only_model = True
|
||||
|
||||
def calculate_reorder_batch_threshold(self) -> None:
|
||||
@ -2678,7 +2661,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
Check that if any backends reorder batches; that the reordering
|
||||
is compatible (e.g., decode threshold is the same)
|
||||
"""
|
||||
for attn_metadata_builder_i in self.attn_metadata_builders:
|
||||
for group in self._attn_group_iterator():
|
||||
attn_metadata_builder_i = group.metadata_builder
|
||||
|
||||
# check that if any backends reorder batches; that the reordering
|
||||
# is compatible (e.g., decode threshold is the same)
|
||||
reorder_batch_threshold_i = (
|
||||
@ -2752,6 +2737,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
)), "Some layers are not correctly initialized"
|
||||
return kv_cache_raw_tensors
|
||||
|
||||
def _attn_group_iterator(self) -> Iterator[AttentionGroup]:
|
||||
return itertools.chain.from_iterable(self.attn_groups)
|
||||
|
||||
def _kv_cache_spec_attn_group_iterator(
|
||||
self) -> Iterator[tuple[KVCacheSpec, AttentionGroup]]:
|
||||
if not self.kv_cache_config.kv_cache_groups:
|
||||
return
|
||||
for kv_cache_spec_id, attn_groups in enumerate(self.attn_groups):
|
||||
for attn_group in attn_groups:
|
||||
yield self.kv_cache_config.kv_cache_groups[
|
||||
kv_cache_spec_id].kv_cache_spec, attn_group
|
||||
|
||||
def _reshape_kv_cache_tensors(
|
||||
self,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
@ -2770,23 +2767,22 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
"""
|
||||
kv_caches: dict[str, torch.Tensor] = {}
|
||||
has_attn, has_mamba = False, False
|
||||
for i, kv_cache_group_spec in enumerate(
|
||||
kv_cache_config.kv_cache_groups):
|
||||
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
for kv_cache_spec, group in self._kv_cache_spec_attn_group_iterator():
|
||||
attn_backend = group.backend
|
||||
for layer_name in group.layer_names:
|
||||
raw_tensor = kv_cache_raw_tensors[layer_name]
|
||||
assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
|
||||
num_blocks = (raw_tensor.numel() //
|
||||
kv_cache_spec.page_size_bytes)
|
||||
if isinstance(kv_cache_spec, AttentionSpec):
|
||||
has_attn = True
|
||||
kv_cache_shape = self.attn_backends[i].get_kv_cache_shape(
|
||||
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
||||
num_blocks, kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
||||
dtype = kv_cache_spec.dtype
|
||||
try:
|
||||
kv_cache_stride_order = self.attn_backends[
|
||||
i].get_kv_cache_stride_order()
|
||||
kv_cache_stride_order = \
|
||||
attn_backend.get_kv_cache_stride_order()
|
||||
assert len(kv_cache_stride_order) == len(
|
||||
kv_cache_shape)
|
||||
except (AttributeError, NotImplementedError):
|
||||
@ -2850,15 +2846,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
kv_cache_raw_tensors: The KV cache buffer of each layer.
|
||||
"""
|
||||
|
||||
for i, kv_cache_group_spec in enumerate(
|
||||
kv_cache_config.kv_cache_groups):
|
||||
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
for kv_cache_spec, group in self._kv_cache_spec_attn_group_iterator():
|
||||
for layer_name in group.layer_names:
|
||||
raw_tensor = kv_cache_raw_tensors[layer_name]
|
||||
num_blocks = (raw_tensor.numel() //
|
||||
kv_cache_spec.page_size_bytes)
|
||||
if isinstance(kv_cache_spec, AttentionSpec):
|
||||
kv_cache_shape = self.attn_backends[i].get_kv_cache_shape(
|
||||
|
||||
kv_cache_shape = group.backend.get_kv_cache_shape(
|
||||
num_blocks, kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
||||
if kv_cache_shape[0] != num_blocks or kv_cache_shape[
|
||||
@ -2893,6 +2888,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.shared_kv_cache_layers,
|
||||
kv_cache_config.kv_cache_groups,
|
||||
kv_caches,
|
||||
self.attn_groups,
|
||||
)
|
||||
attn_layers = get_layers_from_vllm_config(self.vllm_config,
|
||||
Attention)
|
||||
@ -2958,9 +2954,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
continue
|
||||
|
||||
# TODO: Support other attention modules, e.g., cross-attention
|
||||
# TODO(lucas): move the attention specs into the model layers like
|
||||
# the attention backends
|
||||
if attn_module.attn_type == AttentionType.DECODER:
|
||||
use_local_attention = (self.attention_chunk_size is not None
|
||||
and attn_module.use_irope)
|
||||
if attn_module.sliding_window is not None:
|
||||
kv_cache_spec[layer_name] = SlidingWindowSpec(
|
||||
block_size=block_size,
|
||||
@ -2969,10 +2965,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
dtype=self.kv_cache_dtype,
|
||||
sliding_window=attn_module.sliding_window,
|
||||
use_mla=use_mla)
|
||||
assert not use_local_attention, (
|
||||
"attention module can not be with ",
|
||||
"both local attention and sliding window")
|
||||
elif use_local_attention:
|
||||
elif self.attention_chunk_size is not None \
|
||||
and isinstance(attn_module, ChunkedLocalAttention):
|
||||
kv_cache_spec[layer_name] = ChunkedLocalAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
@ -3043,7 +3037,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
# Use the first attention metadata builder
|
||||
# to create encoder attention metadata
|
||||
builder = self.attn_metadata_builders[0]
|
||||
builder = self.attn_groups[0][0].metadata_builder
|
||||
|
||||
dummy_block_table = torch.zeros((num_reqs, 1),
|
||||
dtype=torch.int32,
|
||||
|
||||
@ -15,8 +15,9 @@ import torch_xla.distributed.spmd as xs
|
||||
import torch_xla.runtime as xr
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention import Attention
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
|
||||
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
||||
from vllm.config import (ParallelConfig, VllmConfig,
|
||||
get_layers_from_vllm_config, update_config)
|
||||
@ -518,7 +519,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
continue
|
||||
|
||||
if attn_module.attn_type == AttentionType.DECODER:
|
||||
if attn_module.use_irope:
|
||||
if isinstance(attn_module, ChunkedLocalAttention):
|
||||
logger.warning_once(
|
||||
"Using irope in Pallas is not supported yet, it "
|
||||
"will fall back to global attention for long context.")
|
||||
|
||||
@ -1,14 +1,17 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.config import ModelConfig, SchedulerConfig
|
||||
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
from vllm.multimodal.registry import MultiModalRegistry
|
||||
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
|
||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||
from vllm.v1.kv_cache_interface import KVCacheGroupSpec
|
||||
|
||||
@ -122,6 +125,13 @@ class MultiModalBudget:
|
||||
return max_items_per_prompt, max_items_per_batch
|
||||
|
||||
|
||||
@dataclass
|
||||
class AttentionGroup:
|
||||
backend: type[AttentionBackend]
|
||||
metadata_builder: AttentionMetadataBuilder
|
||||
layer_names: list[str]
|
||||
|
||||
|
||||
def sanity_check_mm_encoder_outputs(
|
||||
mm_embeddings: MultiModalEmbeddings,
|
||||
expected_num_items: int,
|
||||
@ -196,6 +206,8 @@ def initialize_kv_cache_for_kv_sharing(
|
||||
shared_kv_cache_layers: dict[str, str],
|
||||
kv_cache_groups: list[KVCacheGroupSpec],
|
||||
kv_caches: dict[str, torch.Tensor],
|
||||
# Optional for now to avoid breaking TPU
|
||||
attn_groups: Optional[list[list[AttentionGroup]]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Sets up KV cache sharing by reusing the allocated KV caches in `kv_caches`
|
||||
@ -225,6 +237,15 @@ def initialize_kv_cache_for_kv_sharing(
|
||||
group_idx = layer_to_kv_cache_group_idx[target_layer_name]
|
||||
kv_cache_groups[group_idx].layer_names.append(layer_name)
|
||||
|
||||
if attn_groups is not None:
|
||||
assert len(attn_groups[group_idx]) == 1, (
|
||||
"Only one attention group per KV cache group is supported "
|
||||
"for KV-cache sharing for now.")
|
||||
# TODO(lucas): I think in the future the layers that re-use a
|
||||
# KV cache will be in a different attention group so we can
|
||||
# remove this code from here.
|
||||
attn_groups[group_idx][0].layer_names.append(layer_name)
|
||||
|
||||
|
||||
def bind_kv_cache(
|
||||
kv_caches: dict[str, torch.Tensor],
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user