mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 18:34:28 +08:00
[Attention] Refactor AttentionMetadata Preparation for Encoder-only Models (#23154)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
parent
5964069367
commit
17373dcd93
@ -680,6 +680,7 @@ def test_init_kv_cache_with_kv_sharing_valid():
|
||||
kv_cache_spec[layer_0].page_size_bytes
|
||||
|
||||
runner.initialize_kv_cache(kv_cache_config)
|
||||
kv_cache_config_after_init = runner.kv_cache_config
|
||||
|
||||
layer_0_kv = vllm_ctx[layer_0].kv_cache[0]
|
||||
layer_1_kv = vllm_ctx[layer_1].kv_cache[0]
|
||||
@ -687,10 +688,12 @@ def test_init_kv_cache_with_kv_sharing_valid():
|
||||
assert id(layer_1_kv) == id(layer_0_kv)
|
||||
|
||||
# check layer 1 added to kv cache group's layer names
|
||||
assert len(kv_cache_config.kv_cache_groups) == 1
|
||||
assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2
|
||||
assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0
|
||||
assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1
|
||||
assert len(kv_cache_config_after_init.kv_cache_groups) == 1
|
||||
assert len(kv_cache_config_after_init.kv_cache_groups[0].layer_names) == 2
|
||||
assert kv_cache_config_after_init.kv_cache_groups[0].layer_names[
|
||||
0] == layer_0
|
||||
assert kv_cache_config_after_init.kv_cache_groups[0].layer_names[
|
||||
1] == layer_1
|
||||
|
||||
|
||||
def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
|
||||
|
||||
@ -6,12 +6,13 @@ from typing import List, Optional
|
||||
import torch
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadata)
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
from vllm.config import CacheConfig, QuantizationConfig
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
CommonAttentionMetadata, make_local_attention_virtual_batches,
|
||||
subclass_attention_backend, subclass_attention_metadata_builder)
|
||||
subclass_attention_backend)
|
||||
|
||||
from ..layer import Attention
|
||||
|
||||
@ -24,21 +25,23 @@ def create_chunked_local_attention_backend(
|
||||
) -> type[AttentionBackend]:
|
||||
prefix = f"ChunkedLocalAttention_{attention_chunk_size}_{block_size}_"
|
||||
|
||||
def build_preprocess_fn(cm: CommonAttentionMetadata):
|
||||
return make_local_attention_virtual_batches(attention_chunk_size, cm,
|
||||
block_size)
|
||||
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||
|
||||
class ChunkedLocalAttentionBuilder(underlying_builder): # type: ignore
|
||||
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False) -> AttentionMetadata:
|
||||
common_attn_metadata = make_local_attention_virtual_batches(
|
||||
attention_chunk_size, common_attn_metadata, block_size)
|
||||
return super().build(common_prefix_len, common_attn_metadata,
|
||||
fast_build)
|
||||
|
||||
# Dynamically create a new attention backend that wraps the
|
||||
# underlying attention backend but applies
|
||||
# `make_local_attention_virtual_batches` before calling `build(...)`
|
||||
builder_cls = subclass_attention_metadata_builder(
|
||||
name_prefix=prefix,
|
||||
builder_cls=underlying_attn_backend.get_builder_cls(),
|
||||
build_preprocess_fn=build_preprocess_fn)
|
||||
attn_backend = subclass_attention_backend(
|
||||
name_prefix=prefix,
|
||||
attention_backend_cls=underlying_attn_backend,
|
||||
builder_cls=builder_cls)
|
||||
builder_cls=ChunkedLocalAttentionBuilder)
|
||||
|
||||
return attn_backend
|
||||
|
||||
|
||||
86
vllm/attention/layers/encoder_only_attention.py
Normal file
86
vllm/attention/layers/encoder_only_attention.py
Normal file
@ -0,0 +1,86 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
from copy import copy
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from transformers import CacheConfig
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
|
||||
subclass_attention_backend)
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def create_encoder_only_attention_backend(
|
||||
underlying_attn_backend: AttentionBackend, ) -> type[AttentionBackend]:
|
||||
prefix = "EncoderOnlyAttention_"
|
||||
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||
|
||||
class EncoderOnlyAttentionBuilder(underlying_builder): # type: ignore
|
||||
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False) -> AttentionMetadata:
|
||||
new_common_attn_metadata = copy(common_attn_metadata)
|
||||
new_common_attn_metadata.causal = False
|
||||
return super().build(common_prefix_len, new_common_attn_metadata,
|
||||
fast_build)
|
||||
|
||||
attn_backend = subclass_attention_backend(
|
||||
name_prefix=prefix,
|
||||
attention_backend_cls=underlying_attn_backend,
|
||||
builder_cls=EncoderOnlyAttentionBuilder)
|
||||
|
||||
return attn_backend
|
||||
|
||||
|
||||
class EncoderOnlyAttention(Attention):
|
||||
"""
|
||||
Encoder attention is a special case that doesn't need a KV Cache.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
attn_type: Optional[str] = None,
|
||||
**kwargs):
|
||||
dtype = torch.get_default_dtype()
|
||||
|
||||
if cache_config is not None:
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
block_size = cache_config.block_size
|
||||
else:
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = 16
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
underlying_attn_backend = get_attn_backend(head_size, dtype,
|
||||
kv_cache_dtype,
|
||||
block_size)
|
||||
|
||||
attn_backend = create_encoder_only_attention_backend(
|
||||
underlying_attn_backend)
|
||||
else:
|
||||
# in v0 encoder only attention is handled inside the backends
|
||||
attn_backend = None
|
||||
|
||||
if attn_type is not None:
|
||||
assert attn_type == AttentionType.ENCODER_ONLY, \
|
||||
"EncoderOnlyAttention only supports AttentionType.ENCODER_ONLY"
|
||||
|
||||
super().__init__(num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
cache_config=cache_config,
|
||||
attn_backend=attn_backend,
|
||||
attn_type=AttentionType.ENCODER_ONLY,
|
||||
**kwargs)
|
||||
@ -8,7 +8,7 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import BertConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionType
|
||||
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, PoolerConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
@ -239,14 +239,13 @@ class BertSelfAttention(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj")
|
||||
|
||||
self.attn = Attention(num_heads=self.num_heads,
|
||||
head_size=self.head_dim,
|
||||
scale=self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
attn_type=AttentionType.ENCODER_ONLY)
|
||||
self.attn = EncoderOnlyAttention(num_heads=self.num_heads,
|
||||
head_size=self.head_dim,
|
||||
scale=self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -7,7 +7,7 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionType
|
||||
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
@ -119,14 +119,13 @@ class BertWithRopeAttention(nn.Module):
|
||||
|
||||
self.rotary_emb = get_rope(**rotary_kwargs)
|
||||
|
||||
self.attn = Attention(num_heads=self.num_heads,
|
||||
head_size=self.head_dim,
|
||||
scale=self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
attn_type=AttentionType.ENCODER_ONLY)
|
||||
self.attn = EncoderOnlyAttention(num_heads=self.num_heads,
|
||||
head_size=self.head_dim,
|
||||
scale=self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
self.out_proj = RowParallelLinear(input_size=hidden_size,
|
||||
output_size=hidden_size,
|
||||
|
||||
@ -31,6 +31,7 @@ from torch import nn
|
||||
from transformers import LlamaConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionType
|
||||
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
@ -173,7 +174,10 @@ class LlamaAttention(nn.Module):
|
||||
if is_sliding:
|
||||
sliding_window = config.sliding_window
|
||||
|
||||
self.attn = Attention(
|
||||
attn_cls = (EncoderOnlyAttention
|
||||
if attn_type == AttentionType.ENCODER_ONLY else Attention)
|
||||
|
||||
self.attn = attn_cls(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
|
||||
@ -7,7 +7,7 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import ModernBertConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionType
|
||||
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
@ -104,12 +104,12 @@ class ModernBertAttention(nn.Module):
|
||||
head_size=self.head_dim,
|
||||
dim=self.head_dim,
|
||||
base=rope_theta)
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
prefix=f"{layer_id}.attn",
|
||||
attn_type=AttentionType.ENCODER_ONLY,
|
||||
per_layer_sliding_window=sliding_window)
|
||||
self.attn = EncoderOnlyAttention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
prefix=f"{layer_id}.attn",
|
||||
per_layer_sliding_window=sliding_window)
|
||||
self.Wo = RowParallelLinear(config.hidden_size,
|
||||
config.hidden_size,
|
||||
bias=config.attention_bias)
|
||||
|
||||
@ -32,6 +32,7 @@ from torch import nn
|
||||
from transformers import Qwen2Config
|
||||
|
||||
from vllm.attention import Attention, AttentionType
|
||||
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
@ -159,7 +160,9 @@ class Qwen2Attention(nn.Module):
|
||||
rope_scaling=rope_scaling,
|
||||
dual_chunk_attention_config=dual_chunk_attention_config,
|
||||
)
|
||||
self.attn = Attention(
|
||||
attn_cls = (EncoderOnlyAttention
|
||||
if attn_type == AttentionType.ENCODER_ONLY else Attention)
|
||||
self.attn = attn_cls(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
|
||||
@ -5,8 +5,7 @@ import enum
|
||||
import functools
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass, make_dataclass
|
||||
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Generic, Optional,
|
||||
TypeVar)
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -543,35 +542,6 @@ def make_local_attention_virtual_batches(
|
||||
)
|
||||
|
||||
|
||||
def subclass_attention_metadata_builder(
|
||||
name_prefix: str,
|
||||
builder_cls: type[AttentionMetadataBuilder[M]],
|
||||
build_preprocess_fn: Callable[[CommonAttentionMetadata],
|
||||
CommonAttentionMetadata],
|
||||
) -> type[AttentionMetadataBuilder[M]]:
|
||||
"""
|
||||
Return a new subclass of `builder_cls` whose .build(...) method
|
||||
first calls build_preprocess_fn(common_attn_metadata) on the metadata.
|
||||
"""
|
||||
name: str = name_prefix + builder_cls.__name__ # type: ignore
|
||||
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False):
|
||||
return builder_cls.build(self, common_prefix_len,
|
||||
build_preprocess_fn(common_attn_metadata),
|
||||
fast_build)
|
||||
|
||||
Wrapped = type(
|
||||
name,
|
||||
(builder_cls, ), # inherit from the original
|
||||
{
|
||||
"build": build,
|
||||
})
|
||||
return Wrapped # type: ignore
|
||||
|
||||
|
||||
def subclass_attention_backend(
|
||||
name_prefix: str, attention_backend_cls: type[AttentionBackend],
|
||||
builder_cls: type[AttentionMetadataBuilder[M]]
|
||||
|
||||
@ -203,6 +203,14 @@ class MambaSpec(KVCacheSpec):
|
||||
return self.page_size_bytes
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EncoderOnlyAttentionSpec(AttentionSpec):
|
||||
|
||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||
# Encoder-only layers do not need KV cache
|
||||
return 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class KVCacheTensor:
|
||||
"""
|
||||
|
||||
@ -8,6 +8,7 @@ import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from copy import deepcopy
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||
|
||||
import numpy as np
|
||||
@ -62,9 +63,10 @@ from vllm.v1.attention.backends.utils import (
|
||||
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
|
||||
from vllm.v1.kv_cache_interface import (AttentionSpec,
|
||||
ChunkedLocalAttentionSpec,
|
||||
EncoderOnlyAttentionSpec,
|
||||
FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheSpec, MambaSpec,
|
||||
SlidingWindowSpec)
|
||||
KVCacheGroupSpec, KVCacheSpec,
|
||||
MambaSpec, SlidingWindowSpec)
|
||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds,
|
||||
LogprobsTensors, ModelRunnerOutput)
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
@ -136,7 +138,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
cache_config.cache_dtype]
|
||||
|
||||
self.is_pooling_model = model_config.pooler_config is not None
|
||||
self.is_encoder_only_model = False
|
||||
self.is_multimodal_raw_input_supported = (
|
||||
model_config.is_multimodal_raw_input_supported)
|
||||
self.max_model_len = model_config.max_model_len
|
||||
@ -345,6 +346,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
self.reorder_batch_threshold: Optional[int] = None
|
||||
|
||||
# Attention layers that are only in the KVCacheConfig of the runner
|
||||
# (e.g., KV sharing, encoder-only attention), but not in the
|
||||
# KVCacheConfig of the scheduler.
|
||||
self.runner_only_attn_layers: set[str] = set()
|
||||
|
||||
# Cached outputs.
|
||||
self._draft_token_ids: Optional[Union[list[list[int]],
|
||||
torch.Tensor]] = None
|
||||
@ -834,23 +840,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
attn_metadata: dict[str, Any] = {}
|
||||
|
||||
# Prepare encoder attention metadata separately
|
||||
# (encoder layers are not in KV cache groups)
|
||||
if self.is_encoder_only_model:
|
||||
|
||||
per_layer_metadata = \
|
||||
self._build_encoder_only_attn_metadata(
|
||||
scheduler_output)
|
||||
|
||||
# Add encoder attention metadata for all encoder layers
|
||||
attention_layers = get_layers_from_vllm_config(
|
||||
self.vllm_config, Attention)
|
||||
for layer_name, attn_module in attention_layers.items():
|
||||
if attn_module.attn_type == AttentionType.ENCODER_ONLY:
|
||||
common_attn_metadata, encoder_attn_metadata =\
|
||||
per_layer_metadata[layer_name]
|
||||
attn_metadata[layer_name] = encoder_attn_metadata
|
||||
|
||||
# Used in the below loop.
|
||||
query_start_loc_cpu = self.query_start_loc_cpu[:num_reqs + 1]
|
||||
seq_lens_cpu = self.seq_lens_cpu[:num_reqs]
|
||||
@ -863,13 +852,33 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
||||
self.kv_cache_config.kv_cache_groups):
|
||||
|
||||
blk_table = self.input_batch.block_table[kv_cache_group_id]
|
||||
blk_table_tensor = blk_table.get_device_tensor()[:num_reqs]
|
||||
slot_mapping = blk_table.slot_mapping[:total_num_scheduled_tokens]
|
||||
if isinstance(kv_cache_group_spec.kv_cache_spec,
|
||||
EncoderOnlyAttentionSpec):
|
||||
# Encoder-only layers do not have KV cache, so we need to
|
||||
# create a dummy block table and slot mapping for them.
|
||||
blk_table_tensor = torch.zeros(
|
||||
(num_reqs, 1),
|
||||
dtype=torch.int32,
|
||||
pin_memory=self.pin_memory,
|
||||
device="cpu").to(self.device, non_blocking=True)
|
||||
slot_mapping = torch.zeros((total_num_scheduled_tokens, ),
|
||||
dtype=torch.int32,
|
||||
pin_memory=self.pin_memory,
|
||||
device="cpu").to(self.device,
|
||||
non_blocking=True)
|
||||
num_common_prefix_blocks = 0
|
||||
else:
|
||||
blk_table = self.input_batch.block_table[kv_cache_group_id]
|
||||
blk_table_tensor = blk_table.get_device_tensor()[:num_reqs]
|
||||
slot_mapping = blk_table.slot_mapping[:
|
||||
total_num_scheduled_tokens]
|
||||
|
||||
# Fill unused with -1. Needed for reshape_and_cache in full cuda
|
||||
# graph mode.
|
||||
blk_table.slot_mapping[total_num_scheduled_tokens:].fill_(-1)
|
||||
# Fill unused with -1. Needed for reshape_and_cache in full cuda
|
||||
# graph mode.
|
||||
blk_table.slot_mapping[total_num_scheduled_tokens:].fill_(-1)
|
||||
num_common_prefix_blocks = (
|
||||
scheduler_output.
|
||||
num_common_prefix_blocks[kv_cache_group_id])
|
||||
|
||||
common_attn_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=query_start_loc,
|
||||
@ -897,8 +906,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
if self.cascade_attn_enabled:
|
||||
common_prefix_len = self._compute_cascade_attn_prefix_len(
|
||||
num_scheduled_tokens,
|
||||
scheduler_output.
|
||||
num_common_prefix_blocks[kv_cache_group_id],
|
||||
num_common_prefix_blocks,
|
||||
kv_cache_group_spec.kv_cache_spec,
|
||||
builder,
|
||||
)
|
||||
@ -2812,49 +2820,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# Calculate reorder batch threshold (if neeeded)
|
||||
self.calculate_reorder_batch_threshold()
|
||||
|
||||
if len(self.attn_groups) > 0:
|
||||
return
|
||||
|
||||
# Check if model is encoder-only
|
||||
block_size = self.vllm_config.cache_config.block_size
|
||||
use_mla = self.vllm_config.model_config.use_mla
|
||||
attn_specs: dict[AttentionSpec, list[str]] = defaultdict(list)
|
||||
for layer_name, attn_module in attn_layers.items():
|
||||
|
||||
if attn_module.attn_type == AttentionType.ENCODER_ONLY:
|
||||
if attn_module.sliding_window is None:
|
||||
attn_spec: AttentionSpec = FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
use_mla=use_mla)
|
||||
else:
|
||||
attn_spec = SlidingWindowSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
sliding_window=attn_module.sliding_window,
|
||||
use_mla=use_mla)
|
||||
attn_specs[attn_spec].append(layer_name)
|
||||
|
||||
else:
|
||||
raise ValueError("Expected only encoder-only layers")
|
||||
|
||||
if len(attn_specs) > 0:
|
||||
total_layers = 0
|
||||
for attn_spec, layer_names in attn_specs.items():
|
||||
|
||||
attn_backends = get_attn_backends_for_layers(layer_names)
|
||||
total_layers += len(layer_names)
|
||||
|
||||
self.attn_groups.append(
|
||||
create_attn_groups(attn_backends, attn_spec))
|
||||
assert total_layers == len(attn_layers), \
|
||||
"All or none of the layers are expected to be encoder-only"
|
||||
self.is_encoder_only_model = True
|
||||
|
||||
def initialize_cudagraph_capture(self) -> None:
|
||||
min_cg_support = AttentionCGSupport.ALWAYS
|
||||
min_cg_builder_name = None
|
||||
@ -3002,7 +2967,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
layer_names = set()
|
||||
for group in kv_cache_config.kv_cache_groups:
|
||||
layer_names.update(group.layer_names)
|
||||
for layer_name in group.layer_names:
|
||||
if layer_name in self.runner_only_attn_layers:
|
||||
continue
|
||||
layer_names.add(layer_name)
|
||||
assert layer_names == set(kv_cache_raw_tensors.keys(
|
||||
)), "Some layers are not correctly initialized"
|
||||
return kv_cache_raw_tensors
|
||||
@ -3040,6 +3008,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
for kv_cache_spec, group in self._kv_cache_spec_attn_group_iterator():
|
||||
attn_backend = group.backend
|
||||
for layer_name in group.layer_names:
|
||||
if layer_name in self.runner_only_attn_layers:
|
||||
continue
|
||||
raw_tensor = kv_cache_raw_tensors[layer_name]
|
||||
assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
|
||||
num_blocks = (raw_tensor.numel() //
|
||||
@ -3161,6 +3131,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
kv_cache_config.kv_cache_groups,
|
||||
kv_caches,
|
||||
self.attn_groups,
|
||||
self.runner_only_attn_layers,
|
||||
)
|
||||
attn_layers = get_layers_from_vllm_config(self.vllm_config,
|
||||
Attention)
|
||||
@ -3185,8 +3156,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
kv_cache_config: Configuration for the KV cache, including the KV
|
||||
cache size of each layer
|
||||
"""
|
||||
kv_cache_config = deepcopy(kv_cache_config)
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.may_reinitialize_input_batch(kv_cache_config)
|
||||
self.may_add_encoder_only_layers_to_kv_cache_config()
|
||||
self.initialize_attn_backend(kv_cache_config)
|
||||
kv_caches = self.initialize_kv_cache_tensors(kv_cache_config)
|
||||
|
||||
@ -3199,6 +3172,33 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
if has_kv_transfer_group():
|
||||
get_kv_transfer_group().register_kv_caches(kv_caches)
|
||||
|
||||
def may_add_encoder_only_layers_to_kv_cache_config(self) -> None:
|
||||
"""
|
||||
Add encoder-only layers to the KV cache config.
|
||||
"""
|
||||
block_size = self.vllm_config.cache_config.block_size
|
||||
use_mla = self.vllm_config.model_config.use_mla
|
||||
encoder_only_attn_specs: dict[AttentionSpec,
|
||||
list[str]] = defaultdict(list)
|
||||
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
|
||||
for layer_name, attn_module in attn_layers.items():
|
||||
if attn_module.attn_type == AttentionType.ENCODER_ONLY:
|
||||
attn_spec = EncoderOnlyAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
use_mla=use_mla)
|
||||
encoder_only_attn_specs[attn_spec].append(layer_name)
|
||||
self.runner_only_attn_layers.add(layer_name)
|
||||
if len(encoder_only_attn_specs) > 0:
|
||||
assert len(
|
||||
encoder_only_attn_specs
|
||||
) == 1, "Only support one encoder-only attention spec now"
|
||||
spec, layer_names = encoder_only_attn_specs.popitem()
|
||||
self.kv_cache_config.kv_cache_groups.append(
|
||||
KVCacheGroupSpec(layer_names=layer_names, kv_cache_spec=spec))
|
||||
|
||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||||
"""
|
||||
Generates the KVCacheSpec by parsing the kv cache format from each
|
||||
@ -3287,70 +3287,3 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
mamba_type=mamba_module.mamba_type)
|
||||
|
||||
return kv_cache_spec
|
||||
|
||||
def _build_encoder_only_attn_metadata(
|
||||
self, scheduler_output: "SchedulerOutput") -> \
|
||||
dict[str, tuple[CommonAttentionMetadata, Any]]:
|
||||
"""Prepare encoder attention metadata for encoder-only models.
|
||||
|
||||
Args:
|
||||
scheduler_output: Scheduler output
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: Encoder attention metadata
|
||||
"""
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
|
||||
# Get the number of scheduled tokens for each request.
|
||||
req_ids = self.input_batch.req_ids
|
||||
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
|
||||
max_num_scheduled_tokens = max(tokens)
|
||||
|
||||
dummy_block_table = torch.zeros((num_reqs, 1),
|
||||
dtype=torch.int32,
|
||||
pin_memory=self.pin_memory,
|
||||
device="cpu").to(self.device,
|
||||
non_blocking=True)
|
||||
dummy_slot_mapping = torch.zeros((total_num_scheduled_tokens, ),
|
||||
dtype=torch.int32,
|
||||
pin_memory=self.pin_memory,
|
||||
device="cpu").to(self.device,
|
||||
non_blocking=True)
|
||||
|
||||
group_metadata = dict[str, tuple[CommonAttentionMetadata, Any]]()
|
||||
|
||||
for attn_group_list in self.attn_groups:
|
||||
|
||||
assert len(attn_group_list) == 1
|
||||
attn_group = attn_group_list[0]
|
||||
|
||||
# Use the first attention metadata builder
|
||||
# to create encoder attention metadata
|
||||
builder = attn_group.metadata_builder
|
||||
|
||||
common_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=self.query_start_loc[:num_reqs + 1],
|
||||
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
|
||||
seq_lens=self.seq_lens[:num_reqs],
|
||||
seq_lens_cpu=self.seq_lens_cpu[:num_reqs],
|
||||
num_computed_tokens_cpu=self.input_batch.
|
||||
num_computed_tokens_cpu_tensor[:num_reqs],
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=total_num_scheduled_tokens,
|
||||
max_query_len=max_num_scheduled_tokens,
|
||||
max_seq_len=self.seq_lens_cpu[:num_reqs].max().item(),
|
||||
block_table_tensor=dummy_block_table,
|
||||
slot_mapping=dummy_slot_mapping,
|
||||
causal=False,
|
||||
)
|
||||
|
||||
metadata = builder.build(
|
||||
common_prefix_len=0, # No cascade for encoder
|
||||
common_attn_metadata=common_metadata,
|
||||
)
|
||||
|
||||
for layer_name in attn_group.layer_names:
|
||||
group_metadata[layer_name] = (common_metadata, metadata)
|
||||
|
||||
return group_metadata
|
||||
|
||||
@ -204,6 +204,7 @@ def initialize_kv_cache_for_kv_sharing(
|
||||
kv_caches: dict[str, torch.Tensor],
|
||||
# Optional for now to avoid breaking TPU
|
||||
attn_groups: Optional[list[list[AttentionGroup]]] = None,
|
||||
runner_only_attn_layers: Optional[set[str]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Sets up KV cache sharing by reusing the allocated KV caches in `kv_caches`
|
||||
@ -250,6 +251,9 @@ def initialize_kv_cache_for_kv_sharing(
|
||||
attn_groups[kv_cache_group_idx][attn_group_idx].layer_names.append(
|
||||
layer_name)
|
||||
|
||||
if runner_only_attn_layers is not None:
|
||||
runner_only_attn_layers.add(layer_name)
|
||||
|
||||
|
||||
def bind_kv_cache(
|
||||
kv_caches: dict[str, torch.Tensor],
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user