[v1] - Mamba1 Attention Metadata (#21249)

Signed-off-by: asafg <asafg@ai21.com>
Co-authored-by: asafg <asafg@ai21.com>
This commit is contained in:
Asaf Joseph Gardin 2025-08-07 03:03:42 +03:00 committed by GitHub
parent 31f09c615f
commit 46a13949d5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 367 additions and 161 deletions

View File

@ -45,6 +45,9 @@ struct SSMParamsBase {
index_t out_d_stride;
index_t out_z_batch_stride;
index_t out_z_d_stride;
index_t ssm_states_batch_stride;
index_t ssm_states_dim_stride;
index_t ssm_states_dstate_stride;
// Common data pointers.
void *__restrict__ A_ptr;

View File

@ -132,8 +132,10 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + sequence_start_index * params.B_batch_stride + group_id * params.B_group_stride;
weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * kNRows * params.C_d_stride;
input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + sequence_start_index * params.C_batch_stride + group_id * params.C_group_stride;
input_t *ssm_states = reinterpret_cast<input_t *>(params.ssm_states_ptr) + (cache_index * params.dim + dim_id * kNRows) * params.dstate;
input_t *ssm_states = reinterpret_cast<input_t *>(params.ssm_states_ptr) +
cache_index * params.ssm_states_batch_stride +
dim_id * kNRows * params.ssm_states_dim_stride;
float D_val[kNRows] = {0};
if (params.D_ptr != nullptr) {
#pragma unroll
@ -248,7 +250,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
}
// Initialize running total
scan_t running_prefix = chunk > 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.0, has_initial_state ? float(ssm_states[state_idx]): 0.0);
scan_t running_prefix = chunk > 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.0, has_initial_state ? float(ssm_states[state_idx * params.ssm_states_dstate_stride]): 0.0);
SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
typename Ktraits::BlockScanT(smem_scan).InclusiveScan(
@ -259,7 +261,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
if (threadIdx.x == 0) {
smem_running_prefix[state_idx] = prefix_op.running_prefix;
if (chunk == n_chunks - 1) {
ssm_states[state_idx] = input_t(prefix_op.running_prefix.y);
ssm_states[state_idx * params.ssm_states_dstate_stride] = input_t(prefix_op.running_prefix.y);
}
}
#pragma unroll
@ -481,6 +483,10 @@ void set_ssm_params_fwd(SSMParamsBase &params,
params.out_batch_stride = out.stride(1);
params.out_d_stride = out.stride(0);
params.ssm_states_batch_stride = ssm_states.stride(0);
params.ssm_states_dim_stride = ssm_states.stride(1);
params.ssm_states_dstate_stride = ssm_states.stride(2);
}
else{
if (!is_variable_B) {
@ -509,6 +515,10 @@ void set_ssm_params_fwd(SSMParamsBase &params,
}
params.out_batch_stride = out.stride(0);
params.out_d_stride = out.stride(1);
params.ssm_states_batch_stride = ssm_states.stride(0);
params.ssm_states_dim_stride = ssm_states.stride(1);
params.ssm_states_dstate_stride = ssm_states.stride(2);
}
}

View File

@ -370,9 +370,9 @@ th {
| `InternLM2ForCausalLM` | InternLM2 | `internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `InternLM3ForCausalLM` | InternLM3 | `internlm/internlm3-8b-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `JAISLMHeadModel` | Jais | `inceptionai/jais-13b`, `inceptionai/jais-13b-chat`, `inceptionai/jais-30b-v3`, `inceptionai/jais-30b-chat-v3`, etc. | | ✅︎ | ✅︎ |
| `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | ✅︎ | ✅︎ | |
| `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `LlamaForCausalLM` | Llama 3.1, Llama 3, Llama 2, LLaMA, Yi | `meta-llama/Meta-Llama-3.1-405B-Instruct`, `meta-llama/Meta-Llama-3.1-70B`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `01-ai/Yi-34B`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ | |
| `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ | ✅︎ |
| `Mamba2ForCausalLM` | Mamba2 | `mistralai/Mamba-Codestral-7B-v0.1`, etc. | | ✅︎ | ✅︎ |
| `MiMoForCausalLM` | MiMo | `XiaomiMiMo/MiMo-7B-RL`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `MiniCPMForCausalLM` | MiniCPM | `openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, `openbmb/MiniCPM-S-1B-sft`, etc. | ✅︎ | ✅︎ | ✅︎ |

View File

@ -83,7 +83,7 @@ based on assigned priority, with FCFS as a tie-breaker), configurable via the
| **Decoder-only Models** | <nobr>🚀 Optimized</nobr> |
| **Encoder-Decoder Models** | <nobr>🟠 Delayed</nobr> |
| **Embedding Models** | <nobr>🟢 Functional</nobr> |
| **Mamba Models** | <nobr>🟢 (Mamba-2), 🟡 (Mamba-1)</nobr> |
| **Mamba Models** | <nobr>🟢 (Mamba-2), 🟢 (Mamba-1)</nobr> |
| **Multimodal Models** | <nobr>🟢 Functional</nobr> |
vLLM V1 currently excludes model architectures with the `SupportsV0Only` protocol.
@ -104,13 +104,11 @@ to enable simultaneous generation and embedding using the same engine instance i
#### Mamba Models
Models using selective state-space mechanisms instead of standard transformer attention are partially supported.
Models that use Mamba-2 layers (e.g., `Mamba2ForCausalLM`) are supported, but models that use older Mamba-1 layers
(e.g., `MambaForCausalLM`, `JambaForCausalLM`) are not yet supported. Please note that these models currently require
disabling prefix caching in V1.
Models using selective state-space mechanisms instead of standard transformer attention are supported.
Models that use Mamba-2 and Mamba-1 layers (e.g., `Mamba2ForCausalLM`, `MambaForCausalLM`) are supported. Please note that these models currently require disabling prefix caching in V1. Additionally, Mamba-1 models require `enforce_eager=True`.
Models that combine Mamba-2 layers with standard attention layers are also supported (e.g., `BambaForCausalLM`,
`Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`). Please note that
Models that combine Mamba-2 and Mamba-1 layers with standard attention layers are also supported (e.g., `BambaForCausalLM`,
`Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`, `JambaForCausalLM`). Please note that
these models currently require disabling prefix caching and using the FlashInfer attention backend in V1.
#### Encoder-Decoder Models

View File

@ -53,6 +53,8 @@ HF_UNSUPPORTED_MODELS = [
]
V1_SUPPORTED_MODELS = [
"state-spaces/mamba-130m-hf",
"ai21labs/Jamba-tiny-dev",
"mistralai/Mamba-Codestral-7B-v0.1",
"ibm-ai-platform/Bamba-9B-v1",
"Zyphra/Zamba2-1.2B-instruct",

View File

@ -12,7 +12,6 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
UNSUPPORTED_MODELS_V1 = [
"openai/whisper-large-v3", # transcription
"facebook/bart-large-cnn", # encoder decoder
"state-spaces/mamba-130m-hf", # mamba1
]
MODEL = "meta-llama/Llama-3.2-1B-Instruct"

View File

@ -1,30 +1,37 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
from torch import nn
from torch.nn.parameter import Parameter
from vllm.attention.backends.abstract import AttentionMetadata
from vllm import envs
from vllm.config import get_current_vllm_config
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.forward_context import get_forward_context
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateShapeCalculator)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_scan_fn, selective_state_update)
from vllm.model_executor.models.mamba_cache import MambaCacheParams
from vllm.model_executor.utils import set_weight_attrs
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
@CustomOp.register("mamba_mixer")
class MambaMixer(CustomOp):
class MambaMixer(MambaBase, CustomOp):
"""
Compute , A, B, C, and D the state space parameters and compute
the `contextualized_states`. A, D are input independent
@ -47,13 +54,16 @@ class MambaMixer(CustomOp):
rms_norm_has_weight: bool = True,
rms_norm_eps: float = 1e-5,
activation="silu",
is_lora_enabled: bool = False):
is_lora_enabled: bool = False,
prefix: str = ""):
super().__init__()
self.time_step_rank = time_step_rank
self.ssm_state_size = ssm_state_size
self.use_rms_norm = use_rms_norm
self.activation = activation
self.is_lora_enabled = is_lora_enabled
self.conv_kernel_size = conv_kernel_size
self.intermediate_size = intermediate_size
self.conv1d = ColumnParallelLinear(
input_size=conv_kernel_size,
@ -131,14 +141,62 @@ class MambaMixer(CustomOp):
has_weight=rms_norm_has_weight,
) if use_rms_norm else None
def forward_native(self, hidden_states: torch.Tensor,
conv_state: torch.Tensor, ssm_state: torch.Tensor):
if envs.VLLM_USE_V1:
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
# The outer list is for v0 PP virtual engine. Though this code path
# only runs for v1, we have to do this to unify with the interface
# of Attention + v0 PP.
# The inner tuple is (conv_state, ssm_state)
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
self.prefix = prefix
def forward(self,
hidden_states: torch.Tensor,
mamba_cache_params: Optional[MambaCacheParams] = None):
if not envs.VLLM_USE_V1:
return CustomOp.forward(self, hidden_states, mamba_cache_params)
else:
return self.forward_cuda(hidden_states, mamba_cache_params)
def forward_native(self,
hidden_states: torch.Tensor,
mamba_cache_params: Optional[MambaCacheParams] = None):
pass
def forward_cuda(self, hidden_states: torch.Tensor,
mamba_cache_params: MambaCacheParams):
def forward_cuda(self,
hidden_states: torch.Tensor,
mamba_cache_params: Optional[MambaCacheParams] = None):
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if envs.VLLM_USE_V1:
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
mamba1_metadata = attn_metadata
assert isinstance(mamba1_metadata, Mamba1AttentionMetadata)
query_start_loc = mamba1_metadata.query_start_loc
state_indices_tensor = mamba1_metadata.state_indices_tensor
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
has_initial_state = mamba1_metadata.has_initial_states
context_lens_tensor = mamba1_metadata.context_lens_tensor
else:
assert mamba_cache_params is not None
conv_state = mamba_cache_params.conv_state
ssm_state = mamba_cache_params.ssm_state
state_indices_tensor = mamba_cache_params.state_indices_tensor
query_start_loc = attn_metadata.query_start_loc
context_lens_tensor = attn_metadata.context_lens_tensor
if context_lens_tensor is not None:
has_initial_state = context_lens_tensor > 0
# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
@ -148,8 +206,12 @@ class MambaMixer(CustomOp):
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
self.conv1d.weight.size(2))
if attn_metadata.query_start_loc is not None \
and attn_metadata.context_lens_tensor is not None:
if envs.VLLM_USE_V1 and attn_metadata is None:
# V1 profile run
hidden_states = hidden_states.contiguous()
return self.out_proj(hidden_states.transpose(-2, -1))[0]
if query_start_loc is not None and context_lens_tensor is not None:
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
@ -161,18 +223,18 @@ class MambaMixer(CustomOp):
conv_weights,
bias=self.conv1d.bias,
activation=self.activation,
conv_states=mamba_cache_params.conv_state,
has_initial_state=attn_metadata.context_lens_tensor > 0,
cache_indices=mamba_cache_params.state_indices_tensor,
query_start_loc=attn_metadata.query_start_loc)
conv_states=conv_state,
has_initial_state=has_initial_state,
cache_indices=state_indices_tensor,
query_start_loc=query_start_loc)
else:
hidden_states = causal_conv1d_update(
hidden_states.transpose(0, 1),
mamba_cache_params.conv_state,
conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=mamba_cache_params.state_indices_tensor)
conv_state_indices=state_indices_tensor)
hidden_states = hidden_states.transpose(0, 1)
# 3. State Space Model sequence transformation
@ -203,11 +265,10 @@ class MambaMixer(CustomOp):
time_proj_bias = (self.dt_proj.bias.float() if hasattr(
self.dt_proj, "bias") else None)
if attn_metadata.query_start_loc is not None \
and attn_metadata.context_lens_tensor is not None:
if query_start_loc is not None and context_lens_tensor is not None:
scan_outputs = selective_scan_fn(
hidden_states,
mamba_cache_params.ssm_state,
ssm_state,
discrete_time_step,
self.A,
B.transpose(-2, -1),
@ -216,24 +277,23 @@ class MambaMixer(CustomOp):
gate,
time_proj_bias,
delta_softplus=True,
cache_indices=mamba_cache_params.state_indices_tensor,
has_initial_state=attn_metadata.context_lens_tensor > 0,
query_start_loc=attn_metadata.query_start_loc)
cache_indices=state_indices_tensor,
has_initial_state=has_initial_state,
query_start_loc=query_start_loc)
else:
scan_outputs = torch.empty_like(hidden_states.transpose(0, 1))
selective_state_update(
mamba_cache_params.ssm_state,
hidden_states.transpose(0, 1),
discrete_time_step.transpose(0, 1),
self.A,
B,
C,
self.D,
gate.transpose(0, 1),
time_proj_bias,
dt_softplus=True,
state_batch_indices=mamba_cache_params.state_indices_tensor,
out=scan_outputs)
selective_state_update(ssm_state,
hidden_states.transpose(0, 1),
discrete_time_step.transpose(0, 1),
self.A,
B,
C,
self.D,
gate.transpose(0, 1),
time_proj_bias,
dt_softplus=True,
state_batch_indices=state_indices_tensor,
out=scan_outputs)
scan_outputs = scan_outputs.transpose(0, 1)
# 4. Final linear projection
@ -245,3 +305,15 @@ class MambaMixer(CustomOp):
contextualized_states = self.out_proj(
scan_outputs.transpose(-2, -1))[0]
return contextualized_states
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
return MambaStateShapeCalculator.mamba1_state_shape(
tp_world_size=get_tensor_model_parallel_world_size(),
intermediate_size=self.intermediate_size,
state_size=self.ssm_state_size,
conv_kernel=self.conv_kernel_size,
)
@property
def mamba_type(self) -> str:
return "mamba1"

View File

@ -21,7 +21,7 @@ from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata,
update_metadata)
from vllm.model_executor.layers.mamba.mamba_utils import (
extra_groups_for_head_shards, get_mamba_state_shape)
MambaStateShapeCalculator)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
from vllm.model_executor.layers.mamba.ops.layernorm_gated import rms_norm_gated
@ -278,8 +278,9 @@ class MambaMixer2(MambaBase, CustomOp):
# - for TP we shard conv_dim by sharding on n_groups,
# - but if n_groups cannot divide tp_size, we need to
# extend some extra groups
self.n_groups = n_groups + extra_groups_for_head_shards(
groups = MambaStateShapeCalculator.extra_groups_for_head_shards(
n_groups, self.tp_size)
self.n_groups = n_groups + groups
self.conv_dim = intermediate_size + 2 * self.n_groups * ssm_state_size
self.conv1d = ColumnParallelLinear(
@ -732,7 +733,7 @@ class MambaMixer2(MambaBase, CustomOp):
output[:num_actual_tokens], _ = self.out_proj(hidden_states)
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
return get_mamba_state_shape(
return MambaStateShapeCalculator.mamba2_state_shape(
intermediate_size=self.intermediate_size,
tp_world_size=get_tensor_model_parallel_world_size(),
n_groups=self.n_groups,

View File

@ -3,53 +3,70 @@
from vllm.distributed import divide
def extra_groups_for_head_shards(ngroups: int, tp_size: int):
"""Compute the increase in group numbers to account for
replication in order to accompany the head shards."""
class MambaStateShapeCalculator:
# in the case ngoups % tp_size == 0, this will be zero
if ngroups % tp_size == 0:
return 0
@classmethod
def mamba1_state_shape(
cls,
tp_world_size: int,
intermediate_size: int,
state_size: int,
conv_kernel: int,
use_v1: bool = True,
) -> tuple[tuple[int, int], tuple[int, int]]:
conv_state_shape = (divide(intermediate_size,
tp_world_size), conv_kernel - 1)
# for n_groups == 1, this is exactly tp_size - n_groups
return tp_size - ngroups
temporal_state_shape = (divide(intermediate_size,
tp_world_size), state_size)
# In V0, the conv_state shape was swapped during allocation in
# MambaCacheManager, but in V1 it needs to be determined here at the
# calculation level
if use_v1:
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
def get_mamba_state_shape(
intermediate_size: int,
tp_world_size: int,
n_groups: int,
num_heads: int,
head_dim: int,
state_size: int,
conv_kernel: int,
use_v1: bool = True,
) -> tuple[tuple[int, int], tuple[int, int, int]]:
""" Get the shape of mamba state."""
return conv_state_shape, temporal_state_shape
# if n_groups is not divisible by world_size, need to extend the shards
# to ensure all groups needed by a head is sharded along with it
n_groups = (n_groups +
extra_groups_for_head_shards(n_groups, tp_world_size))
@classmethod
def mamba2_state_shape(
cls,
tp_world_size: int,
intermediate_size: int,
n_groups: int,
num_heads: int,
head_dim: int,
state_size: int,
conv_kernel: int,
use_v1: bool = True,
) -> tuple[tuple[int, int], tuple[int, int, int]]:
# if n_groups is not divisible by world_size, need to extend the shards
# to ensure all groups needed by a head is sharded along with it
n_groups = n_groups + cls.extra_groups_for_head_shards(
n_groups, tp_world_size)
# heads and n_groups are TP-ed
conv_dim = intermediate_size + 2 * n_groups * state_size
# - heads and n_groups are TP-ed
conv_dim = (intermediate_size + 2 * n_groups * state_size)
# contiguous along 'dim' axis
conv_state_shape = (
conv_kernel - 1,
divide(conv_dim, tp_world_size),
)
# contiguous along 'dim' axis
conv_state_shape = (conv_kernel - 1, divide(conv_dim, tp_world_size))
if not use_v1:
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
if not use_v1:
conv_state_shape = (conv_state_shape[1], conv_state_shape[0])
# These are not TP-ed as they depend on A, dt_bias, D
# - they are typically small
# e.g., (h_heads, head_dim, state_size) = (128, 64, 128)
temporal_state_shape = (divide(num_heads,
tp_world_size), head_dim, state_size)
return conv_state_shape, temporal_state_shape
# These are not TP-ed as they depend on A, dt_bias, D
# - they are typically small
# e.g., (h_heads, head_dim, state_size) = (128, 64, 128)
temporal_state_shape = (
divide(num_heads, tp_world_size),
head_dim,
state_size,
)
@classmethod
def extra_groups_for_head_shards(cls, ngroups: int, tp_size: int):
"""Compute the increase in group numbers to account for
replication in order to accompany the head shards."""
return conv_state_shape, temporal_state_shape
# in the case ngoups % tp_size == 0, this will be zero
if ngroups % tp_size == 0:
return 0
# for n_groups == 1, this is exactly tp_size - n_groups
return tp_size - ngroups

View File

@ -25,7 +25,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateShapeCalculator)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
@ -457,7 +458,7 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
hf_config = vllm_config.model_config.hf_config
intermediate_size = hf_config.mamba_expand * hf_config.hidden_size
return get_mamba_state_shape(
return MambaStateShapeCalculator.mamba2_state_shape(
intermediate_size=intermediate_size,
tp_world_size=parallel_config.tensor_parallel_size,
n_groups=hf_config.mamba_n_groups,

View File

@ -24,7 +24,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateShapeCalculator)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
@ -543,7 +544,7 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
if hf_config.mamba_d_ssm is None else
hf_config.mamba_d_ssm)
return get_mamba_state_shape(
return MambaStateShapeCalculator.mamba2_state_shape(
intermediate_size=intermediate_size,
tp_world_size=parallel_config.tensor_parallel_size,
n_groups=hf_config.mamba_n_groups,

View File

@ -23,7 +23,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateShapeCalculator)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
@ -547,7 +548,7 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
hf_config = vllm_config.model_config.hf_config
intermediate_size = hf_config.mamba_expand * hf_config.hidden_size
return get_mamba_state_shape(
return MambaStateShapeCalculator.mamba2_state_shape(
intermediate_size=intermediate_size,
tp_world_size=parallel_config.tensor_parallel_size,
n_groups=hf_config.mamba_n_groups,

View File

@ -8,6 +8,7 @@ import torch
from torch import nn
from transformers import JambaConfig
from vllm import envs
from vllm.attention.layer import Attention
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
@ -19,6 +20,8 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateShapeCalculator)
from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler,
PoolingType)
from vllm.model_executor.layers.quantization import QuantizationConfig
@ -32,8 +35,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType
from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP,
SupportsV0Only)
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
@ -112,7 +114,8 @@ class JambaMambaDecoderLayer(nn.Module):
use_rms_norm=True,
rms_norm_eps=config.rms_norm_eps,
activation=config.hidden_act,
is_lora_enabled = self.is_lora_enabled
is_lora_enabled = self.is_lora_enabled,
prefix=f"{prefix}.mixer",
)
num_experts = config.layers_num_experts[layer_idx]
@ -344,7 +347,8 @@ class JambaModel(nn.Module):
layer_mamba_cache_params = None
if isinstance(layer, JambaAttentionDecoderLayer):
kv_cache_index += 1
if isinstance(layer, JambaMambaDecoderLayer):
if isinstance(layer,
JambaMambaDecoderLayer) and mamba_cache_params:
current_state_layer = mamba_cache_index
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
current_state_layer)
@ -442,7 +446,7 @@ class JambaModel(nn.Module):
class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
IsHybrid, SupportsV0Only):
IsHybrid):
hf_to_vllm_mapper = WeightsMapper(orig_to_new_substr={
".self_attn.": ".",
".A_log": ".A"
@ -509,14 +513,19 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs):
if self.mamba_cache is None:
num_mamba_layers = self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config, LayerBlockType.mamba)
self.mamba_cache = MambaCacheManager(
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
*self._get_mamba_cache_shape())
# NOTE: mamba_cache_params is not needed for v1
mamba_cache_params = None
if not envs.VLLM_USE_V1:
if self.mamba_cache is None:
num_layers = self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config, LayerBlockType.mamba)
state_shape = self.get_mamba_state_shape_from_config(
self.vllm_config)
self.mamba_cache = MambaCacheManager(self.vllm_config,
self.lm_head.weight.dtype,
num_layers, *state_shape)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
hidden_states = self.model(input_ids, positions, mamba_cache_params,
intermediate_tensors, inputs_embeds)
@ -529,19 +538,22 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def _get_mamba_cache_shape(
self) -> tuple[tuple[int, int], tuple[int, int]]:
world_size = get_tensor_model_parallel_world_size()
hidden_size = self.config.hidden_size
conv_state_shape = (
self.config.mamba_expand * hidden_size // world_size,
self.config.mamba_d_conv - 1,
@classmethod
def get_mamba_state_shape_from_config(
cls,
vllm_config: "VllmConfig",
) -> tuple[tuple[int, int], tuple[int, int]]:
parallel_config = vllm_config.parallel_config
hf_config = vllm_config.model_config.hf_config
hidden_size = hf_config.hidden_size
return MambaStateShapeCalculator.mamba1_state_shape(
tp_world_size=parallel_config.tensor_parallel_size,
intermediate_size=hf_config.mamba_expand * hidden_size,
state_size=hf_config.mamba_d_state,
conv_kernel=hf_config.mamba_d_conv,
use_v1=envs.VLLM_USE_V1,
)
temporal_state_shape = (
self.config.mamba_expand * hidden_size // world_size,
self.config.mamba_d_state,
)
return conv_state_shape, temporal_state_shape
def compute_logits(
self,

View File

@ -8,20 +8,21 @@ import torch
from torch import nn
from transformers import MambaConfig
from vllm import envs
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateShapeCalculator)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import (HasInnerState,
IsAttentionFree, SupportsPP,
SupportsV0Only)
IsAttentionFree, SupportsPP)
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.model_executor.sampling_metadata import SamplingMetadata
@ -41,7 +42,8 @@ class MambaDecoderLayer(nn.Module):
config: MambaConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
is_lora_enabled: Optional[bool] = False) -> None:
is_lora_enabled: Optional[bool] = False,
prefix: str = "") -> None:
super().__init__()
self.config = config
self.is_falcon_mamba = config.model_type == "falcon_mamba"
@ -58,7 +60,8 @@ class MambaDecoderLayer(nn.Module):
rms_norm_has_weight=not self.is_falcon_mamba,
rms_norm_eps=mixer_rms_eps,
activation=config.hidden_act,
is_lora_enabled=self.is_lora_enabled)
is_lora_enabled=self.is_lora_enabled,
prefix=f"{prefix}.mixer")
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
@ -107,7 +110,8 @@ class MambaModel(nn.Module):
lambda prefix: MambaDecoderLayer(config,
cache_config=cache_config,
quant_config=quant_config,
is_lora_enabled=is_lora_enabled),
is_lora_enabled=is_lora_enabled,
prefix=prefix),
prefix=f"{prefix}.layers")
self.norm_f = RMSNorm(config.hidden_size,
@ -123,7 +127,7 @@ class MambaModel(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
mamba_cache_params: MambaCacheParams,
mamba_cache_params: Optional[MambaCacheParams] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
@ -140,12 +144,17 @@ class MambaModel(nn.Module):
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
layer_cache_params = None
if mamba_cache_params is not None:
layer_cache_params = mamba_cache_params.at_layer_idx(
i - self.start_layer)
hidden_states, residual = layer(
positions=positions,
hidden_states=hidden_states,
residual=residual,
mamba_cache_params=mamba_cache_params.at_layer_idx(
i - self.start_layer))
mamba_cache_params=layer_cache_params)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
@ -176,8 +185,7 @@ class MambaModel(nn.Module):
return loaded_params
class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP,
SupportsV0Only):
class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
@ -227,20 +235,40 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs):
if self.mamba_cache is None:
num_mamba_layers = self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config, LayerBlockType.mamba)
self.mamba_cache = MambaCacheManager(
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
*self._get_mamba_cache_shape())
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
mamba_cache_params = None
if not envs.VLLM_USE_V1:
if self.mamba_cache is None:
num_layers = self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config, LayerBlockType.mamba)
state_shape = self.get_mamba_state_shape_from_config(
self.vllm_config)
self.mamba_cache = MambaCacheManager(self.vllm_config,
self.lm_head.weight.dtype,
num_layers, *state_shape)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
hidden_states = self.backbone(input_ids, positions, mamba_cache_params,
intermediate_tensors, inputs_embeds)
return hidden_states
@classmethod
def get_mamba_state_shape_from_config(
cls,
vllm_config: "VllmConfig",
) -> tuple[tuple[int, int], tuple[int, int]]:
parallel_config = vllm_config.parallel_config
hf_config = vllm_config.model_config.hf_config
return MambaStateShapeCalculator.mamba1_state_shape(
tp_world_size=parallel_config.tensor_parallel_size,
intermediate_size=hf_config.intermediate_size,
state_size=hf_config.state_size,
conv_kernel=hf_config.conv_kernel,
use_v1=envs.VLLM_USE_V1)
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
return self.mamba_cache.copy_inputs_before_cuda_graphs(
input_buffers, **kwargs)
@ -248,19 +276,6 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP,
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def _get_mamba_cache_shape(
self) -> tuple[tuple[int, int], tuple[int, int]]:
world_size = get_tensor_model_parallel_world_size()
conv_state_shape = (
self.config.intermediate_size // world_size,
self.config.conv_kernel - 1,
)
temporal_state_shape = (
self.config.intermediate_size // world_size,
self.config.state_size,
)
return conv_state_shape, temporal_state_shape
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states,

View File

@ -19,7 +19,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateShapeCalculator)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import (
@ -220,7 +221,7 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
hf_config = vllm_config.model_config.hf_config
intermediate_size = hf_config.expand * hf_config.hidden_size
return get_mamba_state_shape(
return MambaStateShapeCalculator.mamba2_state_shape(
intermediate_size=intermediate_size,
tp_world_size=parallel_config.tensor_parallel_size,
n_groups=hf_config.n_groups,

View File

@ -39,7 +39,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateShapeCalculator)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
@ -482,7 +483,7 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
hf_config = vllm_config.model_config.hf_config
intermediate_size = hf_config.expand * hf_config.hidden_size
return get_mamba_state_shape(
return MambaStateShapeCalculator.mamba2_state_shape(
intermediate_size=intermediate_size,
tp_world_size=parallel_config.tensor_parallel_size,
n_groups=hf_config.n_groups,

View File

@ -32,7 +32,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateShapeCalculator)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
@ -869,7 +870,7 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
hf_config = vllm_config.model_config.hf_config
intermediate_size = hf_config.mamba_expand * hf_config.hidden_size
return get_mamba_state_shape(
return MambaStateShapeCalculator.mamba2_state_shape(
intermediate_size=intermediate_size,
tp_world_size=parallel_config.tensor_parallel_size,
n_groups=hf_config.mamba_ngroups,

View File

@ -0,0 +1,67 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import ClassVar
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
class Mamba1AttentionBackend(AttentionBackend):
@staticmethod
def get_builder_cls() -> type["Mamba1AttentionMetadataBuilder"]:
return Mamba1AttentionMetadataBuilder
@dataclass
class Mamba1AttentionMetadata:
query_start_loc: torch.Tensor
context_lens_tensor: torch.Tensor
state_indices_tensor: torch.Tensor
has_initial_states: torch.Tensor
class Mamba1AttentionMetadataBuilder(
AttentionMetadataBuilder[Mamba1AttentionMetadata]):
reorder_batch_threshold: ClassVar[int] = 1
def __init__(
self,
kv_cache_spec: AttentionSpec,
vllm_config: VllmConfig,
device: torch.device,
layer_names: list[str],
):
assert isinstance(kv_cache_spec, MambaSpec)
self.kv_cache_spec = kv_cache_spec
self.device = device
self.vllm_config = vllm_config
self.layer_names = layer_names
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> Mamba1AttentionMetadata:
query_start_loc = common_attn_metadata.query_start_loc
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
context_lens_tensor = common_attn_metadata.num_computed_tokens_cpu.to(
query_start_loc.device)
has_initial_states = (context_lens_tensor > 0)
return Mamba1AttentionMetadata(
query_start_loc=query_start_loc,
context_lens_tensor=context_lens_tensor,
has_initial_states=has_initial_states,
state_indices_tensor=state_indices_tensor,
)

View File

@ -1,10 +1,14 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.attention.backends.abstract import AttentionBackend
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend
def get_mamba_attn_backend(mamba_type: str) -> type[AttentionBackend]:
if mamba_type == "mamba1":
return Mamba1AttentionBackend
if mamba_type == "mamba2":
return Mamba2AttentionBackend