[Attention] Refactor AttentionMetadata Preparation for Encoder-only Models (#23154)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang 2025-08-21 22:05:59 -07:00 committed by GitHub
parent 5964069367
commit 17373dcd93
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 226 additions and 214 deletions

View File

@ -680,6 +680,7 @@ def test_init_kv_cache_with_kv_sharing_valid():
kv_cache_spec[layer_0].page_size_bytes
runner.initialize_kv_cache(kv_cache_config)
kv_cache_config_after_init = runner.kv_cache_config
layer_0_kv = vllm_ctx[layer_0].kv_cache[0]
layer_1_kv = vllm_ctx[layer_1].kv_cache[0]
@ -687,10 +688,12 @@ def test_init_kv_cache_with_kv_sharing_valid():
assert id(layer_1_kv) == id(layer_0_kv)
# check layer 1 added to kv cache group's layer names
assert len(kv_cache_config.kv_cache_groups) == 1
assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2
assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0
assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1
assert len(kv_cache_config_after_init.kv_cache_groups) == 1
assert len(kv_cache_config_after_init.kv_cache_groups[0].layer_names) == 2
assert kv_cache_config_after_init.kv_cache_groups[0].layer_names[
0] == layer_0
assert kv_cache_config_after_init.kv_cache_groups[0].layer_names[
1] == layer_1
def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):

View File

@ -6,12 +6,13 @@ from typing import List, Optional
import torch
from vllm import envs
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata)
from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig, QuantizationConfig
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata, make_local_attention_virtual_batches,
subclass_attention_backend, subclass_attention_metadata_builder)
subclass_attention_backend)
from ..layer import Attention
@ -24,21 +25,23 @@ def create_chunked_local_attention_backend(
) -> type[AttentionBackend]:
prefix = f"ChunkedLocalAttention_{attention_chunk_size}_{block_size}_"
def build_preprocess_fn(cm: CommonAttentionMetadata):
return make_local_attention_virtual_batches(attention_chunk_size, cm,
block_size)
underlying_builder = underlying_attn_backend.get_builder_cls()
class ChunkedLocalAttentionBuilder(underlying_builder): # type: ignore
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> AttentionMetadata:
common_attn_metadata = make_local_attention_virtual_batches(
attention_chunk_size, common_attn_metadata, block_size)
return super().build(common_prefix_len, common_attn_metadata,
fast_build)
# Dynamically create a new attention backend that wraps the
# underlying attention backend but applies
# `make_local_attention_virtual_batches` before calling `build(...)`
builder_cls = subclass_attention_metadata_builder(
name_prefix=prefix,
builder_cls=underlying_attn_backend.get_builder_cls(),
build_preprocess_fn=build_preprocess_fn)
attn_backend = subclass_attention_backend(
name_prefix=prefix,
attention_backend_cls=underlying_attn_backend,
builder_cls=builder_cls)
builder_cls=ChunkedLocalAttentionBuilder)
return attn_backend

View File

@ -0,0 +1,86 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from copy import copy
from typing import Optional
import torch
from transformers import CacheConfig
from vllm import envs
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata, AttentionType)
from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
subclass_attention_backend)
@functools.lru_cache
def create_encoder_only_attention_backend(
underlying_attn_backend: AttentionBackend, ) -> type[AttentionBackend]:
prefix = "EncoderOnlyAttention_"
underlying_builder = underlying_attn_backend.get_builder_cls()
class EncoderOnlyAttentionBuilder(underlying_builder): # type: ignore
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> AttentionMetadata:
new_common_attn_metadata = copy(common_attn_metadata)
new_common_attn_metadata.causal = False
return super().build(common_prefix_len, new_common_attn_metadata,
fast_build)
attn_backend = subclass_attention_backend(
name_prefix=prefix,
attention_backend_cls=underlying_attn_backend,
builder_cls=EncoderOnlyAttentionBuilder)
return attn_backend
class EncoderOnlyAttention(Attention):
"""
Encoder attention is a special case that doesn't need a KV Cache.
"""
def __init__(self,
num_heads: int,
head_size: int,
scale: float,
cache_config: Optional[CacheConfig] = None,
attn_type: Optional[str] = None,
**kwargs):
dtype = torch.get_default_dtype()
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
else:
kv_cache_dtype = "auto"
block_size = 16
if envs.VLLM_USE_V1:
underlying_attn_backend = get_attn_backend(head_size, dtype,
kv_cache_dtype,
block_size)
attn_backend = create_encoder_only_attention_backend(
underlying_attn_backend)
else:
# in v0 encoder only attention is handled inside the backends
attn_backend = None
if attn_type is not None:
assert attn_type == AttentionType.ENCODER_ONLY, \
"EncoderOnlyAttention only supports AttentionType.ENCODER_ONLY"
super().__init__(num_heads=num_heads,
head_size=head_size,
scale=scale,
cache_config=cache_config,
attn_backend=attn_backend,
attn_type=AttentionType.ENCODER_ONLY,
**kwargs)

