mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 11:48:29 +08:00
428 lines
18 KiB
Python
428 lines
18 KiB
Python
# coding=utf-8
|
|
"""PyTorch MAMBA model."""
|
|
from typing import Iterable, List, Optional, Tuple
|
|
|
|
import torch
|
|
from torch import nn
|
|
from transformers import MambaConfig
|
|
|
|
from vllm.attention.backends.abstract import AttentionMetadata
|
|
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
|
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|
MergedColumnParallelLinear,
|
|
RowParallelLinear)
|
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
|
causal_conv1d_fn, causal_conv1d_update)
|
|
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
|
selective_scan_fn, selective_state_update)
|
|
from vllm.model_executor.layers.quantization.base_config import (
|
|
QuantizationConfig)
|
|
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
|
from vllm.model_executor.model_loader.weight_utils import (
|
|
composed_weight_loader, default_weight_loader, sharded_weight_loader)
|
|
from vllm.model_executor.models.interfaces import (HasInnerState,
|
|
IsAttentionFree)
|
|
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
|
MambaCacheParams)
|
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|
from vllm.model_executor.utils import set_weight_attrs
|
|
from vllm.sequence import IntermediateTensors
|
|
from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE,
|
|
_get_graph_batch_size)
|
|
|
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
|
|
|
|
|
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
|
|
class MambaMixer(nn.Module):
|
|
"""
|
|
Compute ∆, A, B, C, and D the state space parameters and compute
|
|
the `contextualized_states`. A, D are input independent
|
|
(see Mamba paper [1] Section 3.5.2 "Interpretation of A"
|
|
for why A isn't selective) ∆, B, C are input-dependent
|
|
(this is a key difference between Mamba and the linear time
|
|
invariant S4, and is why Mamba is called
|
|
**selective** state spaces)
|
|
"""
|
|
|
|
def __init__(self, config: MambaConfig, layer_idx):
|
|
super().__init__()
|
|
self.config = config
|
|
self.layer_idx = layer_idx
|
|
self.hidden_size = config.hidden_size
|
|
self.ssm_state_size = config.state_size
|
|
self.conv_kernel_size = config.conv_kernel
|
|
self.intermediate_size = config.intermediate_size
|
|
self.time_step_rank = int(config.time_step_rank)
|
|
self.is_falcon_mamba = config.model_type == "falcon_mamba"
|
|
self.conv1d = ColumnParallelLinear(
|
|
input_size=self.conv_kernel_size,
|
|
output_size=self.intermediate_size,
|
|
bias=config.use_conv_bias,
|
|
)
|
|
# unsqueeze to fit conv1d weights shape into the linear weights shape.
|
|
# Can't do this in `weight_loader` since it already exists in
|
|
# `ColumnParallelLinear` and `set_weight_attrs`
|
|
# doesn't allow to override it
|
|
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
|
|
|
|
self.in_proj = MergedColumnParallelLinear(self.hidden_size,
|
|
[self.intermediate_size] * 2,
|
|
bias=config.use_bias)
|
|
# selective projection used to make dt, B and C input dependent
|
|
self.x_proj = RowParallelLinear(
|
|
self.intermediate_size,
|
|
self.time_step_rank + self.ssm_state_size * 2,
|
|
bias=False,
|
|
)
|
|
# time step projection (discretization) -
|
|
# In the forward we need to apply dt_proj without the bias,
|
|
# as the bias is added in the selective scan kernel.
|
|
self.dt_proj = ColumnParallelLinear(self.time_step_rank,
|
|
self.intermediate_size,
|
|
bias=True,
|
|
skip_bias_add=True)
|
|
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
self.A = nn.Parameter(
|
|
torch.empty(
|
|
self.intermediate_size // tp_size,
|
|
self.ssm_state_size,
|
|
dtype=torch.float32,
|
|
))
|
|
self.D = nn.Parameter(torch.ones(self.intermediate_size // tp_size))
|
|
|
|
set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)})
|
|
a_weight_loader = composed_weight_loader(
|
|
sharded_weight_loader(0), lambda x: -torch.exp(x.float()))
|
|
set_weight_attrs(self.A, {"weight_loader": a_weight_loader})
|
|
|
|
self.out_proj = RowParallelLinear(
|
|
self.intermediate_size,
|
|
self.hidden_size,
|
|
bias=config.use_bias,
|
|
input_is_parallel=True,
|
|
)
|
|
self.activation = config.hidden_act
|
|
if self.is_falcon_mamba:
|
|
self.dt_layernorm = RMSNorm(self.time_step_rank,
|
|
eps=config.mixer_rms_eps)
|
|
self.b_layernorm = RMSNorm(self.ssm_state_size,
|
|
eps=config.mixer_rms_eps)
|
|
self.c_layernorm = RMSNorm(self.ssm_state_size,
|
|
eps=config.mixer_rms_eps)
|
|
|
|
def forward(self, hidden_states: torch.Tensor,
|
|
attn_metadata: AttentionMetadata,
|
|
mamba_cache_params: MambaCacheParams):
|
|
|
|
# 1. Gated MLP's linear projection
|
|
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
|
|
hidden_states, gate = projected_states.chunk(2, dim=-2)
|
|
|
|
# 2. Convolution sequence transformation
|
|
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
|
|
self.conv1d.weight.size(2))
|
|
|
|
if attn_metadata.query_start_loc is not None \
|
|
and attn_metadata.context_lens_tensor is not None:
|
|
# |---------- N-1 iteration --------|
|
|
# |---------------- N iteration ---------------------|
|
|
# |- tokenA -|......................|-- newTokens ---|
|
|
# |---------- context_len ----------|
|
|
# |-------------------- seq_len ---------------------|
|
|
# |-- query_len ---|
|
|
hidden_states = causal_conv1d_fn(
|
|
hidden_states,
|
|
conv_weights,
|
|
self.conv1d.bias,
|
|
activation=self.activation,
|
|
conv_states=mamba_cache_params.conv_state,
|
|
has_initial_state=attn_metadata.context_lens_tensor > 0,
|
|
cache_indices=mamba_cache_params.state_indices_tensor,
|
|
query_start_loc=attn_metadata.query_start_loc)
|
|
else:
|
|
hidden_states = causal_conv1d_update(
|
|
hidden_states.transpose(0, 1),
|
|
mamba_cache_params.conv_state,
|
|
conv_weights,
|
|
self.conv1d.bias,
|
|
self.activation,
|
|
conv_state_indices=mamba_cache_params.state_indices_tensor)
|
|
hidden_states = hidden_states.transpose(0, 1)
|
|
|
|
# 3. State Space Model sequence transformation
|
|
# 3.a. input varying initialization of time_step, B and C
|
|
ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0]
|
|
|
|
time_step, B, C = torch.split(
|
|
ssm_parameters,
|
|
[self.time_step_rank, self.ssm_state_size, self.ssm_state_size],
|
|
dim=-1,
|
|
)
|
|
# Note that Jamba and FalconMamba normalizes B, C, and time_step here
|
|
# but Mamba doesn't.
|
|
if self.is_falcon_mamba:
|
|
time_step = self.dt_layernorm(time_step.contiguous())
|
|
B = self.b_layernorm(B.contiguous())
|
|
C = self.c_layernorm(C.contiguous())
|
|
|
|
discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
|
|
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
|
|
time_proj_bias = (self.dt_proj.bias.float() if hasattr(
|
|
self.dt_proj, "bias") else None)
|
|
|
|
if attn_metadata.query_start_loc is not None \
|
|
and attn_metadata.context_lens_tensor is not None:
|
|
scan_outputs = selective_scan_fn(
|
|
hidden_states,
|
|
mamba_cache_params.ssm_state,
|
|
discrete_time_step,
|
|
self.A,
|
|
B.transpose(-2, -1),
|
|
C.transpose(-2, -1),
|
|
self.D.float(),
|
|
gate,
|
|
time_proj_bias,
|
|
delta_softplus=True,
|
|
cache_indices=mamba_cache_params.state_indices_tensor,
|
|
has_initial_state=attn_metadata.context_lens_tensor > 0,
|
|
query_start_loc=attn_metadata.query_start_loc)
|
|
else:
|
|
scan_outputs = selective_state_update(
|
|
mamba_cache_params.ssm_state,
|
|
hidden_states.transpose(0, 1),
|
|
discrete_time_step.transpose(0, 1),
|
|
self.A,
|
|
B,
|
|
C,
|
|
self.D,
|
|
gate.transpose(0, 1),
|
|
time_proj_bias,
|
|
dt_softplus=True,
|
|
state_batch_indices=mamba_cache_params.state_indices_tensor)
|
|
scan_outputs = scan_outputs.transpose(0, 1)
|
|
|
|
# 4. Final linear projection
|
|
contextualized_states = self.out_proj(scan_outputs.transpose(-2,
|
|
-1))[0]
|
|
return contextualized_states
|
|
|
|
|
|
class MambaDecoderLayer(nn.Module):
|
|
|
|
def __init__(self,
|
|
config: MambaConfig,
|
|
layer_idx: int,
|
|
cache_config: Optional[CacheConfig] = None,
|
|
quant_config: Optional[QuantizationConfig] = None) -> None:
|
|
super().__init__()
|
|
self.layer_idx = layer_idx
|
|
self.config = config
|
|
self.is_falcon_mamba = config.model_type == "falcon_mamba"
|
|
self.mixer = MambaMixer(config, layer_idx)
|
|
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attn_metadata: AttentionMetadata,
|
|
residual: Optional[torch.Tensor],
|
|
mamba_cache_params: MambaCacheParams,
|
|
**kwargs,
|
|
):
|
|
if residual is None:
|
|
residual = hidden_states
|
|
hidden_states = self.norm(hidden_states)
|
|
else:
|
|
hidden_states, residual = self.norm(hidden_states, residual)
|
|
|
|
hidden_states = self.mixer(hidden_states, attn_metadata,
|
|
mamba_cache_params)
|
|
return hidden_states, residual
|
|
|
|
|
|
class MambaModel(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
config: MambaConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
cache_config: Optional[CacheConfig] = None,
|
|
lora_config: Optional[LoRAConfig] = None,
|
|
) -> None:
|
|
super().__init__()
|
|
self.config = config
|
|
self.padding_idx = config.pad_token_id
|
|
lora_vocab = ((lora_config.lora_extra_vocab_size *
|
|
(lora_config.max_loras or 1)) if lora_config else 0)
|
|
self.vocab_size = config.vocab_size + lora_vocab
|
|
self.org_vocab_size = config.vocab_size
|
|
|
|
self.embeddings = VocabParallelEmbedding(
|
|
self.vocab_size,
|
|
config.hidden_size,
|
|
org_num_embeddings=config.vocab_size,
|
|
)
|
|
|
|
decoder_layers = []
|
|
for i in range(config.num_hidden_layers):
|
|
decoder_layers.append(
|
|
MambaDecoderLayer(config,
|
|
layer_idx=i,
|
|
cache_config=cache_config,
|
|
quant_config=quant_config))
|
|
self.layers = nn.ModuleList(decoder_layers)
|
|
self.norm_f = RMSNorm(config.hidden_size,
|
|
eps=config.layer_norm_epsilon)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
attn_metadata: AttentionMetadata,
|
|
mamba_cache_params: MambaCacheParams,
|
|
) -> torch.Tensor:
|
|
|
|
hidden_states = self.embeddings(input_ids)
|
|
residual = None
|
|
|
|
for i in range(len(self.layers)):
|
|
layer = self.layers[i]
|
|
hidden_states, residual = layer(
|
|
positions=positions,
|
|
hidden_states=hidden_states,
|
|
attn_metadata=attn_metadata,
|
|
residual=residual,
|
|
mamba_cache_params=mamba_cache_params.at_layer_idx(i))
|
|
hidden_states, _ = self.norm_f(hidden_states, residual)
|
|
|
|
return hidden_states
|
|
|
|
|
|
class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
|
|
|
def __init__(
|
|
self,
|
|
config: MambaConfig,
|
|
cache_config: Optional[CacheConfig] = None,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
lora_config: Optional[LoRAConfig] = None,
|
|
scheduler_config: Optional[SchedulerConfig] = None,
|
|
) -> None:
|
|
assert not cache_config.enable_prefix_caching, \
|
|
"Mamba does not support prefix caching"
|
|
|
|
super().__init__()
|
|
self.config = config
|
|
self.scheduler_config = scheduler_config
|
|
self.backbone = MambaModel(config,
|
|
cache_config=cache_config,
|
|
quant_config=quant_config,
|
|
lora_config=lora_config)
|
|
self.unpadded_vocab_size = config.vocab_size
|
|
if lora_config:
|
|
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
|
if config.tie_word_embeddings:
|
|
self.lm_head = self.backbone.embeddings
|
|
else:
|
|
self.lm_head = ParallelLMHead(
|
|
self.unpadded_vocab_size,
|
|
config.hidden_size,
|
|
org_num_embeddings=config.vocab_size,
|
|
padding_size=DEFAULT_VOCAB_PADDING_SIZE
|
|
# We need bigger padding if using lora for kernel
|
|
# compatibility
|
|
if not lora_config else lora_config.lora_vocab_padding_size,
|
|
)
|
|
|
|
# Used to track and store by the Mamba cache between steps.
|
|
self.mamba_cache: Optional[MambaCacheManager] = None
|
|
|
|
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
|
config.vocab_size)
|
|
self.sampler = Sampler()
|
|
|
|
def forward(self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
kv_caches: List[KVCache],
|
|
attn_metadata: AttentionMetadata,
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
**kwargs):
|
|
if self.mamba_cache is None:
|
|
max_batch_size = (_get_graph_batch_size(
|
|
self.scheduler_config.max_num_seqs) if self.scheduler_config
|
|
else max(_BATCH_SIZES_TO_CAPTURE) + 2)
|
|
self.mamba_cache = MambaCacheManager(
|
|
self.lm_head.weight.dtype, self.config.num_hidden_layers,
|
|
max_batch_size, *self._get_mamba_cache_shape())
|
|
|
|
(
|
|
mamba_cache_tensors,
|
|
state_indices_tensor,
|
|
) = self.mamba_cache.current_run_tensors(input_ids, attn_metadata,
|
|
**kwargs)
|
|
|
|
mamba_cache_params = MambaCacheParams(mamba_cache_tensors[0],
|
|
mamba_cache_tensors[1],
|
|
state_indices_tensor)
|
|
|
|
hidden_states = self.backbone(input_ids, positions, attn_metadata,
|
|
mamba_cache_params)
|
|
|
|
return hidden_states
|
|
|
|
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
|
return self.mamba_cache.copy_inputs_before_cuda_graphs(
|
|
input_buffers, **kwargs)
|
|
|
|
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
|
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
|
|
|
def _get_mamba_cache_shape(
|
|
self) -> Tuple[Tuple[int, int], Tuple[int, int]]:
|
|
world_size = get_tensor_model_parallel_world_size()
|
|
conv_state_shape = (
|
|
self.config.intermediate_size // world_size,
|
|
self.config.conv_kernel - 1,
|
|
)
|
|
temporal_state_shape = (
|
|
self.config.intermediate_size // world_size,
|
|
self.config.state_size,
|
|
)
|
|
return conv_state_shape, temporal_state_shape
|
|
|
|
def compute_logits(self, hidden_states: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
|
logits = self.logits_processor(self.lm_head, hidden_states,
|
|
sampling_metadata)
|
|
return logits
|
|
|
|
def sample(
|
|
self,
|
|
logits: Optional[torch.Tensor],
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> Optional[SamplerOutput]:
|
|
next_tokens = self.sampler(logits, sampling_metadata)
|
|
return next_tokens
|
|
|
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
|
params_dict = dict(self.named_parameters())
|
|
for name, loaded_weight in weights:
|
|
if "A_log" in name:
|
|
name = name.replace("A_log", "A")
|
|
# Skip loading extra bias for GPTQ models.
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
|
|
param = params_dict[name]
|
|
weight_loader = getattr(param, "weight_loader",
|
|
default_weight_loader)
|
|
weight_loader(param, loaded_weight)
|