[Attention] Unify mamba and attention backend selection (#23171)

Signed-off-by: Ayush Satyam <ayushsatyam146@gmail.com>
This commit is contained in:
Ayush Satyam 2025-08-25 14:39:36 +05:30 committed by GitHub
parent d0a4a3f645
commit 5c4b6e66fe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 186 additions and 72 deletions

View File

@ -0,0 +1,104 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for mamba attention backend selectors."""
from types import SimpleNamespace
import pytest
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.short_conv import ShortConv
from vllm.model_executor.models.minimax_text_01 import (
MiniMaxText01LinearAttention)
from vllm.v1.attention.backends.linear_attn import LinearAttentionBackend
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend
from vllm.v1.attention.backends.short_conv_attn import (
ShortConvAttentionBackend)
@pytest.mark.parametrize(
"layer_class, init_kwargs, expected_backend, expected_mamba_type", [
(
MambaMixer,
dict(
hidden_size=128,
ssm_state_size=16,
conv_kernel_size=4,
intermediate_size=256,
time_step_rank=8,
use_conv_bias=True,
use_bias=False,
use_rms_norm=True,
),
Mamba1AttentionBackend,
"mamba1",
),
(
MambaMixer2,
dict(
hidden_size=128,
ssm_state_size=16,
conv_kernel_size=4,
intermediate_size=256,
use_conv_bias=True,
use_bias=False,
n_groups=1,
num_heads=8,
head_dim=32,
),
Mamba2AttentionBackend,
"mamba2",
),
(
MiniMaxText01LinearAttention,
dict(
hidden_size=128,
hidden_inner_size=256,
num_heads=8,
head_dim=32,
max_position=2048,
block_size=64,
num_hidden_layer=12,
layer_idx=0,
linear_layer_idx=0,
),
LinearAttentionBackend,
"linear_attention",
),
(
ShortConv,
dict(
config=SimpleNamespace(conv_L_cache=32, conv_bias=True),
dim=128,
layer_idx=0,
),
ShortConvAttentionBackend,
"short_conv",
),
])
def test_mamba_layers_get_attn_backend(dist_init, layer_class, init_kwargs,
expected_backend, expected_mamba_type):
"""Test that Mamba-like layers return the correct attention backend."""
layer = layer_class(**init_kwargs)
backend_class = layer.get_attn_backend()
assert backend_class is expected_backend
assert layer.mamba_type == expected_mamba_type
@pytest.mark.parametrize("layer_class,expected_backend,expected_mamba_type", [
(MambaMixer, Mamba1AttentionBackend, "mamba1"),
(MambaMixer2, Mamba2AttentionBackend, "mamba2"),
(MiniMaxText01LinearAttention, LinearAttentionBackend, "linear_attention"),
(ShortConv, ShortConvAttentionBackend, "short_conv"),
])
def test_mamba_layers_have_unified_interface(layer_class, expected_backend,
expected_mamba_type):
"""Test that all Mamba layers have the unified get_attn_backend
interface."""
assert hasattr(layer_class, 'get_attn_backend'), (
f"{layer_class.__name__} should have get_attn_backend method")
assert hasattr(layer_class, 'mamba_type'), (
f"{layer_class.__name__} should have mamba_type property")

View File

@ -1,25 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for mamba attention backend selectors."""
import pytest
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend
from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend
@pytest.mark.parametrize(argnames=["mamba_type", "expected_backend"],
argvalues=[("mamba2", Mamba2AttentionBackend)])
def test_get_mamba_attn_backend_mamba2(mamba_type, expected_backend):
backend_class = get_mamba_attn_backend(mamba_type)
assert backend_class is expected_backend
def test_get_mamba_attn_backend_unsupported():
unsupported_types = ["mamba", ""]
for mamba_type in unsupported_types:
err_message = f"Mamba Attention type {mamba_type} is not supported yet."
with pytest.raises(NotImplementedError, match=err_message):
get_mamba_attn_backend(mamba_type)

View File

@ -18,6 +18,7 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
is_v1_kv_transfer_group)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
@ -54,7 +55,7 @@ def check_xformers_availability():
return USE_XFORMERS_OPS
class Attention(nn.Module):
class Attention(nn.Module, AttentionLayerBase):
"""Attention layer.
This class takes query, key, and value tensors as input. The input tensors

View File

@ -0,0 +1,23 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Base class for attention-like layers."""
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
class AttentionLayerBase(ABC):
"""
Base class for attention-like layers (Attention, Mamba, etc.)
that support the v1 engine.
This provides a common interface for getting attention backends
from different layer types.
"""
@abstractmethod
def get_attn_backend(self) -> type["AttentionBackend"]:
"""Get the attention backend class for this layer."""
pass

View File

