mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-02 01:24:32 +08:00
[Attention] Refactor AttentionMetadata Preparation for Encoder-only Models (#23154)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
parent
5964069367
commit
17373dcd93
@ -680,6 +680,7 @@ def test_init_kv_cache_with_kv_sharing_valid():
|
|||||||
kv_cache_spec[layer_0].page_size_bytes
|
kv_cache_spec[layer_0].page_size_bytes
|
||||||
|
|
||||||
runner.initialize_kv_cache(kv_cache_config)
|
runner.initialize_kv_cache(kv_cache_config)
|
||||||
|
kv_cache_config_after_init = runner.kv_cache_config
|
||||||
|
|
||||||
layer_0_kv = vllm_ctx[layer_0].kv_cache[0]
|
layer_0_kv = vllm_ctx[layer_0].kv_cache[0]
|
||||||
layer_1_kv = vllm_ctx[layer_1].kv_cache[0]
|
layer_1_kv = vllm_ctx[layer_1].kv_cache[0]
|
||||||
@ -687,10 +688,12 @@ def test_init_kv_cache_with_kv_sharing_valid():
|
|||||||
assert id(layer_1_kv) == id(layer_0_kv)
|
assert id(layer_1_kv) == id(layer_0_kv)
|
||||||
|
|
||||||
# check layer 1 added to kv cache group's layer names
|
# check layer 1 added to kv cache group's layer names
|
||||||
assert len(kv_cache_config.kv_cache_groups) == 1
|
assert len(kv_cache_config_after_init.kv_cache_groups) == 1
|
||||||
assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2
|
assert len(kv_cache_config_after_init.kv_cache_groups[0].layer_names) == 2
|
||||||
assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0
|
assert kv_cache_config_after_init.kv_cache_groups[0].layer_names[
|
||||||
assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1
|
0] == layer_0
|
||||||
|
assert kv_cache_config_after_init.kv_cache_groups[0].layer_names[
|
||||||
|
1] == layer_1
|
||||||
|
|
||||||
|
|
||||||
def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
|
def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
|
||||||
|
|||||||
@ -6,12 +6,13 @@ from typing import List, Optional
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.attention.backends.abstract import AttentionBackend
|
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||||
|
AttentionMetadata)
|
||||||
from vllm.attention.selector import get_attn_backend
|
from vllm.attention.selector import get_attn_backend
|
||||||
from vllm.config import CacheConfig, QuantizationConfig
|
from vllm.config import CacheConfig, QuantizationConfig
|
||||||
from vllm.v1.attention.backends.utils import (
|
from vllm.v1.attention.backends.utils import (
|
||||||
CommonAttentionMetadata, make_local_attention_virtual_batches,
|
CommonAttentionMetadata, make_local_attention_virtual_batches,
|
||||||
subclass_attention_backend, subclass_attention_metadata_builder)
|
subclass_attention_backend)
|
||||||
|
|
||||||
from ..layer import Attention
|
from ..layer import Attention
|
||||||
|
|
||||||
@ -24,21 +25,23 @@ def create_chunked_local_attention_backend(
|
|||||||
) -> type[AttentionBackend]:
|
) -> type[AttentionBackend]:
|
||||||
prefix = f"ChunkedLocalAttention_{attention_chunk_size}_{block_size}_"
|
prefix = f"ChunkedLocalAttention_{attention_chunk_size}_{block_size}_"
|
||||||
|
|
||||||
def build_preprocess_fn(cm: CommonAttentionMetadata):
|
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||||
return make_local_attention_virtual_batches(attention_chunk_size, cm,
|
|
||||||
block_size)
|
class ChunkedLocalAttentionBuilder(underlying_builder): # type: ignore
|
||||||
|
|
||||||
|
def build(self,
|
||||||
|
common_prefix_len: int,
|
||||||
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
|
fast_build: bool = False) -> AttentionMetadata:
|
||||||
|
common_attn_metadata = make_local_attention_virtual_batches(
|
||||||
|
attention_chunk_size, common_attn_metadata, block_size)
|
||||||
|
return super().build(common_prefix_len, common_attn_metadata,
|
||||||
|
fast_build)
|
||||||
|
|
||||||
# Dynamically create a new attention backend that wraps the
|
|
||||||
# underlying attention backend but applies
|
|
||||||
# `make_local_attention_virtual_batches` before calling `build(...)`
|
|
||||||
builder_cls = subclass_attention_metadata_builder(
|
|
||||||
name_prefix=prefix,
|
|
||||||
builder_cls=underlying_attn_backend.get_builder_cls(),
|
|
||||||
build_preprocess_fn=build_preprocess_fn)
|
|
||||||
attn_backend = subclass_attention_backend(
|
attn_backend = subclass_attention_backend(
|
||||||
name_prefix=prefix,
|
name_prefix=prefix,
|
||||||
attention_backend_cls=underlying_attn_backend,
|
attention_backend_cls=underlying_attn_backend,
|
||||||
builder_cls=builder_cls)
|
builder_cls=ChunkedLocalAttentionBuilder)
|
||||||
|
|
||||||
return attn_backend
|
return attn_backend
|
||||||
|
|
||||||
|
|||||||
86
vllm/attention/layers/encoder_only_attention.py
Normal file
86
vllm/attention/layers/encoder_only_attention.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import functools
|
||||||
|
from copy import copy
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import CacheConfig
|
||||||
|
|
||||||
|
from vllm import envs
|
||||||
|
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||||
|
AttentionMetadata, AttentionType)
|
||||||
|
from vllm.attention.layer import Attention
|
||||||
|
from vllm.attention.selector import get_attn_backend
|
||||||
|
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
|
||||||
|
subclass_attention_backend)
|
||||||
|
|
||||||
|
|
||||||
|
@functools.lru_cache
|
||||||
|
def create_encoder_only_attention_backend(
|
||||||
|
underlying_attn_backend: AttentionBackend, ) -> type[AttentionBackend]:
|
||||||
|
prefix = "EncoderOnlyAttention_"
|
||||||
|
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||||
|
|
||||||
|
class EncoderOnlyAttentionBuilder(underlying_builder): # type: ignore
|
||||||
|
|
||||||
|
def build(self,
|
||||||
|
common_prefix_len: int,
|
||||||
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
|
fast_build: bool = False) -> AttentionMetadata:
|
||||||
|
new_common_attn_metadata = copy(common_attn_metadata)
|
||||||
|
new_common_attn_metadata.causal = False
|
||||||
|
return super().build(common_prefix_len, new_common_attn_metadata,
|
||||||
|
fast_build)
|
||||||
|
|
||||||
|
attn_backend = subclass_attention_backend(
|
||||||
|
name_prefix=prefix,
|
||||||
|
attention_backend_cls=underlying_attn_backend,
|
||||||
|
builder_cls=EncoderOnlyAttentionBuilder)
|
||||||
|
|
||||||
|
return attn_backend
|
||||||
|
|
||||||
|
|
||||||
|
class EncoderOnlyAttention(Attention):
|
||||||
|
"""
|
||||||
|
Encoder attention is a special case that doesn't need a KV Cache.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
scale: float,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
attn_type: Optional[str] = None,
|
||||||
|
**kwargs):
|
||||||
|
dtype = torch.get_default_dtype()
|
||||||
|
|
||||||
|
if cache_config is not None:
|
||||||
|
kv_cache_dtype = cache_config.cache_dtype
|
||||||
|
block_size = cache_config.block_size
|
||||||
|
else:
|
||||||
|
kv_cache_dtype = "auto"
|
||||||
|
block_size = 16
|
||||||
|
|
||||||
|
if envs.VLLM_USE_V1:
|
||||||
|
underlying_attn_backend = get_attn_backend(head_size, dtype,
|
||||||
|
kv_cache_dtype,
|
||||||
|
block_size)
|
||||||
|
|
||||||
|
attn_backend = create_encoder_only_attention_backend(
|
||||||
|
underlying_attn_backend)
|
||||||
|
else:
|
||||||
|
# in v0 encoder only attention is handled inside the backends
|
||||||
|
attn_backend = None
|
||||||
|
|
||||||
|
if attn_type is not None:
|
||||||
|
assert attn_type == AttentionType.ENCODER_ONLY, \
|
||||||
|
"EncoderOnlyAttention only supports AttentionType.ENCODER_ONLY"
|
||||||
|
|
||||||
|
super().__init__(num_heads=num_heads,
|
||||||
|
head_size=head_size,
|
||||||
|
scale=scale,
|
||||||
|
cache_config=cache_config,
|
||||||
|
attn_backend=attn_backend,
|
||||||
|
attn_type=AttentionType.ENCODER_ONLY,
|
||||||
|
**kwargs)
|
||||||
@ -8,7 +8,7 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import BertConfig
|
from transformers import BertConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionType
|
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, PoolerConfig, VllmConfig
|
from vllm.config import CacheConfig, PoolerConfig, VllmConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
@ -239,14 +239,13 @@ class BertSelfAttention(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.qkv_proj")
|
prefix=f"{prefix}.qkv_proj")
|
||||||
|
|
||||||
self.attn = Attention(num_heads=self.num_heads,
|
self.attn = EncoderOnlyAttention(num_heads=self.num_heads,
|
||||||
head_size=self.head_dim,
|
head_size=self.head_dim,
|
||||||
scale=self.scaling,
|
scale=self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.attn",
|
prefix=f"{prefix}.attn")
|
||||||
attn_type=AttentionType.ENCODER_ONLY)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -7,7 +7,7 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionType
|
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||||
@ -119,14 +119,13 @@ class BertWithRopeAttention(nn.Module):
|
|||||||
|
|
||||||
self.rotary_emb = get_rope(**rotary_kwargs)
|
self.rotary_emb = get_rope(**rotary_kwargs)
|
||||||
|
|
||||||
self.attn = Attention(num_heads=self.num_heads,
|
self.attn = EncoderOnlyAttention(num_heads=self.num_heads,
|
||||||
head_size=self.head_dim,
|
head_size=self.head_dim,
|
||||||
scale=self.scaling,
|
scale=self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.attn",
|
prefix=f"{prefix}.attn")
|
||||||
attn_type=AttentionType.ENCODER_ONLY)
|
|
||||||
|
|
||||||
self.out_proj = RowParallelLinear(input_size=hidden_size,
|
self.out_proj = RowParallelLinear(input_size=hidden_size,
|
||||||
output_size=hidden_size,
|
output_size=hidden_size,
|
||||||
|
|||||||
@ -31,6 +31,7 @@ from torch import nn
|
|||||||
from transformers import LlamaConfig
|
from transformers import LlamaConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionType
|
from vllm.attention import Attention, AttentionType
|
||||||
|
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
@ -173,7 +174,10 @@ class LlamaAttention(nn.Module):
|
|||||||
if is_sliding:
|
if is_sliding:
|
||||||
sliding_window = config.sliding_window
|
sliding_window = config.sliding_window
|
||||||
|
|
||||||
self.attn = Attention(
|
attn_cls = (EncoderOnlyAttention
|
||||||
|
if attn_type == AttentionType.ENCODER_ONLY else Attention)
|
||||||
|
|
||||||
|
self.attn = attn_cls(
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.scaling,
|
self.scaling,
|
||||||
|
|||||||
@ -7,7 +7,7 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import ModernBertConfig
|
from transformers import ModernBertConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionType
|
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
@ -104,12 +104,12 @@ class ModernBertAttention(nn.Module):
|
|||||||
head_size=self.head_dim,
|
head_size=self.head_dim,
|
||||||
dim=self.head_dim,
|
dim=self.head_dim,
|
||||||
base=rope_theta)
|
base=rope_theta)
|
||||||
self.attn = Attention(self.num_heads,
|
self.attn = EncoderOnlyAttention(
|
||||||
self.head_dim,
|
self.num_heads,
|
||||||
self.scaling,
|
self.head_dim,
|
||||||
prefix=f"{layer_id}.attn",
|
self.scaling,
|
||||||
attn_type=AttentionType.ENCODER_ONLY,
|
prefix=f"{layer_id}.attn",
|
||||||
per_layer_sliding_window=sliding_window)
|
per_layer_sliding_window=sliding_window)
|
||||||
self.Wo = RowParallelLinear(config.hidden_size,
|
self.Wo = RowParallelLinear(config.hidden_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
bias=config.attention_bias)
|
bias=config.attention_bias)
|
||||||
|
|||||||
@ -32,6 +32,7 @@ from torch import nn
|
|||||||
from transformers import Qwen2Config
|
from transformers import Qwen2Config
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionType
|
from vllm.attention import Attention, AttentionType
|
||||||
|
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
@ -159,7 +160,9 @@ class Qwen2Attention(nn.Module):
|
|||||||
rope_scaling=rope_scaling,
|
rope_scaling=rope_scaling,
|
||||||
dual_chunk_attention_config=dual_chunk_attention_config,
|
dual_chunk_attention_config=dual_chunk_attention_config,
|
||||||
)
|
)
|
||||||
self.attn = Attention(
|
attn_cls = (EncoderOnlyAttention
|
||||||
|
if attn_type == AttentionType.ENCODER_ONLY else Attention)
|
||||||
|
self.attn = attn_cls(
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.scaling,
|
self.scaling,
|
||||||
|
|||||||
@ -5,8 +5,7 @@ import enum
|
|||||||
import functools
|
import functools
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from dataclasses import dataclass, make_dataclass
|
from dataclasses import dataclass, make_dataclass
|
||||||
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Generic, Optional,
|
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar
|
||||||
TypeVar)
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -543,35 +542,6 @@ def make_local_attention_virtual_batches(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def subclass_attention_metadata_builder(
|
|
||||||
name_prefix: str,
|
|
||||||
builder_cls: type[AttentionMetadataBuilder[M]],
|
|
||||||
build_preprocess_fn: Callable[[CommonAttentionMetadata],
|
|
||||||
CommonAttentionMetadata],
|
|
||||||
) -> type[AttentionMetadataBuilder[M]]:
|
|
||||||
"""
|
|
||||||
Return a new subclass of `builder_cls` whose .build(...) method
|
|
||||||
first calls build_preprocess_fn(common_attn_metadata) on the metadata.
|
|
||||||
"""
|
|
||||||
name: str = name_prefix + builder_cls.__name__ # type: ignore
|
|
||||||
|
|
||||||
def build(self,
|
|
||||||
common_prefix_len: int,
|
|
||||||
common_attn_metadata: CommonAttentionMetadata,
|
|
||||||
fast_build: bool = False):
|
|
||||||
return builder_cls.build(self, common_prefix_len,
|
|
||||||
build_preprocess_fn(common_attn_metadata),
|
|
||||||
fast_build)
|
|
||||||
|
|
||||||
Wrapped = type(
|
|
||||||
name,
|
|
||||||
(builder_cls, ), # inherit from the original
|
|
||||||
{
|
|
||||||
"build": build,
|
|
||||||
})
|
|
||||||
return Wrapped # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
def subclass_attention_backend(
|
def subclass_attention_backend(
|
||||||
name_prefix: str, attention_backend_cls: type[AttentionBackend],
|
name_prefix: str, attention_backend_cls: type[AttentionBackend],
|
||||||
builder_cls: type[AttentionMetadataBuilder[M]]
|
builder_cls: type[AttentionMetadataBuilder[M]]
|
||||||
|
|||||||
@ -203,6 +203,14 @@ class MambaSpec(KVCacheSpec):
|
|||||||
return self.page_size_bytes
|
return self.page_size_bytes
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class EncoderOnlyAttentionSpec(AttentionSpec):
|
||||||
|
|
||||||
|
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||||
|
# Encoder-only layers do not need KV cache
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class KVCacheTensor:
|
class KVCacheTensor:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -8,6 +8,7 @@ import time
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from copy import deepcopy
|
||||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -62,9 +63,10 @@ from vllm.v1.attention.backends.utils import (
|
|||||||
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
|
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
|
||||||
from vllm.v1.kv_cache_interface import (AttentionSpec,
|
from vllm.v1.kv_cache_interface import (AttentionSpec,
|
||||||
ChunkedLocalAttentionSpec,
|
ChunkedLocalAttentionSpec,
|
||||||
|
EncoderOnlyAttentionSpec,
|
||||||
FullAttentionSpec, KVCacheConfig,
|
FullAttentionSpec, KVCacheConfig,
|
||||||
KVCacheSpec, MambaSpec,
|
KVCacheGroupSpec, KVCacheSpec,
|
||||||
SlidingWindowSpec)
|
MambaSpec, SlidingWindowSpec)
|
||||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds,
|
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds,
|
||||||
LogprobsTensors, ModelRunnerOutput)
|
LogprobsTensors, ModelRunnerOutput)
|
||||||
from vllm.v1.pool.metadata import PoolingMetadata
|
from vllm.v1.pool.metadata import PoolingMetadata
|
||||||
@ -136,7 +138,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
cache_config.cache_dtype]
|
cache_config.cache_dtype]
|
||||||
|
|
||||||
self.is_pooling_model = model_config.pooler_config is not None
|
self.is_pooling_model = model_config.pooler_config is not None
|
||||||
self.is_encoder_only_model = False
|
|
||||||
self.is_multimodal_raw_input_supported = (
|
self.is_multimodal_raw_input_supported = (
|
||||||
model_config.is_multimodal_raw_input_supported)
|
model_config.is_multimodal_raw_input_supported)
|
||||||
self.max_model_len = model_config.max_model_len
|
self.max_model_len = model_config.max_model_len
|
||||||
@ -345,6 +346,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
|
|
||||||
self.reorder_batch_threshold: Optional[int] = None
|
self.reorder_batch_threshold: Optional[int] = None
|
||||||
|
|
||||||
|
# Attention layers that are only in the KVCacheConfig of the runner
|
||||||
|
# (e.g., KV sharing, encoder-only attention), but not in the
|
||||||
|
# KVCacheConfig of the scheduler.
|
||||||
|
self.runner_only_attn_layers: set[str] = set()
|
||||||
|
|
||||||
# Cached outputs.
|
# Cached outputs.
|
||||||
self._draft_token_ids: Optional[Union[list[list[int]],
|
self._draft_token_ids: Optional[Union[list[list[int]],
|
||||||
torch.Tensor]] = None
|
torch.Tensor]] = None
|
||||||
@ -834,23 +840,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
|
|
||||||
attn_metadata: dict[str, Any] = {}
|
attn_metadata: dict[str, Any] = {}
|
||||||
|
|
||||||
# Prepare encoder attention metadata separately
|
|
||||||
# (encoder layers are not in KV cache groups)
|
|
||||||
if self.is_encoder_only_model:
|
|
||||||
|
|
||||||
per_layer_metadata = \
|
|
||||||
self._build_encoder_only_attn_metadata(
|
|
||||||
scheduler_output)
|
|
||||||
|
|
||||||
# Add encoder attention metadata for all encoder layers
|
|
||||||
attention_layers = get_layers_from_vllm_config(
|
|
||||||
self.vllm_config, Attention)
|
|
||||||
for layer_name, attn_module in attention_layers.items():
|
|
||||||
if attn_module.attn_type == AttentionType.ENCODER_ONLY:
|
|
||||||
common_attn_metadata, encoder_attn_metadata =\
|
|
||||||
per_layer_metadata[layer_name]
|
|
||||||
attn_metadata[layer_name] = encoder_attn_metadata
|
|
||||||
|
|
||||||
# Used in the below loop.
|
# Used in the below loop.
|
||||||
query_start_loc_cpu = self.query_start_loc_cpu[:num_reqs + 1]
|
query_start_loc_cpu = self.query_start_loc_cpu[:num_reqs + 1]
|
||||||
seq_lens_cpu = self.seq_lens_cpu[:num_reqs]
|
seq_lens_cpu = self.seq_lens_cpu[:num_reqs]
|
||||||
@ -863,13 +852,33 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
||||||
self.kv_cache_config.kv_cache_groups):
|
self.kv_cache_config.kv_cache_groups):
|
||||||
|
|
||||||
blk_table = self.input_batch.block_table[kv_cache_group_id]
|
if isinstance(kv_cache_group_spec.kv_cache_spec,
|
||||||
blk_table_tensor = blk_table.get_device_tensor()[:num_reqs]
|
EncoderOnlyAttentionSpec):
|
||||||
slot_mapping = blk_table.slot_mapping[:total_num_scheduled_tokens]
|
# Encoder-only layers do not have KV cache, so we need to
|
||||||
|
# create a dummy block table and slot mapping for them.
|
||||||
|
blk_table_tensor = torch.zeros(
|
||||||
|
(num_reqs, 1),
|
||||||
|
dtype=torch.int32,
|
||||||
|
pin_memory=self.pin_memory,
|
||||||
|
device="cpu").to(self.device, non_blocking=True)
|
||||||
|
slot_mapping = torch.zeros((total_num_scheduled_tokens, ),
|
||||||
|
dtype=torch.int32,
|
||||||
|
pin_memory=self.pin_memory,
|
||||||
|
device="cpu").to(self.device,
|
||||||
|
non_blocking=True)
|
||||||
|
num_common_prefix_blocks = 0
|
||||||
|
else:
|
||||||
|
blk_table = self.input_batch.block_table[kv_cache_group_id]
|
||||||
|
blk_table_tensor = blk_table.get_device_tensor()[:num_reqs]
|
||||||
|
slot_mapping = blk_table.slot_mapping[:
|
||||||
|
total_num_scheduled_tokens]
|
||||||
|
|
||||||
# Fill unused with -1. Needed for reshape_and_cache in full cuda
|
# Fill unused with -1. Needed for reshape_and_cache in full cuda
|
||||||
# graph mode.
|
# graph mode.
|
||||||
blk_table.slot_mapping[total_num_scheduled_tokens:].fill_(-1)
|
blk_table.slot_mapping[total_num_scheduled_tokens:].fill_(-1)
|
||||||
|
num_common_prefix_blocks = (
|
||||||
|
scheduler_output.
|
||||||
|
num_common_prefix_blocks[kv_cache_group_id])
|
||||||
|
|
||||||
common_attn_metadata = CommonAttentionMetadata(
|
common_attn_metadata = CommonAttentionMetadata(
|
||||||
query_start_loc=query_start_loc,
|
query_start_loc=query_start_loc,
|
||||||
@ -897,8 +906,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
if self.cascade_attn_enabled:
|
if self.cascade_attn_enabled:
|
||||||
common_prefix_len = self._compute_cascade_attn_prefix_len(
|
common_prefix_len = self._compute_cascade_attn_prefix_len(
|
||||||
num_scheduled_tokens,
|
num_scheduled_tokens,
|
||||||
scheduler_output.
|
num_common_prefix_blocks,
|
||||||
num_common_prefix_blocks[kv_cache_group_id],
|
|
||||||
kv_cache_group_spec.kv_cache_spec,
|
kv_cache_group_spec.kv_cache_spec,
|
||||||
builder,
|
builder,
|
||||||
)
|
)
|
||||||
@ -2812,49 +2820,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# Calculate reorder batch threshold (if neeeded)
|
# Calculate reorder batch threshold (if neeeded)
|
||||||
self.calculate_reorder_batch_threshold()
|
self.calculate_reorder_batch_threshold()
|
||||||
|
|
||||||
if len(self.attn_groups) > 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Check if model is encoder-only
|
|
||||||
block_size = self.vllm_config.cache_config.block_size
|
|
||||||
use_mla = self.vllm_config.model_config.use_mla
|
|
||||||
attn_specs: dict[AttentionSpec, list[str]] = defaultdict(list)
|
|
||||||
for layer_name, attn_module in attn_layers.items():
|
|
||||||
|
|
||||||
if attn_module.attn_type == AttentionType.ENCODER_ONLY:
|
|
||||||
if attn_module.sliding_window is None:
|
|
||||||
attn_spec: AttentionSpec = FullAttentionSpec(
|
|
||||||
block_size=block_size,
|
|
||||||
num_kv_heads=attn_module.num_kv_heads,
|
|
||||||
head_size=attn_module.head_size,
|
|
||||||
dtype=self.kv_cache_dtype,
|
|
||||||
use_mla=use_mla)
|
|
||||||
else:
|
|
||||||
attn_spec = SlidingWindowSpec(
|
|
||||||
block_size=block_size,
|
|
||||||
num_kv_heads=attn_module.num_kv_heads,
|
|
||||||
head_size=attn_module.head_size,
|
|
||||||
dtype=self.kv_cache_dtype,
|
|
||||||
sliding_window=attn_module.sliding_window,
|
|
||||||
use_mla=use_mla)
|
|
||||||
attn_specs[attn_spec].append(layer_name)
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise ValueError("Expected only encoder-only layers")
|
|
||||||
|
|
||||||
if len(attn_specs) > 0:
|
|
||||||
total_layers = 0
|
|
||||||
for attn_spec, layer_names in attn_specs.items():
|
|
||||||
|
|
||||||
attn_backends = get_attn_backends_for_layers(layer_names)
|
|
||||||
total_layers += len(layer_names)
|
|
||||||
|
|
||||||
self.attn_groups.append(
|
|
||||||
create_attn_groups(attn_backends, attn_spec))
|
|
||||||
assert total_layers == len(attn_layers), \
|
|
||||||
"All or none of the layers are expected to be encoder-only"
|
|
||||||
self.is_encoder_only_model = True
|
|
||||||
|
|
||||||
def initialize_cudagraph_capture(self) -> None:
|
def initialize_cudagraph_capture(self) -> None:
|
||||||
min_cg_support = AttentionCGSupport.ALWAYS
|
min_cg_support = AttentionCGSupport.ALWAYS
|
||||||
min_cg_builder_name = None
|
min_cg_builder_name = None
|
||||||
@ -3002,7 +2967,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
|
|
||||||
layer_names = set()
|
layer_names = set()
|
||||||
for group in kv_cache_config.kv_cache_groups:
|
for group in kv_cache_config.kv_cache_groups:
|
||||||
layer_names.update(group.layer_names)
|
for layer_name in group.layer_names:
|
||||||
|
if layer_name in self.runner_only_attn_layers:
|
||||||
|
continue
|
||||||
|
layer_names.add(layer_name)
|
||||||
assert layer_names == set(kv_cache_raw_tensors.keys(
|
assert layer_names == set(kv_cache_raw_tensors.keys(
|
||||||
)), "Some layers are not correctly initialized"
|
)), "Some layers are not correctly initialized"
|
||||||
return kv_cache_raw_tensors
|
return kv_cache_raw_tensors
|
||||||
@ -3040,6 +3008,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
for kv_cache_spec, group in self._kv_cache_spec_attn_group_iterator():
|
for kv_cache_spec, group in self._kv_cache_spec_attn_group_iterator():
|
||||||
attn_backend = group.backend
|
attn_backend = group.backend
|
||||||
for layer_name in group.layer_names:
|
for layer_name in group.layer_names:
|
||||||
|
if layer_name in self.runner_only_attn_layers:
|
||||||
|
continue
|
||||||
raw_tensor = kv_cache_raw_tensors[layer_name]
|
raw_tensor = kv_cache_raw_tensors[layer_name]
|
||||||
assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
|
assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
|
||||||
num_blocks = (raw_tensor.numel() //
|
num_blocks = (raw_tensor.numel() //
|
||||||
@ -3161,6 +3131,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
kv_cache_config.kv_cache_groups,
|
kv_cache_config.kv_cache_groups,
|
||||||
kv_caches,
|
kv_caches,
|
||||||
self.attn_groups,
|
self.attn_groups,
|
||||||
|
self.runner_only_attn_layers,
|
||||||
)
|
)
|
||||||
attn_layers = get_layers_from_vllm_config(self.vllm_config,
|
attn_layers = get_layers_from_vllm_config(self.vllm_config,
|
||||||
Attention)
|
Attention)
|
||||||
@ -3185,8 +3156,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
kv_cache_config: Configuration for the KV cache, including the KV
|
kv_cache_config: Configuration for the KV cache, including the KV
|
||||||
cache size of each layer
|
cache size of each layer
|
||||||
"""
|
"""
|
||||||
|
kv_cache_config = deepcopy(kv_cache_config)
|
||||||
self.kv_cache_config = kv_cache_config
|
self.kv_cache_config = kv_cache_config
|
||||||
self.may_reinitialize_input_batch(kv_cache_config)
|
self.may_reinitialize_input_batch(kv_cache_config)
|
||||||
|
self.may_add_encoder_only_layers_to_kv_cache_config()
|
||||||
self.initialize_attn_backend(kv_cache_config)
|
self.initialize_attn_backend(kv_cache_config)
|
||||||
kv_caches = self.initialize_kv_cache_tensors(kv_cache_config)
|
kv_caches = self.initialize_kv_cache_tensors(kv_cache_config)
|
||||||
|
|
||||||
@ -3199,6 +3172,33 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
if has_kv_transfer_group():
|
if has_kv_transfer_group():
|
||||||
get_kv_transfer_group().register_kv_caches(kv_caches)
|
get_kv_transfer_group().register_kv_caches(kv_caches)
|
||||||
|
|
||||||
|
def may_add_encoder_only_layers_to_kv_cache_config(self) -> None:
|
||||||
|
"""
|
||||||
|
Add encoder-only layers to the KV cache config.
|
||||||
|
"""
|
||||||
|
block_size = self.vllm_config.cache_config.block_size
|
||||||
|
use_mla = self.vllm_config.model_config.use_mla
|
||||||
|
encoder_only_attn_specs: dict[AttentionSpec,
|
||||||
|
list[str]] = defaultdict(list)
|
||||||
|
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
|
||||||
|
for layer_name, attn_module in attn_layers.items():
|
||||||
|
if attn_module.attn_type == AttentionType.ENCODER_ONLY:
|
||||||
|
attn_spec = EncoderOnlyAttentionSpec(
|
||||||
|
block_size=block_size,
|
||||||
|
num_kv_heads=attn_module.num_kv_heads,
|
||||||
|
head_size=attn_module.head_size,
|
||||||
|
dtype=self.kv_cache_dtype,
|
||||||
|
use_mla=use_mla)
|
||||||
|
encoder_only_attn_specs[attn_spec].append(layer_name)
|
||||||
|
self.runner_only_attn_layers.add(layer_name)
|
||||||
|
if len(encoder_only_attn_specs) > 0:
|
||||||
|
assert len(
|
||||||
|
encoder_only_attn_specs
|
||||||
|
) == 1, "Only support one encoder-only attention spec now"
|
||||||
|
spec, layer_names = encoder_only_attn_specs.popitem()
|
||||||
|
self.kv_cache_config.kv_cache_groups.append(
|
||||||
|
KVCacheGroupSpec(layer_names=layer_names, kv_cache_spec=spec))
|
||||||
|
|
||||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||||||
"""
|
"""
|
||||||
Generates the KVCacheSpec by parsing the kv cache format from each
|
Generates the KVCacheSpec by parsing the kv cache format from each
|
||||||
@ -3287,70 +3287,3 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
mamba_type=mamba_module.mamba_type)
|
mamba_type=mamba_module.mamba_type)
|
||||||
|
|
||||||
return kv_cache_spec
|
return kv_cache_spec
|
||||||
|
|
||||||
def _build_encoder_only_attn_metadata(
|
|
||||||
self, scheduler_output: "SchedulerOutput") -> \
|
|
||||||
dict[str, tuple[CommonAttentionMetadata, Any]]:
|
|
||||||
"""Prepare encoder attention metadata for encoder-only models.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
scheduler_output: Scheduler output
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict[str, Any]: Encoder attention metadata
|
|
||||||
"""
|
|
||||||
num_reqs = self.input_batch.num_reqs
|
|
||||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
|
||||||
|
|
||||||
# Get the number of scheduled tokens for each request.
|
|
||||||
req_ids = self.input_batch.req_ids
|
|
||||||
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
|
|
||||||
max_num_scheduled_tokens = max(tokens)
|
|
||||||
|
|
||||||
dummy_block_table = torch.zeros((num_reqs, 1),
|
|
||||||
dtype=torch.int32,
|
|
||||||
pin_memory=self.pin_memory,
|
|
||||||
device="cpu").to(self.device,
|
|
||||||
non_blocking=True)
|
|
||||||
dummy_slot_mapping = torch.zeros((total_num_scheduled_tokens, ),
|
|
||||||
dtype=torch.int32,
|
|
||||||
pin_memory=self.pin_memory,
|
|
||||||
device="cpu").to(self.device,
|
|
||||||
non_blocking=True)
|
|
||||||
|
|
||||||
group_metadata = dict[str, tuple[CommonAttentionMetadata, Any]]()
|
|
||||||
|
|
||||||
for attn_group_list in self.attn_groups:
|
|
||||||
|
|
||||||
assert len(attn_group_list) == 1
|
|
||||||
attn_group = attn_group_list[0]
|
|
||||||
|
|
||||||
# Use the first attention metadata builder
|
|
||||||
# to create encoder attention metadata
|
|
||||||
builder = attn_group.metadata_builder
|
|
||||||
|
|
||||||
common_metadata = CommonAttentionMetadata(
|
|
||||||
query_start_loc=self.query_start_loc[:num_reqs + 1],
|
|
||||||
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
|
|
||||||
seq_lens=self.seq_lens[:num_reqs],
|
|
||||||
seq_lens_cpu=self.seq_lens_cpu[:num_reqs],
|
|
||||||
num_computed_tokens_cpu=self.input_batch.
|
|
||||||
num_computed_tokens_cpu_tensor[:num_reqs],
|
|
||||||
num_reqs=num_reqs,
|
|
||||||
num_actual_tokens=total_num_scheduled_tokens,
|
|
||||||
max_query_len=max_num_scheduled_tokens,
|
|
||||||
max_seq_len=self.seq_lens_cpu[:num_reqs].max().item(),
|
|
||||||
block_table_tensor=dummy_block_table,
|
|
||||||
slot_mapping=dummy_slot_mapping,
|
|
||||||
causal=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
metadata = builder.build(
|
|
||||||
common_prefix_len=0, # No cascade for encoder
|
|
||||||
common_attn_metadata=common_metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
for layer_name in attn_group.layer_names:
|
|
||||||
group_metadata[layer_name] = (common_metadata, metadata)
|
|
||||||
|
|
||||||
return group_metadata
|
|
||||||
|
|||||||
@ -204,6 +204,7 @@ def initialize_kv_cache_for_kv_sharing(
|
|||||||
kv_caches: dict[str, torch.Tensor],
|
kv_caches: dict[str, torch.Tensor],
|
||||||
# Optional for now to avoid breaking TPU
|
# Optional for now to avoid breaking TPU
|
||||||
attn_groups: Optional[list[list[AttentionGroup]]] = None,
|
attn_groups: Optional[list[list[AttentionGroup]]] = None,
|
||||||
|
runner_only_attn_layers: Optional[set[str]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Sets up KV cache sharing by reusing the allocated KV caches in `kv_caches`
|
Sets up KV cache sharing by reusing the allocated KV caches in `kv_caches`
|
||||||
@ -250,6 +251,9 @@ def initialize_kv_cache_for_kv_sharing(
|
|||||||
attn_groups[kv_cache_group_idx][attn_group_idx].layer_names.append(
|
attn_groups[kv_cache_group_idx][attn_group_idx].layer_names.append(
|
||||||
layer_name)
|
layer_name)
|
||||||
|
|
||||||
|
if runner_only_attn_layers is not None:
|
||||||
|
runner_only_attn_layers.add(layer_name)
|
||||||
|
|
||||||
|
|
||||||
def bind_kv_cache(
|
def bind_kv_cache(
|
||||||
kv_caches: dict[str, torch.Tensor],
|
kv_caches: dict[str, torch.Tensor],
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user