diff --git a/tests/v1/attention/test_mamba_selectors.py b/tests/v1/attention/test_mamba_selectors.py new file mode 100644 index 0000000000000..8eaafc5e16816 --- /dev/null +++ b/tests/v1/attention/test_mamba_selectors.py @@ -0,0 +1,25 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for mamba attention backend selectors.""" + +import pytest + +from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend +from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend + + +@pytest.mark.parametrize(argnames=["mamba_type", "expected_backend"], + argvalues=[("mamba2", Mamba2AttentionBackend)]) +def test_get_mamba_attn_backend_mamba2(mamba_type, expected_backend): + backend_class = get_mamba_attn_backend(mamba_type) + + assert backend_class is expected_backend + + +def test_get_mamba_attn_backend_unsupported(): + unsupported_types = ["mamba", ""] + + for mamba_type in unsupported_types: + err_message = f"Mamba Attention type {mamba_type} is not supported yet." + with pytest.raises(NotImplementedError, match=err_message): + get_mamba_attn_backend(mamba_type) diff --git a/vllm/model_executor/layers/mamba/abstract.py b/vllm/model_executor/layers/mamba/abstract.py index 4c4997b4894aa..daebe46f6f771 100644 --- a/vllm/model_executor/layers/mamba/abstract.py +++ b/vllm/model_executor/layers/mamba/abstract.py @@ -27,3 +27,8 @@ class MambaBase(ABC): In this case, returns (conv_state_shape, ssm_state_shape). """ pass + + @property + @abstractmethod + def mamba_type(self) -> str: + pass diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 2c95099e53ad6..36edac2375d0e 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -732,6 +732,10 @@ class MambaMixer2(MambaBase, CustomOp): conv_kernel=self.conv_kernel_size, ) + @property + def mamba_type(self) -> str: + return "mamba2" + def mamba_mixer2( hidden_states: torch.Tensor, diff --git a/vllm/v1/attention/backends/mamba_selectors.py b/vllm/v1/attention/backends/mamba_selectors.py new file mode 100644 index 0000000000000..80021a2165567 --- /dev/null +++ b/vllm/v1/attention/backends/mamba_selectors.py @@ -0,0 +1,12 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm.attention.backends.abstract import AttentionBackend +from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend + + +def get_mamba_attn_backend(mamba_type: str) -> type[AttentionBackend]: + if mamba_type == "mamba2": + return Mamba2AttentionBackend + + raise NotImplementedError(f"Mamba Attention type {mamba_type} is not " + "supported yet.") diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index bec31a7a058d2..1da5230116d26 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -200,13 +200,14 @@ class MambaSpec(KVCacheSpec): shapes: tuple[tuple[int, ...], ...] dtype: torch.dtype page_size_padded: Optional[int] = None + mamba_type: str = "mamba2" def __post_init__(self): self.num_elements = sum(prod(shape) for shape in self.shapes) @property def type_id(self) -> str: - return f"mamba_{self.shapes}_{self.dtype}" + return f"mamba_{self.shapes}_{self.dtype}_{self.mamba_type}" @property def page_size_bytes(self) -> int: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 559515d48e471..f3384b5756618 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -44,7 +44,7 @@ from vllm.tasks import GenerationTask, PoolingTask, SupportedTask from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size, is_pin_memory_available, round_up) -from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend +from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend from vllm.v1.attention.backends.utils import ( AttentionMetadataBuilder, CommonAttentionMetadata, make_local_attention_virtual_batches) @@ -2539,7 +2539,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): "Non-Attention backend is not supported by V1 " "GPUModelRunner.") elif isinstance(kv_cache_spec, MambaSpec): - attn_backend_i = Mamba2AttentionBackend + attn_backend_i = get_mamba_attn_backend(kv_cache_spec.mamba_type) else: raise ValueError( f"Unknown KV cache spec type: {type(kv_cache_spec)}") @@ -2919,7 +2919,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): shapes=mamba_module.get_state_shape(), dtype=self.kv_cache_dtype, block_size=max_model_len, - page_size_padded=page_size_padded) + page_size_padded=page_size_padded, + mamba_type=mamba_module.mamba_type) return kv_cache_spec