mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 22:25:32 +08:00
[Attention] Unify mamba and attention backend selection (#23171)
Signed-off-by: Ayush Satyam <ayushsatyam146@gmail.com>
This commit is contained in:
parent
d0a4a3f645
commit
5c4b6e66fe
104
tests/v1/attention/test_attention_backends_selection.py
Normal file
104
tests/v1/attention/test_attention_backends_selection.py
Normal file
@ -0,0 +1,104 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for mamba attention backend selectors."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||
from vllm.model_executor.layers.mamba.short_conv import ShortConv
|
||||
from vllm.model_executor.models.minimax_text_01 import (
|
||||
MiniMaxText01LinearAttention)
|
||||
from vllm.v1.attention.backends.linear_attn import LinearAttentionBackend
|
||||
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend
|
||||
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend
|
||||
from vllm.v1.attention.backends.short_conv_attn import (
|
||||
ShortConvAttentionBackend)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"layer_class, init_kwargs, expected_backend, expected_mamba_type", [
|
||||
(
|
||||
MambaMixer,
|
||||
dict(
|
||||
hidden_size=128,
|
||||
ssm_state_size=16,
|
||||
conv_kernel_size=4,
|
||||
intermediate_size=256,
|
||||
time_step_rank=8,
|
||||
use_conv_bias=True,
|
||||
use_bias=False,
|
||||
use_rms_norm=True,
|
||||
),
|
||||
Mamba1AttentionBackend,
|
||||
"mamba1",
|
||||
),
|
||||
(
|
||||
MambaMixer2,
|
||||
dict(
|
||||
hidden_size=128,
|
||||
ssm_state_size=16,
|
||||
conv_kernel_size=4,
|
||||
intermediate_size=256,
|
||||
use_conv_bias=True,
|
||||
use_bias=False,
|
||||
n_groups=1,
|
||||
num_heads=8,
|
||||
head_dim=32,
|
||||
),
|
||||
Mamba2AttentionBackend,
|
||||
"mamba2",
|
||||
),
|
||||
(
|
||||
MiniMaxText01LinearAttention,
|
||||
dict(
|
||||
hidden_size=128,
|
||||
hidden_inner_size=256,
|
||||
num_heads=8,
|
||||
head_dim=32,
|
||||
max_position=2048,
|
||||
block_size=64,
|
||||
num_hidden_layer=12,
|
||||
layer_idx=0,
|
||||
linear_layer_idx=0,
|
||||
),
|
||||
LinearAttentionBackend,
|
||||
"linear_attention",
|
||||
),
|
||||
(
|
||||
ShortConv,
|
||||
dict(
|
||||
config=SimpleNamespace(conv_L_cache=32, conv_bias=True),
|
||||
dim=128,
|
||||
layer_idx=0,
|
||||
),
|
||||
ShortConvAttentionBackend,
|
||||
"short_conv",
|
||||
),
|
||||
])
|
||||
def test_mamba_layers_get_attn_backend(dist_init, layer_class, init_kwargs,
|
||||
expected_backend, expected_mamba_type):
|
||||
"""Test that Mamba-like layers return the correct attention backend."""
|
||||
layer = layer_class(**init_kwargs)
|
||||
|
||||
backend_class = layer.get_attn_backend()
|
||||
assert backend_class is expected_backend
|
||||
assert layer.mamba_type == expected_mamba_type
|
||||
|
||||
|
||||
@pytest.mark.parametrize("layer_class,expected_backend,expected_mamba_type", [
|
||||
(MambaMixer, Mamba1AttentionBackend, "mamba1"),
|
||||
(MambaMixer2, Mamba2AttentionBackend, "mamba2"),
|
||||
(MiniMaxText01LinearAttention, LinearAttentionBackend, "linear_attention"),
|
||||
(ShortConv, ShortConvAttentionBackend, "short_conv"),
|
||||
])
|
||||
def test_mamba_layers_have_unified_interface(layer_class, expected_backend,
|
||||
expected_mamba_type):
|
||||
"""Test that all Mamba layers have the unified get_attn_backend
|
||||
interface."""
|
||||
assert hasattr(layer_class, 'get_attn_backend'), (
|
||||
f"{layer_class.__name__} should have get_attn_backend method")
|
||||
assert hasattr(layer_class, 'mamba_type'), (
|
||||
f"{layer_class.__name__} should have mamba_type property")
|
||||
@ -1,25 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for mamba attention backend selectors."""
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend
|
||||
from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend
|
||||
|
||||
|
||||
@pytest.mark.parametrize(argnames=["mamba_type", "expected_backend"],
|
||||
argvalues=[("mamba2", Mamba2AttentionBackend)])
|
||||
def test_get_mamba_attn_backend_mamba2(mamba_type, expected_backend):
|
||||
backend_class = get_mamba_attn_backend(mamba_type)
|
||||
|
||||
assert backend_class is expected_backend
|
||||
|
||||
|
||||
def test_get_mamba_attn_backend_unsupported():
|
||||
unsupported_types = ["mamba", ""]
|
||||
|
||||
for mamba_type in unsupported_types:
|
||||
err_message = f"Mamba Attention type {mamba_type} is not supported yet."
|
||||
with pytest.raises(NotImplementedError, match=err_message):
|
||||
get_mamba_attn_backend(mamba_type)
|
||||
@ -18,6 +18,7 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
is_v1_kv_transfer_group)
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
@ -54,7 +55,7 @@ def check_xformers_availability():
|
||||
return USE_XFORMERS_OPS
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
class Attention(nn.Module, AttentionLayerBase):
|
||||
"""Attention layer.
|
||||
|
||||
This class takes query, key, and value tensors as input. The input tensors
|
||||
|
||||
23
vllm/model_executor/layers/attention_layer_base.py
Normal file
23
vllm/model_executor/layers/attention_layer_base.py
Normal file
@ -0,0 +1,23 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Base class for attention-like layers."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
|
||||
|
||||
class AttentionLayerBase(ABC):
|
||||
"""
|
||||
Base class for attention-like layers (Attention, Mamba, etc.)
|
||||
that support the v1 engine.
|
||||
|
||||
This provides a common interface for getting attention backends
|
||||
from different layer types.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_attn_backend(self) -> type["AttentionBackend"]:
|
||||
"""Get the attention backend class for this layer."""
|
||||
pass
|
||||
@ -1,12 +1,18 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
|
||||
class MambaBase(ABC):
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
|
||||
|
||||
class MambaBase(AttentionLayerBase):
|
||||
"""
|
||||
Base class for Mamba-like layers which support the v1 engine.
|
||||
Inherit from this class if you implement a custom layer.
|
||||
@ -32,3 +38,8 @@ class MambaBase(ABC):
|
||||
@abstractmethod
|
||||
def mamba_type(self) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_attn_backend(self) -> type["AttentionBackend"]:
|
||||
"""Get the attention backend class for this Mamba layer."""
|
||||
pass
|
||||
|
||||
@ -1,7 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import NamedTuple, Optional
|
||||
from typing import TYPE_CHECKING, NamedTuple, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -404,6 +407,11 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
def mamba_type(self) -> str:
|
||||
return "mamba1"
|
||||
|
||||
def get_attn_backend(self) -> type["AttentionBackend"]:
|
||||
from vllm.v1.attention.backends.mamba1_attn import (
|
||||
Mamba1AttentionBackend)
|
||||
return Mamba1AttentionBackend
|
||||
|
||||
def _time_proj_bias(self) -> Optional[torch.Tensor]:
|
||||
if hasattr(self.dt_proj, "bias") and self.dt_proj.bias is not None:
|
||||
return self.dt_proj.bias.float()
|
||||
|
||||
@ -1,7 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional, Union
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -758,6 +761,11 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
def mamba_type(self) -> str:
|
||||
return "mamba2"
|
||||
|
||||
def get_attn_backend(self) -> type["AttentionBackend"]:
|
||||
from vllm.v1.attention.backends.mamba2_attn import (
|
||||
Mamba2AttentionBackend)
|
||||
return Mamba2AttentionBackend
|
||||
|
||||
|
||||
def mamba_mixer2(
|
||||
hidden_states: torch.Tensor,
|
||||
|
||||
@ -1,7 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
|
||||
import torch
|
||||
|
||||
@ -232,6 +235,11 @@ class ShortConv(MambaBase, CustomOp):
|
||||
def mamba_type(self) -> str:
|
||||
return "short_conv"
|
||||
|
||||
def get_attn_backend(self) -> type["AttentionBackend"]:
|
||||
from vllm.v1.attention.backends.short_conv_attn import (
|
||||
ShortConvAttentionBackend)
|
||||
return ShortConvAttentionBackend
|
||||
|
||||
|
||||
def short_conv(
|
||||
hidden_states: torch.Tensor,
|
||||
|
||||
@ -4,7 +4,10 @@
|
||||
import copy
|
||||
import math
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional, Union
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
@ -339,6 +342,11 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
||||
def mamba_type(self) -> str:
|
||||
return "linear_attention"
|
||||
|
||||
def get_attn_backend(self) -> type["AttentionBackend"]:
|
||||
from vllm.v1.attention.backends.linear_attn import (
|
||||
LinearAttentionBackend)
|
||||
return LinearAttentionBackend
|
||||
|
||||
def get_state_dtype(self) -> tuple[torch.dtype]:
|
||||
return MambaStateDtypeCalculator.linear_attention_state_dtype(
|
||||
self.model_config.dtype,
|
||||
|
||||
@ -1,22 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.v1.attention.backends.linear_attn import LinearAttentionBackend
|
||||
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend
|
||||
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend
|
||||
from vllm.v1.attention.backends.short_conv_attn import (
|
||||
ShortConvAttentionBackend)
|
||||
|
||||
|
||||
def get_mamba_attn_backend(mamba_type: str) -> type[AttentionBackend]:
|
||||
if mamba_type == "mamba1":
|
||||
return Mamba1AttentionBackend
|
||||
if mamba_type == "mamba2":
|
||||
return Mamba2AttentionBackend
|
||||
if mamba_type == "linear_attention":
|
||||
return LinearAttentionBackend
|
||||
if mamba_type == "short_conv":
|
||||
return ShortConvAttentionBackend
|
||||
|
||||
raise NotImplementedError(f"Mamba Attention type {mamba_type} is not "
|
||||
"supported yet.")
|
||||
@ -35,7 +35,8 @@ from vllm.distributed.parallel_state import (
|
||||
from vllm.forward_context import (BatchDescriptor, DPMetadata,
|
||||
set_forward_context)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaBase
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
|
||||
from vllm.model_executor.models.interfaces import (is_mixture_of_experts,
|
||||
@ -55,7 +56,6 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||
GiB_bytes, LazyLoader, cdiv, check_use_alibi,
|
||||
get_dtype_size, is_pin_memory_available, round_up,
|
||||
supports_dynamo)
|
||||
from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
|
||||
make_kv_sharing_fast_prefill_attention_metadata,
|
||||
@ -2747,11 +2747,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
"""
|
||||
assert len(self.attn_groups) == 0, \
|
||||
"Attention backends are already initialized"
|
||||
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
|
||||
|
||||
def get_attn_backends_for_layers(
|
||||
layer_names: list[str]
|
||||
) -> dict[type[AttentionBackend], list[str]]:
|
||||
layers = get_layers_from_vllm_config(self.vllm_config,
|
||||
AttentionLayerBase,
|
||||
layer_names)
|
||||
attn_backends = {}
|
||||
attn_backend_layers = defaultdict(list)
|
||||
# Dedupe based on full class name; this is a bit safer than using
|
||||
@ -2760,7 +2762,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# they are cached correctly, there will be different objects per
|
||||
# layer.
|
||||
for layer_name in layer_names:
|
||||
attn_backend = attn_layers[layer_name].get_attn_backend()
|
||||
attn_backend = layers[layer_name].get_attn_backend()
|
||||
key = attn_backend.full_cls_name()
|
||||
attn_backends[key] = attn_backend
|
||||
attn_backend_layers[key].append(layer_name)
|
||||
@ -2789,20 +2791,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
|
||||
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
|
||||
if isinstance(kv_cache_spec, AttentionSpec):
|
||||
attn_backends = get_attn_backends_for_layers(
|
||||
kv_cache_group_spec.layer_names)
|
||||
# TODO(lucas): move `get_mamba_attn_backend` into the mamba
|
||||
# layers like above
|
||||
elif isinstance(kv_cache_spec, MambaSpec):
|
||||
attn_backends = {
|
||||
get_mamba_attn_backend(kv_cache_spec.mamba_type):
|
||||
kv_cache_group_spec.layer_names
|
||||
}
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown KV cache spec type: {type(kv_cache_spec)}")
|
||||
|
||||
attn_backends = get_attn_backends_for_layers(
|
||||
kv_cache_group_spec.layer_names)
|
||||
self.attn_groups.append(
|
||||
create_attn_groups(attn_backends, kv_cache_spec))
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user