mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 08:24:59 +08:00
[V1] Enable Mamba2 layers other than MambaMixer2 in the v1 engine (#20660)
Signed-off-by: nopperl <54780682+nopperl@users.noreply.github.com>
This commit is contained in:
parent
31d5c1797f
commit
5d09152ff1
@ -1331,6 +1331,17 @@ class ModelConfig:
|
|||||||
|
|
||||||
return sum(t == 1 for t in attn_type_list[start:end])
|
return sum(t == 1 for t in attn_type_list[start:end])
|
||||||
|
|
||||||
|
def get_mamba_chunk_size(self) -> Optional[int]:
|
||||||
|
"""
|
||||||
|
Returns the mamba chunk size if it exists
|
||||||
|
"""
|
||||||
|
# used by e.g. Bamba, FalconH1, Granite, PLaMo2
|
||||||
|
chunk_size = getattr(self.hf_text_config, "mamba_chunk_size", None)
|
||||||
|
if chunk_size is None:
|
||||||
|
# used by e.g. Mamba2, NemotronH, Zamba
|
||||||
|
chunk_size = getattr(self.hf_text_config, "chunk_size", None)
|
||||||
|
return chunk_size
|
||||||
|
|
||||||
def get_multimodal_config(self) -> "MultiModalConfig":
|
def get_multimodal_config(self) -> "MultiModalConfig":
|
||||||
"""
|
"""
|
||||||
Get the multimodal configuration of the model.
|
Get the multimodal configuration of the model.
|
||||||
|
|||||||
29
vllm/model_executor/layers/mamba/abstract.py
Normal file
29
vllm/model_executor/layers/mamba/abstract.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import Iterable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class MambaBase(ABC):
|
||||||
|
"""
|
||||||
|
Base class for Mamba-like layers which support the v1 engine.
|
||||||
|
Inherit from this class if you implement a custom layer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Contains the KV cache (mamba state) for the layer
|
||||||
|
# in the shape specified by `self.get_state_shape`.
|
||||||
|
# The outer list is for v0 PP virtual engine. Though this code path
|
||||||
|
# only runs for v1, we have to do this to unify with the interface
|
||||||
|
# of Attention + v0 PP.
|
||||||
|
kv_cache: list[Iterable[torch.Tensor]]
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_state_shape(self) -> Iterable[tuple[int, ...]]:
|
||||||
|
"""
|
||||||
|
Defines the shape of the state.
|
||||||
|
For mamba layers this is usually a (conv_state, ssm_state) tuple.
|
||||||
|
In this case, returns (conv_state_shape, ssm_state_shape).
|
||||||
|
"""
|
||||||
|
pass
|
||||||
@ -17,6 +17,7 @@ from vllm.forward_context import get_forward_context
|
|||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
|
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata,
|
from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata,
|
||||||
update_metadata)
|
update_metadata)
|
||||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||||
@ -219,7 +220,7 @@ def mamba_v2_sharded_weight_loader(
|
|||||||
|
|
||||||
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
|
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
|
||||||
@CustomOp.register("mamba_mixer2")
|
@CustomOp.register("mamba_mixer2")
|
||||||
class MambaMixer2(CustomOp):
|
class MambaMixer2(MambaBase, CustomOp):
|
||||||
"""
|
"""
|
||||||
Compute ∆, A, B, C, and D the state space parameters and compute
|
Compute ∆, A, B, C, and D the state space parameters and compute
|
||||||
the `contextualized_states`. A, D are input independent
|
the `contextualized_states`. A, D are input independent
|
||||||
@ -231,22 +232,21 @@ class MambaMixer2(CustomOp):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
ssm_state_size: int,
|
ssm_state_size: int,
|
||||||
conv_kernel_size: int,
|
conv_kernel_size: int,
|
||||||
intermediate_size: int,
|
intermediate_size: int,
|
||||||
use_conv_bias: bool,
|
use_conv_bias: bool,
|
||||||
use_bias: bool,
|
use_bias: bool,
|
||||||
n_groups: int = 1,
|
n_groups: int = 1,
|
||||||
num_heads: int = 128,
|
num_heads: int = 128,
|
||||||
head_dim: int = 64,
|
head_dim: int = 64,
|
||||||
rms_norm_eps: float = 1e-5,
|
rms_norm_eps: float = 1e-5,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
use_rms_norm: bool = True,
|
use_rms_norm: bool = True,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
chunk_size: int = -1, # the chunk size used by v1
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -428,10 +428,7 @@ class MambaMixer2(CustomOp):
|
|||||||
# of Attention + v0 PP.
|
# of Attention + v0 PP.
|
||||||
# The inner tuple is (conv_state, ssm_state)
|
# The inner tuple is (conv_state, ssm_state)
|
||||||
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
|
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
|
||||||
assert chunk_size != -1, "chunk_size must be set for v1"
|
|
||||||
|
|
||||||
# NOTE: chunk_size may be -1 for models without v1 support
|
|
||||||
self.chunk_size = chunk_size
|
|
||||||
self.prefix = prefix
|
self.prefix = prefix
|
||||||
|
|
||||||
def forward_native(
|
def forward_native(
|
||||||
|
|||||||
@ -99,8 +99,7 @@ class BambaMixerDecoderLayer(nn.Module):
|
|||||||
rms_norm_eps=config.rms_norm_eps,
|
rms_norm_eps=config.rms_norm_eps,
|
||||||
activation=config.hidden_act,
|
activation=config.hidden_act,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.mixer",
|
prefix=f"{prefix}.mixer")
|
||||||
chunk_size=config.mamba_chunk_size)
|
|
||||||
|
|
||||||
self.feed_forward = BambaMLP(config, quant_config=quant_config)
|
self.feed_forward = BambaMLP(config, quant_config=quant_config)
|
||||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||||
|
|||||||
@ -109,7 +109,6 @@ class FalconH1SSMDecoderLayer(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
use_rms_norm=config.mamba_rms_norm,
|
use_rms_norm=config.mamba_rms_norm,
|
||||||
prefix=f"{prefix}.mixer",
|
prefix=f"{prefix}.mixer",
|
||||||
chunk_size=config.mamba_chunk_size,
|
|
||||||
)
|
)
|
||||||
# n_groups is overridden later by `MambaMixer2`
|
# n_groups is overridden later by `MambaMixer2`
|
||||||
self.groups_time_state_size = self.mamba.n_groups * config.mamba_d_state
|
self.groups_time_state_size = self.mamba.n_groups * config.mamba_d_state
|
||||||
|
|||||||
@ -69,8 +69,7 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module):
|
|||||||
rms_norm_eps=config.rms_norm_eps,
|
rms_norm_eps=config.rms_norm_eps,
|
||||||
activation=config.hidden_act,
|
activation=config.hidden_act,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.mixer",
|
prefix=f"{prefix}.mixer")
|
||||||
chunk_size=config.mamba_chunk_size)
|
|
||||||
|
|
||||||
self.block_sparse_moe = None
|
self.block_sparse_moe = None
|
||||||
if getattr(config, "num_local_experts", 0) > 0:
|
if getattr(config, "num_local_experts", 0) > 0:
|
||||||
|
|||||||
@ -62,8 +62,7 @@ class Mamba2DecoderLayer(nn.Module):
|
|||||||
rms_norm_eps=config.layer_norm_epsilon,
|
rms_norm_eps=config.layer_norm_epsilon,
|
||||||
activation=config.hidden_act,
|
activation=config.hidden_act,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.mixer",
|
prefix=f"{prefix}.mixer")
|
||||||
chunk_size=config.chunk_size)
|
|
||||||
|
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||||
|
|
||||||
|
|||||||
@ -154,7 +154,6 @@ class NemotronHMambaDecoderLayer(nn.Module):
|
|||||||
activation=config.mamba_hidden_act,
|
activation=config.mamba_hidden_act,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.mixer",
|
prefix=f"{prefix}.mixer",
|
||||||
chunk_size=config.chunk_size,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|||||||
@ -501,8 +501,7 @@ class Zamba2MambaDecoderLayer(nn.Module):
|
|||||||
rms_norm_eps=config.rms_norm_eps,
|
rms_norm_eps=config.rms_norm_eps,
|
||||||
activation="silu",
|
activation="silu",
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.mixer",
|
prefix=f"{prefix}.mixer")
|
||||||
chunk_size=config.chunk_size)
|
|
||||||
|
|
||||||
# Input normalization
|
# Input normalization
|
||||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||||
|
|||||||
@ -7,7 +7,6 @@ from typing import TYPE_CHECKING, Optional
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import AttentionBackend
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
|
||||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||||
CommonAttentionMetadata)
|
CommonAttentionMetadata)
|
||||||
from vllm.v1.kv_cache_interface import MambaSpec
|
from vllm.v1.kv_cache_interface import MambaSpec
|
||||||
@ -19,15 +18,6 @@ if TYPE_CHECKING:
|
|||||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||||
|
|
||||||
|
|
||||||
def get_mamba2_chunk_size(vllm_config: VllmConfig) -> int:
|
|
||||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
|
||||||
layers = get_layers_from_vllm_config(vllm_config, MambaMixer2)
|
|
||||||
chunk_sizes = set(layer.chunk_size for layer in layers.values())
|
|
||||||
assert len(
|
|
||||||
chunk_sizes) == 1, "All Mamba2 layers must have the same chunk size"
|
|
||||||
return chunk_sizes.pop()
|
|
||||||
|
|
||||||
|
|
||||||
def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor,
|
def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor,
|
||||||
chunk_size: int,
|
chunk_size: int,
|
||||||
total_seqlens: int):
|
total_seqlens: int):
|
||||||
@ -102,7 +92,10 @@ class Mamba2AttentionMetadataBuilder(
|
|||||||
self.runner = runner
|
self.runner = runner
|
||||||
self.kv_cache_spec = kv_cache_spec
|
self.kv_cache_spec = kv_cache_spec
|
||||||
self.block_table = block_table
|
self.block_table = block_table
|
||||||
self.chunk_size = get_mamba2_chunk_size(runner.vllm_config)
|
self.chunk_size = runner.vllm_config.model_config.get_mamba_chunk_size(
|
||||||
|
)
|
||||||
|
assert self.chunk_size is not None, (
|
||||||
|
"chunk_size needs to be set in the model config for Mamba2 models")
|
||||||
|
|
||||||
def reorder_batch(self, input_batch: "InputBatch",
|
def reorder_batch(self, input_batch: "InputBatch",
|
||||||
scheduler_output: "SchedulerOutput") -> bool:
|
scheduler_output: "SchedulerOutput") -> bool:
|
||||||
|
|||||||
@ -30,7 +30,7 @@ from vllm.distributed.parallel_state import (
|
|||||||
from vllm.forward_context import (DPMetadata, get_forward_context,
|
from vllm.forward_context import (DPMetadata, get_forward_context,
|
||||||
set_forward_context)
|
set_forward_context)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaBase
|
||||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||||
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
|
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
|
||||||
from vllm.model_executor.models.interfaces import (has_step_pooler,
|
from vllm.model_executor.models.interfaces import (has_step_pooler,
|
||||||
@ -2623,8 +2623,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown attention type: {attn_module.attn_type}")
|
f"Unknown attention type: {attn_module.attn_type}")
|
||||||
|
|
||||||
mamba_layers = get_layers_from_vllm_config(self.vllm_config,
|
mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaBase)
|
||||||
MambaMixer2)
|
|
||||||
if len(mamba_layers) > 0:
|
if len(mamba_layers) > 0:
|
||||||
if self.vllm_config.speculative_config is not None:
|
if self.vllm_config.speculative_config is not None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
@ -2655,7 +2654,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
def _maybe_pad_mamba_page_size(
|
def _maybe_pad_mamba_page_size(
|
||||||
self,
|
self,
|
||||||
attn_layers: dict[str, Attention],
|
attn_layers: dict[str, Attention],
|
||||||
mamba_layers: dict[str, MambaMixer2],
|
mamba_layers: dict[str, MambaBase],
|
||||||
kv_cache_spec: dict[str, KVCacheSpec],
|
kv_cache_spec: dict[str, KVCacheSpec],
|
||||||
max_model_len: int,
|
max_model_len: int,
|
||||||
block_size: int,
|
block_size: int,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user