@ -1,12 +1,18 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from abc import abstractmethod
from collections.abc import Iterable
from typing import TYPE_CHECKING
import torch
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
class MambaBase(ABC):
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
class MambaBase(AttentionLayerBase):
"""
Base class for Mamba-like layers which support the v1 engine.
Inherit from this class if you implement a custom layer.
@ -32,3 +38,8 @@ class MambaBase(ABC):
@abstractmethod
def mamba_type(self) -> str:
pass
@abstractmethod
def get_attn_backend(self) -> type["AttentionBackend"]:
"""Get the attention backend class for this Mamba layer."""
pass

View File

@ -1,7 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import NamedTuple, Optional
from typing import TYPE_CHECKING, NamedTuple, Optional
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
import torch
from torch import nn
@ -404,6 +407,11 @@ class MambaMixer(MambaBase, CustomOp):
def mamba_type(self) -> str:
return "mamba1"
def get_attn_backend(self) -> type["AttentionBackend"]:
from vllm.v1.attention.backends.mamba1_attn import (
Mamba1AttentionBackend)
return Mamba1AttentionBackend
def _time_proj_bias(self) -> Optional[torch.Tensor]:
if hasattr(self.dt_proj, "bias") and self.dt_proj.bias is not None:
return self.dt_proj.bias.float()

View File

@ -1,7 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, Union
from typing import TYPE_CHECKING, Optional, Union
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
import torch
from torch import nn
@ -758,6 +761,11 @@ class MambaMixer2(MambaBase, CustomOp):
def mamba_type(self) -> str:
return "mamba2"
def get_attn_backend(self) -> type["AttentionBackend"]:
from vllm.v1.attention.backends.mamba2_attn import (
Mamba2AttentionBackend)
return Mamba2AttentionBackend
def mamba_mixer2(
hidden_states: torch.Tensor,

View File

@ -1,7 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
import torch
@ -232,6 +235,11 @@ class ShortConv(MambaBase, CustomOp):
def mamba_type(self) -> str:
return "short_conv"
def get_attn_backend(self) -> type["AttentionBackend"]:
from vllm.v1.attention.backends.short_conv_attn import (
ShortConvAttentionBackend)
return ShortConvAttentionBackend
def short_conv(
hidden_states: torch.Tensor,

View File

@ -4,7 +4,10 @@
import copy
import math
from collections.abc import Iterable
from typing import Optional, Union
from typing import TYPE_CHECKING, Optional, Union
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
import regex as re
import torch
@ -339,6 +342,11 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
def mamba_type(self) -> str:
return "linear_attention"
def get_attn_backend(self) -> type["AttentionBackend"]:
from vllm.v1.attention.backends.linear_attn import (
LinearAttentionBackend)
return LinearAttentionBackend
def get_state_dtype(self) -> tuple[torch.dtype]:
return MambaStateDtypeCalculator.linear_attention_state_dtype(
self.model_config.dtype,

View File

@ -1,22 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.attention.backends.abstract import AttentionBackend
from vllm.v1.attention.backends.linear_attn import LinearAttentionBackend
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend
from vllm.v1.attention.backends.short_conv_attn import (
ShortConvAttentionBackend)
def get_mamba_attn_backend(mamba_type: str) -> type[AttentionBackend]:
if mamba_type == "mamba1":
return Mamba1AttentionBackend
if mamba_type == "mamba2":
return Mamba2AttentionBackend
if mamba_type == "linear_attention":
return LinearAttentionBackend
if mamba_type == "short_conv":
return ShortConvAttentionBackend
raise NotImplementedError(f"Mamba Attention type {mamba_type} is not "
"supported yet.")

View File

@ -35,7 +35,8 @@ from vllm.distributed.parallel_state import (
from vllm.forward_context import (BatchDescriptor, DPMetadata,
set_forward_context)
from vllm.logger import init_logger
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaBase
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
from vllm.model_executor.models.interfaces import (is_mixture_of_experts,
@ -55,7 +56,6 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
GiB_bytes, LazyLoader, cdiv, check_use_alibi,
get_dtype_size, is_pin_memory_available, round_up,
supports_dynamo)
from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend
from vllm.v1.attention.backends.utils import (
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
make_kv_sharing_fast_prefill_attention_metadata,
@ -2747,11 +2747,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
"""
assert len(self.attn_groups) == 0, \
"Attention backends are already initialized"
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
def get_attn_backends_for_layers(
layer_names: list[str]
) -> dict[type[AttentionBackend], list[str]]:
layers = get_layers_from_vllm_config(self.vllm_config,
AttentionLayerBase,
layer_names)
attn_backends = {}
attn_backend_layers = defaultdict(list)
# Dedupe based on full class name; this is a bit safer than using
@ -2760,7 +2762,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# they are cached correctly, there will be different objects per
# layer.
for layer_name in layer_names:
attn_backend = attn_layers[layer_name].get_attn_backend()
attn_backend = layers[layer_name].get_attn_backend()
key = attn_backend.full_cls_name()
attn_backends[key] = attn_backend
attn_backend_layers[key].append(layer_name)
@ -2789,20 +2791,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
if isinstance(kv_cache_spec, AttentionSpec):
attn_backends = get_attn_backends_for_layers(
kv_cache_group_spec.layer_names)
# TODO(lucas): move `get_mamba_attn_backend` into the mamba
# layers like above
elif isinstance(kv_cache_spec, MambaSpec):
attn_backends = {
get_mamba_attn_backend(kv_cache_spec.mamba_type):
kv_cache_group_spec.layer_names
}
else:
raise ValueError(
f"Unknown KV cache spec type: {type(kv_cache_spec)}")
attn_backends = get_attn_backends_for_layers(
kv_cache_group_spec.layer_names)
self.attn_groups.append(
create_attn_groups(attn_backends, kv_cache_spec))