[Model][Mamba] Add selector for mamba attention backend and make it pluggable for other device (#26487)

Signed-off-by: shen-shanshan <467638484@qq.com>
This commit is contained in:
Shanshan Shen 2025-11-20 00:24:55 +08:00 committed by GitHub
parent 48fc8b1e59
commit d44e9df7d4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 144 additions and 85 deletions

View File

@ -146,6 +146,7 @@ We use "mamba-like" to refer to layers that posses a state that is updated in-pl
For implementing new custom mamba-like layers, one should inherit from `MambaBase` and implement the methods `get_state_dtype`, `get_state_shape` to calculate the data types and state shapes at runtime, as well as `mamba_type` and `get_attn_backend`. For implementing new custom mamba-like layers, one should inherit from `MambaBase` and implement the methods `get_state_dtype`, `get_state_shape` to calculate the data types and state shapes at runtime, as well as `mamba_type` and `get_attn_backend`.
It is also necessary to implement the "attention meta-data" class which handles the meta-data that is common across all layers. It is also necessary to implement the "attention meta-data" class which handles the meta-data that is common across all layers.
Please see [`LinearAttentionMetadata`](../../../vllm/v1/attention/backends/linear_attn.py) or [`ShortConvAttentionMetadata`](../../../vllm/v1/attention/backends/short_conv_attn.py) for examples of this. Please see [`LinearAttentionMetadata`](../../../vllm/v1/attention/backends/linear_attn.py) or [`ShortConvAttentionMetadata`](../../../vllm/v1/attention/backends/short_conv_attn.py) for examples of this.
It is also worth noting that we should update `MAMBA_TYPE_TO_BACKEND_MAP` and `MambaAttentionBackendEnum` in [`registry.py`](../../../vllm/attention/backends/registry.py) when adding a new mamba backend.
Finally, if one wants to support torch compile and CUDA graphs, it necessary to wrap the call to the mamba-like layer inside a custom op and register it. Finally, if one wants to support torch compile and CUDA graphs, it necessary to wrap the call to the mamba-like layer inside a custom op and register it.
Please see the calls to `direct_register_custom_op` in [vllm/model_executor/models/minimax_text_01.py](../../../vllm/model_executor/models/minimax_text_01.py) or [vllm/model_executor/layers/mamba/short_conv.py](../../../vllm/model_executor/layers/mamba/short_conv.py) for examples of this. Please see the calls to `direct_register_custom_op` in [vllm/model_executor/models/minimax_text_01.py](../../../vllm/model_executor/models/minimax_text_01.py) or [vllm/model_executor/layers/mamba/short_conv.py](../../../vllm/model_executor/layers/mamba/short_conv.py) for examples of this.
The new custom op should then be added to the list `_attention_ops` in [vllm/config/compilation.py](../../../vllm/config/compilation.py) to ensure that piecewise CUDA graphs works as intended. The new custom op should then be added to the list `_attention_ops` in [vllm/config/compilation.py](../../../vllm/config/compilation.py) to ensure that piecewise CUDA graphs works as intended.

View File

@ -7,7 +7,7 @@ from vllm.attention.backends.abstract import (
AttentionType, AttentionType,
) )
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend, get_mamba_attn_backend
__all__ = [ __all__ = [
"Attention", "Attention",
@ -15,4 +15,5 @@ __all__ = [
"AttentionMetadata", "AttentionMetadata",
"AttentionType", "AttentionType",
"get_attn_backend", "get_attn_backend",
"get_mamba_attn_backend",
] ]

View File

@ -2,8 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention backend registry""" """Attention backend registry"""
import enum
from collections.abc import Callable from collections.abc import Callable
from enum import Enum, EnumMeta
from typing import TYPE_CHECKING, cast from typing import TYPE_CHECKING, cast
from vllm.logger import init_logger from vllm.logger import init_logger
@ -15,7 +15,7 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
class _AttentionBackendEnumMeta(enum.EnumMeta): class _AttentionBackendEnumMeta(EnumMeta):
"""Metaclass for AttentionBackendEnum to provide better error messages.""" """Metaclass for AttentionBackendEnum to provide better error messages."""
def __getitem__(cls, name: str): def __getitem__(cls, name: str):
@ -23,15 +23,15 @@ class _AttentionBackendEnumMeta(enum.EnumMeta):
try: try:
return super().__getitem__(name) return super().__getitem__(name)
except KeyError: except KeyError:
members = cast("dict[str, AttentionBackendEnum]", cls.__members__).values() members = cast("dict[str, Enum]", cls.__members__).keys()
valid_backends = ", ".join(m.name for m in members) valid_backends = ", ".join(members)
raise ValueError( raise ValueError(
f"Unknown attention backend: '{name}'. " f"Unknown attention backend: '{name}'. "
f"Valid options are: {valid_backends}" f"Valid options are: {valid_backends}"
) from None ) from None
class AttentionBackendEnum(enum.Enum, metaclass=_AttentionBackendEnumMeta): class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta):
"""Enumeration of all supported attention backends. """Enumeration of all supported attention backends.
The enum value is the default class path, but this can be overridden The enum value is the default class path, but this can be overridden
@ -83,7 +83,7 @@ class AttentionBackendEnum(enum.Enum, metaclass=_AttentionBackendEnumMeta):
Raises: Raises:
ValueError: If Backend.CUSTOM is used without being registered ValueError: If Backend.CUSTOM is used without being registered
""" """
path = _OVERRIDES.get(self, self.value) path = _ATTN_OVERRIDES.get(self, self.value)
if not path: if not path:
raise ValueError( raise ValueError(
f"Backend {self.name} must be registered before use. " f"Backend {self.name} must be registered before use. "
@ -111,18 +111,93 @@ class AttentionBackendEnum(enum.Enum, metaclass=_AttentionBackendEnumMeta):
Returns: Returns:
True if the backend has a registered override True if the backend has a registered override
""" """
return self in _OVERRIDES return self in _ATTN_OVERRIDES
def clear_override(self) -> None: def clear_override(self) -> None:
"""Clear any override for this backend, reverting to the default.""" """Clear any override for this backend, reverting to the default."""
_OVERRIDES.pop(self, None) _ATTN_OVERRIDES.pop(self, None)
_OVERRIDES: dict[AttentionBackendEnum, str] = {} class MambaAttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta):
"""Enumeration of all supported mamba attention backends.
The enum value is the default class path, but this can be overridden
at runtime using register_backend().
To get the actual backend class (respecting overrides), use:
backend.get_class()
"""
MAMBA1 = "vllm.v1.attention.backends.mamba1_attn.Mamba1AttentionBackend"
MAMBA2 = "vllm.v1.attention.backends.mamba2_attn.Mamba2AttentionBackend"
SHORT_CONV = "vllm.v1.attention.backends.short_conv_attn.ShortConvAttentionBackend"
LINEAR = "vllm.v1.attention.backends.linear_attn.LinearAttentionBackend"
GDN_ATTN = "vllm.v1.attention.backends.gdn_attn.GDNAttentionBackend"
# Placeholder for third-party/custom backends - must be registered before use
CUSTOM = ""
def get_path(self, include_classname: bool = True) -> str:
"""Get the class path for this backend (respects overrides).
Returns:
The fully qualified class path string
Raises:
ValueError: If Backend.CUSTOM is used without being registered
"""
path = _MAMBA_ATTN_OVERRIDES.get(self, self.value)
if not path:
raise ValueError(
f"Backend {self.name} must be registered before use. "
f"Use register_backend(Backend.{self.name}, 'your.module.YourClass')"
)
if not include_classname:
path = path.rsplit(".", 1)[0]
return path
def get_class(self) -> "type[AttentionBackend]":
"""Get the backend class (respects overrides).
Returns:
The backend class
Raises:
ImportError: If the backend class cannot be imported
ValueError: If Backend.CUSTOM is used without being registered
"""
return resolve_obj_by_qualname(self.get_path())
def is_overridden(self) -> bool:
"""Check if this backend has been overridden.
Returns:
True if the backend has a registered override
"""
return self in _MAMBA_ATTN_OVERRIDES
def clear_override(self) -> None:
"""Clear any override for this backend, reverting to the default."""
_MAMBA_ATTN_OVERRIDES.pop(self, None)
MAMBA_TYPE_TO_BACKEND_MAP = {
"mamba1": MambaAttentionBackendEnum.MAMBA1.name,
"mamba2": MambaAttentionBackendEnum.MAMBA2.name,
"short_conv": MambaAttentionBackendEnum.SHORT_CONV.name,
"linear_attention": MambaAttentionBackendEnum.LINEAR.name,
"gdn_attention": MambaAttentionBackendEnum.GDN_ATTN.name,
"custom": MambaAttentionBackendEnum.CUSTOM.name,
}
_ATTN_OVERRIDES: dict[AttentionBackendEnum, str] = {}
_MAMBA_ATTN_OVERRIDES: dict[MambaAttentionBackendEnum, str] = {}
def register_backend( def register_backend(
backend: AttentionBackendEnum, class_path: str | None = None backend: AttentionBackendEnum | MambaAttentionBackendEnum,
is_mamba: bool = False,
class_path: str | None = None,
) -> Callable[[type], type]: ) -> Callable[[type], type]:
"""Register or override a backend implementation. """Register or override a backend implementation.
@ -135,12 +210,17 @@ def register_backend(
Decorator function if class_path is None, otherwise a no-op Decorator function if class_path is None, otherwise a no-op
Examples: Examples:
# Override an existing backend # Override an existing attention backend
@register_backend(AttentionBackendEnum.FLASH_ATTN) @register_backend(AttentionBackendEnum.FLASH_ATTN)
class MyCustomFlashAttn: class MyCustomFlashAttn:
... ...
# Register a custom third-party backend # Override an existing mamba attention backend
@register_backend(MambaAttentionBackendEnum.LINEAR, is_mamba=True)
class MyCustomMambaAttn:
...
# Register a custom third-party attention backend
@register_backend(AttentionBackendEnum.CUSTOM) @register_backend(AttentionBackendEnum.CUSTOM)
class MyCustomBackend: class MyCustomBackend:
... ...
@ -153,11 +233,17 @@ def register_backend(
""" """
def decorator(cls: type) -> type: def decorator(cls: type) -> type:
_OVERRIDES[backend] = f"{cls.__module__}.{cls.__qualname__}" if is_mamba:
_MAMBA_ATTN_OVERRIDES[backend] = f"{cls.__module__}.{cls.__qualname__}" # type: ignore[index]
else:
_ATTN_OVERRIDES[backend] = f"{cls.__module__}.{cls.__qualname__}" # type: ignore[index]
return cls return cls
if class_path is not None: if class_path is not None:
_OVERRIDES[backend] = class_path if is_mamba:
_MAMBA_ATTN_OVERRIDES[backend] = class_path # type: ignore[index]
else:
_ATTN_OVERRIDES[backend] = class_path # type: ignore[index]
return lambda x: x return lambda x: x
return decorator return decorator

View File

@ -12,7 +12,11 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import (
MAMBA_TYPE_TO_BACKEND_MAP,
AttentionBackendEnum,
MambaAttentionBackendEnum,
)
from vllm.config.cache import CacheDType from vllm.config.cache import CacheDType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import STR_BACKEND_ENV_VAR from vllm.utils import STR_BACKEND_ENV_VAR
@ -197,6 +201,33 @@ def _cached_get_attn_backend(
return backend return backend
def get_mamba_attn_backend(
mamba_type: str,
) -> type[AttentionBackend]:
"""Select which mamba attention backend to use and lazily import it."""
return _cached_get_mamba_attn_backend(mamba_type)
@cache
def _cached_get_mamba_attn_backend(
mamba_type: str,
) -> type[AttentionBackend]:
assert mamba_type and isinstance(mamba_type, str)
selected_backend = None
try:
backend_name = MAMBA_TYPE_TO_BACKEND_MAP[mamba_type]
selected_backend = MambaAttentionBackendEnum[backend_name]
except KeyError as e:
raise ValueError(
f"Invalid mamba attention backend type: '{backend_name}'. Valid "
f"backends are: {list(MambaAttentionBackendEnum.__members__.keys())}"
) from e
mamba_attn_backend = selected_backend.get_class()
return mamba_attn_backend
@contextmanager @contextmanager
def global_force_attn_backend_context_manager( def global_force_attn_backend_context_manager(
attn_backend: AttentionBackendEnum, attn_backend: AttentionBackendEnum,

View File

@ -5,7 +5,6 @@ import torch
from einops import rearrange from einops import rearrange
from torch import nn from torch import nn
from vllm.attention import AttentionBackend
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
from vllm.distributed import ( from vllm.distributed import (
@ -83,12 +82,7 @@ direct_register_custom_op(
class KimiDeltaAttention(nn.Module, MambaBase): class KimiDeltaAttention(nn.Module, MambaBase):
@property @property
def mamba_type(self) -> str: def mamba_type(self) -> str:
return "linear_attention" return "gdn_attention"
def get_attn_backend(self) -> type["AttentionBackend"]:
from vllm.v1.attention.backends.gdn_attn import GDNAttentionBackend
return GDNAttentionBackend
def get_state_dtype( def get_state_dtype(
self, self,

View File

@ -6,6 +6,7 @@ from typing import TYPE_CHECKING
import torch import torch
from vllm.attention.selector import get_mamba_attn_backend
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.v1.kv_cache_interface import KVCacheSpec, MambaSpec from vllm.v1.kv_cache_interface import KVCacheSpec, MambaSpec
@ -38,11 +39,6 @@ class MambaBase(AttentionLayerBase):
def mamba_type(self) -> str: def mamba_type(self) -> str:
pass pass
@abstractmethod
def get_attn_backend(self) -> type["AttentionBackend"]:
"""Get the attention backend class for this Mamba layer."""
pass
@abstractmethod @abstractmethod
def get_state_dtype(self) -> tuple[torch.dtype, ...]: def get_state_dtype(self) -> tuple[torch.dtype, ...]:
pass pass
@ -69,3 +65,7 @@ class MambaBase(AttentionLayerBase):
else 0 else 0
), ),
) )
def get_attn_backend(self) -> type["AttentionBackend"]:
"""Get the attention backend class for this Mamba layer."""
return get_mamba_attn_backend(self.mamba_type)

View File

@ -2,12 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math import math
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
from typing import TYPE_CHECKING
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -37,9 +31,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.utils.torch_utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
class MiniMaxText01RMSNormTP(CustomOp): class MiniMaxText01RMSNormTP(CustomOp):
name = "MiniMaxText01RMSNormTP" name = "MiniMaxText01RMSNormTP"
@ -123,11 +114,6 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
def mamba_type(self) -> str: def mamba_type(self) -> str:
return "linear_attention" return "linear_attention"
def get_attn_backend(self) -> type["AttentionBackend"]:
from vllm.v1.attention.backends.linear_attn import LinearAttentionBackend
return LinearAttentionBackend
def get_state_dtype(self) -> tuple[torch.dtype]: def get_state_dtype(self) -> tuple[torch.dtype]:
assert self.model_config is not None assert self.model_config is not None
assert self.cache_config is not None assert self.cache_config is not None

View File

@ -1,10 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, NamedTuple from typing import NamedTuple
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
import torch import torch
from torch import nn from torch import nn
@ -452,11 +449,6 @@ class MambaMixer(MambaBase, CustomOp):
def mamba_type(self) -> str: def mamba_type(self) -> str:
return "mamba1" return "mamba1"
def get_attn_backend(self) -> type["AttentionBackend"]:
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend
return Mamba1AttentionBackend
def _time_proj_bias(self) -> torch.Tensor | None: def _time_proj_bias(self) -> torch.Tensor | None:
if hasattr(self.dt_proj, "bias") and self.dt_proj.bias is not None: if hasattr(self.dt_proj, "bias") and self.dt_proj.bias is not None:
return self.dt_proj.bias.float() return self.dt_proj.bias.float()

View File

@ -1,10 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
import torch import torch
from torch import nn from torch import nn
@ -908,11 +904,6 @@ class MambaMixer2(MambaBase, CustomOp):
def mamba_type(self) -> str: def mamba_type(self) -> str:
return "mamba2" return "mamba2"
def get_attn_backend(self) -> type["AttentionBackend"]:
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend
return Mamba2AttentionBackend
def mamba_mixer2( def mamba_mixer2(
projected_states: torch.Tensor, projected_states: torch.Tensor,

View File

@ -1,10 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
import torch import torch
@ -232,11 +228,6 @@ class ShortConv(MambaBase, CustomOp):
def mamba_type(self) -> str: def mamba_type(self) -> str:
return "short_conv" return "short_conv"
def get_attn_backend(self) -> type["AttentionBackend"]:
from vllm.v1.attention.backends.short_conv_attn import ShortConvAttentionBackend
return ShortConvAttentionBackend
def short_conv( def short_conv(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,

View File

@ -4,10 +4,6 @@
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice from itertools import islice
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
import torch import torch
from torch import nn from torch import nn
@ -467,11 +463,6 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
def mamba_type(self) -> str: def mamba_type(self) -> str:
return "mamba2" return "mamba2"
def get_attn_backend(self) -> type["AttentionBackend"]:
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend
return Mamba2AttentionBackend
def plamo2_mamba_mixer( def plamo2_mamba_mixer(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,

View File

@ -10,7 +10,7 @@ from einops import rearrange
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from vllm.attention import Attention, AttentionBackend, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import ( from vllm.config import (
CacheConfig, CacheConfig,
@ -216,12 +216,7 @@ class Qwen3NextSparseMoeBlock(nn.Module):
class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
@property @property
def mamba_type(self) -> str: def mamba_type(self) -> str:
return "linear_attention" return "gdn_attention"
def get_attn_backend(self) -> type["AttentionBackend"]:
from vllm.v1.attention.backends.gdn_attn import GDNAttentionBackend
return GDNAttentionBackend
def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]: def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]:
return MambaStateDtypeCalculator.gated_delta_net_state_dtype( return MambaStateDtypeCalculator.gated_delta_net_state_dtype(