From 17373dcd93ca60554d72cef4e159e70abbfd15af Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Thu, 21 Aug 2025 22:05:59 -0700 Subject: [PATCH] [Attention] Refactor AttentionMetadata Preparation for Encoder-only Models (#23154) Signed-off-by: Chen Zhang --- tests/v1/worker/test_gpu_model_runner.py | 11 +- .../layers/chunked_local_attention.py | 29 +-- .../layers/encoder_only_attention.py | 86 +++++++ vllm/model_executor/models/bert.py | 17 +- vllm/model_executor/models/bert_with_rope.py | 17 +- vllm/model_executor/models/llama.py | 6 +- vllm/model_executor/models/modernbert.py | 14 +- vllm/model_executor/models/qwen2.py | 5 +- vllm/v1/attention/backends/utils.py | 32 +-- vllm/v1/kv_cache_interface.py | 8 + vllm/v1/worker/gpu_model_runner.py | 211 ++++++------------ vllm/v1/worker/utils.py | 4 + 12 files changed, 226 insertions(+), 214 deletions(-) create mode 100644 vllm/attention/layers/encoder_only_attention.py diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 4bcc63f293e03..b9b2314ce573f 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -680,6 +680,7 @@ def test_init_kv_cache_with_kv_sharing_valid(): kv_cache_spec[layer_0].page_size_bytes runner.initialize_kv_cache(kv_cache_config) + kv_cache_config_after_init = runner.kv_cache_config layer_0_kv = vllm_ctx[layer_0].kv_cache[0] layer_1_kv = vllm_ctx[layer_1].kv_cache[0] @@ -687,10 +688,12 @@ def test_init_kv_cache_with_kv_sharing_valid(): assert id(layer_1_kv) == id(layer_0_kv) # check layer 1 added to kv cache group's layer names - assert len(kv_cache_config.kv_cache_groups) == 1 - assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2 - assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0 - assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1 + assert len(kv_cache_config_after_init.kv_cache_groups) == 1 + assert len(kv_cache_config_after_init.kv_cache_groups[0].layer_names) == 2 + assert kv_cache_config_after_init.kv_cache_groups[0].layer_names[ + 0] == layer_0 + assert kv_cache_config_after_init.kv_cache_groups[0].layer_names[ + 1] == layer_1 def test_hybrid_attention_mamba_tensor_shapes(monkeypatch): diff --git a/vllm/attention/layers/chunked_local_attention.py b/vllm/attention/layers/chunked_local_attention.py index 892077ba91e07..087c5004bde06 100644 --- a/vllm/attention/layers/chunked_local_attention.py +++ b/vllm/attention/layers/chunked_local_attention.py @@ -6,12 +6,13 @@ from typing import List, Optional import torch from vllm import envs -from vllm.attention.backends.abstract import AttentionBackend +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionMetadata) from vllm.attention.selector import get_attn_backend from vllm.config import CacheConfig, QuantizationConfig from vllm.v1.attention.backends.utils import ( CommonAttentionMetadata, make_local_attention_virtual_batches, - subclass_attention_backend, subclass_attention_metadata_builder) + subclass_attention_backend) from ..layer import Attention @@ -24,21 +25,23 @@ def create_chunked_local_attention_backend( ) -> type[AttentionBackend]: prefix = f"ChunkedLocalAttention_{attention_chunk_size}_{block_size}_" - def build_preprocess_fn(cm: CommonAttentionMetadata): - return make_local_attention_virtual_batches(attention_chunk_size, cm, - block_size) + underlying_builder = underlying_attn_backend.get_builder_cls() + + class ChunkedLocalAttentionBuilder(underlying_builder): # type: ignore + + def build(self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False) -> AttentionMetadata: + common_attn_metadata = make_local_attention_virtual_batches( + attention_chunk_size, common_attn_metadata, block_size) + return super().build(common_prefix_len, common_attn_metadata, + fast_build) - # Dynamically create a new attention backend that wraps the - # underlying attention backend but applies - # `make_local_attention_virtual_batches` before calling `build(...)` - builder_cls = subclass_attention_metadata_builder( - name_prefix=prefix, - builder_cls=underlying_attn_backend.get_builder_cls(), - build_preprocess_fn=build_preprocess_fn) attn_backend = subclass_attention_backend( name_prefix=prefix, attention_backend_cls=underlying_attn_backend, - builder_cls=builder_cls) + builder_cls=ChunkedLocalAttentionBuilder) return attn_backend diff --git a/vllm/attention/layers/encoder_only_attention.py b/vllm/attention/layers/encoder_only_attention.py new file mode 100644 index 0000000000000..7b3dcbd823c06 --- /dev/null +++ b/vllm/attention/layers/encoder_only_attention.py @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import functools +from copy import copy +from typing import Optional + +import torch +from transformers import CacheConfig + +from vllm import envs +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionMetadata, AttentionType) +from vllm.attention.layer import Attention +from vllm.attention.selector import get_attn_backend +from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, + subclass_attention_backend) + + +@functools.lru_cache +def create_encoder_only_attention_backend( + underlying_attn_backend: AttentionBackend, ) -> type[AttentionBackend]: + prefix = "EncoderOnlyAttention_" + underlying_builder = underlying_attn_backend.get_builder_cls() + + class EncoderOnlyAttentionBuilder(underlying_builder): # type: ignore + + def build(self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False) -> AttentionMetadata: + new_common_attn_metadata = copy(common_attn_metadata) + new_common_attn_metadata.causal = False + return super().build(common_prefix_len, new_common_attn_metadata, + fast_build) + + attn_backend = subclass_attention_backend( + name_prefix=prefix, + attention_backend_cls=underlying_attn_backend, + builder_cls=EncoderOnlyAttentionBuilder) + + return attn_backend + + +class EncoderOnlyAttention(Attention): + """ + Encoder attention is a special case that doesn't need a KV Cache. + """ + + def __init__(self, + num_heads: int, + head_size: int, + scale: float, + cache_config: Optional[CacheConfig] = None, + attn_type: Optional[str] = None, + **kwargs): + dtype = torch.get_default_dtype() + + if cache_config is not None: + kv_cache_dtype = cache_config.cache_dtype + block_size = cache_config.block_size + else: + kv_cache_dtype = "auto" + block_size = 16 + + if envs.VLLM_USE_V1: + underlying_attn_backend = get_attn_backend(head_size, dtype, + kv_cache_dtype, + block_size) + + attn_backend = create_encoder_only_attention_backend( + underlying_attn_backend) + else: + # in v0 encoder only attention is handled inside the backends + attn_backend = None + + if attn_type is not None: + assert attn_type == AttentionType.ENCODER_ONLY, \ + "EncoderOnlyAttention only supports AttentionType.ENCODER_ONLY" + + super().__init__(num_heads=num_heads, + head_size=head_size, + scale=scale, + cache_config=cache_config, + attn_backend=attn_backend, + attn_type=AttentionType.ENCODER_ONLY, + **kwargs) diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 2bd5eb5bb7aa8..22b6c4401213c 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -8,7 +8,7 @@ import torch from torch import nn from transformers import BertConfig -from vllm.attention import Attention, AttentionType +from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, PoolerConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size @@ -239,14 +239,13 @@ class BertSelfAttention(nn.Module): quant_config=quant_config, prefix=f"{prefix}.qkv_proj") - self.attn = Attention(num_heads=self.num_heads, - head_size=self.head_dim, - scale=self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - attn_type=AttentionType.ENCODER_ONLY) + self.attn = EncoderOnlyAttention(num_heads=self.num_heads, + head_size=self.head_dim, + scale=self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, diff --git a/vllm/model_executor/models/bert_with_rope.py b/vllm/model_executor/models/bert_with_rope.py index e18b7b7ffabab..129450927e564 100644 --- a/vllm/model_executor/models/bert_with_rope.py +++ b/vllm/model_executor/models/bert_with_rope.py @@ -7,7 +7,7 @@ import torch from torch import nn from transformers import PretrainedConfig -from vllm.attention import Attention, AttentionType +from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (divide, get_tensor_model_parallel_rank, @@ -119,14 +119,13 @@ class BertWithRopeAttention(nn.Module): self.rotary_emb = get_rope(**rotary_kwargs) - self.attn = Attention(num_heads=self.num_heads, - head_size=self.head_dim, - scale=self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - attn_type=AttentionType.ENCODER_ONLY) + self.attn = EncoderOnlyAttention(num_heads=self.num_heads, + head_size=self.head_dim, + scale=self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") self.out_proj = RowParallelLinear(input_size=hidden_size, output_size=hidden_size, diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 24cd448d8361f..f99f1c3643fd4 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -31,6 +31,7 @@ from torch import nn from transformers import LlamaConfig from vllm.attention import Attention, AttentionType +from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -173,7 +174,10 @@ class LlamaAttention(nn.Module): if is_sliding: sliding_window = config.sliding_window - self.attn = Attention( + attn_cls = (EncoderOnlyAttention + if attn_type == AttentionType.ENCODER_ONLY else Attention) + + self.attn = attn_cls( self.num_heads, self.head_dim, self.scaling, diff --git a/vllm/model_executor/models/modernbert.py b/vllm/model_executor/models/modernbert.py index c6e84e2d4e040..72290bf2ee29f 100644 --- a/vllm/model_executor/models/modernbert.py +++ b/vllm/model_executor/models/modernbert.py @@ -7,7 +7,7 @@ import torch from torch import nn from transformers import ModernBertConfig -from vllm.attention import Attention, AttentionType +from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size @@ -104,12 +104,12 @@ class ModernBertAttention(nn.Module): head_size=self.head_dim, dim=self.head_dim, base=rope_theta) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - prefix=f"{layer_id}.attn", - attn_type=AttentionType.ENCODER_ONLY, - per_layer_sliding_window=sliding_window) + self.attn = EncoderOnlyAttention( + self.num_heads, + self.head_dim, + self.scaling, + prefix=f"{layer_id}.attn", + per_layer_sliding_window=sliding_window) self.Wo = RowParallelLinear(config.hidden_size, config.hidden_size, bias=config.attention_bias) diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 7304fbf120ccd..b6a1d2db303c7 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -32,6 +32,7 @@ from torch import nn from transformers import Qwen2Config from vllm.attention import Attention, AttentionType +from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -159,7 +160,9 @@ class Qwen2Attention(nn.Module): rope_scaling=rope_scaling, dual_chunk_attention_config=dual_chunk_attention_config, ) - self.attn = Attention( + attn_cls = (EncoderOnlyAttention + if attn_type == AttentionType.ENCODER_ONLY else Attention) + self.attn = attn_cls( self.num_heads, self.head_dim, self.scaling, diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 57c4d436c5b6b..39bdbe125635b 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -5,8 +5,7 @@ import enum import functools from abc import abstractmethod from dataclasses import dataclass, make_dataclass -from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Generic, Optional, - TypeVar) +from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar import numpy as np import torch @@ -543,35 +542,6 @@ def make_local_attention_virtual_batches( ) -def subclass_attention_metadata_builder( - name_prefix: str, - builder_cls: type[AttentionMetadataBuilder[M]], - build_preprocess_fn: Callable[[CommonAttentionMetadata], - CommonAttentionMetadata], -) -> type[AttentionMetadataBuilder[M]]: - """ - Return a new subclass of `builder_cls` whose .build(...) method - first calls build_preprocess_fn(common_attn_metadata) on the metadata. - """ - name: str = name_prefix + builder_cls.__name__ # type: ignore - - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False): - return builder_cls.build(self, common_prefix_len, - build_preprocess_fn(common_attn_metadata), - fast_build) - - Wrapped = type( - name, - (builder_cls, ), # inherit from the original - { - "build": build, - }) - return Wrapped # type: ignore - - def subclass_attention_backend( name_prefix: str, attention_backend_cls: type[AttentionBackend], builder_cls: type[AttentionMetadataBuilder[M]] diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 429416afa2483..ed8e0bf798988 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -203,6 +203,14 @@ class MambaSpec(KVCacheSpec): return self.page_size_bytes +@dataclass(frozen=True) +class EncoderOnlyAttentionSpec(AttentionSpec): + + def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: + # Encoder-only layers do not need KV cache + return 0 + + @dataclass class KVCacheTensor: """ diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 870aca41ec2ab..d520b71de3ff9 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -8,6 +8,7 @@ import time from collections import defaultdict from collections.abc import Iterator from contextlib import contextmanager +from copy import deepcopy from typing import TYPE_CHECKING, Any, Optional, Union, cast import numpy as np @@ -62,9 +63,10 @@ from vllm.v1.attention.backends.utils import ( from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher from vllm.v1.kv_cache_interface import (AttentionSpec, ChunkedLocalAttentionSpec, + EncoderOnlyAttentionSpec, FullAttentionSpec, KVCacheConfig, - KVCacheSpec, MambaSpec, - SlidingWindowSpec) + KVCacheGroupSpec, KVCacheSpec, + MambaSpec, SlidingWindowSpec) from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds, LogprobsTensors, ModelRunnerOutput) from vllm.v1.pool.metadata import PoolingMetadata @@ -136,7 +138,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): cache_config.cache_dtype] self.is_pooling_model = model_config.pooler_config is not None - self.is_encoder_only_model = False self.is_multimodal_raw_input_supported = ( model_config.is_multimodal_raw_input_supported) self.max_model_len = model_config.max_model_len @@ -345,6 +346,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.reorder_batch_threshold: Optional[int] = None + # Attention layers that are only in the KVCacheConfig of the runner + # (e.g., KV sharing, encoder-only attention), but not in the + # KVCacheConfig of the scheduler. + self.runner_only_attn_layers: set[str] = set() + # Cached outputs. self._draft_token_ids: Optional[Union[list[list[int]], torch.Tensor]] = None @@ -834,23 +840,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): attn_metadata: dict[str, Any] = {} - # Prepare encoder attention metadata separately - # (encoder layers are not in KV cache groups) - if self.is_encoder_only_model: - - per_layer_metadata = \ - self._build_encoder_only_attn_metadata( - scheduler_output) - - # Add encoder attention metadata for all encoder layers - attention_layers = get_layers_from_vllm_config( - self.vllm_config, Attention) - for layer_name, attn_module in attention_layers.items(): - if attn_module.attn_type == AttentionType.ENCODER_ONLY: - common_attn_metadata, encoder_attn_metadata =\ - per_layer_metadata[layer_name] - attn_metadata[layer_name] = encoder_attn_metadata - # Used in the below loop. query_start_loc_cpu = self.query_start_loc_cpu[:num_reqs + 1] seq_lens_cpu = self.seq_lens_cpu[:num_reqs] @@ -863,13 +852,33 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): for kv_cache_group_id, kv_cache_group_spec in enumerate( self.kv_cache_config.kv_cache_groups): - blk_table = self.input_batch.block_table[kv_cache_group_id] - blk_table_tensor = blk_table.get_device_tensor()[:num_reqs] - slot_mapping = blk_table.slot_mapping[:total_num_scheduled_tokens] + if isinstance(kv_cache_group_spec.kv_cache_spec, + EncoderOnlyAttentionSpec): + # Encoder-only layers do not have KV cache, so we need to + # create a dummy block table and slot mapping for them. + blk_table_tensor = torch.zeros( + (num_reqs, 1), + dtype=torch.int32, + pin_memory=self.pin_memory, + device="cpu").to(self.device, non_blocking=True) + slot_mapping = torch.zeros((total_num_scheduled_tokens, ), + dtype=torch.int32, + pin_memory=self.pin_memory, + device="cpu").to(self.device, + non_blocking=True) + num_common_prefix_blocks = 0 + else: + blk_table = self.input_batch.block_table[kv_cache_group_id] + blk_table_tensor = blk_table.get_device_tensor()[:num_reqs] + slot_mapping = blk_table.slot_mapping[: + total_num_scheduled_tokens] - # Fill unused with -1. Needed for reshape_and_cache in full cuda - # graph mode. - blk_table.slot_mapping[total_num_scheduled_tokens:].fill_(-1) + # Fill unused with -1. Needed for reshape_and_cache in full cuda + # graph mode. + blk_table.slot_mapping[total_num_scheduled_tokens:].fill_(-1) + num_common_prefix_blocks = ( + scheduler_output. + num_common_prefix_blocks[kv_cache_group_id]) common_attn_metadata = CommonAttentionMetadata( query_start_loc=query_start_loc, @@ -897,8 +906,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if self.cascade_attn_enabled: common_prefix_len = self._compute_cascade_attn_prefix_len( num_scheduled_tokens, - scheduler_output. - num_common_prefix_blocks[kv_cache_group_id], + num_common_prefix_blocks, kv_cache_group_spec.kv_cache_spec, builder, ) @@ -2812,49 +2820,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Calculate reorder batch threshold (if neeeded) self.calculate_reorder_batch_threshold() - if len(self.attn_groups) > 0: - return - - # Check if model is encoder-only - block_size = self.vllm_config.cache_config.block_size - use_mla = self.vllm_config.model_config.use_mla - attn_specs: dict[AttentionSpec, list[str]] = defaultdict(list) - for layer_name, attn_module in attn_layers.items(): - - if attn_module.attn_type == AttentionType.ENCODER_ONLY: - if attn_module.sliding_window is None: - attn_spec: AttentionSpec = FullAttentionSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - use_mla=use_mla) - else: - attn_spec = SlidingWindowSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - sliding_window=attn_module.sliding_window, - use_mla=use_mla) - attn_specs[attn_spec].append(layer_name) - - else: - raise ValueError("Expected only encoder-only layers") - - if len(attn_specs) > 0: - total_layers = 0 - for attn_spec, layer_names in attn_specs.items(): - - attn_backends = get_attn_backends_for_layers(layer_names) - total_layers += len(layer_names) - - self.attn_groups.append( - create_attn_groups(attn_backends, attn_spec)) - assert total_layers == len(attn_layers), \ - "All or none of the layers are expected to be encoder-only" - self.is_encoder_only_model = True - def initialize_cudagraph_capture(self) -> None: min_cg_support = AttentionCGSupport.ALWAYS min_cg_builder_name = None @@ -3002,7 +2967,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): layer_names = set() for group in kv_cache_config.kv_cache_groups: - layer_names.update(group.layer_names) + for layer_name in group.layer_names: + if layer_name in self.runner_only_attn_layers: + continue + layer_names.add(layer_name) assert layer_names == set(kv_cache_raw_tensors.keys( )), "Some layers are not correctly initialized" return kv_cache_raw_tensors @@ -3040,6 +3008,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): for kv_cache_spec, group in self._kv_cache_spec_attn_group_iterator(): attn_backend = group.backend for layer_name in group.layer_names: + if layer_name in self.runner_only_attn_layers: + continue raw_tensor = kv_cache_raw_tensors[layer_name] assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0 num_blocks = (raw_tensor.numel() // @@ -3161,6 +3131,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): kv_cache_config.kv_cache_groups, kv_caches, self.attn_groups, + self.runner_only_attn_layers, ) attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) @@ -3185,8 +3156,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): kv_cache_config: Configuration for the KV cache, including the KV cache size of each layer """ + kv_cache_config = deepcopy(kv_cache_config) self.kv_cache_config = kv_cache_config self.may_reinitialize_input_batch(kv_cache_config) + self.may_add_encoder_only_layers_to_kv_cache_config() self.initialize_attn_backend(kv_cache_config) kv_caches = self.initialize_kv_cache_tensors(kv_cache_config) @@ -3199,6 +3172,33 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if has_kv_transfer_group(): get_kv_transfer_group().register_kv_caches(kv_caches) + def may_add_encoder_only_layers_to_kv_cache_config(self) -> None: + """ + Add encoder-only layers to the KV cache config. + """ + block_size = self.vllm_config.cache_config.block_size + use_mla = self.vllm_config.model_config.use_mla + encoder_only_attn_specs: dict[AttentionSpec, + list[str]] = defaultdict(list) + attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) + for layer_name, attn_module in attn_layers.items(): + if attn_module.attn_type == AttentionType.ENCODER_ONLY: + attn_spec = EncoderOnlyAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=self.kv_cache_dtype, + use_mla=use_mla) + encoder_only_attn_specs[attn_spec].append(layer_name) + self.runner_only_attn_layers.add(layer_name) + if len(encoder_only_attn_specs) > 0: + assert len( + encoder_only_attn_specs + ) == 1, "Only support one encoder-only attention spec now" + spec, layer_names = encoder_only_attn_specs.popitem() + self.kv_cache_config.kv_cache_groups.append( + KVCacheGroupSpec(layer_names=layer_names, kv_cache_spec=spec)) + def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: """ Generates the KVCacheSpec by parsing the kv cache format from each @@ -3287,70 +3287,3 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): mamba_type=mamba_module.mamba_type) return kv_cache_spec - - def _build_encoder_only_attn_metadata( - self, scheduler_output: "SchedulerOutput") -> \ - dict[str, tuple[CommonAttentionMetadata, Any]]: - """Prepare encoder attention metadata for encoder-only models. - - Args: - scheduler_output: Scheduler output - - Returns: - dict[str, Any]: Encoder attention metadata - """ - num_reqs = self.input_batch.num_reqs - total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - - # Get the number of scheduled tokens for each request. - req_ids = self.input_batch.req_ids - tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] - max_num_scheduled_tokens = max(tokens) - - dummy_block_table = torch.zeros((num_reqs, 1), - dtype=torch.int32, - pin_memory=self.pin_memory, - device="cpu").to(self.device, - non_blocking=True) - dummy_slot_mapping = torch.zeros((total_num_scheduled_tokens, ), - dtype=torch.int32, - pin_memory=self.pin_memory, - device="cpu").to(self.device, - non_blocking=True) - - group_metadata = dict[str, tuple[CommonAttentionMetadata, Any]]() - - for attn_group_list in self.attn_groups: - - assert len(attn_group_list) == 1 - attn_group = attn_group_list[0] - - # Use the first attention metadata builder - # to create encoder attention metadata - builder = attn_group.metadata_builder - - common_metadata = CommonAttentionMetadata( - query_start_loc=self.query_start_loc[:num_reqs + 1], - query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1], - seq_lens=self.seq_lens[:num_reqs], - seq_lens_cpu=self.seq_lens_cpu[:num_reqs], - num_computed_tokens_cpu=self.input_batch. - num_computed_tokens_cpu_tensor[:num_reqs], - num_reqs=num_reqs, - num_actual_tokens=total_num_scheduled_tokens, - max_query_len=max_num_scheduled_tokens, - max_seq_len=self.seq_lens_cpu[:num_reqs].max().item(), - block_table_tensor=dummy_block_table, - slot_mapping=dummy_slot_mapping, - causal=False, - ) - - metadata = builder.build( - common_prefix_len=0, # No cascade for encoder - common_attn_metadata=common_metadata, - ) - - for layer_name in attn_group.layer_names: - group_metadata[layer_name] = (common_metadata, metadata) - - return group_metadata diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index c7ccd2e254976..ffc1a11bc3ba1 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -204,6 +204,7 @@ def initialize_kv_cache_for_kv_sharing( kv_caches: dict[str, torch.Tensor], # Optional for now to avoid breaking TPU attn_groups: Optional[list[list[AttentionGroup]]] = None, + runner_only_attn_layers: Optional[set[str]] = None, ) -> None: """ Sets up KV cache sharing by reusing the allocated KV caches in `kv_caches` @@ -250,6 +251,9 @@ def initialize_kv_cache_for_kv_sharing( attn_groups[kv_cache_group_idx][attn_group_idx].layer_names.append( layer_name) + if runner_only_attn_layers is not None: + runner_only_attn_layers.add(layer_name) + def bind_kv_cache( kv_caches: dict[str, torch.Tensor],