diff --git a/tests/v1/attention/test_attention_backends_selection.py b/tests/v1/attention/test_attention_backends_selection.py new file mode 100644 index 0000000000000..59e5628149468 --- /dev/null +++ b/tests/v1/attention/test_attention_backends_selection.py @@ -0,0 +1,104 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for mamba attention backend selectors.""" + +from types import SimpleNamespace + +import pytest + +from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer +from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 +from vllm.model_executor.layers.mamba.short_conv import ShortConv +from vllm.model_executor.models.minimax_text_01 import ( + MiniMaxText01LinearAttention) +from vllm.v1.attention.backends.linear_attn import LinearAttentionBackend +from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend +from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend +from vllm.v1.attention.backends.short_conv_attn import ( + ShortConvAttentionBackend) + + +@pytest.mark.parametrize( + "layer_class, init_kwargs, expected_backend, expected_mamba_type", [ + ( + MambaMixer, + dict( + hidden_size=128, + ssm_state_size=16, + conv_kernel_size=4, + intermediate_size=256, + time_step_rank=8, + use_conv_bias=True, + use_bias=False, + use_rms_norm=True, + ), + Mamba1AttentionBackend, + "mamba1", + ), + ( + MambaMixer2, + dict( + hidden_size=128, + ssm_state_size=16, + conv_kernel_size=4, + intermediate_size=256, + use_conv_bias=True, + use_bias=False, + n_groups=1, + num_heads=8, + head_dim=32, + ), + Mamba2AttentionBackend, + "mamba2", + ), + ( + MiniMaxText01LinearAttention, + dict( + hidden_size=128, + hidden_inner_size=256, + num_heads=8, + head_dim=32, + max_position=2048, + block_size=64, + num_hidden_layer=12, + layer_idx=0, + linear_layer_idx=0, + ), + LinearAttentionBackend, + "linear_attention", + ), + ( + ShortConv, + dict( + config=SimpleNamespace(conv_L_cache=32, conv_bias=True), + dim=128, + layer_idx=0, + ), + ShortConvAttentionBackend, + "short_conv", + ), + ]) +def test_mamba_layers_get_attn_backend(dist_init, layer_class, init_kwargs, + expected_backend, expected_mamba_type): + """Test that Mamba-like layers return the correct attention backend.""" + layer = layer_class(**init_kwargs) + + backend_class = layer.get_attn_backend() + assert backend_class is expected_backend + assert layer.mamba_type == expected_mamba_type + + +@pytest.mark.parametrize("layer_class,expected_backend,expected_mamba_type", [ + (MambaMixer, Mamba1AttentionBackend, "mamba1"), + (MambaMixer2, Mamba2AttentionBackend, "mamba2"), + (MiniMaxText01LinearAttention, LinearAttentionBackend, "linear_attention"), + (ShortConv, ShortConvAttentionBackend, "short_conv"), +]) +def test_mamba_layers_have_unified_interface(layer_class, expected_backend, + expected_mamba_type): + """Test that all Mamba layers have the unified get_attn_backend + interface.""" + assert hasattr(layer_class, 'get_attn_backend'), ( + f"{layer_class.__name__} should have get_attn_backend method") + assert hasattr(layer_class, 'mamba_type'), ( + f"{layer_class.__name__} should have mamba_type property") diff --git a/tests/v1/attention/test_mamba_selectors.py b/tests/v1/attention/test_mamba_selectors.py deleted file mode 100644 index 4245b50c71310..0000000000000 --- a/tests/v1/attention/test_mamba_selectors.py +++ /dev/null @@ -1,25 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Tests for mamba attention backend selectors.""" - -import pytest - -from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend -from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend - - -@pytest.mark.parametrize(argnames=["mamba_type", "expected_backend"], - argvalues=[("mamba2", Mamba2AttentionBackend)]) -def test_get_mamba_attn_backend_mamba2(mamba_type, expected_backend): - backend_class = get_mamba_attn_backend(mamba_type) - - assert backend_class is expected_backend - - -def test_get_mamba_attn_backend_unsupported(): - unsupported_types = ["mamba", ""] - - for mamba_type in unsupported_types: - err_message = f"Mamba Attention type {mamba_type} is not supported yet." - with pytest.raises(NotImplementedError, match=err_message): - get_mamba_attn_backend(mamba_type) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 9fbead31782a9..2d288bcbe0c95 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -18,6 +18,7 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group, is_v1_kv_transfer_group) from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.linear import UnquantizedLinearMethod from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) @@ -54,7 +55,7 @@ def check_xformers_availability(): return USE_XFORMERS_OPS -class Attention(nn.Module): +class Attention(nn.Module, AttentionLayerBase): """Attention layer. This class takes query, key, and value tensors as input. The input tensors diff --git a/vllm/model_executor/layers/attention_layer_base.py b/vllm/model_executor/layers/attention_layer_base.py new file mode 100644 index 0000000000000..782818f55fbc2 --- /dev/null +++ b/vllm/model_executor/layers/attention_layer_base.py @@ -0,0 +1,23 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Base class for attention-like layers.""" +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend + + +class AttentionLayerBase(ABC): + """ + Base class for attention-like layers (Attention, Mamba, etc.) + that support the v1 engine. + + This provides a common interface for getting attention backends + from different layer types. + """ + + @abstractmethod + def get_attn_backend(self) -> type["AttentionBackend"]: + """Get the attention backend class for this layer.""" + pass diff --git a/vllm/model_executor/layers/mamba/abstract.py b/vllm/model_executor/layers/mamba/abstract.py index daebe46f6f771..a524e13405807 100644 --- a/vllm/model_executor/layers/mamba/abstract.py +++ b/vllm/model_executor/layers/mamba/abstract.py @@ -1,12 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from abc import ABC, abstractmethod +from abc import abstractmethod from collections.abc import Iterable +from typing import TYPE_CHECKING import torch +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase -class MambaBase(ABC): +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend + + +class MambaBase(AttentionLayerBase): """ Base class for Mamba-like layers which support the v1 engine. Inherit from this class if you implement a custom layer. @@ -32,3 +38,8 @@ class MambaBase(ABC): @abstractmethod def mamba_type(self) -> str: pass + + @abstractmethod + def get_attn_backend(self) -> type["AttentionBackend"]: + """Get the attention backend class for this Mamba layer.""" + pass diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index a24e72778b34b..e704bfd451bce 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -1,7 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import NamedTuple, Optional +from typing import TYPE_CHECKING, NamedTuple, Optional + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend import torch from torch import nn @@ -404,6 +407,11 @@ class MambaMixer(MambaBase, CustomOp): def mamba_type(self) -> str: return "mamba1" + def get_attn_backend(self) -> type["AttentionBackend"]: + from vllm.v1.attention.backends.mamba1_attn import ( + Mamba1AttentionBackend) + return Mamba1AttentionBackend + def _time_proj_bias(self) -> Optional[torch.Tensor]: if hasattr(self.dt_proj, "bias") and self.dt_proj.bias is not None: return self.dt_proj.bias.float() diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 743e520ec8ee1..bb3fdd38dbef3 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -1,7 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional, Union + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend import torch from torch import nn @@ -758,6 +761,11 @@ class MambaMixer2(MambaBase, CustomOp): def mamba_type(self) -> str: return "mamba2" + def get_attn_backend(self) -> type["AttentionBackend"]: + from vllm.v1.attention.backends.mamba2_attn import ( + Mamba2AttentionBackend) + return Mamba2AttentionBackend + def mamba_mixer2( hidden_states: torch.Tensor, diff --git a/vllm/model_executor/layers/mamba/short_conv.py b/vllm/model_executor/layers/mamba/short_conv.py index fead1e73e3450..335191a5c82c1 100644 --- a/vllm/model_executor/layers/mamba/short_conv.py +++ b/vllm/model_executor/layers/mamba/short_conv.py @@ -1,7 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend import torch @@ -232,6 +235,11 @@ class ShortConv(MambaBase, CustomOp): def mamba_type(self) -> str: return "short_conv" + def get_attn_backend(self) -> type["AttentionBackend"]: + from vllm.v1.attention.backends.short_conv_attn import ( + ShortConvAttentionBackend) + return ShortConvAttentionBackend + def short_conv( hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index 82e96844cd5f6..0e854bd7d913d 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -4,7 +4,10 @@ import copy import math from collections.abc import Iterable -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional, Union + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend import regex as re import torch @@ -339,6 +342,11 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase): def mamba_type(self) -> str: return "linear_attention" + def get_attn_backend(self) -> type["AttentionBackend"]: + from vllm.v1.attention.backends.linear_attn import ( + LinearAttentionBackend) + return LinearAttentionBackend + def get_state_dtype(self) -> tuple[torch.dtype]: return MambaStateDtypeCalculator.linear_attention_state_dtype( self.model_config.dtype, diff --git a/vllm/v1/attention/backends/mamba_selectors.py b/vllm/v1/attention/backends/mamba_selectors.py deleted file mode 100644 index fb1844508211b..0000000000000 --- a/vllm/v1/attention/backends/mamba_selectors.py +++ /dev/null @@ -1,22 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm.attention.backends.abstract import AttentionBackend -from vllm.v1.attention.backends.linear_attn import LinearAttentionBackend -from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend -from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend -from vllm.v1.attention.backends.short_conv_attn import ( - ShortConvAttentionBackend) - - -def get_mamba_attn_backend(mamba_type: str) -> type[AttentionBackend]: - if mamba_type == "mamba1": - return Mamba1AttentionBackend - if mamba_type == "mamba2": - return Mamba2AttentionBackend - if mamba_type == "linear_attention": - return LinearAttentionBackend - if mamba_type == "short_conv": - return ShortConvAttentionBackend - - raise NotImplementedError(f"Mamba Attention type {mamba_type} is not " - "supported yet.") diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d634cf280f7fd..73117c75b9af5 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -35,7 +35,8 @@ from vllm.distributed.parallel_state import ( from vllm.forward_context import (BatchDescriptor, DPMetadata, set_forward_context) from vllm.logger import init_logger -from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaBase +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader from vllm.model_executor.models.interfaces import (is_mixture_of_experts, @@ -55,7 +56,6 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, GiB_bytes, LazyLoader, cdiv, check_use_alibi, get_dtype_size, is_pin_memory_available, round_up, supports_dynamo) -from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, make_kv_sharing_fast_prefill_attention_metadata, @@ -2747,11 +2747,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): """ assert len(self.attn_groups) == 0, \ "Attention backends are already initialized" - attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) def get_attn_backends_for_layers( layer_names: list[str] ) -> dict[type[AttentionBackend], list[str]]: + layers = get_layers_from_vllm_config(self.vllm_config, + AttentionLayerBase, + layer_names) attn_backends = {} attn_backend_layers = defaultdict(list) # Dedupe based on full class name; this is a bit safer than using @@ -2760,7 +2762,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # they are cached correctly, there will be different objects per # layer. for layer_name in layer_names: - attn_backend = attn_layers[layer_name].get_attn_backend() + attn_backend = layers[layer_name].get_attn_backend() key = attn_backend.full_cls_name() attn_backends[key] = attn_backend attn_backend_layers[key].append(layer_name) @@ -2789,20 +2791,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): for kv_cache_group_spec in kv_cache_config.kv_cache_groups: kv_cache_spec = kv_cache_group_spec.kv_cache_spec - if isinstance(kv_cache_spec, AttentionSpec): - attn_backends = get_attn_backends_for_layers( - kv_cache_group_spec.layer_names) - # TODO(lucas): move `get_mamba_attn_backend` into the mamba - # layers like above - elif isinstance(kv_cache_spec, MambaSpec): - attn_backends = { - get_mamba_attn_backend(kv_cache_spec.mamba_type): - kv_cache_group_spec.layer_names - } - else: - raise ValueError( - f"Unknown KV cache spec type: {type(kv_cache_spec)}") - + attn_backends = get_attn_backends_for_layers( + kv_cache_group_spec.layer_names) self.attn_groups.append( create_attn_groups(attn_backends, kv_cache_spec))