From 06dd08256f076689945418cd61397c1759f4abfa Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Fri, 21 Mar 2025 08:44:37 +0800 Subject: [PATCH] Enforce that TP > 1 is not supported for Mamba2 if Quantization is Enabled. (#14617) Signed-off-by: Yu Chin Fabian Lim --- .../layers/mamba/mamba_mixer2.py | 37 +++++++++++-------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 53d68b60f2fde..fec6d6112d665 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -251,6 +251,9 @@ class MambaMixer2(CustomOp): "then num_groups must equal 1." ) + assert self.tp_size == 1 or quant_config is None, \ + "Tensor parallel currently not supported for quantized models." + self.ssm_state_size = ssm_state_size self.activation = activation @@ -331,22 +334,24 @@ class MambaMixer2(CustomOp): ], self.tp_size, tp_rank) }) - delattr(self.in_proj.weight, "weight_loader") - set_weight_attrs( - self.in_proj.weight, - { - "weight_loader": - mamba_v2_sharded_weight_loader( - [ - intermediate_settings, # for gate - intermediate_settings, - group_shard_settings, - group_shard_settings, - head_setings, # for dt - ], - self.tp_size, - tp_rank) - }) + if quant_config is None: + # - quant layers do not have a weight loader + delattr(self.in_proj.weight, "weight_loader") + set_weight_attrs( + self.in_proj.weight, + { + "weight_loader": + mamba_v2_sharded_weight_loader( + [ + intermediate_settings, # for gate + intermediate_settings, + group_shard_settings, + group_shard_settings, + head_setings, # for dt + ], + self.tp_size, + tp_rank) + }) # - these are TPed by heads to reduce the size of the # temporal shape