mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 19:34:59 +08:00
[Mamba] Support TP>1 with quantization for mamba2 mixer in case n_groups % tp_size == 0 (#24593)
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Signed-off-by: tomeras91 <57313761+tomeras91@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
parent
68dbde5dbb
commit
27fcfe7bcf
@ -19,6 +19,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
|||||||
from vllm.forward_context import ForwardContext, get_forward_context
|
from vllm.forward_context import ForwardContext, get_forward_context
|
||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
|
MergedColumnParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata,
|
from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata,
|
||||||
@ -261,12 +262,14 @@ class MambaMixer2(MambaBase, CustomOp):
|
|||||||
), "Tensor parallel world size must divide num heads."
|
), "Tensor parallel world size must divide num heads."
|
||||||
|
|
||||||
assert (n_groups % self.tp_size) == 0 or n_groups == 1, (
|
assert (n_groups % self.tp_size) == 0 or n_groups == 1, (
|
||||||
"If tensor parallel world size does not divide num_heads, "
|
"If tensor parallel world size does not divide num_groups, "
|
||||||
"then num_groups must equal 1.")
|
"then num_groups must equal 1.")
|
||||||
|
|
||||||
assert (
|
assert (n_groups % self.tp_size == 0) or self.tp_size == 1 or \
|
||||||
self.tp_size == 1 or quant_config is None
|
quant_config is None, (
|
||||||
), "Tensor parallel currently not supported for quantized models."
|
"Tensor parallel currently supported for quantized models only "
|
||||||
|
"if tensor parallel world size divides num groups."
|
||||||
|
)
|
||||||
|
|
||||||
self.ssm_state_size = ssm_state_size
|
self.ssm_state_size = ssm_state_size
|
||||||
self.conv_kernel_size = conv_kernel_size
|
self.conv_kernel_size = conv_kernel_size
|
||||||
@ -285,94 +288,84 @@ class MambaMixer2(MambaBase, CustomOp):
|
|||||||
n_groups, self.tp_size)
|
n_groups, self.tp_size)
|
||||||
self.n_groups = n_groups + groups
|
self.n_groups = n_groups + groups
|
||||||
|
|
||||||
self.conv_dim = intermediate_size + 2 * self.n_groups * ssm_state_size
|
self.groups_ssm_state_size = self.n_groups * self.ssm_state_size
|
||||||
self.conv1d = ColumnParallelLinear(
|
self.conv_dim = intermediate_size + 2 * self.groups_ssm_state_size
|
||||||
input_size=conv_kernel_size,
|
|
||||||
output_size=self.conv_dim,
|
|
||||||
bias=use_conv_bias,
|
|
||||||
quant_config=None,
|
|
||||||
prefix=f"{prefix}.conv1d",
|
|
||||||
)
|
|
||||||
# unsqueeze to fit conv1d weights shape into the linear weights shape.
|
|
||||||
# Can't do this in `weight_loader` since it already exists in
|
|
||||||
# `ColumnParallelLinear` and `set_weight_attrs`
|
|
||||||
# doesn't allow to override it
|
|
||||||
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
|
|
||||||
|
|
||||||
self.in_proj = ColumnParallelLinear(
|
if n_groups % self.tp_size == 0:
|
||||||
input_size=hidden_size,
|
self.conv1d = MergedColumnParallelLinear(
|
||||||
output_size=intermediate_size + self.conv_dim + self.num_heads,
|
input_size=conv_kernel_size,
|
||||||
bias=use_bias,
|
output_sizes=[
|
||||||
quant_config=quant_config,
|
intermediate_size,
|
||||||
prefix=f"{prefix}.in_proj",
|
self.groups_ssm_state_size,
|
||||||
)
|
self.groups_ssm_state_size,
|
||||||
|
],
|
||||||
|
bias=use_conv_bias,
|
||||||
|
quant_config=None,
|
||||||
|
prefix=f"{prefix}.conv1d",
|
||||||
|
)
|
||||||
|
|
||||||
# - because in_proj is a concatenation of 3 weights, we
|
self.in_proj = MergedColumnParallelLinear(
|
||||||
# need to interleave them before sharding
|
input_size=hidden_size,
|
||||||
# - use the custom weight loader mamba_v2_sharded_weight_loader
|
output_sizes=[
|
||||||
# for conv1d.bias, covn1d.weight and in_proj.weight
|
intermediate_size,
|
||||||
# - need to set these settings, to assign the groups to the head shards
|
intermediate_size,
|
||||||
group_shard_settings = (
|
self.groups_ssm_state_size,
|
||||||
self.n_groups * self.ssm_state_size, # expected model size
|
self.groups_ssm_state_size,
|
||||||
(self.n_groups - n_groups) *
|
self.num_heads,
|
||||||
self.ssm_state_size, # extra dims assigned
|
],
|
||||||
n_groups == 1, # if there was only one group
|
bias=use_bias,
|
||||||
)
|
quant_config=quant_config,
|
||||||
intermediate_settings = (intermediate_size, 0, False)
|
prefix=f"{prefix}.in_proj",
|
||||||
head_settings = (self.num_heads, 0, False)
|
)
|
||||||
|
else:
|
||||||
|
# This is the n_groups == 1 case,
|
||||||
|
# where we need to duplicate groups if TP>1.
|
||||||
|
|
||||||
# - the weight already has a "weight_loader" attribute
|
self.conv1d = ColumnParallelLinear(
|
||||||
# which set_weight_attrs will raise if we do not
|
input_size=conv_kernel_size,
|
||||||
# delete before trying to override it
|
output_size=self.conv_dim,
|
||||||
# - ditto for the other two weights below
|
bias=use_conv_bias,
|
||||||
delattr(self.conv1d.bias, "weight_loader")
|
quant_config=None,
|
||||||
set_weight_attrs(
|
prefix=f"{prefix}.conv1d",
|
||||||
self.conv1d.bias,
|
)
|
||||||
{
|
|
||||||
"weight_loader":
|
|
||||||
mamba_v2_sharded_weight_loader(
|
|
||||||
[
|
|
||||||
intermediate_settings,
|
|
||||||
group_shard_settings,
|
|
||||||
group_shard_settings,
|
|
||||||
],
|
|
||||||
self.tp_size,
|
|
||||||
tp_rank,
|
|
||||||
)
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
delattr(self.conv1d.weight, "weight_loader")
|
self.in_proj = ColumnParallelLinear(
|
||||||
set_weight_attrs(
|
input_size=hidden_size,
|
||||||
self.conv1d.weight,
|
output_size=intermediate_size + self.conv_dim + self.num_heads,
|
||||||
{
|
bias=use_bias,
|
||||||
"weight_loader":
|
quant_config=quant_config,
|
||||||
mamba_v2_sharded_weight_loader(
|
prefix=f"{prefix}.in_proj",
|
||||||
[
|
)
|
||||||
intermediate_settings,
|
|
||||||
group_shard_settings,
|
|
||||||
group_shard_settings,
|
|
||||||
],
|
|
||||||
self.tp_size,
|
|
||||||
tp_rank,
|
|
||||||
)
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
if quant_config is None:
|
# - because in_proj is a concatenation of 3 weights, we
|
||||||
# - quant layers do not have a weight loader
|
# need to interleave them before sharding
|
||||||
delattr(self.in_proj.weight, "weight_loader")
|
# - use the custom weight loader mamba_v2_sharded_weight_loader
|
||||||
|
# for conv1d.bias, covn1d.weight and in_proj.weight
|
||||||
|
# - need to set these settings, to assign the groups
|
||||||
|
# to the head shards
|
||||||
|
group_shard_settings = (
|
||||||
|
self.groups_ssm_state_size, # expected model size
|
||||||
|
(self.n_groups - n_groups) *
|
||||||
|
self.ssm_state_size, # extra dims assigned
|
||||||
|
n_groups == 1, # if there was only one group
|
||||||
|
)
|
||||||
|
intermediate_settings = (intermediate_size, 0, False)
|
||||||
|
head_settings = (self.num_heads, 0, False)
|
||||||
|
|
||||||
|
# - the weight already has a "weight_loader" attribute
|
||||||
|
# which set_weight_attrs will raise if we do not
|
||||||
|
# delete before trying to override it
|
||||||
|
# - ditto for the other two weights below
|
||||||
|
delattr(self.conv1d.bias, "weight_loader")
|
||||||
set_weight_attrs(
|
set_weight_attrs(
|
||||||
self.in_proj.weight,
|
self.conv1d.bias,
|
||||||
{
|
{
|
||||||
"weight_loader":
|
"weight_loader":
|
||||||
mamba_v2_sharded_weight_loader(
|
mamba_v2_sharded_weight_loader(
|
||||||
[
|
[
|
||||||
intermediate_settings, # for gate
|
|
||||||
intermediate_settings,
|
intermediate_settings,
|
||||||
group_shard_settings,
|
group_shard_settings,
|
||||||
group_shard_settings,
|
group_shard_settings,
|
||||||
head_settings, # for dt
|
|
||||||
],
|
],
|
||||||
self.tp_size,
|
self.tp_size,
|
||||||
tp_rank,
|
tp_rank,
|
||||||
@ -380,6 +373,50 @@ class MambaMixer2(MambaBase, CustomOp):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
delattr(self.conv1d.weight, "weight_loader")
|
||||||
|
set_weight_attrs(
|
||||||
|
self.conv1d.weight,
|
||||||
|
{
|
||||||
|
"weight_loader":
|
||||||
|
mamba_v2_sharded_weight_loader(
|
||||||
|
[
|
||||||
|
intermediate_settings,
|
||||||
|
group_shard_settings,
|
||||||
|
group_shard_settings,
|
||||||
|
],
|
||||||
|
self.tp_size,
|
||||||
|
tp_rank,
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if quant_config is None:
|
||||||
|
# - quant layers do not have a weight loader
|
||||||
|
delattr(self.in_proj.weight, "weight_loader")
|
||||||
|
set_weight_attrs(
|
||||||
|
self.in_proj.weight,
|
||||||
|
{
|
||||||
|
"weight_loader":
|
||||||
|
mamba_v2_sharded_weight_loader(
|
||||||
|
[
|
||||||
|
intermediate_settings, # for gate
|
||||||
|
intermediate_settings,
|
||||||
|
group_shard_settings,
|
||||||
|
group_shard_settings,
|
||||||
|
head_settings, # for dt
|
||||||
|
],
|
||||||
|
self.tp_size,
|
||||||
|
tp_rank,
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# unsqueeze to fit conv1d weights shape into the linear weights shape.
|
||||||
|
# Can't do this in `weight_loader` since it already exists in
|
||||||
|
# `ColumnParallelLinear` and `MergedColumnParallelLinear`,
|
||||||
|
# and `set_weight_attrs` doesn't allow to override it
|
||||||
|
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
|
||||||
|
|
||||||
# - these are TPed by heads to reduce the size of the
|
# - these are TPed by heads to reduce the size of the
|
||||||
# temporal shape
|
# temporal shape
|
||||||
self.A = nn.Parameter(
|
self.A = nn.Parameter(
|
||||||
@ -498,8 +535,6 @@ class MambaMixer2(MambaBase, CustomOp):
|
|||||||
chunk_indices_p = mamba2_metadata.chunk_indices
|
chunk_indices_p = mamba2_metadata.chunk_indices
|
||||||
chunk_offsets_p = mamba2_metadata.chunk_offsets
|
chunk_offsets_p = mamba2_metadata.chunk_offsets
|
||||||
|
|
||||||
groups_time_state_size = self.n_groups * self.ssm_state_size
|
|
||||||
|
|
||||||
# 1. Gated MLP's linear projection
|
# 1. Gated MLP's linear projection
|
||||||
projected_states, _ = self.in_proj(hidden_states)
|
projected_states, _ = self.in_proj(hidden_states)
|
||||||
|
|
||||||
@ -524,8 +559,8 @@ class MambaMixer2(MambaBase, CustomOp):
|
|||||||
hidden_states_B_C,
|
hidden_states_B_C,
|
||||||
[
|
[
|
||||||
self.intermediate_size // self.tp_size,
|
self.intermediate_size // self.tp_size,
|
||||||
groups_time_state_size // self.tp_size,
|
self.groups_ssm_state_size // self.tp_size,
|
||||||
groups_time_state_size // self.tp_size,
|
self.groups_ssm_state_size // self.tp_size,
|
||||||
],
|
],
|
||||||
dim=-1,
|
dim=-1,
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user