[v1][mamba] Added mamba_type into MambaSpec (#21715)

Signed-off-by: asafg <asafg@ai21.com>
Co-authored-by: asafg <asafg@ai21.com>
This commit is contained in:
Asaf Joseph Gardin 2025-07-28 11:15:55 +03:00 committed by GitHub
parent 139a7f07bd
commit a6c050286a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 52 additions and 4 deletions

View File

@ -0,0 +1,25 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for mamba attention backend selectors."""
import pytest
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend
from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend
@pytest.mark.parametrize(argnames=["mamba_type", "expected_backend"],
argvalues=[("mamba2", Mamba2AttentionBackend)])
def test_get_mamba_attn_backend_mamba2(mamba_type, expected_backend):
backend_class = get_mamba_attn_backend(mamba_type)
assert backend_class is expected_backend
def test_get_mamba_attn_backend_unsupported():
unsupported_types = ["mamba", ""]
for mamba_type in unsupported_types:
err_message = f"Mamba Attention type {mamba_type} is not supported yet."
with pytest.raises(NotImplementedError, match=err_message):
get_mamba_attn_backend(mamba_type)

View File

@ -27,3 +27,8 @@ class MambaBase(ABC):
In this case, returns (conv_state_shape, ssm_state_shape).
"""
pass
@property
@abstractmethod
def mamba_type(self) -> str:
pass

View File

@ -732,6 +732,10 @@ class MambaMixer2(MambaBase, CustomOp):
conv_kernel=self.conv_kernel_size,
)
@property
def mamba_type(self) -> str:
return "mamba2"
def mamba_mixer2(
hidden_states: torch.Tensor,

View File

@ -0,0 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.attention.backends.abstract import AttentionBackend
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend
def get_mamba_attn_backend(mamba_type: str) -> type[AttentionBackend]:
if mamba_type == "mamba2":
return Mamba2AttentionBackend
raise NotImplementedError(f"Mamba Attention type {mamba_type} is not "
"supported yet.")

View File

@ -200,13 +200,14 @@ class MambaSpec(KVCacheSpec):
shapes: tuple[tuple[int, ...], ...]
dtype: torch.dtype
page_size_padded: Optional[int] = None
mamba_type: str = "mamba2"
def __post_init__(self):
self.num_elements = sum(prod(shape) for shape in self.shapes)
@property
def type_id(self) -> str:
return f"mamba_{self.shapes}_{self.dtype}"
return f"mamba_{self.shapes}_{self.dtype}_{self.mamba_type}"
@property
def page_size_bytes(self) -> int:

View File

@ -44,7 +44,7 @@ from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size,
is_pin_memory_available, round_up)
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend
from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata,
make_local_attention_virtual_batches)
@ -2539,7 +2539,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
"Non-Attention backend is not supported by V1 "
"GPUModelRunner.")
elif isinstance(kv_cache_spec, MambaSpec):
attn_backend_i = Mamba2AttentionBackend
attn_backend_i = get_mamba_attn_backend(kv_cache_spec.mamba_type)
else:
raise ValueError(
f"Unknown KV cache spec type: {type(kv_cache_spec)}")
@ -2919,7 +2919,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
shapes=mamba_module.get_state_shape(),
dtype=self.kv_cache_dtype,
block_size=max_model_len,
page_size_padded=page_size_padded)
page_size_padded=page_size_padded,
mamba_type=mamba_module.mamba_type)
return kv_cache_spec