Enforce that TP > 1 is not supported for Mamba2 if Quantization is Enabled. (#14617)

Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
This commit is contained in:
Yu Chin Fabian Lim 2025-03-21 08:44:37 +08:00 committed by GitHub
parent 2b22290ce0
commit 06dd08256f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -251,6 +251,9 @@ class MambaMixer2(CustomOp):
"then num_groups must equal 1." "then num_groups must equal 1."
) )
assert self.tp_size == 1 or quant_config is None, \
"Tensor parallel currently not supported for quantized models."
self.ssm_state_size = ssm_state_size self.ssm_state_size = ssm_state_size
self.activation = activation self.activation = activation
@ -331,22 +334,24 @@ class MambaMixer2(CustomOp):
], self.tp_size, tp_rank) ], self.tp_size, tp_rank)
}) })
delattr(self.in_proj.weight, "weight_loader") if quant_config is None:
set_weight_attrs( # - quant layers do not have a weight loader
self.in_proj.weight, delattr(self.in_proj.weight, "weight_loader")
{ set_weight_attrs(
"weight_loader": self.in_proj.weight,
mamba_v2_sharded_weight_loader( {
[ "weight_loader":
intermediate_settings, # for gate mamba_v2_sharded_weight_loader(
intermediate_settings, [
group_shard_settings, intermediate_settings, # for gate
group_shard_settings, intermediate_settings,
head_setings, # for dt group_shard_settings,
], group_shard_settings,
self.tp_size, head_setings, # for dt
tp_rank) ],
}) self.tp_size,
tp_rank)
})
# - these are TPed by heads to reduce the size of the # - these are TPed by heads to reduce the size of the
# temporal shape # temporal shape