mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:34:57 +08:00
Signed-off-by: Yury Tokpanov <yury@zyphra.com> Signed-off-by: Quentin Anthony <qganthony@yahoo.com> Co-authored-by: Quentin Anthony <qganthony@yahoo.com> Co-authored-by: Tyler Michael Smith <tysmith@redhat.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
545 lines
22 KiB
Python
545 lines
22 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
from typing import List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from vllm.attention.backends.abstract import AttentionMetadata
|
|
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
|
|
from vllm.attention.backends.placeholder_attn import (
|
|
PlaceholderAttentionMetadata)
|
|
from vllm.attention.backends.xformers import XFormersMetadata
|
|
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
|
get_tensor_model_parallel_world_size,
|
|
tensor_model_parallel_all_gather,
|
|
tensor_model_parallel_all_reduce)
|
|
from vllm.forward_context import get_forward_context
|
|
from vllm.model_executor.custom_op import CustomOp
|
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|
RowParallelLinear)
|
|
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
|
causal_conv1d_fn, causal_conv1d_update)
|
|
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
|
selective_state_update)
|
|
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
|
|
mamba_chunk_scan_combined)
|
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
|
from vllm.model_executor.model_loader.weight_utils import (
|
|
LoaderFunction, composed_weight_loader, sharded_weight_loader)
|
|
from vllm.model_executor.models.mamba_cache import MambaCacheParams
|
|
from vllm.model_executor.utils import set_weight_attrs
|
|
|
|
# Added by the IBM Team, 2024
|
|
|
|
|
|
# Adapted from transformers.models.mamba2.modeling_mamba2.MambaRMSNormGated
|
|
@CustomOp.register("mixer2_gated_rms_norm")
|
|
class Mixer2RMSNormGated(CustomOp):
|
|
|
|
def __init__(self, full_hidden_size, full_n_groups, eps=1e-6):
|
|
super().__init__()
|
|
self.tp_size = get_tensor_model_parallel_world_size()
|
|
self.tp_rank = get_tensor_model_parallel_rank()
|
|
self.full_hidden_size = full_hidden_size
|
|
self.group_size = full_hidden_size // full_n_groups
|
|
self.per_rank_hidden_size = full_hidden_size // self.tp_size
|
|
self.n_groups = full_hidden_size // self.group_size
|
|
|
|
self.variance_epsilon = eps
|
|
self.weight = nn.Parameter(torch.ones(self.per_rank_hidden_size))
|
|
set_weight_attrs(self.weight,
|
|
{"weight_loader": sharded_weight_loader(0)})
|
|
assert self.full_hidden_size % self.tp_size== 0,\
|
|
"Tensor parallel world size must divide hidden size."
|
|
|
|
def forward_native(
|
|
self,
|
|
x: torch.Tensor,
|
|
gate: torch.Tensor,
|
|
):
|
|
# Three tensor-parallel cases:
|
|
# 1. n_groups is 1
|
|
# In this case we parallelize along the reduction dim.
|
|
# Each rank computes a local sum of squares followed by AllReduce
|
|
# 2. tp_size divides n_groups
|
|
# Each rank only reduces within its local group(s).
|
|
# No collective ops necessary.
|
|
# 3. The general case can be pretty complicated so we AllGather
|
|
# the input and then redundantly compute the RMSNorm.
|
|
input_dtype = x.dtype
|
|
x = x * nn.functional.silu(gate.to(torch.float32))
|
|
|
|
if self.n_groups == 1:
|
|
if self.tp_size > 1:
|
|
# Compute local sum and then reduce to obtain global sum
|
|
local_sums = x.pow(2).sum(dim=-1, keepdim=True)
|
|
global_sums = tensor_model_parallel_all_reduce(local_sums)
|
|
# Calculate the variance
|
|
count = self.tp_size * x.shape[-1]
|
|
variance = (global_sums / count)
|
|
|
|
else:
|
|
variance = x.pow(2).mean(-1, keepdim=True)
|
|
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
|
else:
|
|
redundant_tp: bool = self.n_groups % self.tp_size != 0
|
|
if redundant_tp:
|
|
# To handle the general case, redundantly apply the variance
|
|
x = tensor_model_parallel_all_gather(x, -1)
|
|
|
|
*prefix_dims, hidden_dim = x.shape
|
|
group_count = hidden_dim // self.group_size
|
|
x_grouped = x.view(*prefix_dims, group_count, self.group_size)
|
|
variance = x_grouped.pow(2).mean(-1, keepdim=True)
|
|
x_grouped = x_grouped * torch.rsqrt(variance +
|
|
self.variance_epsilon)
|
|
x = x_grouped.view(*prefix_dims, hidden_dim)
|
|
|
|
if redundant_tp:
|
|
start = self.per_rank_hidden_size * self.tp_rank
|
|
end = start + self.per_rank_hidden_size
|
|
x = x[..., start:end]
|
|
|
|
return self.weight * x.to(input_dtype)
|
|
|
|
def forward_cuda(
|
|
self,
|
|
x: torch.Tensor,
|
|
gate: torch.Tensor,
|
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
|
|
if self.tp_size > 1 or self.n_groups != 1:
|
|
return self.forward_native(x, gate)
|
|
|
|
from vllm import _custom_ops as ops
|
|
|
|
# cast x and gate to float32 before silu
|
|
out = torch.empty_like(x)
|
|
y = x * nn.functional.silu(gate.to(torch.float32))
|
|
ops.rms_norm(
|
|
out,
|
|
y.to(x.dtype),
|
|
self.weight.data,
|
|
self.variance_epsilon,
|
|
)
|
|
return out
|
|
|
|
|
|
def extra_groups_for_head_shards(ngroups: int, tp_size: int):
|
|
"""Compute the increase in group numbers to account for
|
|
replication in order to accompany the head shards."""
|
|
|
|
# in the case ngoups % tp_size == 0, this will be zero
|
|
if ngroups % tp_size == 0:
|
|
return 0
|
|
|
|
# for n_groups == 1, this is exactly tp_size - n_groups
|
|
return tp_size - ngroups
|
|
|
|
|
|
def mamba_v2_sharded_weight_loader(
|
|
shard_spec: List[Tuple[int, int, float]],
|
|
tp_size: int,
|
|
tp_rank: int,
|
|
) -> LoaderFunction:
|
|
"""Create a weight loader for mamba v2. This ensures that the projections
|
|
are correctly sharded so that they can be split into x, B, C. It also
|
|
ensures the the all the groups corresponding to a head shard is placed
|
|
together with it.
|
|
"""
|
|
|
|
def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
|
|
|
|
# - track boundary of (sharded) param, and loaded_weight, respectively
|
|
boundary, loaded_boundary = 0, 0
|
|
|
|
# - iterate over the shard specs
|
|
for full_dim, extra, duplicate_groups in shard_spec:
|
|
# - full dim is the model dim (before TP).
|
|
# - extra > 0, means there is expected overall increase
|
|
# of dimensions. This is so because of replication.
|
|
# - ratio is used map the tp_rank to the actual shard
|
|
# rank. This is useful when there is replication of
|
|
# groups to accompany head shards.
|
|
|
|
# - size of the loaded shard
|
|
shard_size = full_dim // tp_size
|
|
|
|
# - compute the rank into the loaded shard.
|
|
# - if there is replication, different TP shards will
|
|
# take from the same rank.
|
|
# NOTE: currently we only support duplication
|
|
# in the case where num_groups == 1
|
|
rank = 0 if duplicate_groups else tp_rank
|
|
|
|
# - leftmost boundary index into loaded weight.
|
|
loaded_skip = rank * shard_size
|
|
loaded_start_idx = loaded_boundary + loaded_skip
|
|
|
|
# - take these many dims from the loaded weight.
|
|
take = min(shard_size, full_dim - extra - loaded_skip)
|
|
|
|
# - always shard on dim 0
|
|
# - the ignore is for a mundane mypy error as it does not
|
|
# seem to handle slices well.
|
|
# https://github.com/python/mypy/issues/2410
|
|
param.data[
|
|
boundary:(boundary + take), # type: ignore[misc]
|
|
...] = loaded_weight[loaded_start_idx:( # type: ignore[misc]
|
|
loaded_start_idx + take)] # type: ignore[misc]
|
|
|
|
# move indexing boundaries
|
|
boundary += shard_size
|
|
loaded_boundary += (full_dim - extra)
|
|
|
|
return loader
|
|
|
|
|
|
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
|
|
@CustomOp.register("mamba_mixer2")
|
|
class MambaMixer2(CustomOp):
|
|
"""
|
|
Compute ∆, A, B, C, and D the state space parameters and compute
|
|
the `contextualized_states`. A, D are input independent
|
|
(see Mamba paper [1] Section 3.5.2 "Interpretation of A"
|
|
for why A isn't selective) ∆, B, C are input-dependent
|
|
(this is a key difference between Mamba and the linear time
|
|
invariant S4, and is why Mamba is called
|
|
**selective** state spaces)
|
|
"""
|
|
|
|
def __init__(self,
|
|
hidden_size: int,
|
|
ssm_state_size: int,
|
|
conv_kernel_size: int,
|
|
intermediate_size: int,
|
|
use_conv_bias: bool,
|
|
use_bias: bool,
|
|
n_groups: int = 1,
|
|
num_heads: int = 128,
|
|
head_dim: int = 64,
|
|
rms_norm_eps: float = 1e-5,
|
|
activation="silu",
|
|
chunk_size: int = 256,
|
|
quant_config: Optional[QuantizationConfig] = None):
|
|
super().__init__()
|
|
|
|
# For TP, the sharding plan is as follows:
|
|
# - for the conv modules, since
|
|
# conv_dim = intermediate_size * 2 * n_groups * ssm_state_size,
|
|
# we shard intermediate_size and n_groups
|
|
# - since intermediate_size = n_heads * head_dim, sharding on
|
|
# intermediate_size is achieved by sharding on n_heads.
|
|
# - IF, world_size divides groups, then sharding
|
|
# (n_groups / world_size, n_heads / world_size)
|
|
# also maintains the invariant n_heads % n_groups == 0
|
|
# - HOWEVER IF, world_size DOES NOT divide groups, then we need
|
|
# to allocate extra space in the shard, such that groups
|
|
# may be replicated to follow the head shard.
|
|
# - NOTE: currently for the world size DOES NOT divide groups
|
|
# case, we only support the case when n_groups == 1
|
|
self.tp_size = get_tensor_model_parallel_world_size()
|
|
tp_rank = get_tensor_model_parallel_rank()
|
|
|
|
assert num_heads % self.tp_size == 0, \
|
|
"Tensor parallel world size must divide num heads."
|
|
|
|
assert (n_groups % self.tp_size) == 0 or n_groups == 1, \
|
|
(
|
|
"If tensor parallel world size does not divide num_heads, "
|
|
"then num_groups must equal 1."
|
|
)
|
|
|
|
self.ssm_state_size = ssm_state_size
|
|
self.activation = activation
|
|
|
|
self.chunk_size = chunk_size
|
|
self.intermediate_size = intermediate_size
|
|
self.head_dim = head_dim
|
|
self.num_heads = num_heads
|
|
|
|
self.n_groups = n_groups
|
|
if n_groups % self.tp_size != 0:
|
|
# - for TP we shard conv_dim by sharding on n_groups,
|
|
# - but if n_groups cannot divide tp_size, we need to
|
|
# extend some extra groups
|
|
self.n_groups = n_groups + extra_groups_for_head_shards(
|
|
n_groups, self.tp_size)
|
|
|
|
self.conv_dim = (intermediate_size +
|
|
2 * self.n_groups * ssm_state_size)
|
|
self.conv1d = ColumnParallelLinear(
|
|
input_size=conv_kernel_size,
|
|
output_size=self.conv_dim,
|
|
bias=use_conv_bias,
|
|
quant_config=None,
|
|
)
|
|
# unsqueeze to fit conv1d weights shape into the linear weights shape.
|
|
# Can't do this in `weight_loader` since it already exists in
|
|
# `ColumnParallelLinear` and `set_weight_attrs`
|
|
# doesn't allow to override it
|
|
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
|
|
|
|
self.in_proj = ColumnParallelLinear(input_size=hidden_size,
|
|
output_size=intermediate_size +
|
|
self.conv_dim + self.num_heads,
|
|
bias=use_bias,
|
|
quant_config=quant_config)
|
|
|
|
# - because in_proj is a concatenation of 3 weights, we
|
|
# need to interleave them before sharding
|
|
# - use the custom weight loader mamba_v2_sharded_weight_loader
|
|
# for conv1d.bias, covn1d.weight and in_proj.weight
|
|
# - need to set these settings, to assign the groups to the head shards
|
|
group_shard_settings = (
|
|
self.n_groups * self.ssm_state_size, # expected model size
|
|
(self.n_groups - n_groups) *
|
|
self.ssm_state_size, # extra dims assigned
|
|
n_groups == 1, # if there was only one group
|
|
)
|
|
intermediate_settings = (intermediate_size, 0, False)
|
|
head_setings = (self.num_heads, 0, False)
|
|
|
|
# - the weight already has a "weight_loader" attribute
|
|
# which set_weight_attrs will raise if we do not
|
|
# delete before trying to override it
|
|
# - ditto for the otther two weights below
|
|
delattr(self.conv1d.bias, "weight_loader")
|
|
set_weight_attrs(
|
|
self.conv1d.bias, {
|
|
"weight_loader":
|
|
mamba_v2_sharded_weight_loader(
|
|
[
|
|
intermediate_settings,
|
|
group_shard_settings,
|
|
group_shard_settings,
|
|
],
|
|
self.tp_size,
|
|
tp_rank,
|
|
)
|
|
})
|
|
|
|
delattr(self.conv1d.weight, "weight_loader")
|
|
set_weight_attrs(
|
|
self.conv1d.weight, {
|
|
"weight_loader":
|
|
mamba_v2_sharded_weight_loader([
|
|
intermediate_settings,
|
|
group_shard_settings,
|
|
group_shard_settings,
|
|
], self.tp_size, tp_rank)
|
|
})
|
|
|
|
delattr(self.in_proj.weight, "weight_loader")
|
|
set_weight_attrs(
|
|
self.in_proj.weight,
|
|
{
|
|
"weight_loader":
|
|
mamba_v2_sharded_weight_loader(
|
|
[
|
|
intermediate_settings, # for gate
|
|
intermediate_settings,
|
|
group_shard_settings,
|
|
group_shard_settings,
|
|
head_setings, # for dt
|
|
],
|
|
self.tp_size,
|
|
tp_rank)
|
|
})
|
|
|
|
# - these are TPed by heads to reduce the size of the
|
|
# temporal shape
|
|
self.A = nn.Parameter(
|
|
torch.empty(
|
|
divide(num_heads, self.tp_size),
|
|
dtype=torch.float32,
|
|
))
|
|
self.D = nn.Parameter(torch.ones(num_heads // self.tp_size))
|
|
self.dt_bias = nn.Parameter(torch.ones(num_heads // self.tp_size))
|
|
|
|
set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)})
|
|
a_weight_loader = composed_weight_loader(
|
|
sharded_weight_loader(0), lambda x: -torch.exp(x.float()))
|
|
set_weight_attrs(self.A, {"weight_loader": a_weight_loader})
|
|
set_weight_attrs(self.dt_bias,
|
|
{"weight_loader": sharded_weight_loader(0)})
|
|
|
|
self.out_proj = RowParallelLinear(intermediate_size,
|
|
hidden_size,
|
|
bias=use_bias,
|
|
input_is_parallel=True,
|
|
quant_config=quant_config)
|
|
|
|
self.norm = Mixer2RMSNormGated(intermediate_size,
|
|
n_groups,
|
|
eps=rms_norm_eps)
|
|
|
|
def forward_native(self, hidden_states: torch.Tensor,
|
|
conv_state: torch.Tensor, ssm_state: torch.Tensor):
|
|
pass
|
|
|
|
def forward_cuda(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
mamba_cache_params: MambaCacheParams,
|
|
sequence_idx: Optional[torch.Tensor] = None,
|
|
):
|
|
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
|
|
|
|
seq_len, _ = hidden_states.shape
|
|
groups_time_state_size = self.n_groups * self.ssm_state_size
|
|
|
|
# detect if there are prefills
|
|
has_prefill = attn_metadata.num_prefills > 0
|
|
|
|
# - also need flags to indicate if there are initial states
|
|
# - currently we really only support the FlashAttention backend
|
|
has_initial_states = None
|
|
if (isinstance(attn_metadata,
|
|
(FlashAttentionMetadata, XFormersMetadata,
|
|
PlaceholderAttentionMetadata))
|
|
and attn_metadata.context_lens_tensor is not None):
|
|
has_initial_states = attn_metadata.context_lens_tensor > 0
|
|
|
|
# 1. Gated MLP's linear projection
|
|
projected_states, _ = self.in_proj(hidden_states)
|
|
gate, hidden_states_B_C, dt = torch.split(
|
|
projected_states,
|
|
[
|
|
self.intermediate_size // self.tp_size,
|
|
self.conv_dim // self.tp_size,
|
|
self.num_heads // self.tp_size,
|
|
],
|
|
dim=-1,
|
|
)
|
|
|
|
# 2. Convolution sequence transformation
|
|
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
|
|
self.conv1d.weight.size(2))
|
|
|
|
if has_prefill:
|
|
# |---------- N-1 iteration --------|
|
|
# |---------------- N iteration ---------------------|
|
|
# |- tokenA -|......................|-- newTokens ---|
|
|
# |---------- context_len ----------|
|
|
# |-------------------- seq_len ---------------------|
|
|
# |-- query_len ---|
|
|
|
|
# - "cache_indices" updates the conv_state cache in positions
|
|
# pointed to by "mamba_cache_params.state_indices_tensor"
|
|
hidden_states_B_C = causal_conv1d_fn(
|
|
hidden_states_B_C.transpose(0, 1),
|
|
conv_weights,
|
|
self.conv1d.bias,
|
|
activation=self.activation,
|
|
conv_states=mamba_cache_params.conv_state,
|
|
has_initial_state=has_initial_states,
|
|
cache_indices=mamba_cache_params.state_indices_tensor,
|
|
query_start_loc=attn_metadata.query_start_loc).transpose(
|
|
0, 1)[:seq_len]
|
|
|
|
# TODO: Why is this needed?
|
|
hidden_states_B_C = hidden_states_B_C.contiguous()
|
|
else:
|
|
hidden_states_B_C = causal_conv1d_update(
|
|
hidden_states_B_C,
|
|
mamba_cache_params.conv_state,
|
|
conv_weights,
|
|
self.conv1d.bias,
|
|
self.activation,
|
|
conv_state_indices=mamba_cache_params.state_indices_tensor)
|
|
|
|
# - get hidden_states, B and C after depthwise convolution.
|
|
hidden_states, B, C = torch.split(
|
|
hidden_states_B_C,
|
|
[
|
|
self.intermediate_size // self.tp_size,
|
|
groups_time_state_size // self.tp_size,
|
|
groups_time_state_size // self.tp_size,
|
|
],
|
|
dim=-1,
|
|
)
|
|
|
|
# 3. State Space Model sequence transformation
|
|
if has_prefill:
|
|
|
|
initial_states = None
|
|
if has_initial_states is not None and any(has_initial_states):
|
|
for idx in mamba_cache_params.state_indices_tensor[
|
|
~has_initial_states]:
|
|
mamba_cache_params.ssm_state[idx].zero_()
|
|
initial_states = mamba_cache_params.ssm_state[
|
|
mamba_cache_params.state_indices_tensor]
|
|
|
|
scan_output, varlen_state = mamba_chunk_scan_combined(
|
|
hidden_states.view(1, seq_len, self.num_heads // self.tp_size,
|
|
self.head_dim),
|
|
dt.unsqueeze(0),
|
|
self.A,
|
|
B.view(1, seq_len, self.n_groups // self.tp_size, -1),
|
|
C.view(1, seq_len, self.n_groups // self.tp_size, -1),
|
|
chunk_size=self.chunk_size,
|
|
D=self.D,
|
|
z=None,
|
|
dt_bias=self.dt_bias,
|
|
seq_idx=sequence_idx,
|
|
cu_seqlens=attn_metadata.query_start_loc,
|
|
initial_states=initial_states,
|
|
return_varlen_states=True,
|
|
return_final_states=False,
|
|
dt_softplus=True,
|
|
dt_limit=(0.0, float("inf")),
|
|
)
|
|
|
|
# update ssm states
|
|
# - varlen state is a (batch, nheads, headdim, dstate) tensor
|
|
for i, idx in enumerate(mamba_cache_params.state_indices_tensor):
|
|
mamba_cache_params.ssm_state[idx].copy_(varlen_state[i])
|
|
|
|
# - reshape
|
|
hidden_states = scan_output.view(seq_len, -1)
|
|
else:
|
|
|
|
n_groups = self.n_groups // self.tp_size
|
|
A = self.A[:, None, ...][:, :, None].expand(
|
|
-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
|
|
dt = dt[:, :, None].expand(-1, -1, self.head_dim)
|
|
dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
|
|
D = self.D[:, None, ...].expand(-1, self.head_dim)
|
|
B = B.view(-1, n_groups, B.shape[1] // n_groups)
|
|
C = C.view(-1, n_groups, C.shape[1] // n_groups)
|
|
hidden_states_reshaped = hidden_states.view(
|
|
-1, self.num_heads // self.tp_size, self.head_dim)
|
|
|
|
# - the hidden is reshaped into number of current batches
|
|
# - in this case there is no more prefill, so the batches gen
|
|
# 1 token at a time
|
|
# - thus hidden will be (bs, num_heads, head_dim)
|
|
# - mamba_cache_params.ssm_state's slots will be selected
|
|
# using "mamba_cache_params.state_indices_tensor", just as
|
|
# above in the prefill case
|
|
|
|
hidden_states = selective_state_update(
|
|
mamba_cache_params.ssm_state,
|
|
hidden_states_reshaped,
|
|
dt,
|
|
A,
|
|
B,
|
|
C,
|
|
D,
|
|
z=None,
|
|
dt_bias=dt_bias,
|
|
dt_softplus=True,
|
|
state_batch_indices=mamba_cache_params.state_indices_tensor,
|
|
)
|
|
hidden_states = hidden_states.view(
|
|
-1, (self.num_heads // self.tp_size) * self.head_dim)
|
|
|
|
# # 4. gated MLP
|
|
hidden_states = self.norm(hidden_states, gate)
|
|
|
|
# # 5. Final linear projection
|
|
out, _ = self.out_proj(hidden_states)
|
|
return out
|