mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 12:44:27 +08:00
[Model][Mamba] Add selector for mamba attention backend and make it pluggable for other device (#26487)
Signed-off-by: shen-shanshan <467638484@qq.com>
This commit is contained in:
parent
48fc8b1e59
commit
d44e9df7d4
@ -146,6 +146,7 @@ We use "mamba-like" to refer to layers that posses a state that is updated in-pl
|
||||
For implementing new custom mamba-like layers, one should inherit from `MambaBase` and implement the methods `get_state_dtype`, `get_state_shape` to calculate the data types and state shapes at runtime, as well as `mamba_type` and `get_attn_backend`.
|
||||
It is also necessary to implement the "attention meta-data" class which handles the meta-data that is common across all layers.
|
||||
Please see [`LinearAttentionMetadata`](../../../vllm/v1/attention/backends/linear_attn.py) or [`ShortConvAttentionMetadata`](../../../vllm/v1/attention/backends/short_conv_attn.py) for examples of this.
|
||||
It is also worth noting that we should update `MAMBA_TYPE_TO_BACKEND_MAP` and `MambaAttentionBackendEnum` in [`registry.py`](../../../vllm/attention/backends/registry.py) when adding a new mamba backend.
|
||||
Finally, if one wants to support torch compile and CUDA graphs, it necessary to wrap the call to the mamba-like layer inside a custom op and register it.
|
||||
Please see the calls to `direct_register_custom_op` in [vllm/model_executor/models/minimax_text_01.py](../../../vllm/model_executor/models/minimax_text_01.py) or [vllm/model_executor/layers/mamba/short_conv.py](../../../vllm/model_executor/layers/mamba/short_conv.py) for examples of this.
|
||||
The new custom op should then be added to the list `_attention_ops` in [vllm/config/compilation.py](../../../vllm/config/compilation.py) to ensure that piecewise CUDA graphs works as intended.
|
||||
|
||||
@ -7,7 +7,7 @@ from vllm.attention.backends.abstract import (
|
||||
AttentionType,
|
||||
)
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
from vllm.attention.selector import get_attn_backend, get_mamba_attn_backend
|
||||
|
||||
__all__ = [
|
||||
"Attention",
|
||||
@ -15,4 +15,5 @@ __all__ = [
|
||||
"AttentionMetadata",
|
||||
"AttentionType",
|
||||
"get_attn_backend",
|
||||
"get_mamba_attn_backend",
|
||||
]
|
||||
|
||||
@ -2,8 +2,8 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention backend registry"""
|
||||
|
||||
import enum
|
||||
from collections.abc import Callable
|
||||
from enum import Enum, EnumMeta
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
from vllm.logger import init_logger
|
||||
@ -15,7 +15,7 @@ if TYPE_CHECKING:
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class _AttentionBackendEnumMeta(enum.EnumMeta):
|
||||
class _AttentionBackendEnumMeta(EnumMeta):
|
||||
"""Metaclass for AttentionBackendEnum to provide better error messages."""
|
||||
|
||||
def __getitem__(cls, name: str):
|
||||
@ -23,15 +23,15 @@ class _AttentionBackendEnumMeta(enum.EnumMeta):
|
||||
try:
|
||||
return super().__getitem__(name)
|
||||
except KeyError:
|
||||
members = cast("dict[str, AttentionBackendEnum]", cls.__members__).values()
|
||||
valid_backends = ", ".join(m.name for m in members)
|
||||
members = cast("dict[str, Enum]", cls.__members__).keys()
|
||||
valid_backends = ", ".join(members)
|
||||
raise ValueError(
|
||||
f"Unknown attention backend: '{name}'. "
|
||||
f"Valid options are: {valid_backends}"
|
||||
) from None
|
||||
|
||||
|
||||
class AttentionBackendEnum(enum.Enum, metaclass=_AttentionBackendEnumMeta):
|
||||
class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta):
|
||||
"""Enumeration of all supported attention backends.
|
||||
|
||||
The enum value is the default class path, but this can be overridden
|
||||
@ -83,7 +83,7 @@ class AttentionBackendEnum(enum.Enum, metaclass=_AttentionBackendEnumMeta):
|
||||
Raises:
|
||||
ValueError: If Backend.CUSTOM is used without being registered
|
||||
"""
|
||||
path = _OVERRIDES.get(self, self.value)
|
||||
path = _ATTN_OVERRIDES.get(self, self.value)
|
||||
if not path:
|
||||
raise ValueError(
|
||||
f"Backend {self.name} must be registered before use. "
|
||||
@ -111,18 +111,93 @@ class AttentionBackendEnum(enum.Enum, metaclass=_AttentionBackendEnumMeta):
|
||||
Returns:
|
||||
True if the backend has a registered override
|
||||
"""
|
||||
return self in _OVERRIDES
|
||||
return self in _ATTN_OVERRIDES
|
||||
|
||||
def clear_override(self) -> None:
|
||||
"""Clear any override for this backend, reverting to the default."""
|
||||
_OVERRIDES.pop(self, None)
|
||||
_ATTN_OVERRIDES.pop(self, None)
|
||||
|
||||
|
||||
_OVERRIDES: dict[AttentionBackendEnum, str] = {}
|
||||
class MambaAttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta):
|
||||
"""Enumeration of all supported mamba attention backends.
|
||||
|
||||
The enum value is the default class path, but this can be overridden
|
||||
at runtime using register_backend().
|
||||
|
||||
To get the actual backend class (respecting overrides), use:
|
||||
backend.get_class()
|
||||
"""
|
||||
|
||||
MAMBA1 = "vllm.v1.attention.backends.mamba1_attn.Mamba1AttentionBackend"
|
||||
MAMBA2 = "vllm.v1.attention.backends.mamba2_attn.Mamba2AttentionBackend"
|
||||
SHORT_CONV = "vllm.v1.attention.backends.short_conv_attn.ShortConvAttentionBackend"
|
||||
LINEAR = "vllm.v1.attention.backends.linear_attn.LinearAttentionBackend"
|
||||
GDN_ATTN = "vllm.v1.attention.backends.gdn_attn.GDNAttentionBackend"
|
||||
# Placeholder for third-party/custom backends - must be registered before use
|
||||
CUSTOM = ""
|
||||
|
||||
def get_path(self, include_classname: bool = True) -> str:
|
||||
"""Get the class path for this backend (respects overrides).
|
||||
|
||||
Returns:
|
||||
The fully qualified class path string
|
||||
|
||||
Raises:
|
||||
ValueError: If Backend.CUSTOM is used without being registered
|
||||
"""
|
||||
path = _MAMBA_ATTN_OVERRIDES.get(self, self.value)
|
||||
if not path:
|
||||
raise ValueError(
|
||||
f"Backend {self.name} must be registered before use. "
|
||||
f"Use register_backend(Backend.{self.name}, 'your.module.YourClass')"
|
||||
)
|
||||
if not include_classname:
|
||||
path = path.rsplit(".", 1)[0]
|
||||
return path
|
||||
|
||||
def get_class(self) -> "type[AttentionBackend]":
|
||||
"""Get the backend class (respects overrides).
|
||||
|
||||
Returns:
|
||||
The backend class
|
||||
|
||||
Raises:
|
||||
ImportError: If the backend class cannot be imported
|
||||
ValueError: If Backend.CUSTOM is used without being registered
|
||||
"""
|
||||
return resolve_obj_by_qualname(self.get_path())
|
||||
|
||||
def is_overridden(self) -> bool:
|
||||
"""Check if this backend has been overridden.
|
||||
|
||||
Returns:
|
||||
True if the backend has a registered override
|
||||
"""
|
||||
return self in _MAMBA_ATTN_OVERRIDES
|
||||
|
||||
def clear_override(self) -> None:
|
||||
"""Clear any override for this backend, reverting to the default."""
|
||||
_MAMBA_ATTN_OVERRIDES.pop(self, None)
|
||||
|
||||
|
||||
MAMBA_TYPE_TO_BACKEND_MAP = {
|
||||
"mamba1": MambaAttentionBackendEnum.MAMBA1.name,
|
||||
"mamba2": MambaAttentionBackendEnum.MAMBA2.name,
|
||||
"short_conv": MambaAttentionBackendEnum.SHORT_CONV.name,
|
||||
"linear_attention": MambaAttentionBackendEnum.LINEAR.name,
|
||||
"gdn_attention": MambaAttentionBackendEnum.GDN_ATTN.name,
|
||||
"custom": MambaAttentionBackendEnum.CUSTOM.name,
|
||||
}
|
||||
|
||||
|
||||
_ATTN_OVERRIDES: dict[AttentionBackendEnum, str] = {}
|
||||
_MAMBA_ATTN_OVERRIDES: dict[MambaAttentionBackendEnum, str] = {}
|
||||
|
||||
|
||||
def register_backend(
|
||||
backend: AttentionBackendEnum, class_path: str | None = None
|
||||
backend: AttentionBackendEnum | MambaAttentionBackendEnum,
|
||||
is_mamba: bool = False,
|
||||
class_path: str | None = None,
|
||||
) -> Callable[[type], type]:
|
||||
"""Register or override a backend implementation.
|
||||
|
||||
@ -135,12 +210,17 @@ def register_backend(
|
||||
Decorator function if class_path is None, otherwise a no-op
|
||||
|
||||
Examples:
|
||||
# Override an existing backend
|
||||
# Override an existing attention backend
|
||||
@register_backend(AttentionBackendEnum.FLASH_ATTN)
|
||||
class MyCustomFlashAttn:
|
||||
...
|
||||
|
||||
# Register a custom third-party backend
|
||||
# Override an existing mamba attention backend
|
||||
@register_backend(MambaAttentionBackendEnum.LINEAR, is_mamba=True)
|
||||
class MyCustomMambaAttn:
|
||||
...
|
||||
|
||||
# Register a custom third-party attention backend
|
||||
@register_backend(AttentionBackendEnum.CUSTOM)
|
||||
class MyCustomBackend:
|
||||
...
|
||||
@ -153,11 +233,17 @@ def register_backend(
|
||||
"""
|
||||
|
||||
def decorator(cls: type) -> type:
|
||||
_OVERRIDES[backend] = f"{cls.__module__}.{cls.__qualname__}"
|
||||
if is_mamba:
|
||||
_MAMBA_ATTN_OVERRIDES[backend] = f"{cls.__module__}.{cls.__qualname__}" # type: ignore[index]
|
||||
else:
|
||||
_ATTN_OVERRIDES[backend] = f"{cls.__module__}.{cls.__qualname__}" # type: ignore[index]
|
||||
return cls
|
||||
|
||||
if class_path is not None:
|
||||
_OVERRIDES[backend] = class_path
|
||||
if is_mamba:
|
||||
_MAMBA_ATTN_OVERRIDES[backend] = class_path # type: ignore[index]
|
||||
else:
|
||||
_ATTN_OVERRIDES[backend] = class_path # type: ignore[index]
|
||||
return lambda x: x
|
||||
|
||||
return decorator
|
||||
|
||||
@ -12,7 +12,11 @@ import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.backends.registry import (
|
||||
MAMBA_TYPE_TO_BACKEND_MAP,
|
||||
AttentionBackendEnum,
|
||||
MambaAttentionBackendEnum,
|
||||
)
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import STR_BACKEND_ENV_VAR
|
||||
@ -197,6 +201,33 @@ def _cached_get_attn_backend(
|
||||
return backend
|
||||
|
||||
|
||||
def get_mamba_attn_backend(
|
||||
mamba_type: str,
|
||||
) -> type[AttentionBackend]:
|
||||
"""Select which mamba attention backend to use and lazily import it."""
|
||||
return _cached_get_mamba_attn_backend(mamba_type)
|
||||
|
||||
|
||||
@cache
|
||||
def _cached_get_mamba_attn_backend(
|
||||
mamba_type: str,
|
||||
) -> type[AttentionBackend]:
|
||||
assert mamba_type and isinstance(mamba_type, str)
|
||||
|
||||
selected_backend = None
|
||||
try:
|
||||
backend_name = MAMBA_TYPE_TO_BACKEND_MAP[mamba_type]
|
||||
selected_backend = MambaAttentionBackendEnum[backend_name]
|
||||
except KeyError as e:
|
||||
raise ValueError(
|
||||
f"Invalid mamba attention backend type: '{backend_name}'. Valid "
|
||||
f"backends are: {list(MambaAttentionBackendEnum.__members__.keys())}"
|
||||
) from e
|
||||
|
||||
mamba_attn_backend = selected_backend.get_class()
|
||||
return mamba_attn_backend
|
||||
|
||||
|
||||
@contextmanager
|
||||
def global_force_attn_backend_context_manager(
|
||||
attn_backend: AttentionBackendEnum,
|
||||
|
||||
@ -5,7 +5,6 @@ import torch
|
||||
from einops import rearrange
|
||||
from torch import nn
|
||||
|
||||
from vllm.attention import AttentionBackend
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
|
||||
from vllm.distributed import (
|
||||
@ -83,12 +82,7 @@ direct_register_custom_op(
|
||||
class KimiDeltaAttention(nn.Module, MambaBase):
|
||||
@property
|
||||
def mamba_type(self) -> str:
|
||||
return "linear_attention"
|
||||
|
||||
def get_attn_backend(self) -> type["AttentionBackend"]:
|
||||
from vllm.v1.attention.backends.gdn_attn import GDNAttentionBackend
|
||||
|
||||
return GDNAttentionBackend
|
||||
return "gdn_attention"
|
||||
|
||||
def get_state_dtype(
|
||||
self,
|
||||
|
||||
@ -6,6 +6,7 @@ from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.selector import get_mamba_attn_backend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.v1.kv_cache_interface import KVCacheSpec, MambaSpec
|
||||
@ -38,11 +39,6 @@ class MambaBase(AttentionLayerBase):
|
||||
def mamba_type(self) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_attn_backend(self) -> type["AttentionBackend"]:
|
||||
"""Get the attention backend class for this Mamba layer."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_state_dtype(self) -> tuple[torch.dtype, ...]:
|
||||
pass
|
||||
@ -69,3 +65,7 @@ class MambaBase(AttentionLayerBase):
|
||||
else 0
|
||||
),
|
||||
)
|
||||
|
||||
def get_attn_backend(self) -> type["AttentionBackend"]:
|
||||
"""Get the attention backend class for this Mamba layer."""
|
||||
return get_mamba_attn_backend(self.mamba_type)
|
||||
|
||||
@ -2,12 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -37,9 +31,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
|
||||
|
||||
class MiniMaxText01RMSNormTP(CustomOp):
|
||||
name = "MiniMaxText01RMSNormTP"
|
||||
@ -123,11 +114,6 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
||||
def mamba_type(self) -> str:
|
||||
return "linear_attention"
|
||||
|
||||
def get_attn_backend(self) -> type["AttentionBackend"]:
|
||||
from vllm.v1.attention.backends.linear_attn import LinearAttentionBackend
|
||||
|
||||
return LinearAttentionBackend
|
||||
|
||||
def get_state_dtype(self) -> tuple[torch.dtype]:
|
||||
assert self.model_config is not None
|
||||
assert self.cache_config is not None
|
||||
|
||||
@ -1,10 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TYPE_CHECKING, NamedTuple
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from typing import NamedTuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -452,11 +449,6 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
def mamba_type(self) -> str:
|
||||
return "mamba1"
|
||||
|
||||
def get_attn_backend(self) -> type["AttentionBackend"]:
|
||||
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend
|
||||
|
||||
return Mamba1AttentionBackend
|
||||
|
||||
def _time_proj_bias(self) -> torch.Tensor | None:
|
||||
if hasattr(self.dt_proj, "bias") and self.dt_proj.bias is not None:
|
||||
return self.dt_proj.bias.float()
|
||||
|
||||
@ -1,10 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -908,11 +904,6 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
def mamba_type(self) -> str:
|
||||
return "mamba2"
|
||||
|
||||
def get_attn_backend(self) -> type["AttentionBackend"]:
|
||||
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend
|
||||
|
||||
return Mamba2AttentionBackend
|
||||
|
||||
|
||||
def mamba_mixer2(
|
||||
projected_states: torch.Tensor,
|
||||
|
||||
@ -1,10 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
|
||||
import torch
|
||||
|
||||
@ -232,11 +228,6 @@ class ShortConv(MambaBase, CustomOp):
|
||||
def mamba_type(self) -> str:
|
||||
return "short_conv"
|
||||
|
||||
def get_attn_backend(self) -> type["AttentionBackend"]:
|
||||
from vllm.v1.attention.backends.short_conv_attn import ShortConvAttentionBackend
|
||||
|
||||
return ShortConvAttentionBackend
|
||||
|
||||
|
||||
def short_conv(
|
||||
hidden_states: torch.Tensor,
|
||||
|
||||
@ -4,10 +4,6 @@
|
||||
|
||||
from collections.abc import Iterable
|
||||
from itertools import islice
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -467,11 +463,6 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
|
||||
def mamba_type(self) -> str:
|
||||
return "mamba2"
|
||||
|
||||
def get_attn_backend(self) -> type["AttentionBackend"]:
|
||||
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend
|
||||
|
||||
return Mamba2AttentionBackend
|
||||
|
||||
|
||||
def plamo2_mamba_mixer(
|
||||
hidden_states: torch.Tensor,
|
||||
|
||||
@ -10,7 +10,7 @@ from einops import rearrange
|
||||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
|
||||
from vllm.attention import Attention, AttentionBackend, AttentionMetadata
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import (
|
||||
CacheConfig,
|
||||
@ -216,12 +216,7 @@ class Qwen3NextSparseMoeBlock(nn.Module):
|
||||
class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
@property
|
||||
def mamba_type(self) -> str:
|
||||
return "linear_attention"
|
||||
|
||||
def get_attn_backend(self) -> type["AttentionBackend"]:
|
||||
from vllm.v1.attention.backends.gdn_attn import GDNAttentionBackend
|
||||
|
||||
return GDNAttentionBackend
|
||||
return "gdn_attention"
|
||||
|
||||
def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]:
|
||||
return MambaStateDtypeCalculator.gated_delta_net_state_dtype(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user