mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-19 21:34:27 +08:00
[v1][mamba] Added mamba_type into MambaSpec (#21715)
Signed-off-by: asafg <asafg@ai21.com> Co-authored-by: asafg <asafg@ai21.com>
This commit is contained in:
parent
139a7f07bd
commit
a6c050286a
25
tests/v1/attention/test_mamba_selectors.py
Normal file
25
tests/v1/attention/test_mamba_selectors.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""Tests for mamba attention backend selectors."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend
|
||||||
|
from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(argnames=["mamba_type", "expected_backend"],
|
||||||
|
argvalues=[("mamba2", Mamba2AttentionBackend)])
|
||||||
|
def test_get_mamba_attn_backend_mamba2(mamba_type, expected_backend):
|
||||||
|
backend_class = get_mamba_attn_backend(mamba_type)
|
||||||
|
|
||||||
|
assert backend_class is expected_backend
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_mamba_attn_backend_unsupported():
|
||||||
|
unsupported_types = ["mamba", ""]
|
||||||
|
|
||||||
|
for mamba_type in unsupported_types:
|
||||||
|
err_message = f"Mamba Attention type {mamba_type} is not supported yet."
|
||||||
|
with pytest.raises(NotImplementedError, match=err_message):
|
||||||
|
get_mamba_attn_backend(mamba_type)
|
||||||
@ -27,3 +27,8 @@ class MambaBase(ABC):
|
|||||||
In this case, returns (conv_state_shape, ssm_state_shape).
|
In this case, returns (conv_state_shape, ssm_state_shape).
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def mamba_type(self) -> str:
|
||||||
|
pass
|
||||||
|
|||||||
@ -732,6 +732,10 @@ class MambaMixer2(MambaBase, CustomOp):
|
|||||||
conv_kernel=self.conv_kernel_size,
|
conv_kernel=self.conv_kernel_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mamba_type(self) -> str:
|
||||||
|
return "mamba2"
|
||||||
|
|
||||||
|
|
||||||
def mamba_mixer2(
|
def mamba_mixer2(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
|||||||
12
vllm/v1/attention/backends/mamba_selectors.py
Normal file
12
vllm/v1/attention/backends/mamba_selectors.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
|
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend
|
||||||
|
|
||||||
|
|
||||||
|
def get_mamba_attn_backend(mamba_type: str) -> type[AttentionBackend]:
|
||||||
|
if mamba_type == "mamba2":
|
||||||
|
return Mamba2AttentionBackend
|
||||||
|
|
||||||
|
raise NotImplementedError(f"Mamba Attention type {mamba_type} is not "
|
||||||
|
"supported yet.")
|
||||||
@ -200,13 +200,14 @@ class MambaSpec(KVCacheSpec):
|
|||||||
shapes: tuple[tuple[int, ...], ...]
|
shapes: tuple[tuple[int, ...], ...]
|
||||||
dtype: torch.dtype
|
dtype: torch.dtype
|
||||||
page_size_padded: Optional[int] = None
|
page_size_padded: Optional[int] = None
|
||||||
|
mamba_type: str = "mamba2"
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.num_elements = sum(prod(shape) for shape in self.shapes)
|
self.num_elements = sum(prod(shape) for shape in self.shapes)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def type_id(self) -> str:
|
def type_id(self) -> str:
|
||||||
return f"mamba_{self.shapes}_{self.dtype}"
|
return f"mamba_{self.shapes}_{self.dtype}_{self.mamba_type}"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def page_size_bytes(self) -> int:
|
def page_size_bytes(self) -> int:
|
||||||
|
|||||||
@ -44,7 +44,7 @@ from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
|
|||||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||||
GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size,
|
GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size,
|
||||||
is_pin_memory_available, round_up)
|
is_pin_memory_available, round_up)
|
||||||
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend
|
from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend
|
||||||
from vllm.v1.attention.backends.utils import (
|
from vllm.v1.attention.backends.utils import (
|
||||||
AttentionMetadataBuilder, CommonAttentionMetadata,
|
AttentionMetadataBuilder, CommonAttentionMetadata,
|
||||||
make_local_attention_virtual_batches)
|
make_local_attention_virtual_batches)
|
||||||
@ -2539,7 +2539,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
"Non-Attention backend is not supported by V1 "
|
"Non-Attention backend is not supported by V1 "
|
||||||
"GPUModelRunner.")
|
"GPUModelRunner.")
|
||||||
elif isinstance(kv_cache_spec, MambaSpec):
|
elif isinstance(kv_cache_spec, MambaSpec):
|
||||||
attn_backend_i = Mamba2AttentionBackend
|
attn_backend_i = get_mamba_attn_backend(kv_cache_spec.mamba_type)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown KV cache spec type: {type(kv_cache_spec)}")
|
f"Unknown KV cache spec type: {type(kv_cache_spec)}")
|
||||||
@ -2919,7 +2919,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
shapes=mamba_module.get_state_shape(),
|
shapes=mamba_module.get_state_shape(),
|
||||||
dtype=self.kv_cache_dtype,
|
dtype=self.kv_cache_dtype,
|
||||||
block_size=max_model_len,
|
block_size=max_model_len,
|
||||||
page_size_padded=page_size_padded)
|
page_size_padded=page_size_padded,
|
||||||
|
mamba_type=mamba_module.mamba_type)
|
||||||
|
|
||||||
return kv_cache_spec
|
return kv_cache_spec
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user