[V1] Enable Mamba2 layers other than MambaMixer2 in the v1 engine (#20660)

Signed-off-by: nopperl <54780682+nopperl@users.noreply.github.com>
This commit is contained in:
nopperl 2025-07-11 14:53:31 +09:00 committed by GitHub
parent 31d5c1797f
commit 5d09152ff1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 68 additions and 45 deletions

View File

@ -1331,6 +1331,17 @@ class ModelConfig:
return sum(t == 1 for t in attn_type_list[start:end])
def get_mamba_chunk_size(self) -> Optional[int]:
"""
Returns the mamba chunk size if it exists
"""
# used by e.g. Bamba, FalconH1, Granite, PLaMo2
chunk_size = getattr(self.hf_text_config, "mamba_chunk_size", None)
if chunk_size is None:
# used by e.g. Mamba2, NemotronH, Zamba
chunk_size = getattr(self.hf_text_config, "chunk_size", None)
return chunk_size
def get_multimodal_config(self) -> "MultiModalConfig":
"""
Get the multimodal configuration of the model.

View File

@ -0,0 +1,29 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Iterable
import torch
class MambaBase(ABC):
"""
Base class for Mamba-like layers which support the v1 engine.
Inherit from this class if you implement a custom layer.
"""
# Contains the KV cache (mamba state) for the layer
# in the shape specified by `self.get_state_shape`.
# The outer list is for v0 PP virtual engine. Though this code path
# only runs for v1, we have to do this to unify with the interface
# of Attention + v0 PP.
kv_cache: list[Iterable[torch.Tensor]]
@abstractmethod
def get_state_shape(self) -> Iterable[tuple[int, ...]]:
"""
Defines the shape of the state.
For mamba layers this is usually a (conv_state, ssm_state) tuple.
In this case, returns (conv_state_shape, ssm_state_shape).
"""
pass

View File

@ -17,6 +17,7 @@ from vllm.forward_context import get_forward_context
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata,
update_metadata)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
@ -219,7 +220,7 @@ def mamba_v2_sharded_weight_loader(
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
@CustomOp.register("mamba_mixer2")
class MambaMixer2(CustomOp):
class MambaMixer2(MambaBase, CustomOp):
"""
Compute , A, B, C, and D the state space parameters and compute
the `contextualized_states`. A, D are input independent
@ -231,22 +232,21 @@ class MambaMixer2(CustomOp):
"""
def __init__(
self,
hidden_size: int,
ssm_state_size: int,
conv_kernel_size: int,
intermediate_size: int,
use_conv_bias: bool,
use_bias: bool,
n_groups: int = 1,
num_heads: int = 128,
head_dim: int = 64,
rms_norm_eps: float = 1e-5,
activation: str = "silu",
use_rms_norm: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
chunk_size: int = -1, # the chunk size used by v1
self,
hidden_size: int,
ssm_state_size: int,
conv_kernel_size: int,
intermediate_size: int,
use_conv_bias: bool,
use_bias: bool,
n_groups: int = 1,
num_heads: int = 128,
head_dim: int = 64,
rms_norm_eps: float = 1e-5,
activation: str = "silu",
use_rms_norm: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
@ -428,10 +428,7 @@ class MambaMixer2(CustomOp):
# of Attention + v0 PP.
# The inner tuple is (conv_state, ssm_state)
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
assert chunk_size != -1, "chunk_size must be set for v1"
# NOTE: chunk_size may be -1 for models without v1 support
self.chunk_size = chunk_size
self.prefix = prefix
def forward_native(

View File

@ -99,8 +99,7 @@ class BambaMixerDecoderLayer(nn.Module):
rms_norm_eps=config.rms_norm_eps,
activation=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mixer",
chunk_size=config.mamba_chunk_size)
prefix=f"{prefix}.mixer")
self.feed_forward = BambaMLP(config, quant_config=quant_config)
self.input_layernorm = RMSNorm(config.hidden_size,

View File

@ -109,7 +109,6 @@ class FalconH1SSMDecoderLayer(nn.Module):
quant_config=quant_config,
use_rms_norm=config.mamba_rms_norm,
prefix=f"{prefix}.mixer",
chunk_size=config.mamba_chunk_size,
)
# n_groups is overridden later by `MambaMixer2`
self.groups_time_state_size = self.mamba.n_groups * config.mamba_d_state

View File

@ -69,8 +69,7 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module):
rms_norm_eps=config.rms_norm_eps,
activation=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mixer",
chunk_size=config.mamba_chunk_size)
prefix=f"{prefix}.mixer")
self.block_sparse_moe = None
if getattr(config, "num_local_experts", 0) > 0:

View File

@ -62,8 +62,7 @@ class Mamba2DecoderLayer(nn.Module):
rms_norm_eps=config.layer_norm_epsilon,
activation=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mixer",
chunk_size=config.chunk_size)
prefix=f"{prefix}.mixer")
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

View File

@ -154,7 +154,6 @@ class NemotronHMambaDecoderLayer(nn.Module):
activation=config.mamba_hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mixer",
chunk_size=config.chunk_size,
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

View File

@ -501,8 +501,7 @@ class Zamba2MambaDecoderLayer(nn.Module):
rms_norm_eps=config.rms_norm_eps,
activation="silu",
quant_config=quant_config,
prefix=f"{prefix}.mixer",
chunk_size=config.chunk_size)
prefix=f"{prefix}.mixer")
# Input normalization
self.input_layernorm = RMSNorm(config.hidden_size,

View File

@ -7,7 +7,6 @@ from typing import TYPE_CHECKING, Optional
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import MambaSpec
@ -19,15 +18,6 @@ if TYPE_CHECKING:
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
def get_mamba2_chunk_size(vllm_config: VllmConfig) -> int:
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
layers = get_layers_from_vllm_config(vllm_config, MambaMixer2)
chunk_sizes = set(layer.chunk_size for layer in layers.values())
assert len(
chunk_sizes) == 1, "All Mamba2 layers must have the same chunk size"
return chunk_sizes.pop()
def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor,
chunk_size: int,
total_seqlens: int):
@ -102,7 +92,10 @@ class Mamba2AttentionMetadataBuilder(
self.runner = runner
self.kv_cache_spec = kv_cache_spec
self.block_table = block_table
self.chunk_size = get_mamba2_chunk_size(runner.vllm_config)
self.chunk_size = runner.vllm_config.model_config.get_mamba_chunk_size(
)
assert self.chunk_size is not None, (
"chunk_size needs to be set in the model config for Mamba2 models")
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:

View File

@ -30,7 +30,7 @@ from vllm.distributed.parallel_state import (
from vllm.forward_context import (DPMetadata, get_forward_context,
set_forward_context)
from vllm.logger import init_logger
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaBase
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
from vllm.model_executor.models.interfaces import (has_step_pooler,
@ -2623,8 +2623,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
raise ValueError(
f"Unknown attention type: {attn_module.attn_type}")
mamba_layers = get_layers_from_vllm_config(self.vllm_config,
MambaMixer2)
mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaBase)
if len(mamba_layers) > 0:
if self.vllm_config.speculative_config is not None:
raise NotImplementedError(
@ -2655,7 +2654,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def _maybe_pad_mamba_page_size(
self,
attn_layers: dict[str, Attention],
mamba_layers: dict[str, MambaMixer2],
mamba_layers: dict[str, MambaBase],
kv_cache_spec: dict[str, KVCacheSpec],
max_model_len: int,
block_size: int,