View File

@ -8,7 +8,7 @@ import torch
from torch import nn
from transformers import BertConfig
from vllm.attention import Attention, AttentionType
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, PoolerConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
@ -239,14 +239,13 @@ class BertSelfAttention(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj")
self.attn = Attention(num_heads=self.num_heads,
head_size=self.head_dim,
scale=self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
attn_type=AttentionType.ENCODER_ONLY)
self.attn = EncoderOnlyAttention(num_heads=self.num_heads,
head_size=self.head_dim,
scale=self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,

View File

@ -7,7 +7,7 @@ import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionType
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
@ -119,14 +119,13 @@ class BertWithRopeAttention(nn.Module):
self.rotary_emb = get_rope(**rotary_kwargs)
self.attn = Attention(num_heads=self.num_heads,
head_size=self.head_dim,
scale=self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
attn_type=AttentionType.ENCODER_ONLY)
self.attn = EncoderOnlyAttention(num_heads=self.num_heads,
head_size=self.head_dim,
scale=self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn")
self.out_proj = RowParallelLinear(input_size=hidden_size,
output_size=hidden_size,

View File

@ -31,6 +31,7 @@ from torch import nn
from transformers import LlamaConfig
from vllm.attention import Attention, AttentionType
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
@ -173,7 +174,10 @@ class LlamaAttention(nn.Module):
if is_sliding:
sliding_window = config.sliding_window
self.attn = Attention(
attn_cls = (EncoderOnlyAttention
if attn_type == AttentionType.ENCODER_ONLY else Attention)
self.attn = attn_cls(
self.num_heads,
self.head_dim,
self.scaling,

View File

@ -7,7 +7,7 @@ import torch
from torch import nn
from transformers import ModernBertConfig
from vllm.attention import Attention, AttentionType
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
@ -104,12 +104,12 @@ class ModernBertAttention(nn.Module):
head_size=self.head_dim,
dim=self.head_dim,
base=rope_theta)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
prefix=f"{layer_id}.attn",
attn_type=AttentionType.ENCODER_ONLY,
per_layer_sliding_window=sliding_window)
self.attn = EncoderOnlyAttention(
self.num_heads,
self.head_dim,
self.scaling,
prefix=f"{layer_id}.attn",
per_layer_sliding_window=sliding_window)
self.Wo = RowParallelLinear(config.hidden_size,
config.hidden_size,
bias=config.attention_bias)

View File

@ -32,6 +32,7 @@ from torch import nn
from transformers import Qwen2Config
from vllm.attention import Attention, AttentionType
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
@ -159,7 +160,9 @@ class Qwen2Attention(nn.Module):
rope_scaling=rope_scaling,
dual_chunk_attention_config=dual_chunk_attention_config,
)
self.attn = Attention(
attn_cls = (EncoderOnlyAttention
if attn_type == AttentionType.ENCODER_ONLY else Attention)
self.attn = attn_cls(
self.num_heads,
self.head_dim,
self.scaling,

View File

@ -5,8 +5,7 @@ import enum
import functools
from abc import abstractmethod
from dataclasses import dataclass, make_dataclass
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Generic, Optional,
TypeVar)
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar
import numpy as np
import torch
@ -543,35 +542,6 @@ def make_local_attention_virtual_batches(
)
def subclass_attention_metadata_builder(
name_prefix: str,
builder_cls: type[AttentionMetadataBuilder[M]],
build_preprocess_fn: Callable[[CommonAttentionMetadata],
CommonAttentionMetadata],
) -> type[AttentionMetadataBuilder[M]]:
"""
Return a new subclass of `builder_cls` whose .build(...) method
first calls build_preprocess_fn(common_attn_metadata) on the metadata.
"""
name: str = name_prefix + builder_cls.__name__ # type: ignore
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False):
return builder_cls.build(self, common_prefix_len,
build_preprocess_fn(common_attn_metadata),
fast_build)
Wrapped = type(
name,
(builder_cls, ), # inherit from the original
{
"build": build,
})
return Wrapped # type: ignore
def subclass_attention_backend(
name_prefix: str, attention_backend_cls: type[AttentionBackend],
builder_cls: type[AttentionMetadataBuilder[M]]

View File

@ -203,6 +203,14 @@ class MambaSpec(KVCacheSpec):
return self.page_size_bytes
@dataclass(frozen=True)
class EncoderOnlyAttentionSpec(AttentionSpec):
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
# Encoder-only layers do not need KV cache
return 0
@dataclass
class KVCacheTensor:
"""

View File

@ -8,6 +8,7 @@ import time
from collections import defaultdict
from collections.abc import Iterator
from contextlib import contextmanager
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Optional, Union, cast
import numpy as np
@ -62,9 +63,10 @@ from vllm.v1.attention.backends.utils import (
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
from vllm.v1.kv_cache_interface import (AttentionSpec,
ChunkedLocalAttentionSpec,
EncoderOnlyAttentionSpec,
FullAttentionSpec, KVCacheConfig,
KVCacheSpec, MambaSpec,
SlidingWindowSpec)
KVCacheGroupSpec, KVCacheSpec,
MambaSpec, SlidingWindowSpec)
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds,
LogprobsTensors, ModelRunnerOutput)
from vllm.v1.pool.metadata import PoolingMetadata
@ -136,7 +138,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
cache_config.cache_dtype]
self.is_pooling_model = model_config.pooler_config is not None
self.is_encoder_only_model = False
self.is_multimodal_raw_input_supported = (
model_config.is_multimodal_raw_input_supported)
self.max_model_len = model_config.max_model_len
@ -345,6 +346,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.reorder_batch_threshold: Optional[int] = None
# Attention layers that are only in the KVCacheConfig of the runner
# (e.g., KV sharing, encoder-only attention), but not in the
# KVCacheConfig of the scheduler.
self.runner_only_attn_layers: set[str] = set()
# Cached outputs.
self._draft_token_ids: Optional[Union[list[list[int]],
torch.Tensor]] = None
@ -834,23 +840,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
attn_metadata: dict[str, Any] = {}
# Prepare encoder attention metadata separately
# (encoder layers are not in KV cache groups)
if self.is_encoder_only_model:
per_layer_metadata = \
self._build_encoder_only_attn_metadata(
scheduler_output)
# Add encoder attention metadata for all encoder layers
attention_layers = get_layers_from_vllm_config(
self.vllm_config, Attention)
for layer_name, attn_module in attention_layers.items():
if attn_module.attn_type == AttentionType.ENCODER_ONLY:
common_attn_metadata, encoder_attn_metadata =\
per_layer_metadata[layer_name]
attn_metadata[layer_name] = encoder_attn_metadata
# Used in the below loop.
query_start_loc_cpu = self.query_start_loc_cpu[:num_reqs + 1]
seq_lens_cpu = self.seq_lens_cpu[:num_reqs]
@ -863,13 +852,33 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
for kv_cache_group_id, kv_cache_group_spec in enumerate(
self.kv_cache_config.kv_cache_groups):
blk_table = self.input_batch.block_table[kv_cache_group_id]
blk_table_tensor = blk_table.get_device_tensor()[:num_reqs]
slot_mapping = blk_table.slot_mapping[:total_num_scheduled_tokens]
if isinstance(kv_cache_group_spec.kv_cache_spec,
EncoderOnlyAttentionSpec):
# Encoder-only layers do not have KV cache, so we need to
# create a dummy block table and slot mapping for them.
blk_table_tensor = torch.zeros(
(num_reqs, 1),
dtype=torch.int32,
pin_memory=self.pin_memory,
device="cpu").to(self.device, non_blocking=True)
slot_mapping = torch.zeros((total_num_scheduled_tokens, ),
dtype=torch.int32,
pin_memory=self.pin_memory,
device="cpu").to(self.device,
non_blocking=True)
num_common_prefix_blocks = 0
else:
blk_table = self.input_batch.block_table[kv_cache_group_id]
blk_table_tensor = blk_table.get_device_tensor()[:num_reqs]
slot_mapping = blk_table.slot_mapping[:
total_num_scheduled_tokens]
# Fill unused with -1. Needed for reshape_and_cache in full cuda
# graph mode.
blk_table.slot_mapping[total_num_scheduled_tokens:].fill_(-1)
# Fill unused with -1. Needed for reshape_and_cache in full cuda
# graph mode.
blk_table.slot_mapping[total_num_scheduled_tokens:].fill_(-1)
num_common_prefix_blocks = (
scheduler_output.
num_common_prefix_blocks[kv_cache_group_id])
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc,
@ -897,8 +906,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if self.cascade_attn_enabled:
common_prefix_len = self._compute_cascade_attn_prefix_len(
num_scheduled_tokens,
scheduler_output.
num_common_prefix_blocks[kv_cache_group_id],
num_common_prefix_blocks,
kv_cache_group_spec.kv_cache_spec,
builder,
)
@ -2812,49 +2820,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Calculate reorder batch threshold (if neeeded)
self.calculate_reorder_batch_threshold()
if len(self.attn_groups) > 0:
return
# Check if model is encoder-only
block_size = self.vllm_config.cache_config.block_size
use_mla = self.vllm_config.model_config.use_mla
attn_specs: dict[AttentionSpec, list[str]] = defaultdict(list)
for layer_name, attn_module in attn_layers.items():
if attn_module.attn_type == AttentionType.ENCODER_ONLY:
if attn_module.sliding_window is None:
attn_spec: AttentionSpec = FullAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
use_mla=use_mla)
else:
attn_spec = SlidingWindowSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
sliding_window=attn_module.sliding_window,
use_mla=use_mla)
attn_specs[attn_spec].append(layer_name)
else:
raise ValueError("Expected only encoder-only layers")
if len(attn_specs) > 0:
total_layers = 0
for attn_spec, layer_names in attn_specs.items():
attn_backends = get_attn_backends_for_layers(layer_names)
total_layers += len(layer_names)
self.attn_groups.append(
create_attn_groups(attn_backends, attn_spec))
assert total_layers == len(attn_layers), \
"All or none of the layers are expected to be encoder-only"
self.is_encoder_only_model = True
def initialize_cudagraph_capture(self) -> None:
min_cg_support = AttentionCGSupport.ALWAYS
min_cg_builder_name = None
@ -3002,7 +2967,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
layer_names = set()
for group in kv_cache_config.kv_cache_groups:
layer_names.update(group.layer_names)
for layer_name in group.layer_names:
if layer_name in self.runner_only_attn_layers:
continue
layer_names.add(layer_name)
assert layer_names == set(kv_cache_raw_tensors.keys(
)), "Some layers are not correctly initialized"
return kv_cache_raw_tensors
@ -3040,6 +3008,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
for kv_cache_spec, group in self._kv_cache_spec_attn_group_iterator():
attn_backend = group.backend
for layer_name in group.layer_names:
if layer_name in self.runner_only_attn_layers:
continue
raw_tensor = kv_cache_raw_tensors[layer_name]
assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
num_blocks = (raw_tensor.numel() //
@ -3161,6 +3131,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
kv_cache_config.kv_cache_groups,
kv_caches,
self.attn_groups,
self.runner_only_attn_layers,
)
attn_layers = get_layers_from_vllm_config(self.vllm_config,
Attention)
@ -3185,8 +3156,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
kv_cache_config: Configuration for the KV cache, including the KV
cache size of each layer
"""
kv_cache_config = deepcopy(kv_cache_config)
self.kv_cache_config = kv_cache_config
self.may_reinitialize_input_batch(kv_cache_config)
self.may_add_encoder_only_layers_to_kv_cache_config()
self.initialize_attn_backend(kv_cache_config)
kv_caches = self.initialize_kv_cache_tensors(kv_cache_config)
@ -3199,6 +3172,33 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if has_kv_transfer_group():
get_kv_transfer_group().register_kv_caches(kv_caches)
def may_add_encoder_only_layers_to_kv_cache_config(self) -> None:
"""
Add encoder-only layers to the KV cache config.
"""
block_size = self.vllm_config.cache_config.block_size
use_mla = self.vllm_config.model_config.use_mla
encoder_only_attn_specs: dict[AttentionSpec,
list[str]] = defaultdict(list)
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
for layer_name, attn_module in attn_layers.items():
if attn_module.attn_type == AttentionType.ENCODER_ONLY:
attn_spec = EncoderOnlyAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
use_mla=use_mla)
encoder_only_attn_specs[attn_spec].append(layer_name)
self.runner_only_attn_layers.add(layer_name)
if len(encoder_only_attn_specs) > 0:
assert len(
encoder_only_attn_specs
) == 1, "Only support one encoder-only attention spec now"
spec, layer_names = encoder_only_attn_specs.popitem()
self.kv_cache_config.kv_cache_groups.append(
KVCacheGroupSpec(layer_names=layer_names, kv_cache_spec=spec))
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
"""
Generates the KVCacheSpec by parsing the kv cache format from each
@ -3287,70 +3287,3 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
mamba_type=mamba_module.mamba_type)
return kv_cache_spec
def _build_encoder_only_attn_metadata(
self, scheduler_output: "SchedulerOutput") -> \
dict[str, tuple[CommonAttentionMetadata, Any]]:
"""Prepare encoder attention metadata for encoder-only models.
Args:
scheduler_output: Scheduler output
Returns:
dict[str, Any]: Encoder attention metadata
"""
num_reqs = self.input_batch.num_reqs
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
# Get the number of scheduled tokens for each request.
req_ids = self.input_batch.req_ids
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
max_num_scheduled_tokens = max(tokens)
dummy_block_table = torch.zeros((num_reqs, 1),
dtype=torch.int32,
pin_memory=self.pin_memory,
device="cpu").to(self.device,
non_blocking=True)
dummy_slot_mapping = torch.zeros((total_num_scheduled_tokens, ),
dtype=torch.int32,
pin_memory=self.pin_memory,
device="cpu").to(self.device,
non_blocking=True)
group_metadata = dict[str, tuple[CommonAttentionMetadata, Any]]()
for attn_group_list in self.attn_groups:
assert len(attn_group_list) == 1
attn_group = attn_group_list[0]
# Use the first attention metadata builder
# to create encoder attention metadata
builder = attn_group.metadata_builder
common_metadata = CommonAttentionMetadata(
query_start_loc=self.query_start_loc[:num_reqs + 1],
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
seq_lens=self.seq_lens[:num_reqs],
seq_lens_cpu=self.seq_lens_cpu[:num_reqs],
num_computed_tokens_cpu=self.input_batch.
num_computed_tokens_cpu_tensor[:num_reqs],
num_reqs=num_reqs,
num_actual_tokens=total_num_scheduled_tokens,
max_query_len=max_num_scheduled_tokens,
max_seq_len=self.seq_lens_cpu[:num_reqs].max().item(),
block_table_tensor=dummy_block_table,
slot_mapping=dummy_slot_mapping,
causal=False,
)
metadata = builder.build(
common_prefix_len=0, # No cascade for encoder
common_attn_metadata=common_metadata,
)
for layer_name in attn_group.layer_names:
group_metadata[layer_name] = (common_metadata, metadata)
return group_metadata

View File

@ -204,6 +204,7 @@ def initialize_kv_cache_for_kv_sharing(
kv_caches: dict[str, torch.Tensor],
# Optional for now to avoid breaking TPU
attn_groups: Optional[list[list[AttentionGroup]]] = None,
runner_only_attn_layers: Optional[set[str]] = None,
) -> None:
"""
Sets up KV cache sharing by reusing the allocated KV caches in `kv_caches`
@ -250,6 +251,9 @@ def initialize_kv_cache_for_kv_sharing(
attn_groups[kv_cache_group_idx][attn_group_idx].layer_names.append(
layer_name)
if runner_only_attn_layers is not None:
runner_only_attn_layers.add(layer_name)
def bind_kv_cache(
kv_caches: dict[str, torch.Tensor],